This is the codebase for Consistency Models, implemented using JAX for our experiments on CIFAR-10. We have based our repository on yang-song/score_sde, which was released under the Apache-2.0 license. We have modified the code to streamline diffusion model training, with additional implementations for consistency distillation, consistency training, and various sampling & editing algorithms included in the paper.
For code and checkpoints for experiments on ImageNet-64, LSUN Bedroom-256, and LSUN Cat-256, check openai/consistency_models.
We have released checkpoints for the main models in the paper. Before using these models, please review the corresponding model card to understand the intended use and limitations of these models.
Here are the download links for each model checkpoint:
- EDM on CIFAR-10: edm_cifar10_ema
- CD on CIFAR-10 with l1 metric: cd-l1
- CD on CIFAR-10 with l2 metric: cd-l2
- CD on CIFAR-10 with LPIPS metric: cd-lpips
- CT on CIFAR-10 with adaptive schedules and LPIPS metric: ct-lpips
- Continuous-time CD on CIFAR-10 with l2 metric: cifar10-continuous-cd-l2
- Continuous-time CD on CIFAR-10 with l2 metric and stopgrad: cifar10-continuous-cd-l2-stopgrad
- Continuous-time CD on CIFAR-10 with LPIPS metric and stopgrad: cifar10-continuous-cd-lpips-stopgrad
- Continuous-time CT on CIFAR-10 with l2 metric: continuous-ct-l2
- Continuous-time CT on CIFAR-10 with LPIPS metric: continuous-ct-lpips
To install all packages in this codebase along with their dependencies, run
pip install -e .
Then manually install jax by running
pip install https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.7+cuda11.cudnn82-cp39-cp39-manylinux2014_x86_64.whl
The code assumes Python 3.9, JAX 0.4.10, CUDA 11 and cuDNN 8.2. For different Python, CUDA and cuDNN versions, you need to modify the above command accordingly.
We provide examples of EDM training, consistency distillation, consistency training, single-step generation, and model evaluation in launch.sh.
We provide examples for multistep generation and zero-shot image editing in editing_multistep_sampling.ipynb.
If you find this method and/or code useful, please consider citing
@article{song2023consistency,
title={Consistency Models},
author={Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya},
journal={arXiv preprint arXiv:2303.01469},
year={2023},
}
This repo is built upon previous work score_sde. Please consider citing
@inproceedings{
song2021scorebased,
title={Score-Based Generative Modeling through Stochastic Differential Equations},
author={Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=PxTIG12RRHS}
}