Skip to content

Latest commit

 

History

History
99 lines (79 loc) · 5.62 KB

README.md

File metadata and controls

99 lines (79 loc) · 5.62 KB

Variational Information Pursuit for Interpretable Predictions

Aditya Chattopadhyay, Kwan Ho Ryan Chan, Benjamin D. Haeffele, Donald Geman, René Vidal
Mathematical Institute for Data Science, Johns Hopkins University
{achatto1, kchan49, bhaeffele, geman, rvidal}jhu.edu

This is the offical repository for Variational Information Pursuit for Interpretable Predictions (ICLR 2023). For our paper, please visit link.

Overview

teaser.png

There is a growing interest in the machine learning community in developing predictive algorithms that are “interpretable by design”. Towards this end, recent work proposes to make interpretable decisions by sequentially asking interpretable queries about data until a prediction can be made with high confidence based on the answers obtained (the history). To promote short query-answer chains, a greedy procedure called Information Pursuit (IP) is used, which adaptively chooses queries in order of information gain (See Figure above). Generative models are employed to learn the distribution of query-answers and labels, which is in turn used to estimate the most informative query. However, learning and inference with a full generative model of the data is often intractable for complex tasks. In this work, we propose Variational Information Pursuit (V-IP), a variational characterization of IP which bypasses the need for learning generative models. V-IP is based on finding a query selection strategy and a classifier that minimizes the expected cross-entropy between true and predicted labels. We then demonstrate that the IP strategy is the optimal solution to this problem. Therefore, instead of learning generative models, we can use our optimal strategy to directly pick the most informative query given any history. We then develop a practical algorithm by defining a finite-dimensional parameterization of our strategy and classifier using deep networks and train them end-to-end using our objective. A pipeline of our framework is shown below.

pipeline

Requirements

Please check out requirements.txt for detailed requirements. Overall, our code uses basic operations and do not require the latest version of PyTorch or CUDA to work. We also use wandb to moderate training and testing performance. One may remove lines related to wandb and switch to other packages if they desire.

Training MNIST

There are two stages of training: Initial Random Sampling (IRS) and Subsequent Biased Sampling (SBS).

To run IRS:

python3 main_mnist.py \
  --epochs 100 \
  --data mnist \
  --batch_size 128 \
  --max_queries 676 \
  --max_queries_test 21 \
  --lr 0.0001 \
  --tau_start 1.0 \
  --tau_end 0.2 \
  --sampling random \
  --seed 0 \
  --name mnist_random

To run SBS:

python3 main_mnist.py \
  --epochs 20 \
  --data mnist \
  --batch_size 128 \
  --max_queries 21 \
  --max_queries_test 21 \
  --lr 0.0001 \
  --tau_start 0.2 \
  --tau_end 0.2 \
  --sampling biased \
  --seed 0 \
  --ckpt_path <CKPT_PATH> \
  --name mnist_biased

where <CKPT_PATH> is the path to the pre-trained model using IRS.

Checkpoints

Checkpoint to the models used to obtain the results in our paper are listed in the table below. A jupyter notebook named loading.ipynb with checkpoint loading instructions for each dataset. is located in pretrain/. One may put downloaded models in this directory.

Dataset OneDrive Link
MNIST Link
KMNIST Link
Fashion MNIST Link
Huffington News Link
Huffington News (cleaned) Link
CUB-200 Link
CUB-200 (concept) Link
CIFAR10 Link
SymCAT200 Link
SymCAT300 Link
SymCAT400 Link

License

This project is under the MIT License. See LICENSE for details.

Cite

If you find our work useful for your research, please cite:

@article{chattopadhyay2023variational,
  title={Variational Information Pursuit for Interpretable Predictions},
  author={Chattopadhyay, Aditya and Chan, Kwan Ho Ryan and Haeffele, Benjamin D and Geman, Donald and Vidal, Ren{\'e}},
  journal={arXiv preprint arXiv:2302.02876},
  year={2023}
}