-
Notifications
You must be signed in to change notification settings - Fork 48
/
models.py
203 lines (172 loc) · 9.47 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# -*- coding: utf-8 -*-
# @Time : 2020/3/30 10:57
# @Author : Hui Wang
import torch
import torch.nn as nn
from modules import Encoder, LayerNorm
class S3RecModel(nn.Module):
def __init__(self, args):
super(S3RecModel, self).__init__()
self.item_embeddings = nn.Embedding(args.item_size, args.hidden_size, padding_idx=0)
self.attribute_embeddings = nn.Embedding(args.attribute_size, args.hidden_size, padding_idx=0)
self.position_embeddings = nn.Embedding(args.max_seq_length, args.hidden_size)
self.item_encoder = Encoder(args)
self.LayerNorm = LayerNorm(args.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(args.hidden_dropout_prob)
self.args = args
# add unique dense layer for 4 losses respectively
self.aap_norm = nn.Linear(args.hidden_size, args.hidden_size)
self.mip_norm = nn.Linear(args.hidden_size, args.hidden_size)
self.map_norm = nn.Linear(args.hidden_size, args.hidden_size)
self.sp_norm = nn.Linear(args.hidden_size, args.hidden_size)
self.criterion = nn.BCELoss(reduction='none')
self.apply(self.init_weights)
# AAP
def associated_attribute_prediction(self, sequence_output, attribute_embedding):
'''
:param sequence_output: [B L H]
:param attribute_embedding: [arribute_num H]
:return: scores [B*L tag_num]
'''
sequence_output = self.aap_norm(sequence_output) # [B L H]
sequence_output = sequence_output.view([-1, self.args.hidden_size, 1]) # [B*L H 1]
# [tag_num H] [B*L H 1] -> [B*L tag_num 1]
score = torch.matmul(attribute_embedding, sequence_output)
return torch.sigmoid(score.squeeze(-1)) # [B*L tag_num]
# MIP sample neg items
def masked_item_prediction(self, sequence_output, target_item):
'''
:param sequence_output: [B L H]
:param target_item: [B L H]
:return: scores [B*L]
'''
sequence_output = self.mip_norm(sequence_output.view([-1,self.args.hidden_size])) # [B*L H]
target_item = target_item.view([-1,self.args.hidden_size]) # [B*L H]
score = torch.mul(sequence_output, target_item) # [B*L H]
return torch.sigmoid(torch.sum(score, -1)) # [B*L]
# MAP
def masked_attribute_prediction(self, sequence_output, attribute_embedding):
sequence_output = self.map_norm(sequence_output) # [B L H]
sequence_output = sequence_output.view([-1, self.args.hidden_size, 1]) # [B*L H 1]
# [tag_num H] [B*L H 1] -> [B*L tag_num 1]
score = torch.matmul(attribute_embedding, sequence_output)
return torch.sigmoid(score.squeeze(-1)) # [B*L tag_num]
# SP sample neg segment
def segment_prediction(self, context, segment):
'''
:param context: [B H]
:param segment: [B H]
:return:
'''
context = self.sp_norm(context)
score = torch.mul(context, segment) # [B H]
return torch.sigmoid(torch.sum(score, dim=-1)) # [B]
#
def add_position_embedding(self, sequence):
seq_length = sequence.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=sequence.device)
position_ids = position_ids.unsqueeze(0).expand_as(sequence)
item_embeddings = self.item_embeddings(sequence)
position_embeddings = self.position_embeddings(position_ids)
sequence_emb = item_embeddings + position_embeddings
sequence_emb = self.LayerNorm(sequence_emb)
sequence_emb = self.dropout(sequence_emb)
return sequence_emb
def pretrain(self, attributes, masked_item_sequence, pos_items, neg_items,
masked_segment_sequence, pos_segment, neg_segment):
# Encode masked sequence
sequence_emb = self.add_position_embedding(masked_item_sequence)
sequence_mask = (masked_item_sequence == 0).float() * -1e8
sequence_mask = torch.unsqueeze(torch.unsqueeze(sequence_mask, 1), 1)
encoded_layers = self.item_encoder(sequence_emb,
sequence_mask,
output_all_encoded_layers=True)
# [B L H]
sequence_output = encoded_layers[-1]
attribute_embeddings = self.attribute_embeddings.weight
# AAP
aap_score = self.associated_attribute_prediction(sequence_output, attribute_embeddings)
aap_loss = self.criterion(aap_score, attributes.view(-1, self.args.attribute_size).float())
# only compute loss at non-masked position
aap_mask = (masked_item_sequence != self.args.mask_id).float() * \
(masked_item_sequence != 0).float()
aap_loss = torch.sum(aap_loss * aap_mask.flatten().unsqueeze(-1))
# MIP
pos_item_embs = self.item_embeddings(pos_items)
neg_item_embs = self.item_embeddings(neg_items)
pos_score = self.masked_item_prediction(sequence_output, pos_item_embs)
neg_score = self.masked_item_prediction(sequence_output, neg_item_embs)
mip_distance = torch.sigmoid(pos_score - neg_score)
mip_loss = self.criterion(mip_distance, torch.ones_like(mip_distance, dtype=torch.float32))
mip_mask = (masked_item_sequence == self.args.mask_id).float()
mip_loss = torch.sum(mip_loss * mip_mask.flatten())
# MAP
map_score = self.masked_attribute_prediction(sequence_output, attribute_embeddings)
map_loss = self.criterion(map_score, attributes.view(-1, self.args.attribute_size).float())
map_mask = (masked_item_sequence == self.args.mask_id).float()
map_loss = torch.sum(map_loss * map_mask.flatten().unsqueeze(-1))
# SP
# segment context
segment_context = self.add_position_embedding(masked_segment_sequence)
segment_mask = (masked_segment_sequence == 0).float() * -1e8
segment_mask = torch.unsqueeze(torch.unsqueeze(segment_mask, 1), 1)
segment_encoded_layers = self.item_encoder(segment_context,
segment_mask,
output_all_encoded_layers=True)
# take the last position hidden as the context
segment_context = segment_encoded_layers[-1][:, -1, :]# [B H]
# pos_segment
pos_segment_emb = self.add_position_embedding(pos_segment)
pos_segment_mask = (pos_segment == 0).float() * -1e8
pos_segment_mask = torch.unsqueeze(torch.unsqueeze(pos_segment_mask, 1), 1)
pos_segment_encoded_layers = self.item_encoder(pos_segment_emb,
pos_segment_mask,
output_all_encoded_layers=True)
pos_segment_emb = pos_segment_encoded_layers[-1][:, -1, :]
# neg_segment
neg_segment_emb = self.add_position_embedding(neg_segment)
neg_segment_mask = (neg_segment == 0).float() * -1e8
neg_segment_mask = torch.unsqueeze(torch.unsqueeze(neg_segment_mask, 1), 1)
neg_segment_encoded_layers = self.item_encoder(neg_segment_emb,
neg_segment_mask,
output_all_encoded_layers=True)
neg_segment_emb = neg_segment_encoded_layers[-1][:, -1, :] # [B H]
pos_segment_score = self.segment_prediction(segment_context, pos_segment_emb)
neg_segment_score = self.segment_prediction(segment_context, neg_segment_emb)
sp_distance = torch.sigmoid(pos_segment_score - neg_segment_score)
sp_loss = torch.sum(self.criterion(sp_distance,
torch.ones_like(sp_distance, dtype=torch.float32)))
return aap_loss, mip_loss, map_loss, sp_loss
# Fine tune
# same as SASRec
def finetune(self, input_ids):
attention_mask = (input_ids > 0).long()
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.int64
max_len = attention_mask.size(-1)
attn_shape = (1, max_len, max_len)
subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1) # torch.uint8
subsequent_mask = (subsequent_mask == 0).unsqueeze(1)
subsequent_mask = subsequent_mask.long()
if self.args.cuda_condition:
subsequent_mask = subsequent_mask.cuda()
extended_attention_mask = extended_attention_mask * subsequent_mask
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
sequence_emb = self.add_position_embedding(input_ids)
item_encoded_layers = self.item_encoder(sequence_emb,
extended_attention_mask,
output_all_encoded_layers=True)
sequence_output = item_encoded_layers[-1]
return sequence_output
def init_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.args.initializer_range)
elif isinstance(module, LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()