Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train and embed with pretokenized input? #2935

Open
shijy16 opened this issue Sep 13, 2024 · 3 comments
Open

Train and embed with pretokenized input? #2935

shijy16 opened this issue Sep 13, 2024 · 3 comments

Comments

@shijy16
Copy link

shijy16 commented Sep 13, 2024

Hi, I'm new to NLP, and I am currently trying to finetune jina for text similarity comparison.
I construct a dataset with columns sentence1, sentence2 and score. And I can easily train the model with SentenceTransformerTrainer and CoSENTLOSS then.
But I found the tokenizing process time-consuming, as there are many duplicate sentences in the pairs. For example, for these two pairs, [A, a, 1.0] and [A, B, 0.0]. My code need to tokenize A twice.
I've found the following code for embedding with tokenized inputs:

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

sentences = [
    'How do I access the index while iterating over a sequence with a for loop?',
    '# Use the built-in enumerator\nfor idx, x in enumerate(xs):\n    print(idx, x)',
]

tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-code')
model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-code', trust_remote_code=True)

encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

with torch.no_grad():
    model_output = model(**encoded_input)

embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
embeddings = F.normalize(embeddings, p=2, dim=1)

But I am still not sure how to train with pre-tokenzied inputs. I am not sure if this is the best place to ask my questiong, any help would be appreciated!

@ir2718
Copy link
Contributor

ir2718 commented Sep 14, 2024

Hi,

I think this would do the trick:

from transformers import AutoTokenizer
from datasets import Dataset
from sentence_transformers.models import Transformer, Pooling
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers import (
    SentenceTransformer, 
    SentenceTransformerTrainingArguments, 
   SentenceTransformerTrainer
)
import torch
from torch.nn.utils.rnn import pad_sequence

model_name = "bert-base-cased"

tokenizer = AutoTokenizer.from_pretrained(model_name)

sents = [
    ("Extremely Long Question zero" * 400, "Zeroth answer", 1),
    ("First question is longer so it requires padding", "First answer", 1),
    ("Second question", "Second answer", 1),
    ("Third question", "Third answer", 1),
    ("Fourth question", "Fourth answer", 1),
    ("Fifth question", "Fifth answer", 1),
    ("Sixth question", "Sixth answer", 1),
]

pretokenized = {
    "sentence1": [],
    "sentence2": [],
    "score": [],
}
for q, a, s in sents:
    tokd_q = tokenizer(q)
    tokd_a = tokenizer(a)
    pretokenized["sentence1"].append(tokd_q)
    pretokenized["sentence2"].append(tokd_a)
    pretokenized["score"].append(s)

train_dataset = Dataset.from_dict(pretokenized)

#################################################

class HomemadeTransformer(Transformer):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pad_dict = {
            "token_type_ids": 0,
            "attention_mask": 0,
            "input_ids": tokenizer.pad_token_id
        }

    def tokenize(
        self, texts, padding=True
    ):  
        keys = texts[0].keys()
        tok_dict = {
            k:[] for k in set(keys)
        }
        for k in keys:
            for v in texts:
                tok_dict[k].append(torch.tensor(v[k]))

        for k in keys:
            tok_dict[k] = pad_sequence(
                tok_dict[k], 
                batch_first=True,
                padding_value=self.pad_dict[k]
            )
            
        # now we do truncating:
        # -> some input is equal to the max length
        #
        #     1. if its equal to the max length it doesnt matter
        #     2. if its longer, last token will be [PAD], which should be
        #        replaced with [SEP]
        #
        max_len = min(tok_dict["input_ids"].size(-1), tokenizer.model_max_length)
        if max_len == tokenizer.model_max_length:
            for i in range(tok_dict["input_ids"].size(0)):
                if tok_dict["input_ids"][i][-1] != tokenizer.pad_token_id:
                    tok_dict["input_ids"][i][tokenizer.model_max_length-1] = tokenizer.sep_token_id

            tok_dict["input_ids"] = tok_dict["input_ids"][:, :tokenizer.model_max_length]
            tok_dict["token_type_ids"] = tok_dict["token_type_ids"][:, :tokenizer.model_max_length]
            tok_dict["attention_mask"] = tok_dict["attention_mask"][:, :tokenizer.model_max_length]

        return tok_dict
    
transformer = HomemadeTransformer(model_name)
pooling = Pooling(
    word_embedding_dimension=transformer.get_word_embedding_dimension(),
    pooling_mode_cls_token=False,
    pooling_mode_max_tokens=False,
    pooling_mode_mean_tokens=True,
)
st = SentenceTransformer(
    modules=[transformer, pooling],
)
##################################################

loss = MultipleNegativesRankingLoss(st)

args = SentenceTransformerTrainingArguments(
    output_dir=f"train_dir_{model_name}",
    num_train_epochs=200,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    logging_steps=5,
)
trainer = SentenceTransformerTrainer(
    model=st,
    args=args,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()

Unfortunately the _create_model_card function raises a ValueError which I wasn't able to circumvent, but training works as expected.

Hope this helps.

@shijy16
Copy link
Author

shijy16 commented Sep 18, 2024

@ir2718
Thank you for your help! I tried to use your HomemadeTransformer to load my model, but there are some warnings when loading my model with it, such as some Bert parameters will not be initialized or something. I am not sure if this will affect the train process.
So I finally spent a few hours to learn and write self-defined pytorch train code. The new training code completely gets rid of SentenceTransformer. Now the training process seems okay.
Thank again for your code! I've learned a lot about the usage of Transformer from it.

@ir2718
Copy link
Contributor

ir2718 commented Sep 18, 2024

some Bert parameters will not be initialized

I wasn't getting that error, make but I'm glad my comment helped in solving your problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants