-
Notifications
You must be signed in to change notification settings - Fork 6
/
multiclass_loss.py
59 lines (44 loc) · 1.28 KB
/
multiclass_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import sys
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
from holder import *
from util import *
# Multiclass Loss
class MulticlassLoss(torch.nn.Module):
def __init__(self, opt, shared):
super(MulticlassLoss, self).__init__()
self.opt = opt
self.shared = shared
self.num_correct = 0
self.num_ex = 0
self.verbose = False
# NOTE, do not creat loss node globally
def forward(self, pred, gold):
log_p = pred
batch_l = self.shared.batch_l
assert(pred.shape == (batch_l, self.opt.num_label))
# loss
crit = torch.nn.NLLLoss(reduction='sum') # for pytorch < 0.4.1, use size_average=False
if self.opt.gpuid != -1:
crit = crit.cuda()
loss = crit(log_p, gold[:])
# stats
self.num_correct += np.equal(pick_label(log_p.cpu().data), gold.cpu()).sum()
self.num_ex += batch_l
return loss
# return a string of stats
def print_cur_stats(self):
stats = 'Acc {0:.3f} '.format(float(self.num_correct) / self.num_ex)
return stats
# get training metric (scalar metric, extra metric)
def get_epoch_metric(self):
acc = float(self.num_correct) / self.num_ex
return acc, [acc] # and any other scalar metrics
def begin_pass(self):
# clear stats
self.num_correct = 0
self.num_ex = 0
def end_pass(self):
pass