-
Notifications
You must be signed in to change notification settings - Fork 61
/
loss.py
27 lines (20 loc) · 971 Bytes
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from torch.nn import CrossEntropyLoss
from torch.nn.modules import loss
from utils.TripletLoss import TripletLoss
class Loss(loss._Loss):
def __init__(self):
super(Loss, self).__init__()
def forward(self, outputs, labels):
cross_entropy_loss = CrossEntropyLoss()
triplet_loss = TripletLoss(margin=1.2)
Triplet_Loss = [triplet_loss(output, labels) for output in outputs[1:4]]
Triplet_Loss = sum(Triplet_Loss) / len(Triplet_Loss)
CrossEntropy_Loss = [cross_entropy_loss(output, labels) for output in outputs[4:]]
CrossEntropy_Loss = sum(CrossEntropy_Loss) / len(CrossEntropy_Loss)
loss_sum = Triplet_Loss + 2 * CrossEntropy_Loss
print('\rtotal loss:%.2f Triplet_Loss:%.2f CrossEntropy_Loss:%.2f' % (
loss_sum.data.cpu().numpy(),
Triplet_Loss.data.cpu().numpy(),
CrossEntropy_Loss.data.cpu().numpy()),
end=' ')
return loss_sum