-
Notifications
You must be signed in to change notification settings - Fork 11
/
InferenceIWSLT.py
65 lines (58 loc) · 2.62 KB
/
InferenceIWSLT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import sys
import torch
import os
modelname = sys.argv[1]
start = int(sys.argv[2])
end = int(sys.argv[3])
checkpointfolder = 'results/IWSLT/checkpoints/'
ensemblemodelpath = checkpointfolder + 'ensemblemodel.pt'
for n in [10]:
bestbleu = 0
modelfolder = checkpointfolder + modelname + '/'
bleufolder = 'results/IWSLT/BLEU/' + modelname + '/ensemble{}/'.format(n)
os.system('mkdir -p {}'.format(bleufolder))
bestepoch = start
jlist = [j for j in range(start, end + 1)]
for j in jlist:
cpname = modelfolder + 'checkpoint{}.pt'.format(j)
model = torch.load(cpname)
for i in range(1, n):
cpname2 = modelfolder + 'checkpoint{}.pt'.format(j - i)
model2 = torch.load(cpname2)
for param in model['model']:
model['model'][param].add_(model2['model'][param])
del model2
for param in model['model']:
model['model'][param].div_(float(n))
torch.save(model, ensemblemodelpath)
del model
bleu = bleufolder + 'checkpoint{}_ensemble{}.out'.format(j, n)
print('evaluating {}'.format(bleu))
command = 'python generate.py data-bin/iwslt14.tokenized.de-en/ --path {} --beam 4 --batch-size 128 --remove-bpe --lenpen 0.3 --quiet | tee {}'.format(ensemblemodelpath,bleu)
os.system(command)
with open(bleu, 'r') as f:
lines = f.read().splitlines()
lastline = lines[-1].replace(',', '').split()
if bestbleu < float(lastline[13]):
bestbleu = float(lastline[13])
bestepoch = j
print('best bleu {} at epoch {}'.format(bestbleu, bestepoch))
bestensemble = modelfolder + 'bestmodel_ensemble{}_epoch{}_{}.pt'.format(n, bestepoch-n+1, bestepoch)
cpname = modelfolder + 'checkpoint{}.pt'.format(bestepoch)
model = torch.load(cpname)
for i in range(1, n):
cpname2 = modelfolder + 'checkpoint{}.pt'.format(bestepoch - i)
model2 = torch.load(cpname2)
for param in model['model']:
model['model'][param].add_(model2['model'][param])
del model2
for param in model['model']:
model['model'][param].div_(float(n))
torch.save(model, bestensemble)
del model
bleu = bleufolder + 'bestmodel_ensemble{}_epoch{}_{}.out'.format(n, bestepoch-n+1, bestepoch)
print('evaluating {}'.format(bleu))
command = 'python generate.py data-bin/iwslt14.tokenized.de-en/ --path {} --beam 4 --batch-size 128 --remove-bpe --lenpen 0.3 | tee {}'.format(bestensemble, bleu)
os.system(command)
command = './compound_split_bleu.sh {}'.format(bleu)
os.system(command)