Skip to content

yang-ruixin/PyTorch-Image-Models-Multi-Label-Classification

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

27 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch Image Models Multi Label Classification

Multi-label classification based on timm.

Update 2021/09/12

Multi-label classification with SimCLR is available. See another repo of mine PyTorch Image Models With SimCLR.
You would get higher accuracy when you train the model with classification loss together with SimCLR loss at the same time.

Update 2021/03/22

Updated ./timm/models/multi_label_model.py, ./train.py and ./validate.py to calculate accuracies for each label.

Introduction

This repository is used for multi-label classification.
The code is based on pytorch-image-models by Ross Wightman.
Thank Ross for his great work.

I downloaded his code on February 27, 2021.
I think my multi-label classification code would be compatible with his latest version, but I didn’t check.

The main reference for multi-label classification is this website.
Thank Dmitry Retinskiy and Satya Mallick.
For the purpose of understanding our context and the dataset, please spend 5 minutes on reading the link above, though you don’t need to read the specific code there.
Here is the link to download the images.
Put all the images into ./fashion-product-images/images/.

In order to implement multi-label classification, I modify (add) the following files from Ross’ pytorch-image-models:

  1. ./train.py
  2. ./validate.py
  3. ./timm/data/init.py
  4. ./timm/data/dataset.py
  5. ./timm/data/loader.py
  6. ./timm/models/init.py
  7. ./timm/models/efficientnet.py
  8. ./timm/models/multi_label_model.py (add)

In order to train your own dataset, you only need to modify the 1, 2, 4, 8 files.
Simply modify the code between the double dashed lines, or search color/gender/article, that’s the code/label that you need to change.

In terms of backbones, I only modified ./timm/models/efficientnet.py, I add an as_sequential_for_ML method.
For other models, you need to define the as_sequential_for_ML method yourself within each class. It’s simply a part of the as_sequential method.
We only need the backbone at this moment, so remove the last layers, for example classifier layer, from the as_sequential method. See forward_features method in each model class, then you would know which layers you need to remove, or how to define the as_sequential_for_ML method.

In addition, besides the multi-label classification functionality, I also add gradient centralization within AdamP optimizer.
Gradient centralization is a simple technique and may improve the optimizer performance.
No guarantee it will improve, but it is worth giving a try.
To add gradient centralization, I modify (add) the following files:

  1. ./timm/optim/adamp.py
  2. ./timm/optim/centralization.py (add)
  3. ./timm/optim/optim_factory.py

Obviously, you can add gradient centralization within other optimizers as well.

Also, I updated ./timm/utils/summary.py so that we can output learning rate to summary.csv during training.
Hence you could draw your learning rate together with loss and accuracy for the whole training process.

Here is a command example to start to train:

./distributed_train.sh 1 ./fashion-product-images/ --model efficientnet_b2 -b 64 --sched cosine --epochs 50 --decay-epochs 2.4 --decay-rate .97 --opt adamp --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-connect 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .016 --pretrained  

And a command example to start to validate:

python validate.py ./fashion-product-images/ --model efficientnet_b2 --checkpoint ./output/train/YOUR_SPECIFIC_FOLDER/model_best.pth.tar -b 64  

Please give a star if you find this repo helpful.

License

This project is released under the Apache License, Version 2.0.

Citation (BibTeX)

@misc{yrx2021multilabel,
  author = {YANG Ruixin},
  title = {PyTorch Image Models Multi-Label Classification},
  year = {2021},
  publisher = {GitHub}
}