Skip to content

Commit

Permalink
Optimize implementation (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinakaduc authored Oct 6, 2024
1 parent 2f4a2d7 commit 00ff94e
Showing 1 changed file with 29 additions and 39 deletions.
68 changes: 29 additions & 39 deletions src/embed_text_package/embed_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def load(self, model_name: str):
:type which_model: str
"""
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name).to("cuda")
self.model = AutoModel.from_pretrained(model_name, device_map="auto")
self.which_model = model_name

def unload(self):
Expand Down Expand Up @@ -84,45 +84,35 @@ def get_embeddings(self, dataloader: DataLoader, model_name: str, cols: list):
col_emb = []
tqdm_dataloader = tqdm(dataloader)
for batch in tqdm_dataloader:
tqdm_dataloader.set_description(
f"Embedding sentences in '{col}' on '{self.model.device}'"
model_inputs = self.tokenizer(
batch[col],
add_special_tokens=False,
return_tensors="pt",
padding=True,
)

for sentence in batch[col]:
# 1) Get Tokens of sentence
tokenized_sentence = self.tokenizer(
sentence, add_special_tokens=False, return_tensors="pt"
)

# 2) Get Embeddings (hiddenstate of last input)
# Generate model inputs on same device as self.model
# att_mask is vector of ones: Attention on all tokens!

tokenized_sentence = {
k: v.to(self.model.device)
for k, v in tokenized_sentence.items()
}
# >>> sequence_length

# get embedding via forward function of main self.model.
###########################################################
# NOTE: For performance reasons, one could implement
# self.model.forward in vectorizedmanner.
# If you want to do that, keep padding in mind!
###########################################################
sentence_emb = (
self.model.forward(**tokenized_sentence)
.last_hidden_state[0][-1]
.squeeze()
.detach()
.cpu()
.tolist()
)
# >>> hidden_size

# Now just handle list structure.
col_emb.append(sentence_emb)
# >>> dataset_length x hidden_size
model_inputs = {
k: v.to(self.model.device) for k, v in model_inputs.items()
}
embeds = self.model(**model_inputs)

last_idxs = []
for i in range(embeds.last_hidden_state.size(0)):
if self.tokenizer.pad_token_id is None:
end_index = -1
else:
end_indexes = (
model_inputs["input_ids"][i] != self.tokenizer.pad_token_id
).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0

last_idxs.append(end_index)

embed_last_token = (
embeds.last_hidden_state[list(range(len(last_idxs))), last_idxs]
.cpu()
.tolist()
)
col_emb.extend(embed_last_token)

emb_dict[col] = col_emb
# >>> num_cols x dataset_length x hidden_size
Expand Down

0 comments on commit 00ff94e

Please sign in to comment.