When More is Less: Incorporating Additional Datasets Can Hurt Performance By Introducing Spurious Correlations
Code for the paper "When More is Less: Incorporating Additional Datasets Can Hurt Performance By Introducing Spurious Correlations" by Rhys Compton, Lily Zhang, Aahlad Puli, and Rajesh Ranganath
The training harness is heavily based on the excellent ClinicalDG repo which is in turn a modified version of DomainBed.
Run the following commands to clone this repo and create the Conda environment:
See DataSources.md for detailed instructions.
Experiments can be ran using the same procedure as for the DomainBed framework, with a few additional adjustable data hyperparameters which should be passed in as a JSON formatted dictionary.
For example, to train a single model:
python -m clinicaldg.scripts.train\
--algorithm ERM\
--dataset eICUSubsampleUnobs\
--es_method val\
--hparams '{"eicu_architecture": "GRU", "eicu_subsample_g1_mean": 0.5, "eicu_subsample_g2_mean": 0.05}'\
--output_dir /path/to/output
A detailed list of hparams
available for each dataset can be found here.
We provide the bash scripts used for our main experiments in the bash_scripts
directory. You will likely need to customize them, along with the launcher, to your compute environment.
This codebase heavily utilises W+B to run experiments, both for tracking and recording results, and running experiments via the Sweeps feature (along side Slurm arrays).
The process for this is as follows:
- Define your sweep hyperparameters in a
.yaml
file (e.g., this YAML file for image size experimentation) - Start the sweep:
wandb sweep <yaml filename>
- Create your Slurm array script (e.g., sweep.sbatch) -- the key parameter is the
--array=
feature, which should be set to0-<num parameter configs>
- Start the Slurm array job:
sbatch sweeps/sweep.sbatch
- Sit back and watch GPUs go brrrr
The following steps walkthrough the process for loading datasets and applying hospital-label balancing
-
Example wrapper script:
sweeps/sweep.sbatch
(this contains paths to the dataset files / loading them viasingularity
) -
Main entrypoint:
clinicaldg/scripts/train.py
-
Corresponding W+B
.yaml
file for hospital-label balancing:sweeps/2d_nurd_fix.yaml
-
The key parameter is :
--balance_resample "label_notest,under"
, which does label-balancing with no target test hospital (i.e., label balance the two hospitals to each other), via undersampling."hospital_label,under"
would be a more appropriate name, in hindsight. -
The
--test_env
is not used but just provided as a placeholder -
This file gives a range of example invocations of the script, e.g. the following will train a model for Pneumonia prediction on MIMIC and CXP, balancing them such that
P(Y = 1 | Hospital = MIMIC) == P(Y = 1 | Hospital = CXP)
:python ./clinicaldg/scripts/train.py --max_steps 20001 --train_envs MIMIC,CXP --test_env MIMIC --balance_resample "label_notest,under" --binary_label Pneumonia
-
-
The codebase is very generalized (to be used in both the eICU task and CXR classification) so has a lot of code to sift through. The path to dataset loading (and ultimately balancing) is from
train.py
, then through:dataset = ds_class(hparams, args)
, which invokes...CXRBase __init__()
(inclinicaldg/datasets.py
)- Lines
411
to447
inclinicaldg/datasets.py
do the actual hospital-label balancing
If wanting to use this data within another codebase, one could save the train/val/test DFs to CSV files after processing is finished, i.e., at line 496
; these contain the file paths / labels and can be easily loaded into another python project.
The PyTorch dataset can be found in clinicaldg/cxr/data.py
, and the get_dataset()
function in that file also.
The other key class is the InifiniteDataLoader
used in train.py
Line 303
. One of these is created for each hospital we're training on, and equal sized batches are sampled from each at every iteration.