Skip to content

A PyTorch implementation of the paper - "Synthesizer: Rethinking Self-Attention in Transformer Models"

License

Notifications You must be signed in to change notification settings

10-zin/Synthesizer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

26 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Synthesizer

A PyTorch implementation of the paper : Synthesizer: Rethinking Self-Attention in Transformer Models - Yi Tay, Dara Bahri, Donald Metzler, Da-Cheng Juan, Zhe Zhao, Che Zheng

The paper majorly proposes two efficient variants of scaled dot product attention in the regular Transformers.

The snapshot from the paper below perfectly illustrates the difference.

Variants

This repository currently consists of the implementations for the following variants:

  1. Vanilla Attention (regular scaled dot product attention based Transformer)
  2. Dense Attention
  3. Factorized Dense Attention
  4. Random Attention
  5. Factorized Random Attention

Usage

WMT'16 Multimodal Translation: de-en

An example of training for the WMT'16 Multimodal Translation task (http://www.statmt.org/wmt16/multimodal-task.html).

0) Create venv, install requirements and move to the synth directory

python3 -m venv synth-env
source synth-env/bin/activate
pip install -r requirements.txt
cd synth/

1) Download the spacy language model.

# conda install -c conda-forge spacy 
python -m spacy download en
python -m spacy download de

2) Preprocess the data with torchtext and spacy.

python preprocess.py -lang_src de -lang_trg en -share_vocab -save_data m30k_deen_shr.pkl

3) Train the model

python train.py -data_pkl m30k_deen_shr.pkl -log log_dense_1 -embs_share_weight -proj_share_weight -label_smoothing -save_model trained_dense_1 -b 8 -warmup 128000 -n_head 2 -n_layers 2 -attn_type dense -epoch 25

4) Test the model

python translate.py -data_pkl m30k_deen_shr.pkl -model trained.chkpt -output prediction.txt

Comparisons

  • The following graphs demonstrate the comparative performance of synthesizer(dense, random) and transformer(vanilla).

  • Due to lesser compute (1 Nvidia RTX260 super) I have just tested with a configuration of 2 heads, 2 layers and a batch size of 8. However, that is enough to estimate the comparative performance.

  • In alignment with the findings of the paper, Dense attention seems to perform comparably with the vanilla attention for machine translation task. Surprisingly, even random attention (Fixed) performs well.

  • Train time per epoch was 0.9min(random) < 1.15min(dense) < 1.2min(vanilla).

Results are viewed in this notebook, after training and storing the weights of 3 variants.

Todo

  1. Debugging and testing of the factorized versions of synthesizer.
  2. Proper Inference pipeline.
  3. Further systematic comparative monitoring, like time in training/inference.
  4. Implementing other attention variants proposed in the paper like CNN based attentions.
  5. Testing synthesizer on other downstream tasks.

Acknowledgement

About

A PyTorch implementation of the paper - "Synthesizer: Rethinking Self-Attention in Transformer Models"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published