-
Notifications
You must be signed in to change notification settings - Fork 4
/
loss.py
63 lines (51 loc) · 1.91 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn as nn
from torch.nn import functional as F
class Dice_Loss(nn.Module):
"""
Calculates the Sørensen–Dice coefficient-based loss.
Taken from
https://github.com/SaoYan/IPMI2019-AttnMel/blob/master/loss.py#L28
Args:
inputs (torch.Tensor): 1-hot encoded predictions
targets (torch.Tensor): 1-hot encoded ground truth
"""
def __init__(self):
super(Dice_Loss, self).__init__()
def forward(self, inputs, targets):
"""
Dice(A, B) = (2 * |intersection(A, B)|) / (|A| + |B|)
where |x| denotes the cardinality of the set x.
"""
mul = torch.mul(inputs, targets)
add = torch.add(inputs, 1, targets)
dice = 2 * torch.div(mul.sum(), add.sum())
return 1 - dice
class MCC_Loss(nn.Module):
"""
Calculates the proposed Matthews Correlation Coefficient-based loss.
Args:
inputs (torch.Tensor): 1-hot encoded predictions
targets (torch.Tensor): 1-hot encoded ground truth
"""
def __init__(self):
super(MCC_Loss, self).__init__()
def forward(self, inputs, targets):
"""
MCC = (TP.TN - FP.FN) / sqrt((TP+FP) . (TP+FN) . (TN+FP) . (TN+FN))
where TP, TN, FP, and FN are elements in the confusion matrix.
"""
tp = torch.sum(torch.mul(inputs, targets))
tn = torch.sum(torch.mul((1 - inputs), (1 - targets)))
fp = torch.sum(torch.mul(inputs, (1 - targets)))
fn = torch.sum(torch.mul((1 - inputs), targets))
numerator = torch.mul(tp, tn) - torch.mul(fp, fn)
denominator = torch.sqrt(
torch.add(tp, 1, fp)
* torch.add(tp, 1, fn)
* torch.add(tn, 1, fp)
* torch.add(tn, 1, fn)
)
# Adding 1 to the denominator to avoid divide-by-zero errors.
mcc = torch.div(numerator.sum(), denominator.sum() + 1.0)
return 1 - mcc