This repository contains the official PyTorch implementation of the following paper:
Calibration of Few-Shot Classification Tasks: Mitigating Misconfidence from Distribution Mismatch by Sungnyun Kim and Se-Young Yun, IEEE Access vol. 10, 2022, doi:10.1109/ACCESS.2022.31760902022.
Paper: https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9777704
Abstract: As many meta-learning algorithms improve performance in solving few-shot classification problems for practical applications, the accurate prediction of uncertainty is considered essential. In meta-training, the algorithm treats all generated tasks equally and updates the model to perform well on training tasks. During the training, some of the tasks may make it difficult for the model to infer the query examples from the support examples, especially when a large mismatch between the support set and the query set exists. The distribution mismatch causes the model to have incorrect confidence, which causes a calibration problem. In this study, we propose a novel meta-training method that measures the distribution mismatch and enables the model to predict with more precise confidence. Moreover, our method is algorithm-agnostic and can be readily expanded to include a range of meta-learning models. Through extensive experimentation, including dataset shift, we show that our training strategy prevents the model from becoming indiscriminately confident, and thereby helps the model to produce calibrated classification results without the loss of accuracy.
Our code works on torch>=1.5
. Install the required Python packages via
pip install -r requirements.txt
Make sure that the dataset directories are correctly specified in path_configs.py
, and that your dataset directories contain the correct files.
In ./filelists/miniImagenet/
, there is a shell script to setup the dataset files.
bash download_miniImagenet.sh
In ./filelists/CUB/
, there is a shell script to setup the dataset files.
bash download_CUB.sh
In ./filelists/tieredImagenet/
, there is a python file. In the python file, there is an instruction how to download and place the dataset.
After setup, run the python file.
python write_tieredImagenet_filelist.py
The main training file is train.py
and test.py
. To see all CLI arguments, refer to utils.py
.
The following command lines will reproduce the results.
- miniImageNet 5-way 5-shot TCMAML:
python train.py --dataset miniImagenet
- miniImageNet 5-way 5-shot TCMAML (LS)
python train.py --dataset miniImagenet --linear-scaling
- miniImageNet 5-way 5-shot TCProtoNet
python train.py --dataset miniImagenet --method tcproto
- miniImageNet 5-way 1-shot TCMAML
python train.py --dataset miniImagenet --n-shot 1
- CUB 5-way 5-shot TCMAML:
python train.py --dataset CUB
- miniImageNet -> CUB (dataset shift) 5-way 5-shot TCMAML
python train.py --dataset cross --model ResNet18
- miniImageNet 10-way 5-shot with Conv6 backbone TCMAML
python train.py --dataset miniImagenet --model Conv6 --train-n-way 10 --test-n-way 10
Other important arguments include --temp
, --stop-epoch
, and --corrupted-task
.
To evaluate and check the calibration results after training is done, run test.py
with the same arguments.
For exmaple, after training miniImageNet 5-way 1-shot TCProtoNet, run
python test.py --dataset miniImagenet --n-shot 1 --method tcproto
If you find this repo useful for your research, please consider citing our paper:
@article{kim2022calibration,
title={Calibration of Few-Shot Classification Tasks: Mitigating Misconfidence from Distribution Mismatch},
author={Kim, Sungnyun and Yun, Se-Young},
journal={IEEE Access},
year={2022},
publisher={IEEE}
}
Distributed under the MIT License.