forked from lxgithub24/transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
103 lines (83 loc) · 3.68 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# -*- coding: utf-8 -*-
#/usr/bin/python3
'''
Feb. 2019 by kyubyong park.
kbpark.linguist@gmail.com.
https://www.github.com/kyubyong/transformer
'''
import tensorflow as tf
from model import Transformer
from tqdm import tqdm
from data_load import get_batch
from utils import save_hparams, save_variable_specs, get_hypotheses, calc_bleu
import os
from hparams import Hparams
import math
import logging
logging.basicConfig(level=logging.INFO)
logging.info("# hparams")
hparams = Hparams()
parser = hparams.parser
hp = parser.parse_args()
save_hparams(hp, hp.logdir)
logging.info("# Prepare train/eval batches")
train_batches, num_train_batches, num_train_samples = get_batch(hp.train1, hp.train2,
hp.maxlen1, hp.maxlen2,
hp.vocab, hp.batch_size,
shuffle=True)
eval_batches, num_eval_batches, num_eval_samples = get_batch(hp.eval1, hp.eval2,
100000, 100000,
hp.vocab, hp.batch_size,
shuffle=False)
# create a iterator of the correct shape and type
iter = tf.data.Iterator.from_structure(train_batches.output_types, train_batches.output_shapes)
xs, ys = iter.get_next()
train_init_op = iter.make_initializer(train_batches)
eval_init_op = iter.make_initializer(eval_batches)
logging.info("# Load model")
m = Transformer(hp)
loss, train_op, global_step, train_summaries = m.train(xs, ys)
y_hat, eval_summaries = m.eval(xs, ys)
# y_hat = m.infer(xs, ys)
logging.info("# Session")
saver = tf.train.Saver(max_to_keep=hp.num_epochs)
with tf.Session() as sess:
ckpt = tf.train.latest_checkpoint(hp.logdir)
if ckpt is None:
logging.info("Initializing from scratch")
sess.run(tf.global_variables_initializer())
save_variable_specs(os.path.join(hp.logdir, "specs"))
else:
saver.restore(sess, ckpt)
summary_writer = tf.summary.FileWriter(hp.logdir, sess.graph)
sess.run(train_init_op)
total_steps = hp.num_epochs * num_train_batches
_gs = sess.run(global_step)
for i in tqdm(range(_gs, total_steps+1)):
_, _gs, _summary = sess.run([train_op, global_step, train_summaries])
epoch = math.ceil(_gs / num_train_batches)
summary_writer.add_summary(_summary, _gs)
if _gs and _gs % num_train_batches == 0:
logging.info("epoch {} is done".format(epoch))
_loss = sess.run(loss) # train loss
logging.info("# test evaluation")
_, _eval_summaries = sess.run([eval_init_op, eval_summaries])
summary_writer.add_summary(_eval_summaries, _gs)
logging.info("# get hypotheses")
hypotheses = get_hypotheses(num_eval_batches, num_eval_samples, sess, y_hat, m.idx2token)
logging.info("# write results")
model_output = "iwslt2016_E%02dL%.2f" % (epoch, _loss)
if not os.path.exists(hp.evaldir): os.makedirs(hp.evaldir)
translation = os.path.join(hp.evaldir, model_output)
with open(translation, 'w') as fout:
fout.write("\n".join(hypotheses))
logging.info("# calc bleu score and append it to translation")
calc_bleu(hp.eval3, translation)
logging.info("# save models")
ckpt_name = os.path.join(hp.logdir, model_output)
saver.save(sess, ckpt_name, global_step=_gs)
logging.info("after training of {} epochs, {} has been saved.".format(epoch, ckpt_name))
logging.info("# fall back to train mode")
sess.run(train_init_op)
summary_writer.close()
logging.info("Done")