-
Notifications
You must be signed in to change notification settings - Fork 3
/
models.py
35 lines (29 loc) · 1.17 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from torch import nn
from utils import load_embeddings, normalize_embeddings
def build_model(params, with_dis):
"""
Build all components of the model.
"""
# source embeddings
src_dico, _src_emb = load_embeddings(params, source=True)
params.src_dico = src_dico
src_emb = nn.Embedding(len(src_dico), params.emb_dim, sparse=True)
src_emb.weight.data.copy_(_src_emb)
# target embeddings
if params.tgt_lang:
tgt_dico, _tgt_emb = load_embeddings(params, source=False)
params.tgt_dico = tgt_dico
tgt_emb = nn.Embedding(len(tgt_dico), params.emb_dim, sparse=True)
tgt_emb.weight.data.copy_(_tgt_emb)
else:
tgt_emb = None
# mapping
mapping = nn.Linear(params.emb_dim, params.emb_dim, bias=False)
if getattr(params, 'map_id_init', True):
mapping.weight.data.copy_(torch.diag(torch.ones(params.emb_dim)))
# normalize embeddings
params.src_mean = normalize_embeddings(src_emb.weight.data, params.normalize_embeddings)
if params.tgt_lang:
params.tgt_mean = normalize_embeddings(tgt_emb.weight.data, params.normalize_embeddings)
return src_emb, tgt_emb, mapping