diff --git a/src/embed_text_package/embed_text.py b/src/embed_text_package/embed_text.py index 403d6ee..de0c9c8 100644 --- a/src/embed_text_package/embed_text.py +++ b/src/embed_text_package/embed_text.py @@ -40,7 +40,7 @@ def load(self, model_name: str): :type which_model: str """ self.tokenizer = AutoTokenizer.from_pretrained(model_name) - self.model = AutoModel.from_pretrained(model_name).to("cuda") + self.model = AutoModel.from_pretrained(model_name, device_map="auto") self.which_model = model_name def unload(self): @@ -84,45 +84,35 @@ def get_embeddings(self, dataloader: DataLoader, model_name: str, cols: list): col_emb = [] tqdm_dataloader = tqdm(dataloader) for batch in tqdm_dataloader: - tqdm_dataloader.set_description( - f"Embedding sentences in '{col}' on '{self.model.device}'" + model_inputs = self.tokenizer( + batch[col], + add_special_tokens=False, + return_tensors="pt", + padding=True, ) - - for sentence in batch[col]: - # 1) Get Tokens of sentence - tokenized_sentence = self.tokenizer( - sentence, add_special_tokens=False, return_tensors="pt" - ) - - # 2) Get Embeddings (hiddenstate of last input) - # Generate model inputs on same device as self.model - # att_mask is vector of ones: Attention on all tokens! - - tokenized_sentence = { - k: v.to(self.model.device) - for k, v in tokenized_sentence.items() - } - # >>> sequence_length - - # get embedding via forward function of main self.model. - ########################################################### - # NOTE: For performance reasons, one could implement - # self.model.forward in vectorizedmanner. - # If you want to do that, keep padding in mind! - ########################################################### - sentence_emb = ( - self.model.forward(**tokenized_sentence) - .last_hidden_state[0][-1] - .squeeze() - .detach() - .cpu() - .tolist() - ) - # >>> hidden_size - - # Now just handle list structure. - col_emb.append(sentence_emb) - # >>> dataset_length x hidden_size + model_inputs = { + k: v.to(self.model.device) for k, v in model_inputs.items() + } + embeds = self.model(**model_inputs) + + last_idxs = [] + for i in range(embeds.last_hidden_state.size(0)): + if self.tokenizer.pad_token_id is None: + end_index = -1 + else: + end_indexes = ( + model_inputs["input_ids"][i] != self.tokenizer.pad_token_id + ).nonzero() + end_index = end_indexes[-1].item() if len(end_indexes) else 0 + + last_idxs.append(end_index) + + embed_last_token = ( + embeds.last_hidden_state[list(range(len(last_idxs))), last_idxs] + .cpu() + .tolist() + ) + col_emb.extend(embed_last_token) emb_dict[col] = col_emb # >>> num_cols x dataset_length x hidden_size