CVPR2024 - Decomposing Disease Descriptions for Enhanced Pathology Detection: A Multi-Aspect Vision-Language Pre-training Framework
Welcome to the official implementation code for "Decomposing Disease Descriptions for Enhanced Pathology Detection: A Multi-Aspect Vision-Language Matching Framework", accepted at CVPR2024 🎉
This work leverages LLM 🤖 to decompose disease descriptions into a set of visual aspects. Our visual aspect vision-language pre-training framework, dubbed MAVL, achieves the state-of-the-art performance across 7 datasets for zero-shot and low-shot fine-tuning settings for disease classification and segmentation.
If you find our work useful, please cite our paper.
@inproceedings{phan2024decomposing,
title={Decomposing Disease Descriptions for Enhanced Pathology Detection: A Multi-Aspect Vision-Language Pre-training Framework},
author={Phan, Vu Minh Hieu and Xie, Yutong and Qi, Yuankai and Liu, Lingqiao and Liu, Liyang and Zhang, Bowen and Liao, Zhibin and Wu, Qi and To, Minh-Son and Verjans, Johan W},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={11492--11501},
year={2024}
}
Dataset | CheXpert | ChestXray-14 | PadChest-seen | RSNA Pneumonia | SIIM-ACR | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Method | AUC | F1 | ACC | AUC | F1 | ACC | AUC | F1 | ACC | AUC | F1 | ACC | AUC | F1 | ACC |
ConVIRT | 52.10 | 35.61 | 57.43 | 53.15 | 12.38 | 57.88 | 63.72 | 14.56 | 73.47 | 79.21 | 55.67 | 75.08 | 64.25 | 42.87 | 53.42 |
GLoRIA | 54.84 | 37.86 | 60.70 | 55.92 | 14.20 | 59.47 | 64.09 | 14.83 | 73.86 | 70.37 | 48.19 | 70.54 | 54.71 | 40.39 | 47.15 |
BioViL | 60.01 | 42.10 | 66.13 | 57.82 | 15.64 | 61.33 | 60.35 | 10.63 | 70.48 | 84.12 | 54.59 | 74.43 | 70.28 | 46.45 | 68.22 |
BioViL-T | 70.93 | 47.21 | 69.96 | 60.43 | 17.29 | 62.12 | 65.78 | 15.37 | 77.52 | 86.03 | 62.56 | 80.04 | 75.56 | 60.18 | 73.72 |
CheXzero | 87.90 | 61.90 | 81.17 | 66.99 | 21.99 | 65.38 | 73.24 | 19.53 | 83.49 | 85.13 | 61.49 | 78.34 | 84.60 | 65.97 | 77.34 |
MedKLIP | 87.97 | 63.67 | 84.32 | 72.33 | 24.18 | 79.40 | 77.87 | 26.63 | 92.44 | 85.94 | 62.57 | 79.97 | 89.79 | 72.73 | 83.99 |
MAVL (Proposed) | 90.13 | 65.47 | 86.44 | 73.57 | 26.25 | 82.77 | 78.79 | 28.48 | 92.56 | 86.31 | 65.26 | 81.28 | 92.04 | 77.95 | 87.14 |
To get started, install the gdown library:
pip install -U --no-cache-dir gdown --pre
Then, run bash download.sh
The MIMIC-CXR2 needs to be downloaded from physionet.
We have pushed the docker image with necessary environments. You can directly create a docker container using our docker image:
docker pull stevephan46/mavl:latest
docker run --runtime=nvidia --name mavl -it -v /your/data/root/folder:/data --shm-size=4g stevephan46/mavl:latest
You may need to reinstall opencv-python, as there is some conflicting problem with the docker environment pip install opencv-python==4.2.0.32
If you prefer manual installation over docker, please run the following installation:
pip install -r requirements.txt
pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
pip install opencv-python==4.2.0.32
The script to generate diseases' visual aspects using LLM - GPT can be found here.
Our pre-train code is given in Pretrain
.
-
Run download.sh to download necessary files
-
Modify the path in config file configs/MAVL_resnet.yaml, and
python train_mavl.py
to pre-train. -
Run
accelerate launch --multi_gpu --num_processes=4 --num_machines=1 --num_cpu_threads_per_process=8 train_MAVL.py --root /data/2019.MIMIC-CXR-JPG/2.0.0 --config configs/MAVL_resnet.yaml --bs 124 --num_workers 8
Note: The reported results in our paper are obtained by pre-training on 4 x A100 for 60 epochs. We provided the checkpoints here. We found that ckpts at later stage (checkpoint_full_46.pth
) yields higher zero-shot classification accuracy. Ckpt at earlier stage (checkpoint_full_40.pth
) yields more stable accuracy on visual grounding.
We also conducted a lighter pre-training schedule with 2 x A100 for 40 epochs using mixed precision training, achieving similar zero-shot classification results. Checkpoint for this setup is also available here.
accelerate launch --multi_gpu --num_processes=2 --num_machines=1 --num_cpu_threads_per_process=8 --mixed_precision=fp16 train_MAVL.py --root /data/2019.MIMIC-CXR-JPG/2.0.0 --config configs/MAVL_short.yaml --bs 124 --num_workers 8
Links to download downstream datasets are:
- CheXpert.
- ChestXray-14.
- PadChest.
- RSNA - Download images from initial annotations.
- SIIM.
- COVIDx-CXR-2 - The official link on Kaggle is down. The publicly available expanded version, called COVIDx-CXR4 is released here. They encompass COVIDx-CXR-2 as subset. Please use our dataset csv splits to reproduce the results on COVIDx-CXR-2 subset version.
- Covid Rural - The official link includes raw DICOM datasets. We use preprocessed data provided here.
Check this link to download MAVL checkpoints. It can be used for all zero-shot && finetuning tasks
-
Zero-Shot Classification:
We give examples in
Sample_Zero-Shot_Classification
. Modify the path, and test our model bypython test.py --config configs/dataset_name_mavl.yaml
-
Zero-Shot Grounding:
We give examples in
Sample_Zero-Shot_Grounding
. Modify the path, and test our model bypython test.py
-
Finetuning:
We give segmentation and classification finetune code on in
Sample_Finetuning_SIIMACR
. Modify the path, and finetune our model bypython I1_classification/train_res_ft.py --config configs/dataset_name_mavl.yaml
orpython I2_segementation/train_res_ft.py --config configs/dataset_name_mavl.yaml
Our code is built upon https://github.com/MediaBrain-SJTU/MedKLIP. We thank the authors for open-sourcing their code.
Feel free to reach out if you have any questions or need further assistance!