This is the code implementation for the paper: "Discffusion: Discriminative Diffusion Models as Few-shot Vision and Language Learners". The project is developed based on HuggingFace Diffusers.
Diffusion models, such as Stable Diffusion, have shown incredible performance on text-to-image generation. Since text-to-image generation often requires models to generate visual concepts with fine-grained details and attributes specified in text prompts, can we leverage the powerful representations learned by pre-trained diffusion models for discriminative tasks such as image-text matching? To answer this question, we propose a novel approach, Discriminative Stable Diffusion (Discffusion), which turns pre-trained text-to-image diffusion models into few-shot discriminative learners. Our approach mainly uses the cross-attention score of a Stable Diffusion model to capture the mutual influence between visual and textual information and fine-tune the model via efficient attention-based prompt learning to perform image-text matching. By comparing Discffusion with state-of-the-art methods on several benchmark datasets, we demonstrate the potential of using pre-trained diffusion models for discriminative tasks with superior results on few-shot image-text matching.
conda create -n dsd python=3.9
conda activate dsd
cd diffusers
pip install -e .
cd ..
pip install -r requirements.txt
We provide a jupyter notebook file to try the pipeline and visualize.
You can also launch the gradio demo to upload your own image to play by:
python playground.py
Download images from Visual Genome official websites. Put the Visual_Genome in the same path with Discffusion repo. Download comvg_train.csv from link and put it in the data repo. Or you can directly download com_vg.zip from link and unzip it and put it in the same path with Discffusion repo.
We use the RefCOCOg split. Download from RefCOCOg github repo https://github.com/lichengunc/refer. Put refcocog in the same path with Discffusion repo.
Download images from VQAv2 official websites. Put vqav2 in the same path with Discffusion repo. Download vqa_text_train.csv from link and put it in the data repo.
The repo strutures are:
├── DSD/
│ ├── data/
│ │ ├── comvg_train.csv
│ │ ├── vqa_text_train.csv
│ │ └── ...
├── refcocog/
│ ├── images/
│ │ ├── train2014/
│ │ │ ├── COCO_train2014_000000000009.jpg
│ │ │ └── ...
│ │ └── ...
│ ├── refs(google).p
│ └── instances.json
├── vqav2/
│ ├── images/
│ │ └── ...
│ ├── train2014/
│ │ └── ...
│ └── val2014/
│ └── ...
└── Visual_Genome/
├── VG_100K/
└── vg_concept/
└── 180.jpg
└── 411.jpg
└── 414.jpg
└── 1251.jpg
└── ...
ComVG | Refcocog | VQA |
---|---|---|
Download | Download | Download |
ComVG
accelerate config
accelerate launch dsd_infer.py --val_data ComVG_obj --batchsize 16 --sampling_time_steps 40 --output_dir downloaded_checkpoints
accelerate launch dsd_infer.py --val_data ComVG_verb --batchsize 16 --sampling_time_steps 40 --output_dir downloaded_checkpoints
accelerate launch dsd_infer.py --val_data ComVG_sub --batchsize 16 --sampling_steps 40 --output_dir downloaded_checkpoints
downloaded_checkpoints
is the folder path for your loaded ckpt, such as FOLDER_NAME/checkpoint-500000.
Refcocog
accelerate config
accelerate launch dsd_infer.py --val_data Refcocog --batchsize 16 --sampling_time_steps 10 --output_dir downloaded_checkpoints
VQAv2
accelerate config
accelerate launch dsd_infer.py --val_data vqa_binary --batchsize 16 --sampling_time_steps 200 --output_dir downloaded_checkpoints
accelerate launch dsd_infer.py --val_data vqa_other --batchsize 16 --sampling_time_steps 200 --output_dir downloaded_checkpoints
accelerate config
accelerate launch dsd_train.py --pretrained_model_name_or_path stabilityai/stable-diffusion-2-1-base --train_batch_size 1 --val_batch_size 4 --output_dir PATH --train_data TRAIN_DATA --val_data VAL_DATA --num_train_epochs EPOCH --learning_rate 1e-4
We set the accelerate config to use one GPU to do the training.
TRAIN_DATA
currently supports ComVG, Refcocog, vqa. You can add more in the custom_dataset
VAL_DATA
supports giving types, e.g. ComVG_obj/ComVG_verb/ComVG_sub, vqa_other/vqa_binary
Set --bias
if you want to train and inference using the cross-attention score from the diffusion model only.
For example:
accelerate config
accelerate launch dsd_train.py --pretrained_model_name_or_path stabilityai/stable-diffusion-2-1-base --train_batch_size 1 --val_batch_size 4 --bias --output_dir ./output --train_data ComVG --val_data ComVG_verb --num_train_epochs 1 --learning_rate 1e-4
accelerate config
accelerate launch dsd_infer.py --val_data ComVG_obj --bias --batchsize 16 --sampling_time_steps 30 --output_dir YOUR_SAVED_CKPTS
The code and dataset is built on Stable Diffusion, Diffusers, Diffusion-itm, Custom Diffusion, ComCLIP, Refer. We thank the authors for their model and code.
If you find it useful in your research or applications, please consider citing us, thanks!
@article{he2023discriminative,
title={Discriminative Diffusion Models as Few-shot Vision and Language Learners},
author={He, Xuehai and Feng, Weixi and Fu, Tsu-Jui and Jampani, Varun and Akula, Arjun and Narayana, Pradyumna and Basu, Sugato and Wang, William Yang and Wang, Xin Eric},
journal={arXiv preprint arXiv:2305.10722},
year={2023}
}