This is the official code for the paper "Aggregated Attributions for Explanatory Analysis of 3D Segmentation Models," accepted in the first round at the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV) 2025.
Analysis of 3D segmentation models, especially in the context of medical imaging, is often limited to segmentation performance metrics that overlook the crucial aspect of explainability and bias. Currently, effectively explaining these models with saliency maps is challenging due to the high dimensions of input images multiplied by the ever-growing number of segmented class labels. To this end, we introduce Agg$^2$Exp, a methodology for aggregating fine-grained voxel attributions of the segmentation model's predictions. Unlike classical explanation methods that primarily focus on the local feature attribution, Agg$^2$Exp enables a more comprehensive global view on the importance of predicted segments in 3D images. Our benchmarking experiments show that gradient-based voxel attributions are more faithful to the model's predictions than perturbation-based explanations. As a concrete use-case, we apply Agg$^2$Exp to discover knowledge acquired by the Swin UNEt TRansformer model trained on the TotalSegmentator v2 dataset for segmenting anatomical structures in computed tomography medical images. Agg$^2$Exp facilitates the explanatory analysis of large segmentation models beyond their predictive performance.
Model used for our experiments was trained on selected and joined thorax classes from TotalSegmentator-V2 (TSV2) dataset.
Code used for training and inference of our Swin Unetr is available at model
folder. With the following files/folders:
model/data_loader.py
- data loading and preprocessingmodel/inference.py
- inference scriptmodel/models
- folder with model definitionsmodel/train.py
- training script
Code used for explanations, aggregations and visualizations is available at explanations
folder.
Files/folders explanations:
explanations/aggregate_explanations_custom_masks.py
- Aggregates explanations for custom masksexplanations/aggregate_explanations.py
- Aggregates explanations for all segmentation classesexplanations/analyse_tcsnet.R
- R code for visualizationsexplanations/attribution_functions.py
- Functions for generating attributionsexplanations/attributions_evaluation
- folder with code for quantitive evaluation of explanationsexplanations/generate_attributions_for_tsv2.py
- generates gradient-based attributions for tsv2explanations/generate_b50_example_for_all_explanations.py
- creates example explanations for all methodsexplanations/generate_example_explanations.py
-explanations/generate_kernelshap.py
- generates attributions for kernelshapexplanations/generation_utils.py
- utility functions for generating explanationsexplanations/get_tsv2_examples_with_all_organs.py
- finds files with all organs in tsv2explanations/kernelshap_utils.py
- utility functions for kernelshapexplanations/sliding_window_gradient_inference.py
- modification of sliding window inference for gradient-based attributionsexplanations/tsv2_train_outlier_analysis.ipynb
- notebook for outlier analysis
Example aggregated explanations are available at data/tsv2_test_aggregated_sg_explanations_with_dice.csv
file.
ArXiv preprint can be found here.
If you find this repository useful, please consider citing this paper:
@article{chrabaszcz2024agg2exp,
title={Aggregated Attributions for Explanatory Analysis of 3D Segmentation Models},
author={Maciej Chrabaszcz and Hubert Baniecki and Piotr Komorowski and Szymon Płotka and Przemyslaw Biecek},
year={2024},
eprint={2407.16653},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2407.16653},
}
This work was financially supported by the Polish National Center for Research and Development grant number INFOSTRATEG-I/0022/2021-00. We thank Mateusz Krzyzinski and Paulina Tomaszewska for valuable feedback on the initial version of this work.