-
Notifications
You must be signed in to change notification settings - Fork 5
/
luke.diff
84 lines (72 loc) · 2.82 KB
/
luke.diff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
diff --git a/luke/model.py b/luke/model.py
index 214e027..70ba1a6 100644
--- a/luke/model.py
+++ b/luke/model.py
@@ -5,12 +5,12 @@ from typing import Dict
import torch
import torch.nn.functional as F
from torch import nn
+from torch.nn import LayerNorm as BertLayerNorm
from transformers.modeling_bert import (
BertConfig,
BertEmbeddings,
BertEncoder,
BertIntermediate,
- BertLayerNorm,
BertOutput,
BertPooler,
BertSelfOutput,
@@ -102,6 +102,10 @@ class LukeModel(nn.Module):
):
word_seq_size = word_ids.size(1)
+ if torch.cuda.is_available():
+ word_ids = word_ids.to("cuda:0")
+ word_segment_ids = word_segment_ids.to("cuda:0")
+
embedding_output = self.embeddings(word_ids, word_segment_ids)
attention_mask = self._compute_extended_attention_mask(word_attention_mask, entity_attention_mask)
diff --git a/luke/utils/entity_vocab.py b/luke/utils/entity_vocab.py
index 445fdd5..47f343b 100644
--- a/luke/utils/entity_vocab.py
+++ b/luke/utils/entity_vocab.py
@@ -17,6 +17,8 @@ from .interwiki_db import InterwikiDB
PAD_TOKEN = "[PAD]"
UNK_TOKEN = "[UNK]"
MASK_TOKEN = "[MASK]"
+HEAD_TOKEN = "[HEAD]"
+TAIL_TOKEN = "[TAIL]"
SPECIAL_TOKENS = {PAD_TOKEN, UNK_TOKEN, MASK_TOKEN}
diff --git a/luke/utils/model_utils.py b/luke/utils/model_utils.py
index 865190d..f187e47 100644
--- a/luke/utils/model_utils.py
+++ b/luke/utils/model_utils.py
@@ -8,7 +8,7 @@ from typing import Dict
import click
import torch
-from luke.model import LukeConfig
+from ..model import LukeConfig
from .entity_vocab import EntityVocab
from .word_tokenizer import AutoTokenizer
@@ -59,10 +59,11 @@ def create_model_archive(model_file: str, out_file: str, compress: str):
class ModelArchive(object):
- def __init__(self, state_dict: Dict[str, torch.Tensor], metadata: dict, entity_vocab: EntityVocab):
+ def __init__(self, state_dict: Dict[str, torch.Tensor], metadata: dict, entity_vocab: EntityVocab, archive_path: str = None):
self.state_dict = state_dict
self.metadata = metadata
self.entity_vocab = entity_vocab
+ self.archive_path = archive_path
@property
def bert_model_name(self):
@@ -74,6 +75,9 @@ class ModelArchive(object):
@property
def tokenizer(self):
+ # First try to load tokenizer from a local path
+ if self.archive_path is not None:
+ return AutoTokenizer.from_pretrained(self.archive_path)
return AutoTokenizer.from_pretrained(self.bert_model_name)
@property
@@ -107,4 +111,4 @@ class ModelArchive(object):
metadata = json.load(metadata_file)
entity_vocab = EntityVocab(get_entity_vocab_file_path(path))
- return ModelArchive(state_dict, metadata, entity_vocab)
+ return ModelArchive(state_dict, metadata, entity_vocab, path)