-
Notifications
You must be signed in to change notification settings - Fork 5
/
models.py
204 lines (154 loc) · 8.03 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# coding: utf-8
from __future__ import division
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import data
def _get_shape(i, o, keepdims):
if (i == 1 or o == 1) and not keepdims:
return [max(i,o),]
else:
return [i, o]
def _slice(tensor, size, i):
"""Gets slice of columns of the tensor"""
return tensor[:, i*size:(i+1)*size]
def weights_Glorot(i, o, name, rng, is_logistic_sigmoid=False, keepdims=False):
#http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
d = np.sqrt(6. / (i + o))
if is_logistic_sigmoid:
d *= 4.
return tf.Variable(tf.random.uniform(_get_shape(i, o, keepdims), -d, d))
def load(file_path, x, p=None):
import models
import pickle
import numpy as np
with open(file_path, 'rb') as f:
state = pickle.load(f)
Model = getattr(models, state["type"])
rng = np.random
rng.set_state(state["random_state"])
net = Model(
rng=rng,
x=x,
n_hidden=state["n_hidden"]
)
for net_param, state_param in zip(net.params, state["params"]):
net_param.assign(state_param)
return net, (state["learning_rate"], state["validation_ppl_history"], state["epoch"], rng)
class GRUCell(layers.Layer):
def __init__(self, rng, n_in, n_out, minibatch_size):
super(GRUCell, self).__init__()
# Notation from: An Empirical Exploration of Recurrent Network Architectures
self.n_in = n_in
self.n_out = n_out
# Initial hidden state
self.h0 = tf.zeros([minibatch_size, n_out])
# Gate parameters:
self.W_x = weights_Glorot(n_in, n_out*2, 'W_x', rng)
self.W_h = weights_Glorot(n_out, n_out*2, 'W_h', rng)
self.b = tf.Variable(tf.zeros([1, n_out*2]))
# Input parameters
self.W_x_h = weights_Glorot(n_in, n_out, 'W_x_h', rng)
self.W_h_h = weights_Glorot(n_out, n_out, 'W_h_h', rng)
self.b_h = tf.Variable(tf.zeros([1, n_out]))
self.params = [self.W_x, self.W_h, self.b, self.W_x_h, self.W_h_h, self.b_h]
# inputs = x_t, h_tm1
def call(self, inputs):
rz = tf.nn.sigmoid(tf.matmul(inputs[0], self.W_x) + tf.matmul(inputs[1], self.W_h) + self.b)
r = _slice(rz, self.n_out, 0)
z = _slice(rz, self.n_out, 1)
h = tf.nn.tanh(tf.matmul(inputs[0], self.W_x_h) + tf.matmul(inputs[1] * r, self.W_h_h) + self.b_h)
h_t = z * inputs[1] + (1. - z) * h
return h_t
class GRU(tf.keras.Model):
def __init__(self, rng, x, n_hidden):
super(GRU, self).__init__()
self.minibatch_size = tf.shape(x)[1]
self.n_hidden = n_hidden
self.x_vocabulary = data.read_vocabulary(data.WORD_VOCAB_FILE)
self.y_vocabulary = data.read_vocabulary(data.PUNCT_VOCAB_FILE)
self.x_vocabulary_size = len(self.x_vocabulary)
self.y_vocabulary_size = len(self.y_vocabulary)
# input model
self.We = weights_Glorot(self.x_vocabulary_size, n_hidden, 'We', rng) # Share embeddings between forward and backward model
self.GRU_f = GRUCell(rng=rng, n_in=n_hidden, n_out=n_hidden, minibatch_size=self.minibatch_size)
self.GRU_b = GRUCell(rng=rng, n_in=n_hidden, n_out=n_hidden, minibatch_size=self.minibatch_size)
# output model
self.GRU = GRUCell(rng=rng, n_in=n_hidden*2, n_out=n_hidden, minibatch_size=self.minibatch_size)
self.Wy = tf.Variable(tf.zeros([n_hidden, self.y_vocabulary_size]))
self.by = tf.Variable(tf.zeros([1, self.y_vocabulary_size]))
# attention model
n_attention = n_hidden * 2 # to match concatenated forward and reverse model states
self.Wa_h = weights_Glorot(n_hidden, n_attention, 'Wa_h', rng) # output model previous hidden state to attention model weights
self.Wa_c = weights_Glorot(n_attention, n_attention, 'Wa_c', rng) # contexts to attention model weights
self.ba = tf.Variable(tf.zeros([1, n_attention]))
self.Wa_y = weights_Glorot(n_attention, 1, 'Wa_y', rng) # gives weights to contexts
# Late fusion parameters
self.Wf_h = tf.Variable(tf.zeros([n_hidden, n_hidden]))
self.Wf_c = tf.Variable(tf.zeros([n_attention, n_hidden]))
self.Wf_f = tf.Variable(tf.zeros([n_hidden, n_hidden]))
self.bf = tf.Variable(tf.zeros([1, n_hidden]))
self.params = [self.We,
self.Wy, self.by,
self.Wa_h, self.Wa_c, self.ba, self.Wa_y,
self.Wf_h, self.Wf_c, self.Wf_f, self.bf]
self.params += self.GRU.params + self.GRU_f.params + self.GRU_b.params
print([x.shape for x in self.params])
def call(self, inputs, training=None):
# bi-directional recurrence
def input_recurrence(initializer, elems):
x_f_t, x_b_t = elems
h_f_tm1, h_b_tm1 = initializer
h_f_t = self.GRU_f(inputs=(tf.nn.embedding_lookup(self.We, x_f_t), h_f_tm1))
h_b_t = self.GRU_b(inputs=(tf.nn.embedding_lookup(self.We, x_b_t), h_b_tm1))
return [h_f_t, h_b_t]
[h_f_t, h_b_t] = tf.scan(
fn=input_recurrence,
elems=[inputs, inputs[::-1]], # forward and backward sequences
initializer=[self.GRU_f.h0, self.GRU_b.h0]
)
# 0-axis is time steps, 1-axis is batch size and 2-axis is hidden layer size
context = tf.concat([h_f_t, h_b_t[::-1]], axis=2)
#projected_context = tf.matmul(context, self.Wa_c) + self.ba for each tensor slice
projected_context = tf.matmul(context, tf.tile(tf.expand_dims(self.Wa_c, 0), tf.stack([tf.shape(context)[0], 1, 1]))) + self.ba
def output_recurrence(initializer, elems):
x_t = elems
h_tm1, _, _ = initializer
# Attention model
h_a = tf.nn.tanh(projected_context + tf.matmul(h_tm1, self.Wa_h))
#alphas = tf.exp(tf.matmul(h_a, self.Wa_y))
#alphas = tf.reshape(alphas, [tf.shape(alphas)[0], tf.shape(alphas)[1]]) # drop 2-axis (sized 1) is replaced by:
#sess.run(tf.reshape(tf.matmul(tf.reshape(x, [-1, tf.shape(x)[-1]]), tf.expand_dims(z,-1)), tf.shape(x)[:2]))
alphas = tf.exp(tf.reshape(tf.matmul(tf.reshape(h_a, [-1, tf.shape(h_a)[-1]]), tf.expand_dims(self.Wa_y, -1)), tf.shape(h_a)[:2]))
alphas = alphas / tf.reduce_sum(alphas, axis=0, keepdims=True)
weighted_context = tf.reduce_sum(context * alphas[:,:,None], axis=0)
h_t = self.GRU(inputs=(x_t, h_tm1))
# Late fusion
lfc = tf.matmul(weighted_context, self.Wf_c) # late fused context
fw = tf.nn.sigmoid(tf.matmul(lfc, self.Wf_f) + tf.matmul(h_t, self.Wf_h) + self.bf) # fusion weights
hf_t = lfc * fw + h_t # weighted fused context + hidden state
z = tf.matmul(hf_t, self.Wy) + self.by
y_t = z#tf.nn.softmax(z)
return [h_t, hf_t, y_t]
[_, self.last_hidden_states, self.y] = tf.scan(
fn=output_recurrence,
elems=context[1:], # ignore the 1st word in context, because there's no punctuation before that
initializer=[self.GRU.h0, self.GRU.h0, tf.zeros([self.minibatch_size, self.y_vocabulary_size])]
)
return self.y
def cost(y_pred, y_true):
return tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y_pred, labels=y_true))
def save(model, file_path, learning_rate=None, validation_ppl_history=None, best_validation_ppl=None, epoch=None, random_state=None):
import pickle
state = {
"type": model.__class__.__name__,
"n_hidden": model.n_hidden,
"params": [p for p in model.params],
"learning_rate": learning_rate,
"validation_ppl_history": validation_ppl_history,
"epoch": epoch,
"random_state": random_state
}
print([x.shape for x in state["params"]])
with open(file_path, 'wb') as f:
pickle.dump(state, f, protocol=pickle.HIGHEST_PROTOCOL)