Skip to content

Unofficial PyTorch implementation of "Meta Pseudo Labels"

Notifications You must be signed in to change notification settings

kekmodel/MPL-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Meta Pseudo Labels

This is an unofficial PyTorch implementation of Meta Pseudo Labels. The official Tensorflow implementation is here.

Results

CIFAR-10-4K SVHN-1K ImageNet-10%
Paper (w/ finetune) 96.11 ± 0.07 98.01 ± 0.07 73.89
This code (w/o finetune) 96.01 - -
This code (w/ finetune) 96.08 - -
Acc. curve w/o finetune
w/ finetune
- -
  • February 2022, Retested.

Usage

Train the model by 4000 labeled data of CIFAR-10 dataset:

python main.py \
    --seed 2 \
    --name cifar10-4K.2 \
    --expand-labels \
    --dataset cifar10 \
    --num-classes 10 \
    --num-labeled 4000 \
    --total-steps 300000 \
    --eval-step 1000 \
    --randaug 2 16 \
    --batch-size 128 \
    --teacher_lr 0.05 \
    --student_lr 0.05 \
    --weight-decay 5e-4 \
    --ema 0.995 \
    --nesterov \
    --mu 7 \
    --label-smoothing 0.15 \
    --temperature 0.7 \
    --threshold 0.6 \
    --lambda-u 8 \
    --warmup-steps 5000 \
    --uda-steps 5000 \
    --student-wait-steps 3000 \
    --teacher-dropout 0.2 \
    --student-dropout 0.2 \
    --finetune-epochs 625 \
    --finetune-batch-size 512 \
    --finetune-lr 3e-5 \
    --finetune-weight-decay 0 \
    --finetune-momentum 0.9 \
    --amp

Train the model by 10000 labeled data of CIFAR-100 dataset by using DistributedDataParallel:

python -m torch.distributed.launch --nproc_per_node 4 main.py \
    --seed 2 \
    --name cifar100-10K.2 \
    --dataset cifar100 \
    --num-classes 100 \
    --num-labeled 10000 \
    --expand-labels \
    --total-steps 300000 \
    --eval-step 1000 \
    --randaug 2 16 \
    --batch-size 128 \
    --teacher_lr 0.05 \
    --student_lr 0.05 \
    --weight-decay 5e-4 \
    --ema 0.995 \
    --nesterov \
    --mu 7 \
    --label-smoothing 0.15 \
    --temperature 0.7 \
    --threshold 0.6 \
    --lambda-u 8 \
    --warmup-steps 5000 \
    --uda-steps 5000 \
    --student-wait-steps 3000 \
    --teacher-dropout 0.2 \
    --student-dropout 0.2 \
    --finetune-epochs 250 \
    --finetune-batch-size 512 \
    --finetune-lr 3e-5 \
    --finetune-weight-decay 0 \
    --finetune-momentum 0.9 \
    --amp

Monitoring training progress

tensorboard

tensorboard --logdir results

or

Use wandb

Requirements

  • python 3.6+
  • torch 1.7+
  • torchvision 0.8+
  • tensorboard
  • wandb
  • numpy
  • tqdm

Citations

@misc{jd2021mpl,
  author = {Jungdae Kim},
  title = {PyTorch implementation of Meta Pseudo Labels},
  year = {2021},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/kekmodel/MPL-pytorch}}
}

About

Unofficial PyTorch implementation of "Meta Pseudo Labels"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages