-
Notifications
You must be signed in to change notification settings - Fork 35
/
relational_rnn_models.py
407 lines (339 loc) · 17.5 KB
/
relational_rnn_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
# this class largely follows the official sonnet implementation
# https://github.com/deepmind/sonnet/blob/master/sonnet/python/modules/relational_memory.py
class RelationalMemory(nn.Module):
"""
Constructs a `RelationalMemory` object.
Args:
mem_slots: The total number of memory slots to use.
head_size: The size of an attention head.
input_size: The size of input per step. i.e. the dimension of each input vector
num_heads: The number of attention heads to use. Defaults to 1.
num_blocks: Number of times to compute attention per time step. Defaults
to 1.
forget_bias: Bias to use for the forget gate, assuming we are using
some form of gating. Defaults to 1.
input_bias: Bias to use for the input gate, assuming we are using
some form of gating. Defaults to 0.
gate_style: Whether to use per-element gating ('unit'),
per-memory slot gating ('memory'), or no gating at all (None).
Defaults to `unit`.
attention_mlp_layers: Number of layers to use in the post-attention
MLP. Defaults to 2.
key_size: Size of vector to use for key & query vectors in the attention
computation. Defaults to None, in which case we use `head_size`.
name: Name of the module.
Raises:
ValueError: gate_style not one of [None, 'memory', 'unit'].
ValueError: num_blocks is < 1.
ValueError: attention_mlp_layers is < 1.
"""
def __init__(self, mem_slots, head_size, input_size, num_tokens, num_heads=1, num_blocks=1, forget_bias=1.,
input_bias=0.,
gate_style='unit', attention_mlp_layers=2, key_size=None, use_adaptive_softmax=False, cutoffs=None):
super(RelationalMemory, self).__init__()
########## generic parameters for RMC ##########
self.mem_slots = mem_slots
self.head_size = head_size
self.num_heads = num_heads
self.mem_size = self.head_size * self.num_heads
# a new fixed params needed for pytorch port of RMC
# +1 is the concatenated input per time step : we do self-attention with the concatenated memory & input
# so if the mem_slots = 1, this value is 2
self.mem_slots_plus_input = self.mem_slots + 1
if num_blocks < 1:
raise ValueError('num_blocks must be >=1. Got: {}.'.format(num_blocks))
self.num_blocks = num_blocks
if gate_style not in ['unit', 'memory', None]:
raise ValueError(
'gate_style must be one of [\'unit\', \'memory\', None]. got: '
'{}.'.format(gate_style))
self.gate_style = gate_style
if attention_mlp_layers < 1:
raise ValueError('attention_mlp_layers must be >= 1. Got: {}.'.format(
attention_mlp_layers))
self.attention_mlp_layers = attention_mlp_layers
self.key_size = key_size if key_size else self.head_size
########## parameters for multihead attention ##########
# value_size is same as head_size
self.value_size = self.head_size
# total size for query-key-value
self.qkv_size = 2 * self.key_size + self.value_size
self.total_qkv_size = self.qkv_size * self.num_heads # denoted as F
# each head has qkv_sized linear projector
# just using one big param is more efficient, rather than this line
# self.qkv_projector = [nn.Parameter(torch.randn((self.qkv_size, self.qkv_size))) for _ in range(self.num_heads)]
self.qkv_projector = nn.Linear(self.mem_size, self.total_qkv_size)
self.qkv_layernorm = nn.LayerNorm([self.mem_slots_plus_input, self.total_qkv_size])
# used for attend_over_memory function
self.attention_mlp = nn.ModuleList([nn.Linear(self.mem_size, self.mem_size)] * self.attention_mlp_layers)
self.attended_memory_layernorm = nn.LayerNorm([self.mem_slots_plus_input, self.mem_size])
self.attended_memory_layernorm2 = nn.LayerNorm([self.mem_slots_plus_input, self.mem_size])
########## parameters for initial embedded input projection ##########
self.input_size = input_size
self.input_projector = nn.Linear(self.input_size, self.mem_size)
########## parameters for gating ##########
self.num_gates = 2 * self.calculate_gate_size()
self.input_gate_projector = nn.Linear(self.mem_size, self.num_gates)
self.memory_gate_projector = nn.Linear(self.mem_size, self.num_gates)
# trainable scalar gate bias tensors
self.forget_bias = nn.Parameter(torch.tensor(forget_bias, dtype=torch.float32))
self.input_bias = nn.Parameter(torch.tensor(input_bias, dtype=torch.float32))
########## parameters for token-to-embed & output-to-token logit for softmax
self.dropout = nn.Dropout()
self.num_tokens = num_tokens
self.token_to_input_encoder = nn.Embedding(self.num_tokens, self.input_size)
# needs 2 linear layers for tying weights for embedding layers
# first match the "output" of the RMC to input_size, which is the embed dim
self.output_to_embed_decoder = nn.Linear(self.mem_slots * self.mem_size, self.input_size)
self.use_adaptive_softmax = use_adaptive_softmax
if not self.use_adaptive_softmax:
# then, this layer's weight can be tied to the embedding layer
self.embed_to_logit_decoder = nn.Linear(self.input_size, self.num_tokens)
# tie embedding weights of encoder & decoder
self.embed_to_logit_decoder.weight = self.token_to_input_encoder.weight
########## loss function
self.criterion = nn.CrossEntropyLoss()
else:
# use adaptive softmax from the self.input_size logits, instead of the tied embed weights above
self.criterion_adaptive = nn.AdaptiveLogSoftmaxWithLoss(self.input_size, self.num_tokens,
cutoffs=cutoffs)
def repackage_hidden(self, h):
"""Wraps hidden states in new Tensors, to detach them from their history."""
# needed for truncated BPTT, called at every batch forward pass
if isinstance(h, torch.Tensor):
return h.detach()
else:
return tuple(self.repackage_hidden(v) for v in h)
def initial_state(self, batch_size, trainable=False):
"""
Creates the initial memory.
We should ensure each row of the memory is initialized to be unique,
so initialize the matrix to be the identity. We then pad or truncate
as necessary so that init_state is of size
(batch_size, self.mem_slots, self.mem_size).
Args:
batch_size: The size of the batch.
trainable: Whether the initial state is trainable. This is always True.
Returns:
init_state: A truncated or padded matrix of size
(batch_size, self.mem_slots, self.mem_size).
"""
init_state = torch.stack([torch.eye(self.mem_slots) for _ in range(batch_size)])
# pad the matrix with zeros
if self.mem_size > self.mem_slots:
difference = self.mem_size - self.mem_slots
pad = torch.zeros((batch_size, self.mem_slots, difference))
init_state = torch.cat([init_state, pad], -1)
# truncation. take the first 'self.mem_size' components
elif self.mem_size < self.mem_slots:
init_state = init_state[:, :, :self.mem_size]
return init_state
def multihead_attention(self, memory):
"""
Perform multi-head attention from 'Attention is All You Need'.
Implementation of the attention mechanism from
https://arxiv.org/abs/1706.03762.
Args:
memory: Memory tensor to perform attention on.
Returns:
new_memory: New memory tensor.
"""
# First, a simple linear projection is used to construct queries
qkv = self.qkv_projector(memory)
# apply layernorm for every dim except the batch dim
qkv = self.qkv_layernorm(qkv)
# mem_slots needs to be dynamically computed since mem_slots got concatenated with inputs
# example: self.mem_slots=10 and seq_length is 3, and then mem_slots is 10 + 1 = 11 for each 3 step forward pass
# this is the same as self.mem_slots_plus_input, but defined to keep the sonnet implementation code style
mem_slots = memory.shape[1] # denoted as N
# split the qkv to multiple heads H
# [B, N, F] => [B, N, H, F/H]
qkv_reshape = qkv.view(qkv.shape[0], mem_slots, self.num_heads, self.qkv_size)
# [B, N, H, F/H] => [B, H, N, F/H]
qkv_transpose = qkv_reshape.permute(0, 2, 1, 3)
# [B, H, N, key_size], [B, H, N, key_size], [B, H, N, value_size]
q, k, v = torch.split(qkv_transpose, [self.key_size, self.key_size, self.value_size], -1)
# scale q with d_k, the dimensionality of the key vectors
q *= (self.key_size ** -0.5)
# make it [B, H, N, N]
dot_product = torch.matmul(q, k.permute(0, 1, 3, 2))
weights = F.softmax(dot_product, dim=-1)
# output is [B, H, N, V]
output = torch.matmul(weights, v)
# [B, H, N, V] => [B, N, H, V] => [B, N, H*V]
output_transpose = output.permute(0, 2, 1, 3).contiguous()
new_memory = output_transpose.view((output_transpose.shape[0], output_transpose.shape[1], -1))
return new_memory
@property
def state_size(self):
return [self.mem_slots, self.mem_size]
@property
def output_size(self):
return self.mem_slots * self.mem_size
def calculate_gate_size(self):
"""
Calculate the gate size from the gate_style.
Returns:
The per sample, per head parameter size of each gate.
"""
if self.gate_style == 'unit':
return self.mem_size
elif self.gate_style == 'memory':
return 1
else: # self.gate_style == None
return 0
def create_gates(self, inputs, memory):
"""
Create input and forget gates for this step using `inputs` and `memory`.
Args:
inputs: Tensor input.
memory: The current state of memory.
Returns:
input_gate: A LSTM-like insert gate.
forget_gate: A LSTM-like forget gate.
"""
# We'll create the input and forget gates at once. Hence, calculate double
# the gate size.
# equation 8: since there is no output gate, h is just a tanh'ed m
memory = torch.tanh(memory)
# TODO: check this input flattening is correct
# sonnet uses this, but i think it assumes time step of 1 for all cases
# if inputs is (B, T, features) where T > 1, this gets incorrect
# inputs = inputs.view(inputs.shape[0], -1)
# fixed implementation
if len(inputs.shape) == 3:
if inputs.shape[1] > 1:
raise ValueError(
"input seq length is larger than 1. create_gate function is meant to be called for each step, with input seq length of 1")
inputs = inputs.view(inputs.shape[0], -1)
# matmul for equation 4 and 5
# there is no output gate, so equation 6 is not implemented
gate_inputs = self.input_gate_projector(inputs)
gate_inputs = gate_inputs.unsqueeze(dim=1)
gate_memory = self.memory_gate_projector(memory)
else:
raise ValueError("input shape of create_gate function is 2, expects 3")
# this completes the equation 4 and 5
gates = gate_memory + gate_inputs
gates = torch.split(gates, split_size_or_sections=int(gates.shape[2] / 2), dim=2)
input_gate, forget_gate = gates
assert input_gate.shape[2] == forget_gate.shape[2]
# to be used for equation 7
input_gate = torch.sigmoid(input_gate + self.input_bias)
forget_gate = torch.sigmoid(forget_gate + self.forget_bias)
return input_gate, forget_gate
def attend_over_memory(self, memory):
"""
Perform multiheaded attention over `memory`.
Args:
memory: Current relational memory.
Returns:
The attended-over memory.
"""
for _ in range(self.num_blocks):
attended_memory = self.multihead_attention(memory)
# Add a skip connection to the multiheaded attention's input.
memory = self.attended_memory_layernorm(memory + attended_memory)
# add a skip connection to the attention_mlp's input.
attention_mlp = memory
for i, l in enumerate(self.attention_mlp):
attention_mlp = self.attention_mlp[i](attention_mlp)
attention_mlp = F.relu(attention_mlp)
memory = self.attended_memory_layernorm2(memory + attention_mlp)
return memory
def forward_step(self, inputs, memory, treat_input_as_matrix=False):
"""
Forward step of the relational memory core.
Args:
inputs: Tensor input.
memory: Memory output from the previous time step.
treat_input_as_matrix: Optional, whether to treat `input` as a sequence
of matrices. Default to False, in which case the input is flattened
into a vector.
Returns:
output: This time step's output.
next_memory: The next version of memory to use.
"""
# first embed the tokens into vectors
inputs_embed = self.dropout(self.token_to_input_encoder(inputs))
if treat_input_as_matrix:
# keep (Batch, Seq, ...) dim (0, 1), flatten starting from dim 2
inputs_embed = inputs_embed.view(inputs_embed.shape[0], inputs_embed.shape[1], -1)
# apply linear layer for dim 2
inputs_reshape = self.input_projector(inputs_embed)
else:
# keep (Batch, ...) dim (0), flatten starting from dim 1
inputs_embed = inputs_embed.view(inputs_embed.shape[0], -1)
# apply linear layer for dim 1
inputs_embed = self.input_projector(inputs_embed)
# unsqueeze the time step to dim 1
inputs_reshape = inputs_embed.unsqueeze(dim=1)
memory_plus_input = torch.cat([memory, inputs_reshape], dim=1)
next_memory = self.attend_over_memory(memory_plus_input)
# cut out the concatenated input vectors from the original memory slots
n = inputs_reshape.shape[1]
next_memory = next_memory[:, :-n, :]
if self.gate_style == 'unit' or self.gate_style == 'memory':
# these gates are sigmoid-applied ones for equation 7
input_gate, forget_gate = self.create_gates(inputs_reshape, memory)
# equation 7 calculation
next_memory = input_gate * torch.tanh(next_memory)
next_memory += forget_gate * memory
output = next_memory.view(next_memory.shape[0], -1)
# decode output to logit
output_embed = self.output_to_embed_decoder(output)
# TODO: this dropout is not mentioned in the paper. it's to match word-language-model dropout use case
output_embed = self.dropout(output_embed)
if not self.use_adaptive_softmax:
logit = self.embed_to_logit_decoder(output_embed)
else:
logit = output_embed
return logit, next_memory
def forward(self, inputs, memory, targets, require_logits=False):
# Starting each batch, we detach the hidden state from how it was previously produced.
# If we didn't, the model would try backpropagating all the way to start of the dataset.
memory = self.repackage_hidden(memory)
# for loop implementation of (entire) recurrent forward pass of the model
# inputs is batch first [batch, seq], and output logit per step is [batch, vocab]
# so the concatenated logits are [seq * batch, vocab]
# targets are flattened [seq, batch] => [seq * batch], so the dimension is correct
logits = []
# shape[1] is seq_lenth T
for idx_step in range(inputs.shape[1]):
logit, memory = self.forward_step(inputs[:, idx_step], memory)
logits.append(logit)
# concat the output from list(seq_length) of [batch, vocab] to [seq * batch, vocab]
logits = torch.cat(logits)
if targets is not None:
if not self.use_adaptive_softmax:
# calculate loss inside this forward pass for more even VRAM usage of DataParallel
loss = self.criterion(logits, targets)
else:
# calculate the loss using adaptive softmax
_, loss = self.criterion_adaptive(logits, targets)
else:
loss = None
# the forward pass only returns loss, because returning logits causes uneven VRAM usage of DataParallel
# logits are provided only for sampling stage
if not require_logits:
return loss, memory
else:
return logits, loss, memory
# ########## DEBUG: unit test code ##########
# input_size = 44
# seq_length = 1
# batch_size = 32
# model = RelationalMemory(mem_slots=10, head_size=20, input_size=input_size, num_tokens=66, num_heads=8, num_blocks=1, forget_bias=1., input_bias=0.)
# model_memory = model.initial_state(batch_size=batch_size)
#
# # random input
# random_input = torch.randn((32, seq_length, input_size))
# # random targets
# random_targets = torch.randn((32, seq_length, input_size))
#
# # take a one step forward
# logit, next_memory = model(random_input, model_memory, random_targets, treat_input_as_matrix=True)