Source codes for research paper:
Task-level Distributionally Robust Optimization for Large Language Model-based Dense Retrieval, Guangyuan Ma, Yongliang Ma, Xing Wu, Zhenpeng Su, Ming Zhou and Songlin Hu.
This paper proposes a new task-level Distributionally Robust Optimization (tDRO) algorithm for Large Language Model-based Dense Retrieval (LLM-DR) fine-tuning, targeted at improving the universal domain generalization ability by end-to-end reweighting the data distribution of each task.
Please install Faiss-GPU by following their guidelines. Then you can easily set up the environment by cloning this repo, and runing the following command.
pip install -e .
Base Model | Uniform Sampling Baselines | tDRO: Dataset Selection Top-70% | tDRO: Sample Ratio Reweighting |
---|---|---|---|
Qwen1.5-0.5B | s0-baseline-Qwen1.5-0.5B | s2-tdro-Qwen1.5-0.5B-top70 | s2-tdro-Qwen1.5-0.5B-curr |
Qwen1.5-1.8B | s0-baseline-Qwen1.5-1.8B | s2-tdro-Qwen1.5-1.8B-top70 | s2-tdro-Qwen1.5-1.8B-curr |
Qwen1.5-4B | s0-baseline-Qwen1.5-4B | s2-tdro-Qwen1.5-4B-top70 | s2-tdro-Qwen1.5-4B-curr |
Qwen1.5-7B | s0-baseline-Qwen1.5-7B | s2-tdro-Qwen1.5-7B-top70 | s2-tdro-Qwen1.5-7B-curr |
Mistral-7B-v0.1 | s0-baseline-Mistral-7B-v0.1 | s2-tdro-Mistral-7B-v0.1-top70 | s2-tdro-Mistral-7B-v0.1-curr |
Llama-3-8B | s0-baseline-Llama-3-8B | s2-tdro-Llama-3-8B-top70 | s2-tdro-Llama-3-8B-curr |
Dataset Release: tdro-llm/finetune_data
A total of 25 heterogeneous retrieval fine-tuning datasets with Hard Negatives and Deduplication (with test sets) are used as the fine-tuning collections of our experiments. Please refer to the above Dataset Cards for details.
- Dataset sources: All Datasets are made from open-sourced retrieval fine-tuning collections. Most of them (except several multilingual or Chinese datasets) are originated from Sentence Transformers Training Data. Please find their references at the Reference column of Dataset Cards.
- Language: 21 datasets are mono-lingual English datasets. 2 datasets (DuReader and T2Ranking) are mono-lingual Chinese datasets. And 2 datasets (MIRACL and Mr.Tydi) are multilingual datasets.
- Category and Symmetry: In order to enable the diversity of heterogeneous collections, the fine-tuning data covers 13 categories and 2 symmetry.
- Format: The format and source of training triples are also listed in the Dataset Cards, which follows the basic format of (Query, Postive, Negatives).
- HN Mine: All datasets have been processed with Hard Negative (HN) Mining. For 4 multilingual or Chinese datasets (MIRACL, Mr.Tydi, DuReader and T2Ranking), we directly use the originally provided HN. For MS-MARCO Passage Rankining, NQ and Trivia datasets, we follow the data preparation scripts provided with bowdpr. For AllNLI and Quora duplicates triplets, we directly use the negatives from Sentence Transformers Training Data. For the remaining mono-lingual English datasets, we utilize the bge-base-en-v1.5 retriever to mine the hard negatives. Please follow
data/HN_mine.md
to reproduce our HN mine pipeline. - Deduplication: To avoid test label leak on the training collections, we deduplicate all training datasets with SimHash. Please refer to
data/inspect_duplicates.py
for detailed deduplication implemention.
The whole training procedure for tDRO involves 3 stages.
Script: s0_train_baseline_model.sh
First, train a baseline model (which is also the reference model) with uniform weight sampling. Please refer to the above script to reproduce. The model will be saved in results/s0_train_baseline_model
with the above script by default.
Note: Our LLM-based retrievers rely on last token pooling, which requires add a </eos>
at the end of tokenized texts. However, Qwen1.5
& LLaMA3
base models can NOT add a </eos>
correctly when tokenizer.add_eos_token==True
.
We have modified the post_processor
in coresponding tokenizer.json
to support add </eos>
. Please use the modified tokenizer tokenizer.json
in following files:
- Qwen1.5:
scripts/qwen1.5-tokenizer.json
- LLaMA3:
scripts/llama3-8b-tokenizer.json
Script: s1_tdro.sh
Task-level Distributionally Robust Optimization (tDRO) optimizes over heterogeneous training collections to find robust weights for contrastive fine-tuning a retriever. This stage requires a proxy model (which interleaves its own update with weight updates) and a trained reference model (same size with the proxy model). After tDRO Optimization, several robust optimized weights are saved in the model output dictionary results/s1_tdro
by default:
curr_weights.json
: Coresponding to tDRO: Sample Ratio Reweighting in the paper. The final weights at the last step of tDRO.topxx_weights.json
: Coresponding to tDRO: Dataset Selection Top-xx% in the paper. The top xx% tasks with a uniform sampling weights. In our paper, we take thetop70_weights.json
for optimal performances.
Additionally, the averaged weights mean_weights.json
or EMA-averaged weights ema_weights.json
over all tDRO steps are also provided for references. These two weights are not used in our experiments.
Script: s2_train_optimized_model.sh
Robust retriever models are trained with the same hyperparameter settings with baseline model. The only difference is that robust models adapt the tDRO optimized weights from the above stage.
curr_weights.json
: To use the final weights at the last step of tDRO, please callbash s2_train_optimized_model.sh curr
.top70_weights.json
: To use the top 70% tasks with a uniform sampling weights, please callbash s2_train_optimized_model.sh top70
.
Please refer to eval/README.md
for more details.
If you encounter any bugs or questions, please feel free to email me or open an issue.
Contacts: Guangyuan Ma (maguangyuan@iie.ac.cn)
Our codebase are inspired by several excellent open-sourced projects. We want to give special thanks to bowdpr, Tevatron, COCO-DR, DoReMi, Mistral-E5, FlagEmbedding, and so on.
If you are interested in our work, please consider citing our paper.
@article{ma2024tdro,
author = {Guangyuan Ma and
Yongliang Ma and
Xing Wu and
Zhenpeng Su and
Ming Zhou and
Songlin Hu},
title = {Task-level Distributionally Robust Optimization for Large Language
Model-based Dense Retrieval},
journal = {CoRR},
volume = {abs/2408.10613},
year = {2024},
url = {https://doi.org/10.48550/arXiv.2408.10613},
doi = {10.48550/ARXIV.2408.10613},
eprinttype = {arXiv},
eprint = {2408.10613},
timestamp = {Tue, 24 Sep 2024 17:36:32 +0200},
}
tDRO is licensed under the Apache License.