This repository provides the official PyTorch implementation of Brainomaly.
Brainomaly: Unsupervised Neurologic Disease Detection Utilizing Unannotated T1-weighted Brain MR Images
Md Mahfuzur Rahman Siddiquee1,2, Jay Shah1,2, Teresa Wu1,2, Catherine Chong2,3, Todd Schwedt2,3, Gina Dumkrieger3, Simona Nikolova3, and Baoxin Li1,2
1Arizona State University; 2ASU-Mayo Center for Innovative Imaging; 3Mayo Clinic
IEEE/CVF Winter Conference on Applications of Computer Vision (WACV), 2024
Paper (Preprint | Camera Ready) | Presentation Slides | Poster
Deep neural networks have revolutionized the field of supervised learning by enabling accurate predictions through learning from large annotated datasets. However, acquiring large annotated medical imaging datasets is a challenging task, especially for rare diseases, due to the high cost, time, and effort required for annotation. In these scenarios, unsupervised disease detection methods, such as anomaly detection, can save significant human effort. A typically used approach for anomaly detection is to learn the images from healthy subjects only, assuming the model will detect the images from diseased subjects as outliers. However, in many real-world scenarios, unannotated datasets with a mix of healthy and diseased individuals are available. Recent studies have shown improvement in unsupervised disease/anomaly detection using such datasets of unannotated images from healthy and diseased individuals compared to datasets that only include images from healthy individuals. A major issue remains unaddressed in these studies, which is selecting the best model for inference from a set of trained models without annotated samples. To address this issue, we propose Brainomaly, a GAN-based image-to-image translation method for neurologic disease detection using unannotated T1-weighted brain MRIs of individuals with neurologic diseases and healthy subjects. Brainomaly is trained to remove the diseased regions from the input brain MRIs and generate MRIs of corresponding healthy brains. Instead of generating the healthy images directly, Brainomaly generates an additive map where each voxel indicates the amount of changes required to make the input image look healthy. In addition, Brainomaly uses a pseudo-AUC metric for inference model selection, which further improves the detection performance. Our Brainomaly outperforms existing state-of-the-art methods by large margins on one publicly available dataset for Alzheimer's disease detection and one institutional dataset collected from Mayo Clinic for headache detection.
$ git clone https://github.com/mahfuzmohammad/Brainomaly.git
$ cd Brainomaly/
$ conda create -n brainomaly python=3.9
$ conda activate brainomaly
$ conda install scikit-learn scikit-image -c anaconda
$ pip install tqdm pandas
$ pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
$ pip install neptune-client
We use Neptune.ai for logging the training and validation metrics. To use Neptune.ai, you need to create an account and get an API token. Then, you can set the API token in the logger.py
file.
- Alzheimer's Disease Detection Dataset (ADNI) [Data Access Instruction]
- Data splits: data/AD_DS1.csv, data/AD_DS2.csv
- Headache Detection Dataset (Mayo Clinic) (Private)
The folder structure should be as follows (assuming your dataset name is MedicalData):
├─data/MedicalData # data root
│ ├─train # directory for training data
│ │ ├─pos # positive class (diseased) images for unannotated mixed set
│ │ │ ├─xxx.png
│ │ │ ├─ ......
│ │ ├─neg_mixed # negative class (healthy) images for unannotated mixed set
│ │ │ ├─yyy.png
│ │ │ ├─ ......
│ │ ├─neg # negative class (healthy) images for known healthy set
│ │ │ ├─zzz.png
│ │ │ ├─ ......
│ ├─test # directory for testing data
│ │ ├─pos
│ │ │ ├─aaa.png
│ │ │ ├─ ......
│ │ ├─neg
│ │ │ ├─bbb.png
│ │ │ ├─ ......
If your dataset is of 2D modalities, like X-ray, then you can just put your images as png files in the corresponding folders. If your dataset is of 3D modalities like MRI, then you need to store your images slice-by-slice as png files. In such cases, please rename each slice as xxx__000.png
, xxx__001.png
, xxx__002.png
, ..., where xxx
is the ID of the patient and 000
, 001
, 002
, ... are the slice numbers. For example, if you have a 3D image from a patient named pat_001.nii.gz
with 100 slices, then you need to store the slices as pat_001__000.png
, pat_001__001.png
, ..., pat_001__099.png
.
Note for custom data: Please adjust the image size and cropping according to your data.
- Assuming your data folder is
data/MedicalData
:bash train.sh MedicalData
- Inductive testing:
bash test_inductive.sh MedicalData 400000
- Transductive testing:
bash test_transductive.sh MedicalData 400000
- AUCp calculation:
bash test_aucp.sh MedicalData 400000
Coming soon.