A PyTorch implementation of the paper : Synthesizer: Rethinking Self-Attention in Transformer Models - Yi Tay, Dara Bahri, Donald Metzler, Da-Cheng Juan, Zhe Zhao, Che Zheng
The paper majorly proposes two efficient variants of scaled dot product attention in the regular Transformers.
The snapshot from the paper below perfectly illustrates the difference.
This repository currently consists of the implementations for the following variants:
- Vanilla Attention (regular scaled dot product attention based Transformer)
- Dense Attention
- Factorized Dense Attention
- Random Attention
- Factorized Random Attention
An example of training for the WMT'16 Multimodal Translation task (http://www.statmt.org/wmt16/multimodal-task.html).
python3 -m venv synth-env
source synth-env/bin/activate
pip install -r requirements.txt
cd synth/
# conda install -c conda-forge spacy
python -m spacy download en
python -m spacy download de
python preprocess.py -lang_src de -lang_trg en -share_vocab -save_data m30k_deen_shr.pkl
python train.py -data_pkl m30k_deen_shr.pkl -log log_dense_1 -embs_share_weight -proj_share_weight -label_smoothing -save_model trained_dense_1 -b 8 -warmup 128000 -n_head 2 -n_layers 2 -attn_type dense -epoch 25
python translate.py -data_pkl m30k_deen_shr.pkl -model trained.chkpt -output prediction.txt
-
The following graphs demonstrate the comparative performance of synthesizer(dense, random) and transformer(vanilla).
-
Due to lesser compute (1 Nvidia RTX260 super) I have just tested with a configuration of 2 heads, 2 layers and a batch size of 8. However, that is enough to estimate the comparative performance.
-
In alignment with the findings of the paper, Dense attention seems to perform comparably with the vanilla attention for machine translation task. Surprisingly, even random attention (Fixed) performs well.
-
Train time per epoch was 0.9min(random) < 1.15min(dense) < 1.2min(vanilla).
Results are viewed in this notebook, after training and storing the weights of 3 variants.
- Debugging and testing of the factorized versions of synthesizer.
- Proper Inference pipeline.
- Further systematic comparative monitoring, like time in training/inference.
- Implementing other attention variants proposed in the paper like CNN based attentions.
- Testing synthesizer on other downstream tasks.
- The general transformer backbone is heavily borrowed from the amazing repository attention-is-all-you-need-pytorch by Yu-Hsiang Huang
- The byte pair encoding parts are borrowed from subword-nmt.
- The project structure, some scripts and the dataset preprocessing steps are heavily borrowed from OpenNMT/OpenNMT-py.