🔥 Cross-lingual safety generalization: This is the first work to demonstrate preference tuning for toxicity mitigation can generalize cross-lingually in a zero-shot manner. We evaluated on 17 different languages and different LLMs (such as BLOOM, Llama3, and Aya-23), all of which shows cross-lingual detoxification after English DPO preference tuning.
🔍 Mechanistic findings: We show that the dual multilinguality of toxic vectors (in MLP layers) explains the cross-lingual generalization. We find that the toxic vectors in MLPs encode multilingual toxic concepts, and we can control the output toxicity level by controlling the activation levels of those vectors. We then show that English DPO reduces activation levels of toxic vectors across languages.
- Create a conda environment with python version 3.11
conda create --name xgdetox python=3.11
conda activate xgdetox
- Install poetry and other dependencies with poetry. (Make sure you are at project's root directory, where pyproject.toml locates.)
pip install poetry
poetry install
-
Training (Toxicity Pairwise Data): Download the
toxicity_pairwise.zip
data from here (Source: Mechanistically Understanding DPO: Toxicity). -
Evaluation (RTP-LX): Follow instructions from Microsoft to download the dataset of RTP-LX input prompts. It will contain files of
RTP-LX/RTP_LX_{language}.json
. Our repo and experiments use the dataset released in Apr'24 (May'24 works too).
To perform DPO preference tuning (with or without LoRA), simply follow the following code example:
python3 xg/training/dpo.py \
--data_dir /path/to/toxicity_pairwise/ \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--output_dir /path/to/save/model_ckpt \
--per_device_train_batch_size 4 \
--wandb_run_name your_wandb_runname \
--use_lora # remove this line if you want to do full model finetuning
After DPO training, you can directly use the model checkpoint from /path/to/save/model_ckpt/final_checkpoint/
.
However, because parameter-efficient training with LoRA adapters save the adapters, use the following code to merge the LoRA adapters and save the model weight. This helps with vLLM library for generation stage (at the time we design the code, there are bugs with loading LoRA weights so it is more straightforward to pass the merged model instead of base model + lora weights).
python3 xg/training/merge_peft.py \
--base_model_name meta-llama/Llama-2-7b-hf \
--lora_adapter /path/to/save/model_ckpt/final_checkpoint \
--output_dir /path/to/save/merge_final_checkpoint
We have uploaded our trained models to HuggingFace Hub:
We use the vLLM library to obtain model continuations. We recommend user follow their installation instruction before running the following generation code. Our code saves the vLLM generations as /path/to/save/outputs/{MODEL_NAME}/output-rtp_lx_{LANG}.json
PROMPT_FILE=/path/to/RTP-LX/RTP_LX_ZH-Hans.json # you can change the language to other languages than ZH-Hans
python3 xg/generate/vllm_script_sample.py \
--prompt_file $PROMPT_FILE \
--model /path/to/save/merge_final_checkpoint \ # or /path/to/save/model_ckpt/final_checkpoint (if you do full finetuning)
--output_dir /path/to/save/outputs
-
Toxicity: First run the
xg/eval/perspective_api_eval.py
to save the toxicity scores from Perspective API. Then runxg/eval/metric_toxicity.py
to aggregate the scores. -
Fluency: Run the
xg/eval/metric_perplexity.py
script to compute median conditional perplexity with themT5-xl
model. It will also save the array of all perplexity scores. -
Diversity: Run the
xg/eval/metric_diversity.py
script.
MODEL_OUTPUTS_FOLDER=... # vllm generations folder (/path/to/save/outputs/{MODEL_NAME})
############### toxicity ###############
# call Perspective API
LANGS=( ar cs de en es fr hi id it ja ko nl pl pt ru sv zh-hans )
for LANG in "${LANGS[@]}"
do
echo "Processing $LANG"
python3 xg/eval/perspective_api_eval.py \
--api_key ... \ # YOUR API KEY
--datapath "${MODEL_OUTPUTS_FOLDER}/output-rtp_lx_${LANG}.json" \
--output_folder "${MODEL_OUTPUTS_FOLDER}/perspective_api_eval/" \
--language $LANG
done
# aggregate toxicity scores
PERSPECTIVE_OUTPUTS_FOLDER=${MODEL_OUTPUTS_FOLDER}/perspective_api_eval
python3 xg/eval/metric_toxicity.py \
--perspective_outputs_folder $PERSPECTIVE_OUTPUTS_FOLDER
############### fluency ###############
python3 xg/eval/metric_perplexity.py \
--model_outputs_folder $MODEL_OUTPUTS_FOLDER
############### diversity ###############
python3 xg/eval/metric_diversity.py \
--model_outputs_folder $MODEL_OUTPUTS_FOLDER
Download the Jigsaw dataset from Kaggle.
To train a linear probe for binary toxic classification, follow these steps:
- Replace the train_fp variable with the path to the train split of the Jigsaw dataset.
- Run the provided script.
All hyperparameters are pre-configured in the script file.
python scripts/run_train_probe.py
We first identify the potential sources of toxicity by selecting the top 100 value vectors based on their cosine similarities with the probe vector. Then, we collect the corresponding neuron activations averaged across the next 20 tokens generated from the English RTP-LX prompt. The value vectors are retained if their corresponding neuron activations are positive during the forward pass. We found 36 value vectors meeting these criteria, and they are stored here. We then project them onto the vocabulary space to interpret the tokens they promote when activated. More details can be found in this notebook.
To better understand these sub-updates, we directly intervene in their corresponding and inspect the changes they induce. We provide a minimal experiment demonstrating how such interventions are conducted in this notebook. The same code can be used to quantitatively understand the effect of the changes we exert on the neuron activations across all prompts from differnt langauges.
This script can be used to collect neuron activations before and after preference tuning across different languages. We also provide the precomputed results here. The reproduce Figure 3 in the paper, see this notebook.
Data: Since that RTP-LX prompts are not aligned (see Issue), we translate 200 prompts with Google Translate API so we have multiway parallel RTP-LX prompts. This is stored at assets/translated_pairwise_data
.
We first use xg/retrieval/retrieval_acc_save.py
to save the per-layer representations for parallel sentence pairs in English and lang2
language. Then, we use xg/retrieval/retrieval_acc_load.py
to load and calculate the bilingual sentence retrieval accuracy between English and lang2
.
LANG2="ar"
for i in "0 50" "50 100" "100 150" "150 200" # process in batches to avoid OOM
do
set -- $i # Convert the "tuple" into the param args $1 $2...
python3 xg/retrieval/retrieval_acc_save.py \
--lang2 $LANG2 \
--begin $1 \
--end $2 \
--model_name "ai-forever/mGPT"
done
python3 xg/retrieval/retrieval_acc_load.py \
--lang2 $LANG2
@article{li2024preference,
title={Preference Tuning For Toxicity Mitigation Generalizes Across Languages},
author={Li, Xiaochen and Yong, Zheng-Xin and Bach, Stephen H},
journal={arXiv preprint arXiv:2406.16235},
year={2024}
}