Skip to content

"pip installable" PyTorch implementation of CLIP Maximum Mean Discrepancy (CMMD) for evaluating image generation models.

License

Notifications You must be signed in to change notification settings

creative-graphic-design/cmmd-pytorch

 
 

Repository files navigation

cmmd-pytorch

CI

(Unofficial) PyTorch implementation of CLIP Maximum Mean Discrepancy (CMMD) for evaluating image generation models, proposed in Rethinking FID: Towards a Better Evaluation Metric for Image Generation. CMMD stands out to be a better metric than FID and tries to mitigate the longstanding issues of FID.

This implementation is a super simple PyTorch port of the original codebase. I have only focused on the JAX and TensorFlow specific bits and replaced them PyTorch. Some differences:

  • The original codebase relies on scenic for computing CLIP embeddings. This repository uses transformers.
  • For the data loading, the original codebase uses TensorFlow, this one uses PyTorch Dataset and DataLoader.

Setup

First, install PyTorch following instructions from the official website.

Then install the depdencies:

pip install git+https://github.com/creative-graphic-design/cmmd-pytorch

After installation, you will be able to use the command cmmd-pytorch:

❯❯❯ cmmd-pytorch --help
usage: cmmd-pytorch [-h] [--batch-size BATCH_SIZE] [--max-count MAX_COUNT] [--ref-embed-file REF_EMBED_FILE] ref_dir eval_dir

positional arguments:
  ref_dir               Path to the directory containing reference images.
  eval_dir              Path to the directory containing images to be evaluated.

optional arguments:
  -h, --help            show this help message and exit
  --batch-size BATCH_SIZE
                        Batch size for embedding generation.
  --max-count MAX_COUNT
                        Maximum number of images to read from each directory.
  --ref-embed-file REF_EMBED_FILE
                        Path to the pre-computed embedding file for the reference images.

Running

cmmd-pytorch /path/to/reference/images /path/to/eval/images --batch_size=32 --max_count=30000

A working example command:

cmmd-pytorch reference_images generated_images --batch_size=1

It should output:

The CMMD value is:  7.696

This is the same as the original codebase, so, that confirms the implementation correctness 🤗

Tip

GPU execution is supported when a GPU is available.

Results

Below, we report the CMMD metric for some popular pipelines on the COCO-30k dataset, as commonly used by the community. CMMD, like FID, is better when it's lower.

Pipeline Inference Steps Resolution CMMD
stabilityai/stable-diffusion-xl-base-1.0 30 1024x1024 0.696
segmind/SSD-1B 30 1024x1024 0.669
stabilityai/sdxl-turbo 1 512x512 0.548
runwayml/stable-diffusion-v1-5 50 512x512 0.582
PixArt-alpha/PixArt-XL-2-1024-MS 20 1024x1024 1.140
SPRIGHT-T2I/spright-t2i-sd2 50 768x768 0.512

Notes:

  • For SDXL Turbo, guidance_scale is set to 0 following the official guide in diffusers.
  • For all other pipelines, default guidace_scale was used. Refer to the official pipeline documentation pages here for more details.

Caution

As per the CMMD authors, with models producing high-quality/high-resolution images, COCO images don't seem to be a good reference set (they are of pretty small resolution). This might help explain why SD v1.5 has a better CMMD than SDXL.

Obtaining CMMD for your pipelines

One can refer to the generate_images.py script that generates images from the COCO-30k randomly sampled captions using diffusers.

Once the images are generated, run:

cmmd-pytorch /path/to/reference/images /path/to/generated/images --batch_size=32 --max_count=30000

Reference images are COCO-30k images and can be downloaded from here.

Pre-computed embeddings for the COCO-30k images can be found here.

To use the pre-computed reference embeddings, run:

cmmd-pytorch None /path/to/generated/images ref_embed_file=ref_embs.npy --batch_size=32 --max_count=30000

Acknowledgements

Thanks to Sadeep Jayasumana (first author of CMMD) for all the helpful discussions.

About

"pip installable" PyTorch implementation of CLIP Maximum Mean Discrepancy (CMMD) for evaluating image generation models.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 97.7%
  • Makefile 2.3%