forked from dqwang122/HeterSumGraph
-
Notifications
You must be signed in to change notification settings - Fork 0
/
HiGraph.py
255 lines (210 loc) · 10.3 KB
/
HiGraph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
#!/usr/bin/python
# -*- coding: utf-8 -*-
# __author__="Danqing Wang"
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import torch
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
import dgl
# from module.GAT import GAT, GAT_ffn
from module.Encoder import sentEncoder
from module.GAT import WSWGAT
from module.PositionEmbedding import get_sinusoid_encoding_table
class HSumGraph(nn.Module):
""" without sent2sent and add residual connection """
def __init__(self, hps, embed):
"""
:param hps:
:param embed: word embedding
"""
super().__init__()
self._hps = hps
self._n_iter = hps.n_iter
self._embed = embed
self.embed_size = hps.word_emb_dim
# sent node feature
self._init_sn_param()
self._TFembed = nn.Embedding(10, hps.feat_embed_size) # box=10
self.n_feature_proj = nn.Linear(hps.n_feature_size * 2, hps.hidden_size, bias=False)
# word -> sent
embed_size = hps.word_emb_dim
self.word2sent = WSWGAT(in_dim=embed_size,
out_dim=hps.hidden_size,
num_heads=hps.n_head,
attn_drop_out=hps.atten_dropout_prob,
ffn_inner_hidden_size=hps.ffn_inner_hidden_size,
ffn_drop_out=hps.ffn_dropout_prob,
feat_embed_size=hps.feat_embed_size,
layerType="W2S"
)
# sent -> word
self.sent2word = WSWGAT(in_dim=hps.hidden_size,
out_dim=embed_size,
num_heads=6,
attn_drop_out=hps.atten_dropout_prob,
ffn_inner_hidden_size=hps.ffn_inner_hidden_size,
ffn_drop_out=hps.ffn_dropout_prob,
feat_embed_size=hps.feat_embed_size,
layerType="S2W"
)
# node classification
self.n_feature = hps.hidden_size
self.wh = nn.Linear(self.n_feature, 2)
def forward(self, graph):
"""
:param graph: [batch_size] * DGLGraph
node:
word: unit=0, dtype=0, id=(int)wordid in vocab
sentence: unit=1, dtype=1, words=tensor, position=int, label=tensor
edge:
word2sent, sent2word: tffrac=int, type=0
:return: result: [sentnum, 2]
"""
# word node init
word_feature = self.set_wnfeature(graph) # [wnode, embed_size]
sent_feature = self.n_feature_proj(self.set_snfeature(graph)) # [snode, n_feature_size]
# the start state
word_state = word_feature
sent_state = self.word2sent(graph, word_feature, sent_feature)
for i in range(self._n_iter):
# sent -> word
word_state = self.sent2word(graph, word_state, sent_state)
# word -> sent
sent_state = self.word2sent(graph, word_state, sent_state)
result = self.wh(sent_state)
return result
def _init_sn_param(self):
self.sent_pos_embed = nn.Embedding.from_pretrained(
get_sinusoid_encoding_table(self._hps.doc_max_timesteps + 1, self.embed_size, padding_idx=0),
freeze=True)
self.cnn_proj = nn.Linear(self.embed_size, self._hps.n_feature_size)
self.lstm_hidden_state = self._hps.lstm_hidden_state
self.lstm = nn.LSTM(self.embed_size, self.lstm_hidden_state, num_layers=self._hps.lstm_layers, dropout=0.1,
batch_first=True, bidirectional=self._hps.bidirectional)
if self._hps.bidirectional:
self.lstm_proj = nn.Linear(self.lstm_hidden_state * 2, self._hps.n_feature_size)
else:
self.lstm_proj = nn.Linear(self.lstm_hidden_state, self._hps.n_feature_size)
self.ngram_enc = sentEncoder(self._hps, self._embed)
def _sent_cnn_feature(self, graph, snode_id):
ngram_feature = self.ngram_enc.forward(graph.nodes[snode_id].data["words"]) # [snode, embed_size]
graph.nodes[snode_id].data["sent_embedding"] = ngram_feature
snode_pos = graph.nodes[snode_id].data["position"].view(-1) # [n_nodes]
position_embedding = self.sent_pos_embed(snode_pos)
cnn_feature = self.cnn_proj(ngram_feature + position_embedding)
return cnn_feature
def _sent_lstm_feature(self, features, glen):
pad_seq = rnn.pad_sequence(features, batch_first=True)
lstm_input = rnn.pack_padded_sequence(pad_seq, glen, batch_first=True)
lstm_output, _ = self.lstm(lstm_input)
unpacked, unpacked_len = rnn.pad_packed_sequence(lstm_output, batch_first=True)
lstm_embedding = [unpacked[i][:unpacked_len[i]] for i in range(len(unpacked))]
lstm_feature = self.lstm_proj(torch.cat(lstm_embedding, dim=0)) # [n_nodes, n_feature_size]
return lstm_feature
def set_wnfeature(self, graph):
wnode_id = graph.filter_nodes(lambda nodes: nodes.data["unit"]==0)
wsedge_id = graph.filter_edges(lambda edges: edges.data["dtype"] == 0) # for word to supernode(sent&doc)
wid = graph.nodes[wnode_id].data["id"] # [n_wnodes]
w_embed = self._embed(wid) # [n_wnodes, D]
graph.nodes[wnode_id].data["embed"] = w_embed
etf = graph.edges[wsedge_id].data["tffrac"]
graph.edges[wsedge_id].data["tfidfembed"] = self._TFembed(etf)
return w_embed
def set_snfeature(self, graph):
# node feature
snode_id = graph.filter_nodes(lambda nodes: nodes.data["dtype"] == 1)
cnn_feature = self._sent_cnn_feature(graph, snode_id)
features, glen = get_snode_feat(graph, feat="sent_embedding")
lstm_feature = self._sent_lstm_feature(features, glen)
node_feature = torch.cat([cnn_feature, lstm_feature], dim=1) # [n_nodes, n_feature_size * 2]
return node_feature
class HSumDocGraph(HSumGraph):
"""
without sent2sent and add residual connection
add Document Nodes
"""
def __init__(self, hps, embed):
super().__init__(hps, embed)
self.dn_feature_proj = nn.Linear(hps.hidden_size, hps.hidden_size, bias=False)
self.wh = nn.Linear(self.n_feature * 2, 2)
def forward(self, graph):
"""
:param graph: [batch_size] * DGLGraph
node:
word: unit=0, dtype=0, id=(int)wordid in vocab
sentence: unit=1, dtype=1, words=tensor, position=int, label=tensor
document: unit=1, dtype=2
edge:
word2sent, sent2word: tffrac=int, type=0
word2doc, doc2word: tffrac=int, type=0
sent2doc: type=2
:return: result: [sentnum, 2]
"""
snode_id = graph.filter_nodes(lambda nodes: nodes.data["dtype"] == 1)
dnode_id = graph.filter_nodes(lambda nodes: nodes.data["dtype"] == 2)
supernode_id = graph.filter_nodes(lambda nodes: nodes.data["unit"] == 1)
# word node init
word_feature = self.set_wnfeature(graph) # [wnode, embed_size]
sent_feature = self.n_feature_proj(self.set_snfeature(graph)) # [snode, n_feature_size]
# sent and doc node init
graph.nodes[snode_id].data["init_feature"] = sent_feature
doc_feature, snid2dnid = self.set_dnfeature(graph)
doc_feature = self.dn_feature_proj(doc_feature)
graph.nodes[dnode_id].data["init_feature"] = doc_feature
# the start state
word_state = word_feature
sent_state = graph.nodes[supernode_id].data["init_feature"]
sent_state = self.word2sent(graph, word_state, sent_state)
for i in range(self._n_iter):
# sent -> word
word_state = self.sent2word(graph, word_state, sent_state)
# word -> sent
sent_state = self.word2sent(graph, word_state, sent_state)
graph.nodes[supernode_id].data["hidden_state"] = sent_state
# extract sentence nodes
s_state_list = []
for snid in snode_id:
d_state = graph.nodes[snid2dnid[int(snid)]].data["hidden_state"]
s_state = graph.nodes[snid].data["hidden_state"]
s_state = torch.cat([s_state, d_state], dim=-1)
s_state_list.append(s_state)
s_state = torch.cat(s_state_list, dim=0)
result = self.wh(s_state)
return result
def set_dnfeature(self, graph):
""" init doc node by mean pooling on the its sent node (connected by the edges with type=1) """
dnode_id = graph.filter_nodes(lambda nodes: nodes.data["dtype"] == 2)
node_feature_list = []
snid2dnid = {}
for dnode in dnode_id:
snodes = [nid for nid in graph.predecessors(dnode) if graph.nodes[nid].data["dtype"]==1]
doc_feature = graph.nodes[snodes].data["init_feature"].mean(dim=0)
assert not torch.any(torch.isnan(doc_feature)), "doc_feature_element"
node_feature_list.append(doc_feature)
for s in snodes:
snid2dnid[int(s)] = dnode
node_feature = torch.stack(node_feature_list)
return node_feature, snid2dnid
def get_snode_feat(G, feat):
glist = dgl.unbatch(G)
feature = []
glen = []
for g in glist:
snode_id = g.filter_nodes(lambda nodes: nodes.data["dtype"] == 1)
feature.append(g.nodes[snode_id].data[feat])
glen.append(len(snode_id))
return feature, glen