- Accepted at EMNLP 2023 (Main Track) | Paper | Slides |
- ASR Generative Error Correction by leveraging foundational Audio (Whisper) and Language (LLaMA) models.
- Fusing Whisper Encoder and LLaMA decoder
We introduce a novel cross-modal fusion technique designed for generative error correction for Automatic Speech Recognition. In an oversimplified sense, We leverage In-Context learning to feed the n-best hypothesis produced by an Acoustic model into a Large Language model and prompt it to predict the most accurate sentence, as shown below.
We propose a novel mechanism to fuse the acoustic features from the audio input into the LLM to significantly enhance the performance (28.83% -> 37.66% WERR) by leveraging an Audio Foundational model as a feature extractor. We further design our system in a parameter-efficient manner with only 7.97M trainable parameters as shown below. Please refer to the paper [YET] for further information.
Clone the repo
git clone https://github.com/Srijith-rkr/Whispering-LLaMA
cd WHISPERing-LLaMA
And use the environment.yml file to install dependencies with Anaconda.
conda env create -f environment.yml
Or you can also use the requirements.txt as
pip install -r requirements.txt
- To obtain the pre-trained Alpaca weights, please refer here. You can then use convert_hf_checkpoint.py to rename the state_dict the lit-llama implementation
- Or you can use the Alpaca weights hosted in HuggingFace Huggin Face/Whispering-LLaMA. Refer to demo.py on how to use them.
You are all set! 🎉
We have uploaded our N-best Hypotheses dataset generated using Whisper-Tiny on Hugging Face PeacefulData. The hypotheses were generated using the Hugging Face GigaSpeech dataset M subset. You will be able to map the hypothesis on our dataset with the audio clips from the Gigaspeeh dataset using the 'ID' tag.
The model and tokenizer weights are hosted in Huggin Face/Whispering-LLaMA for easier setup. You can refer to demo.py on how to use them.
Please refer to :
-
data_preparation to generate your custom n-best hypothesis dataset
-
training/WL-M.py to train the best our best model on your dataset
-
Inference/WL-M.py to run inference
-
Once you setup your dataset, You can train your models as
python training/WL-S.py --lr 1e-3 --d 1 --pretrained_path 'weights/alpaca.pth' --tokenizer_path 'weights/tokenizer.model' --data 'path to your dataset'
You can configure the following flags.
--lr: learning rate (1e-3 is recommended)
--d: Number of GPUs you are using to run the DDP strategy (You can uncomment lines in the code to switch to DeepSpeed)
--pretrained_path: Path to the Alpaca model weights
--tokenizer_path: Path to the LLaMA tokenizer
--data: Path to your dataset
This implementation builds on
-
lit-llama for the Training pipeline.
-
stanford_alpaca for the pre-trained instruction following Language model.
-
Whisper to obtain acoustic embeddings.
If you consider this work would be related or useful for your research, please consider to cite this paper. Thank you!
@inproceedings{radhakrishnan2023whispering,
title={Whispering LLaMA: A Cross-Modal Generative Error Correction Framework for Speech Recognition},
author={Srijith Radhakrishnan, Chao-Han Huck Yang, Sumeer Ahmad Khan, Rohit Kumar, Narsis A. Kiani, David Gomez-Cabrero, Jesper N. Tegner},
booktitle={Proc. of EMNLP},
year={2023}
}