This repository is the official Pytorch implementation of "To Smooth or Not? When Label Smoothing Meets Noisy Labels" accepted by ICML2022 (Oral).
import torch
import torch.nn.functional as F
def loss_gls(logits, labels, smooth_rate=0.1):
# logits: model prediction logits before the soft-max, with size [batch_size, classes]
# labels: the (noisy) labels for evaluation, with size [batch_size]
# smooth_rate: could go either positive or negative,
# smooth_rate candidates we adopted in the paper: [0.8, 0.6, 0.4, 0.2, 0.0, -0.2, -0.4, -0.6, -0.8, -1.0, -2.0, -4.0, -6.0, -8.0].
confidence = 1. - smooth_rate
logprobs = F.log_softmax(logits, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=labels.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = confidence * nll_loss + smooth_rate * smooth_loss
loss_numpy = loss.data.cpu().numpy()
num_batch = len(loss_numpy)
return torch.sum(loss)/num_batch
We recommend readers build an virtual environment and install required packages in requirements.txt
.
For Vanilla Loss and PLS, direct training works better when learning with symmetric noisy labels under noise rate 0.2. Run the code bellow to reproduce our results:
CUDA_VISIBLE_DEVICES=0 python3 main_GLS_direct_train.py --noise_type symmetric --noise_rate 0.2
When noise rates are large, warming up with CE loss makes PLS and NLS reaches a better performance. Run the code bellow to generate the warm-up model:
CUDA_VISIBLE_DEVICES=0 python3 main_warmup.py --noise_type symmetric --noise_rate 0.2
After the warming up, proceed with GLS:
CUDA_VISIBLE_DEVICES=0 python3 main_GLS_load.py --noise_type symmetric --noise_rate 0.2
You may want to refer to "CIFAR-N Github Page", and modify the file loss.py
by referring to the loss_gls
plug-in implementation specified above.
In experiments, we formulate GLS as wa * Vanilla Loss + wb * GLS
.
- --lr: learning rate
- --noise_rate: the error rate in symmetric noise model
- --n_epoch: number of epochs
- --wa: the weight of Vanilla Loss (default is 0)
- --wb: the weight of GLS (default is 1)
- --smooth_rate: the smooth rate in GLS
If you use our code, please cite the following paper:
@inproceedings{Wei2022ToSO,
title={To Smooth or Not? When Label Smoothing Meets Noisy Labels},
author={Jiaheng Wei and Hangyu Liu and Tongliang Liu and Gang Niu and Yang Liu},
booktitle={ICML},
year={2022}
}