Thesis Title - Advancing Medical Image Segmentation Through Multi-Task and Multi-Scale Contrastive Knowledge Distillation
This is my master’s thesis, where I investigate the feasibility of knowledge transfer between neural networks for medical image segmentation tasks, specifically focusing on the transfer from a larger multi-task “Teacher” network to a smaller “Student” network using a multi-scale contrastive learning approach.
Below are a few quantitative and qualitative results. KD(T1, S1) and KD(T1, S2) are the results obtained from our proposed method. More detailed results and ablation studies can be found in the thesis.
The overall architecture of our multi-task multi-scale contrastive knowledge distillation framework for segmentation.
Representation of Contrastive Pairs. A beginner’s guide to Contrastive Learning can be found here.
Teacher-Student Framework for Knowledge Distillation. A beginner’s guide to Knowledge Distillation can be found here.
We trained two teacher models T1 and T2, one a multi-task pre-trained U-Net and a multi-task TransUNet, respectively.
The student model, a simplified version of the teacher model, is significantly smaller in scale and is trained on only 50% of the data compared to the teacher model.
The CT spleen segmentation dataset from the medical image decathlon is used for all the experiments. Below are the links to the processed 2D images from the CT spleen dataset -
Additionally, other binary segmentation datasets that can be explored are -
- DRIVE (Digital Retinal Images for Vessel Extraction)
- RITE (Retinal Images vessel Tree Extraction)
- ISIC Dataset
- Brain Tumor Dataset
- 2D Brain Tumor Segmentation Dataset
- Colorectal Polyp Segmentation Dataset -
Other multi-class segmentation datasets that can be explored are -
- Synapse Multi-Organ CT Dataset
- ACDC Dataset
- AMOS Multi-Modality Abdominal Multi-Organ Segmentation Challenge
- BraTS 2022
git clone https://github.com/RisabBiswas/MTMS-Med-Seg-KD
cd MTMS-Med-Seg-KD
There are two options - Either download the .NIFTI file and convert them to 2D slices using the conversion script or, you can use the processed spleen dataset, which can be downloaded from the above link.
The data is already split into training and testing datasets.
> Input CT Volume of Spleen Dataset -
> Processed 2D Slices -
Training the multi-task teacher network (T1 or T2) is straightforward. Now that you have already created data folders, to train the T1 model, follow the below commands.
cd Multi-Task Teacher Network (T1)
or,
cd Multi-Task Teacher Network (T2)
Run the training script -
python train.py
You can experiment with different weight values for the reconstruction loss. Additionally, for all the experiments I have used DiceBCE loss as the choice of loss function. You can try other loss functions as well such as Dice Loss.
The pre-trained weights can also be downloaded from below -
- T1 - Will be uploaded soon!
- T2 - Will be uploaded soon!
Once the teacher network is trained, to run inference, follow the below command -
python inference.py
also, you can look at the metrics by running the following -
python metrics.py
Before performing knowledge distillation and analysing its effect on the student model, we would like to train the student model and see its performance w/o any knowledge transfer from the teacher network.
cd Student Network (S1)
Run the training script -
python train.py
Run the inference script -
python inference.py
Also, you can look at the metrics by running the following -
python metrics.py
The pre-trained weights can also be downloaded from below -
- S1 - Will be uploaded soon!
- S2 - Will be uploaded soon!
The steps to train the student model with contrastive knowledge distillation are similar and straightforward -
cd KD_Student Network (T1-S1)
Run the training script -
python train_Student.py
Run the inference script -
python inference.py
Also, you can look at the metrics by running the following -
python metrics.py
The knowledge distillation is performed at various scales, which can be customised in the training code.
Currently, the architecture has only been tested on binary segmentation tasks and there is still room for further exploration such as -
- Experiment on multi-class segmentation task.
- Try other contrastive loss.
I extend my heartfelt gratitude to my guru 🙏🏻 Dr. Chaitanya Kaul for his visionary guidance and unwavering support throughout my project. His mentorship has significantly shaped me as a researcher and a better individual. I am profoundly grateful for his invaluable contributions to my professional and personal growth.
You can find it here if you want to read the thesis. And if you like the project, we would appreciate a citation to the original work:
@misc{biswas2024multitask,
title={Multi-Task Multi-Scale Contrastive Knowledge Distillation for Efficient Medical Image Segmentation},
author={Risab Biswas},
year={2024},
eprint={2406.03173},
archivePrefix={arXiv},
primaryClass={eess.IV}
}
If you have any questions, please feel free to reach out to Risab Biswas.
I appreciate your interest in my research. The code should not have any bugs, but if there are any, I am are sorry about that. Do let us know in the issues section, and we will fix it ASAP! Cheers!