v3.0.0 - Sentence Transformer Training Refactor; new similarity methods; hyperparameter optimization; 50+ datasets release
This release consists of a major refactor that overhauls the training approach (introducing multi-gpu training, bf16, loss logging, callbacks, and much more), adds convenient similarity
and similarity_pairwise
methods, adds extra keyword arguments, introduces Hyperparameter Optimization, and includes a massive reformatting and release of 50+ datasets for training embedding models. In total, this is the largest Sentence Transformers update since the project was first created.
Install this version with
# Full installation:
pip install sentence-transformers[train]==3.0.0
# Inference only:
pip install sentence-transformers==3.0.0
Sentence Transformer training refactor (#2449)
The v3.0 release centers around this huge modernization of the training approach for SentenceTransformer
models. Whereas training before v3.0 used to be all about InputExample
, DataLoader
and model.fit
, the new training approach relies on 5 new components. You can learn more about these components in our Training and Finetuning Embedding Models with Sentence Transformers v3 blogpost. Additionally, you can read the new Training Overview, check out the Training Examples, or read this summary:
- Dataset
A trainingDataset
orDatasetDict
. This class is much more suited for sharing & efficient modifications than lists/DataLoaders ofInputExample
instances. ADataset
can contain multiple text columns that will be fed in order to the corresponding loss function. So, if the loss expects (anchor, positive, negative) triplets, then your dataset should also have 3 columns. The names of these columns are irrelevant. If there is a "label" or "score" column, it is treated separately, and used as the labels during training.
ADatasetDict
can be used to train with multiple datasets at once, e.g.:When aDatasetDict({ multi_nli: Dataset({ features: ['premise', 'hypothesis', 'label'], num_rows: 392702 }) snli: Dataset({ features: ['snli_premise', 'hypothesis', 'label'], num_rows: 549367 }) stsb: Dataset({ features: ['sentence1', 'sentence2', 'label'], num_rows: 5749 }) })
DatasetDict
is used, theloss
parameter to theSentenceTransformerTrainer
must also be a dictionary with these dataset keys, e.g.:{ 'multi_nli': SoftmaxLoss(...), 'snli': SoftmaxLoss(...), 'stsb': CosineSimilarityLoss(...), }
- Loss Function
A loss function, or a dictionary of loss functions like described above. These loss functions do not require changes compared to before this PR. - Training Arguments
A SentenceTransformerTrainingArguments instance, subclass of a TrainingArguments instance. This powerful class controls the specific details of the training. - Evaluator
An optionalSentenceEvaluator
instance. Unlike before, models can now be evaluated both on an evaluation dataset with some loss function and/or aSentenceEvaluator
instance. - Trainer
The newSentenceTransformersTrainer
instance based on thetransformers
Trainer
. This instance is provided with a SentenceTransformer model, a SentenceTransformerTrainingArguments class, a SentenceEvaluator, a training and evaluation Dataset/DatasetDict and a loss function/dict of loss functions. Most of these parameters are optional. Once provided, all you have to do is calltrainer.train()
.
Some of the major features that are now implemented include:
- MultiGPU Training (Data Parallelism (DP) and Distributed Data Parallelism (DDP))
- bf16 training support
- Loss logging
- Evaluation datasets + evaluation loss
- Improved callback support (built-in via Weights and Biases, TensorBoard, CodeCarbon, etc., as well as custom callbacks)
- Gradient checkpointing
- Gradient accumulation
- Improved model card generation
- Warmup ratio
- Pushing to the Hugging Face Hub on every model checkpoint
- Resuming from a training checkpoint
- Hyperparameter Optimization
This script is a minimal example (no evaluator, no training arguments) of training mpnet-base
on a part of the all-nli
dataset using MultipleNegativesRankingLoss
:
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer
from sentence_transformers.losses import MultipleNegativesRankingLoss
# 1. Load a model to finetune
model = SentenceTransformer("microsoft/mpnet-base")
# 2. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/all-nli", "triplet")
train_dataset = dataset["train"].select(range(10_000))
eval_dataset = dataset["dev"].select(range(1_000))
# 3. Define a loss function
loss = MultipleNegativesRankingLoss(model)
# 4. Create a trainer & train
trainer = SentenceTransformerTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
)
trainer.train()
# 5. Save the trained model
model.save_pretrained("models/mpnet-base-all-nli")
Additionally, trained models now automatically produce extensive model cards. Each of the following models were trained using some script from the Training Examples, and the model cards were not edited manually whatsoever:
- tomaarsen/mpnet-base-all-nli-triplet
- tomaarsen/stsb-distilbert-base-mnrl-cl-multi
- tomaarsen/distilroberta-base-paraphrases-multi
Prior to the Sentence Transformer v3 release, all models would be trained using the SentenceTransformer.fit
method. Rather than deprecating this method, starting from v3.0, this method will use the SentenceTransformerTrainer
behind the scenes. This means that your old training code should still work, and should even be upgraded with the new features such as multi-gpu training, loss logging, etc. That said, the new training approach is much more powerful, so it is recommended to write new training scripts using the new approach.
Many of the old training scripts were updated to use the new Trainer-based approach, but not all have been updated yet. We accept help via Pull Requests to assist in updating the scripts.
Similarity Score (#2615, #2490)
Sentence Transformers v3.0 introduces two new useful methods:
and one property:
These can be used to calculate the similarity between embeddings, and to specify which similarity function should be used, for example:
>>> from sentence_transformers import SentenceTransformer
>>> model = SentenceTransformer("all-mpnet-base-v2")
>>> sentences = [
... "The weather is so nice!",
... "It's so sunny outside.",
... "He's driving to the movie theater.",
... "She's going to the cinema.",
... ]
>>> embeddings = model.encode(sentences, normalize_embeddings=True)
>>> model.similarity(embeddings, embeddings)
tensor([[1.0000, 0.7235, 0.0290, 0.1309],
[0.7235, 1.0000, 0.0613, 0.1129],
[0.0290, 0.0613, 1.0000, 0.5027],
[0.1309, 0.1129, 0.5027, 1.0000]])
>>> model.similarity_fn_name
"cosine"
>>> model.similarity_fn_name = "euclidean"
>>> model.similarity(embeddings, embeddings)
tensor([[-0.0000, -0.7437, -1.3935, -1.3184],
[-0.7437, -0.0000, -1.3702, -1.3320],
[-1.3935, -1.3702, -0.0000, -0.9973],
[-1.3184, -1.3320, -0.9973, -0.0000]])
Additionally, you can compute the similarity between pairs of embeddings, resulting in a 1-dimensional vector of similarities rather than a 2-dimensional matrix:
>>> model = SentenceTransformer("all-mpnet-base-v2")
>>> sentences = [
... "The weather is so nice!",
... "It's so sunny outside.",
... "He's driving to the movie theater.",
... "She's going to the cinema.",
... ]
>>> embeddings = model.encode(sentences, normalize_embeddings=True)
>>> model.similarity_pairwise(embeddings[::2], embeddings[1::2])
tensor([0.7235, 0.5027])
>>> model.similarity_fn_name
"cosine"
>>> model.similarity_fn_name = "euclidean"
>>> model.similarity_pairwise(embeddings[::2], embeddings[1::2])
tensor([-0.7437, -0.9973])
The similarity_fn_name
can now be specified via the SentenceTransformer
like so:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("sentence-transformers/multi-qa-mpnet-base-dot-v1", similarity_fn_name="dot")
Valid options include "cosine" (default), "dot", "euclidean", "manhattan". The chosen similarity_fn_name
will also be saved into the model configuration, and loaded automatically. For example, the msmarco-distilbert-dot-v5
model was trained to work best with dot
, so we've configured it to use that similarity_fn_name
in its configuration:
>>> from sentence_transformers import SentenceTransformer
>>> model = SentenceTransformer("sentence-transformers/msmarco-distilbert-dot-v5")
>>> model.similarity_fn_name
'dot'
Big thanks to @ir2718 for helping set up this major feature.
Allow passing model_kwargs
, tokenizer_kwargs
, and config_kwargs
to SentenceTransformer
(#2578)
To those familiar with the internals of Sentence Transformers, you might know that internally, we call AutoModel.from_pretrained
, AutoTokenizer.from_pretrained
and AutoConfig.from_pretrained
from transformers
.
Each of these are rather powerful, and they are constantly improved with new features. For example, the AutoModel
keyword arguments include:
torch_dtype
- this allows you to immediately load a model inbfloat16
orfloat16
(or"auto"
, i.e. whatever the model was stored in), which can speed up inference a lot.quantization_config
attn_implementation
- all models support "eager", but some also support the much faster "fa2" (Flash Attention 2) and "sdpa" (Scaled Dot Product Attention).
These options allow for speeding up the model inference. Additionally, via AutoConfig
you can update the model configuration, e.g. updating the dropout probability during training, and with AutoTokenizer
you can disable the fast Rust-based tokenizer if you're having issues with it via use_fast=False
.
Due to how useful these options can be, the following arguments are added to SentenceTransformer
:
model_kwargs
forAutoModel.from_pretrained
keyword argumentstokenizer_kwargs
forAutoTokenizer.from_pretrained
keyword argumentsconfig_kwargs
forAutoConfig.from_pretrained
keyword arguments
You can use it like so:
from sentence_transformers import SentenceTransformer
import torch
model = SentenceTransformer(
"mixedbread-ai/mxbai-embed-large-v1",
model_kwargs={"torch_dtype": torch.bfloat16, "attn_implementation": "sdpa"},
config_kwargs={"hidden_dropout_prob": 0.3},
)
embeddings = model.encode(["He drove his yellow car to the beach.", "He played football with his friends."])
print(embeddings.shape)
Big thanks to @satyamk7054 for starting this work.
Hyperparameter Optimization (#2655)
Sentence Transformers v3.0 introduces Hyperparameter Optimization (HPO) by extending the transformers
HPO support. We recommend reading the all new Hyperparameter Optimization for many more details.
Datasets Release
Alongside Sentence Transformers v3.0, we reformat and release 50+ useful datasets in our Embedding Model Datasets Collection on Hugging Face. These can be used with at least one loss function in Sentence Transformers v3.0 out of the box. We recommend browsing through these to see if there are datasets akin to your use cases - training a model on them might just produce large gains on your task(s).
MSELoss extension (#2641)
The MSELoss now accepts multiple text columns for each label (where each label is a target/gold embedding), rather than only accepting one text column. This is extremely powerful for following the excellent Multilingual Models strategy to convert a monolingual model into a multilingual one. You can now conveniently train both English and (identical but translated) non-English texts to represent the same embedding (that was generated by a powerful English embedding model).
Add local_files_only
argument to SentenceTransformer & CrossEncoder (#2603)
You can now initialize a SentenceTransformer
and CrossEncoder
with local_files_only
. If True
, then it will not try and download a model from Hugging Face, it will only look in the local filesystem for the model or try and load it from a cache.
Thanks @debanjum for this change.
All changes
- Minor grammar fix in GPL paragraph by @mauricesvp in #2604
- [feat] Add local_files_only argument to load model from cache by @debanjum in #2603
- Fix broken links by @mauricesvp in #2611
- Updated urls for msmarco dataset by @j-dominguez9 in #2609
- [
v3
] Training refactor - MultiGPU, loss logging, bf16, etc. by @tomaarsen in #2449 - [
v3
] Addsimilarity
andsimilarity_pairwise
methods to Sentence Transformers by @tomaarsen in #2615 - [
v3
] Fix various model card errors by @tomaarsen in #2616 - [
v3
] Fix trainercompute_loss
when evaluating/predicting if theloss
updated the inputs in-place by @tomaarsen in #2617 - [
v3
] Never return None in infer_datasets, could result in crash by @tomaarsen in #2620 - [
v3
] Trainer: Implement resume from checkpoint support by @tomaarsen in #2621 - Fall back to CPU device in case there are no PyTorch parameters by @maxfriedrich in #2614
- Add
trust_remote_code
toCrossEncoder.tokenizer
by @michaelfeil in #2623 - [
v3
] Update example scripts to the new v3 training format by @tomaarsen in #2622 - Bug in DenoisingAutoEncoderLoss.py by @arun477 in #2619
- [
v3
] Remove "return_outputs" as it's not strictly necessary. Avoids OOM & speeds up training by @tomaarsen in #2633 - [
v3
] Fix crash from inferring the dataset_id from a local dataset by @tomaarsen in #2636 - Enable Sentence Transformer Inference with Intel Gaudi2 GPU Supported ( 'hpu' ) - Follow up for #2557 by @ZhengHongming888 in #2630
- [
v3
] Fix multilingual conversion script; extend MSELoss to multi-column by @tomaarsen in #2641 - [
v3
] Update evaluation scripts to use HF Datasets by @tomaarsen in #2642 - Use
b1
quantization for USearch by @ashvardanian in #2644 - [
v3
] Fixresume_from_checkpoint
by also updating the loss model by @tomaarsen in #2648 - [
v3
] Fix backwards pass on MSELoss due to in-place update by @tomaarsen in #2647 - [
v3
] Simplifyload_from_checkpoint
usingload_state_dict
by @tomaarsen in #2650 - [
v3
] Usetorch.arange
instead oftorch.tensor(range(...))
by @tomaarsen in #2651 - [
v3
] Resolve inplace modification error in DDP by @tomaarsen in #2654 - [
v3
] Add hyperparameter optimization support by lettingloss
be a Callable that accepts amodel
by @tomaarsen in #2655 - [
v3
] Add tag hinting at the number of training samples by @tomaarsen in #2660 - Allow passing 'precision' when using 'encode_multi_process' to SentenceTransformer by @ariel-talent-fabric in #2659
- Allow passing model_args to ST by @satyamk7054 in #2578
- Fix smart_batching_collate Inefficiency by @PrithivirajDamodaran in #2556
- [
v3
] For the Cached losses; ignore gradients if grad is disabled (e.g. eval) by @tomaarsen in #2668 - [
docs
] Rewrite the https://sbert.net documentation by @tomaarsen in #2632 - [
v3
] Chore - include import sorting in ruff by @tomaarsen in #2672 - [
v3
] Prevent warning with 'model.fit' with transformers >= 4.41.0 due to evaluation_strategy by @tomaarsen in #2673 - [
v3
] Add various useful Sphinx packages (copy code, link to code, nicer tabs) by @tomaarsen in #2674 - [
v3
] Make the "primary_metric" for evaluators a bit more robust by @tomaarsen in #2675 - [
v3
] Setbroadcast_buffers = False
when training with DDP by @tomaarsen in #2663 - [
v3
] Warn about using DP instead of DDP + set dataloader_drop_last with DDP by @tomaarsen in #2677 - [
v3
] Add warning that Evaluators only run on 1 GPU when multi-GPU training by @tomaarsen in #2678 - [
v3
] Move training dependencies into a "train" extra by @tomaarsen in #2676 - [
v3
] Docs: update references to the API reference by @tomaarsen in #2679 - [
v3
] Add "dataset_size:" to the tag denoting the number of training samples by @tomaarsen in #2680
New Contributors
- @mauricesvp made their first contribution in #2604
- @debanjum made their first contribution in #2603
- @j-dominguez9 made their first contribution in #2609
- @michaelfeil made their first contribution in #2623
- @arun477 made their first contribution in #2619
- @ashvardanian made their first contribution in #2644
- @ariel-talent-fabric made their first contribution in #2659
- @satyamk7054 made their first contribution in #2578
- @PrithivirajDamodaran made their first contribution in #2556
A special shoutout to @Jakobhenningjensen, @smerrill, @b5y, @ScottishFold007, @pszemraj, @bwanglzu, @igorkurinnyi, for experimenting with the v3.0 release prior to release and @matthewfranglen for the initial work on the training refactor back in October of 2022 in #1733.
cc @AlexJonesNLP as I know you are interested in this release!
Full Changelog: v2.7.0...v3.0.0