-
Notifications
You must be signed in to change notification settings - Fork 0
/
seq2seq.py
102 lines (87 loc) · 3.69 KB
/
seq2seq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from __future__ import annotations
import typing as t
import torch
import pydantic as pyd
from torch import nn
from transformers import PreTrainedTokenizer
from transformer.models.base import BaseLM
from transformer.modules.transformers import TransformerEncoderDecoder
from transformer.modules.embedding import InputEmbedding
from transformer.params import TransformerParams
__all__ = ["Seq2SeqLM"]
class Seq2SeqLM(BaseLM):
@pyd.validate_call(config=dict(arbitrary_types_allowed=True))
def __init__(
self: t.Self,
params: TransformerParams,
input_tokenizer: PreTrainedTokenizer,
output_tokenizer: PreTrainedTokenizer,
) -> None:
super().__init__(params=params)
self.input_tokenizer = input_tokenizer
self.output_tokenizer = output_tokenizer
self.model = nn.ModuleDict(
{
"input": nn.Sequential(
InputEmbedding(len(self.input_tokenizer), params.model_dim),
nn.Dropout(0.1),
),
"output": nn.Sequential(
InputEmbedding(len(self.output_tokenizer), params.model_dim),
nn.Dropout(0.1),
),
"encoder_decoder": TransformerEncoderDecoder(params),
}
)
def forward(
self: t.Self,
input_ids: torch.LongTensor,
output_ids: torch.LongTensor,
input_masks: torch.LongTensor,
output_masks: torch.LongTensor,
) -> torch.FloatTensor:
# ids/masks shape: [batch_size, context_length]
# create input embeddings for tokens and pass through transformer
inputs = self.model["input"](input_ids)
outputs = self.model["output"](output_ids)
# inputs/outputs shape: [batch_size, context_length, model_dim]
# pass inputs, outputs and their masks through encoder-decoder
hidden = self.model["encoder_decoder"](
inputs=inputs,
outputs=outputs,
input_masks=input_masks,
output_masks=output_masks,
)
# hidden shape: [batch_size, context_length, model_dim]
# project back to output vocabulary size reusing embedding weight matrix (weight-tied)
unemb = self.model["output"][0].unembed(hidden)
return nn.functional.log_softmax(unemb, dim=-1)
# unemb/output shape: [batch_size, context_length, output_vocab_size]
def configure_optimizers(self: t.Self) -> torch.optim.Optimizer:
return torch.optim.SGD(self.model.parameters(), lr=3e-4)
def step(
self: t.Self, batch: tuple[torch.LongTensor, ...], *, stage: str
) -> torch.FloatTensor:
input_ids, output_ids, target_ids, input_masks, output_masks = batch
# make predictions
preds = self(input_ids, output_ids, input_masks, output_masks)
# flatten to one long sequence and ignore padding in predictions/targets
output_masks = output_masks.flatten().bool()
preds = preds.flatten(end_dim=1)[output_masks]
target_ids = target_ids.flatten()[output_masks]
# calculate loss
loss = nn.functional.nll_loss(preds, target_ids)
self.log(f"{stage}_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
return loss
def training_step(
self: t.Self, batch: tuple[torch.LongTensor, ...]
) -> torch.FloatTensor:
return self.step(batch, stage="train")
def validation_step(
self: t.Self, batch: tuple[torch.LongTensor, ...]
) -> torch.FloatTensor:
return self.step(batch, stage="val")
def test_step(
self: t.Self, batch: tuple[torch.LongTensor, ...]
) -> torch.FloatTensor:
return self.step(batch, stage="test")