Skip to content

Commit

Permalink
✨ update custom vocabulary
Browse files Browse the repository at this point in the history
  • Loading branch information
kaylode committed May 4, 2023
1 parent 6bdc0d3 commit ed1d3f2
Showing 1 changed file with 127 additions and 27 deletions.
154 changes: 127 additions & 27 deletions theseus/nlp/base/preprocessors/vocabulary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ def __init__(
max_size=None,
min_freq=None,
max_freq=None,
special_tokens={},
special_tokens=None,
replace=False,
pkl_path=None,
unk_word="<unk>",
pad_word="<pad>",
sos_word="<sos>",
eos_word="<eos>",
):

self.pkl_path = pkl_path
Expand All @@ -27,13 +29,30 @@ def __init__(
self.max_size = max_size
self.unk_word = unk_word
self.pad_word = pad_word
self.sos_word = sos_word
self.eos_word = eos_word

self.init_vocab()
if self.pkl_path is not None:
with open(self.pkl_path, "rb") as f:
vocab = pickle.load(f)
self.word2idx = vocab.word2idx
self.idx2word = vocab.idx2word
self.load_pickle(self.pkl_path)

def load_pickle(self, vocab_path):
with open(vocab_path, "rb") as f:
vocab = pickle.load(f)
self.word2idx = vocab.word2idx
self.idx2word = vocab.idx2word
self.frequency = vocab.frequency
self.special_tokens = vocab.special_tokens
self.replace = vocab.replace
self.min_freq = vocab.min_freq
self.max_freq = vocab.max_freq
self.max_size = vocab.max_size
self.unk_word = vocab.unk_word
self.pad_word = vocab.pad_word
self.sos_word = vocab.sos_word
self.eos_word = vocab.eos_word
self.vocab_size = vocab.vocab_size

LOGGER.text(
"Vocabulary successfully loaded from vocab.pkl file!",
level=LoggerObserver.INFO,
Expand All @@ -54,7 +73,6 @@ def save_vocab(self, save_path):

def build_vocab(self, list_tokens):
"""Populate the dictionaries for converting tokens to integers (and vice-versa)."""

for tok in list_tokens:
if not tok in self.frequency:
self.frequency[tok] = 0
Expand All @@ -77,17 +95,18 @@ def build_vocab(self, list_tokens):
if self.max_size is not None:
list_tokens = list_tokens[: self.max_size]

self.add_special_tokens()
for tok in list_tokens:
self.add_word(tok)

self.add_special_tokens()

def init_vocab(self):
"""Initialize the dictionaries for converting tokens to integers (and vice-versa)."""
self.word2idx = {}
self.idx2word = {}
self.frequency = {}
self.idx = 0
self.vocab_size = 0
if self.special_tokens is None:
self.special_tokens = {}

def add_word(self, word, index=None):
"""Add a token to the vocabulary."""
Expand All @@ -98,18 +117,18 @@ def add_word(self, word, index=None):
assert isinstance(index, int), "Index must be type int"

if index is None:
index = self.idx
index = self.vocab_size

if not word in self.word2idx.keys() and not index in self.idx2word.keys():
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
self.word2idx[word] = self.vocab_size
self.idx2word[self.vocab_size] = word
self.vocab_size += 1
elif not word in self.word2idx.keys() and index in self.idx2word.keys():
if self.replace:
old_word = self.idx2word[index]
self.word2idx[old_word] = self.idx
self.idx2word[self.idx] = old_word
self.idx += 1
self.word2idx[old_word] = self.vocab_size
self.idx2word[self.vocab_size] = old_word
self.vocab_size += 1

self.word2idx[word] = index
self.idx2word[index] = word
Expand Down Expand Up @@ -140,38 +159,119 @@ def add_word(self, word, index=None):
raise ValueError()

def add_special_tokens(self):
if self.unk_word not in self.special_tokens.keys():
self.special_tokens.update({self.unk_word: self.idx})
self.idx += 1
if self.sos_word not in self.special_tokens.keys():
self.add_word(self.sos_word)
self.special_tokens.update({self.sos_word: self.vocab_size})

if self.eos_word not in self.special_tokens.keys():
self.add_word(self.eos_word)
self.special_tokens.update({self.eos_word: self.vocab_size})

if self.pad_word not in self.special_tokens.keys():
self.special_tokens.update({self.pad_word: self.idx})
self.idx += 1
self.add_word(self.pad_word)
self.special_tokens.update({self.pad_word: self.vocab_size})

for token, index in self.special_tokens.items():
self.add_word(token, index)
if self.unk_word not in self.special_tokens.keys():
self.add_word(self.unk_word)
self.special_tokens.update({self.unk_word: self.vocab_size})

def get_pad_token_id(self):
return self.word2idx[self.pad_word]

def get_unk_token_id(self):
return self.word2idx[self.unk_word]

def encode_tokens(self, lists_of_tokens):
def get_sos_token_id(self):
return self.word2idx[self.sos_word]

def get_eos_token_id(self):
return self.word2idx[self.eos_word]

def encode_tokens(self, lists_of_tokens, **kwargs):
"""
Batch of list of tokens
"""

add_special_tokens = (kwargs.get("add_special_tokens", False),)
max_length = kwargs.get("max_length", None)
return_token_type_ids = kwargs.get("return_token_type_ids", False)
truncation = kwargs.get("truncation", False)

if return_token_type_ids:
token_type_idss = []

if max_length is None:
max_length = max([len(x) for x in lists_of_tokens])

encoded_list = []
for token_list in lists_of_tokens:
batch = []
if add_special_tokens:
batch = [self.__call__(self.sos_word)]
else:
batch = []
for token in token_list:
batch.append(self.__call__(token))

if add_special_tokens:
batch.append(self.__call__(self.eos_word))

if max_length is not None:
if len(batch) > max_length:
if truncation:
if add_special_tokens:
batch = batch[: max_length - 2]
batch.append(self.__call__(self.eos_word))
else:
batch = batch[:max_length]
else:
LOGGER.text(
f"Sequence is longer than max_length. Please use truncation=True",
level=LoggerObserver.ERROR,
)
raise ValueError()
if len(batch) < max_length and add_special_tokens:
batch += [self.__call__(self.pad_word)] * (max_length - len(batch))

if return_token_type_ids:
token_type_ids = [
0 if batch[tk] != self.__call__(self.pad_word) else 1
for tk in range(len(batch))
]
token_type_idss.append(token_type_ids)

encoded_list.append(batch)
return encoded_list

if return_token_type_ids:
return {"input_ids": encoded_list, "token_type_ids": token_type_idss}
else:
return {
"input_ids": encoded_list,
}

def decode_tokens(self, list_of_ids):
"""
Batch of list of ids
"""
decoded_list = []
for ids in list_of_ids:
batch = [
self.itos(idx)
for idx in ids
if idx not in [self.pad_word, self.sos_word, self.eos_word]
]
decoded_list.append(batch)
return decoded_list

def encode_texts(self, text, **kwargs):
if isinstance(text, str):
text = [text]

tokenized_texts = [s.split(kwargs.get("delimeter", " ")) for s in text]
return self.encode_tokens(tokenized_texts, **kwargs)

def itos(self, idx):
if not idx in self.idx2word:
return self.idx2word[self.unk_word]
return self.idx2word[self.__call__(self.unk_word)]
return self.idx2word[idx]

def __call__(self, word):
Expand Down

0 comments on commit ed1d3f2

Please sign in to comment.