PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer
[arXiv
] [Yannic Kilcher's video
]
Our XCiT models has a linear complexity w.r.t number of patches/tokens:
Peak Memory (inference) | Millisecond/Image (Inference) |
XCiT can scale to high resolution inputs both due to cheaper compute requirement as well as better adaptability to higher resolution at test time (see Figure 3 in the paper)
XCiT+DINO: High Res. Self-Attention Visualization 🦖
Our XCiT models with self-supervised training using DINO can obtain high resolution attention maps.
xcit_dino.mp4
Self-Attention visualization per head
Below we show the attention maps for each of the 8 heads separately and we can observe that every head specializes in different semantic aspects of the scene for the foreground as well as the background.
Multi_head.mp4
First, clone the repo
git clone https://github.com/facebookresearch/XCiT.git
Then, you can install the required packages including: Pytorch version 1.7.1, torchvision version 0.8.2 and Timm version 0.4.8
pip install -r requirements.txt
Download and extract the ImageNet dataset. Afterwards, set the --data-path
argument to the corresponding extracted ImageNet path.
For full details about all the available arguments, you can use
python main.py --help
For detection and segmentation downstream tasks, please check:
-
COCO Object detection and Instance segmentation: XCiT Detection
-
ADE20k Semantic segmentation: XCiT Semantic Segmentation
We provide XCiT models pre-trained weights on ImageNet-1k.
§: distillation
Arch | params | Model | |||||
---|---|---|---|---|---|---|---|
224 | 224 § | 384 § | |||||
top-1 | weights | top-1 | weights | top-1 | weights | ||
xcit_nano_12_p16 | 3M | 69.9% | download | 72.2% | download | 75.4% | download |
xcit_tiny_12_p16 | 7M | 77.1% | download | 78.6% | download | 80.9% | download |
xcit_tiny_24_p16 | 12M | 79.4% | download | 80.4% | download | 82.6% | download |
xcit_small_12_p16 | 26M | 82.0% | download | 83.3% | download | 84.7% | download |
xcit_small_24_p16 | 48M | 82.6% | download | 83.9% | download | 85.1% | download |
xcit_medium_24_p16 | 84M | 82.7% | download | 84.3% | download | 85.4% | download |
xcit_large_24_p16 | 189M | 82.9% | download | 84.9% | download | 85.8% | download |
Arch | params | Model | |||||
---|---|---|---|---|---|---|---|
224 | 224 § | 384 § | |||||
top-1 | weights | top-1 | weights | top-1 | weights | ||
xcit_nano_12_p8 | 3M | 73.8% | download | 76.3% | download | 77.8% | download |
xcit_tiny_12_p8 | 7M | 79.7% | download | 81.2% | download | 82.4% | download |
xcit_tiny_24_p8 | 12M | 81.9% | download | 82.6% | download | 83.7% | download |
xcit_small_12_p8 | 26M | 83.4% | download | 84.2% | download | 85.1% | download |
xcit_small_24_p8 | 48M | 83.9% | download | 84.9% | download | 85.6% | download |
xcit_medium_24_p8 | 84M | 83.7% | download | 85.1% | download | 85.8% | download |
xcit_large_24_p8 | 189M | 84.4% | download | 85.4% | download | 86.0% | download |
Arch | params | k-nn | linear | download |
---|---|---|---|---|
xcit_small_12_p16 | 26M | 76.0% | 77.8% | backbone |
xcit_small_12_p8 | 26M | 77.1% | 79.2% | backbone |
xcit_medium_24_p16 | 84M | 76.4% | 78.8% | backbone |
xcit_medium_24_p8 | 84M | 77.9% | 80.3% | backbone |
For training using a single node, use the following command
python -m torch.distributed.launch --nproc_per_node=[NUM_GPUS] --use_env main.py --model [MODEL_KEY] --batch-size [BATCH_SIZE] --drop-path [STOCHASTIC_DEPTH_RATIO] --output_dir [OUTPUT_PATH]
For example, the XCiT-S12/16 model can be trained using the following command
python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model xcit_small_12_p16 --batch-size 128 --drop-path 0.05 --output_dir /experiments/xcit_small_12_p16/ --epochs [NUM_EPOCHS]
For multinode training via SLURM you can alternatively use
python run_with_submitit.py --partition [PARTITION_NAME] --nodes 2 --ngpus 8 --model xcit_small_12_p16 --batch-size 64 --drop-path 0.05 --job_dir /experiments/xcit_small_12_p16/ --epochs 400
More details for the hyper-parameters used to train the different models can be found in Table B.1 in the paper.
To evaluate an XCiT model using the checkpoints above or models you trained use the following command:
python main.py --eval --model <MODEL_KEY> --input-size <IMG_SIZE> [--full_crop] --pretrained <PATH/URL>
By default we use the --full_crop
flag which evaluates the model with a crop ratio of 1.0 instead of 0.875 following CaiT.
For example, the command to evaluate the XCiT-S12/16 using 224x224 images:
python main.py --eval --model xcit_small_12_p16 --input-size 384 --full_crop --pretrained https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224.pth
This repository is built using the Timm library and the DeiT repository. The self-supervised training is based on the DINO repository.
This repository is released under the Apache 2.0 license as found in the LICENSE file.
We actively welcome your pull requests! Please see CONTRIBUTING.md and CODE_OF_CONDUCT.md for more info.
If you find this repository useful, please consider citing our work:
@article{el2021xcit,
title={XCiT: Cross-Covariance Image Transformers},
author={El-Nouby, Alaaeldin and Touvron, Hugo and Caron, Mathilde and Bojanowski, Piotr and Douze, Matthijs and Joulin, Armand and Laptev, Ivan and Neverova, Natalia and Synnaeve, Gabriel and Verbeek, Jakob and others},
journal={arXiv preprint arXiv:2106.09681},
year={2021}
}