This repository has been archived by the owner on Jul 2, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 303
Add ResNet training code #436
Merged
Merged
Changes from all commits
Commits
Show all changes
74 commits
Select commit
Hold shift + click to select a range
d46f680
Merge branch 'random-sized-crop' into resnet-train
yuyu2172 93e0bf9
add train_imagenet
yuyu2172 a4a03d7
Merge remote-tracking branch 'yuyu2172/random-sized-crop' into HEAD
yuyu2172 2808c16
add train_imagenet_mn
yuyu2172 2c3667a
fix learning rate
yuyu2172 ea80eee
add color_jitter
yuyu2172 b5003a1
Merge remote-tracking branch 'yuyu2172/color-jitter' into HEAD
yuyu2172 63ae346
use color_jitter
yuyu2172 c0c536b
update readme
yuyu2172 9efd461
Merge remote-tracking branch 'yuyu2172/random-sized-crop' into resnet…
yuyu2172 056ec6b
make training code follow Training ImageNet 1 hour paper
yuyu2172 89d1ba1
Merge remote-tracking branch 'yuyu2172/resnet-link' into resnet-train
yuyu2172 03d4ac1
set initial gamma to zero for the last bn of block
yuyu2172 193e47f
add corrected_momentum_sgd
yuyu2172 1cb7d37
add observe_lr extension
yuyu2172 007e937
Merge remote-tracking branch 'yuyu2172/resnet-link' into resnet-train
yuyu2172 61f5936
Merge remote-tracking branch 'yuyu2172/resnet-link' into resnet-train
yuyu2172 9fdb6ab
Merge branch 'resnet-link' into resnet-train
yuyu2172 779a4d1
merge
yuyu2172 d39ebf9
remove ResNet18
yuyu2172 fb2e062
Merge remote-tracking branch 'origin/master' into HEAD
yuyu2172 d357af3
delete color_jitter
yuyu2172 c677fcb
move corrected_momentum_sgd to chainer_experimental
yuyu2172 d5c1cd3
delete color_jitter from reference
yuyu2172 7e53af3
doc
yuyu2172 43b548b
fix cpu mode of CorrectedMomentumSGD
yuyu2172 fbf38d6
Merge remote-tracking branch 'origin/master' into HEAD
yuyu2172 a263992
do not use lambdas
yuyu2172 16502cf
Merge remote-tracking branch 'yuyu2172/no-lambda' into HEAD
yuyu2172 60be570
Merge branch 'resnet-train' of https://github.com/yuyu2172/chainercv …
yuyu2172 a4e14f0
fix
yuyu2172 7b0b8a7
Merge remote-tracking branch 'yuyu2172/resnet-train' into HEAD
yuyu2172 898dbbb
fix
yuyu2172 93c410f
style
yuyu2172 77cd4cb
remove redundant improt
yuyu2172 aa8b713
grammar
yuyu2172 e556613
update README
yuyu2172 8bb1135
delete unnecessary cmd option
yuyu2172 5d40516
use original style
yuyu2172 6de2a17
simplify
yuyu2172 42ae6f9
initialize the last BN of each BuildingBlock in 1 hour style
yuyu2172 399a140
fix script
yuyu2172 02b4f30
flake8
yuyu2172 c292e09
add warmpup
yuyu2172 6c6042e
warmup initial lr changed
yuyu2172 aacd11f
update warmup
yuyu2172 bc2b34c
try to fix segfault
yuyu2172 93fbd36
fix
yuyu2172 86e9ded
delete cv2 setNumThreads
yuyu2172 ad77235
Merge remote-tracking branch 'origin/master' into resnet-train
yuyu2172 e299b22
merge
yuyu2172 078bb4c
delete CorrectedMomentumSGD
yuyu2172 f591ead
simplify
yuyu2172 6a9f56b
fix flake8
yuyu2172 a3d2e3e
update README
yuyu2172 54a37f2
update README
yuyu2172 7f289d2
Merge remote-tracking branch 'origin/master' into HEAD
yuyu2172 6d3cca9
use make_shift
yuyu2172 cc4c875
fix bug
yuyu2172 fbd4ae4
change commandline argument: arch -> model
yuyu2172 a3d85d0
fix an error that is raised when gpu <= 8
yuyu2172 315a645
update README
yuyu2172 7c73f05
add url link
yuyu2172 783ba36
change default arch to fb for eval
yuyu2172 fbcbe56
update README
yuyu2172 68479ef
typo
yuyu2172 175e0fb
update README
yuyu2172 efdb580
add cv2 option
yuyu2172 7ca130c
fix doc
yuyu2172 084fac9
delete unnecessary options for iterators
yuyu2172 1e19bdf
delete performance related stuff
yuyu2172 2b0dc98
delete PlotReport
yuyu2172 dd814dc
fix README
yuyu2172 80f449e
remove unnecessary
yuyu2172 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,17 @@ | ||
# Classification | ||
|
||
## Performance | ||
## ImageNet | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
Single crop error rate. | ||
### Weight conversion | ||
|
||
Single crop error rates of the models with the weights converted from Caffe weights. | ||
|
||
| Model | Top 1 | Reference Top 1 | | ||
|:-:|:-:|:-:| | ||
| VGG16 | 29.0 % | 28.5 % [1] | | ||
| ResNet50 | 24.8 % | 24.7 % [2] | | ||
| ResNet101 | 23.6 % | 23.6 % [2] | | ||
| ResNet152 | 23.2 % | 23.0 % [2] | | ||
| ResNet50 (`arch=he`) | 24.8 % | 24.7 % [2] | | ||
| ResNet101 (`arch=he`) | 23.6 % | 23.6 % [2] | | ||
| ResNet152 (`arch=he`) | 23.2 % | 23.0 % [2] | | ||
| SE-ResNet50 | 22.7 % | 22.4 % [3,4] | | ||
| SE-ResNet101 | 21.8 % | 21.8 % [3,4] | | ||
| SE-ResNet152 | 21.4 % | 21.3 % [3,4] | | ||
|
@@ -21,9 +23,9 @@ Ten crop error rate. | |
| Model | Top 1 | Reference Top 1 | | ||
|:-:|:-:|:-:| | ||
| VGG16 | 27.1 % | | | ||
| ResNet50 | 23.0 % | 22.9 % [2] | | ||
| ResNet101 | 21.8 % | 21.8 % [2] | | ||
| ResNet152 | 21.4 % | 21.4 % [2] | | ||
| ResNet50 (`arch=he`) | 23.0 % | 22.9 % [2] | | ||
| ResNet101 (`arch=he`) | 21.8 % | 21.8 % [2] | | ||
| ResNet152 (`arch=he`) | 21.4 % | 21.4 % [2] | | ||
| SE-ResNet50 | 20.8 % | | | ||
| SE-ResNet101 | 20.1 % | | | ||
| SE-ResNet152 | 19.7 % | | | ||
|
@@ -32,15 +34,48 @@ Ten crop error rate. | |
|
||
|
||
The results can be reproduced by the following command. | ||
The score is reported using a weight converted from a weight trained by Caffe. | ||
These scores are obtained using OpenCV backend. If Pillow is used, scores would differ. | ||
|
||
``` | ||
$ python eval_imagenet.py <path_to_val_dataset> [--model vgg16|resnet50|resnet101|resnet152|se-resnet50|se-resnet101|se-resnet152] [--pretrained-model <model_path>] [--batchsize <batchsize>] [--gpu <gpu>] [--crop center|10] | ||
``` | ||
|
||
### Trained model | ||
|
||
Single crop error rates of the models trained with the ChainerCV's training script. | ||
|
||
| Model | Top 1 | Reference Top 1 | | ||
|:-:|:-:|:-:| | ||
| ResNet50 (`arch=fb`) | 23.51 % | 23.60% [5] | | ||
| ResNet101 (`arch=fb`) | 22.07 % | 22.08% [5] | | ||
| ResNet152 (`arch=fb`) | 21.67 % | | | ||
|
||
|
||
The scores of the models trained with `train_imagenet_multi.py`, which can be executed like below. | ||
Please consult the full list of arguments for the training script with `python train_imagenet_multi.py -h`. | ||
``` | ||
$ mpiexec -n N python train_imagenet_multi.py <path_to_train_dataset> <path_to_val_dataset> | ||
``` | ||
|
||
The training procedure carefully follows the "ResNet in 1 hour" paper [5]. | ||
|
||
#### Performance tip | ||
When training over multiple nodes, set the communicator to `pure_nccl` (requires NCCL2). | ||
The default communicator (`hierarchical`) uses MPI to communicate between nodes, which is slower than the pure NCCL communicator. | ||
Also, cuDNN convolution functions can be optimized with extra commands (see https://docs.chainer.org/en/stable/performance.html#optimize-cudnn-convolution). | ||
|
||
#### Detailed training results | ||
|
||
Here, we investigate the effect of the number of GPUs on the final performance. | ||
For more statistically reliable results, we obtained results from five different random seeds. | ||
|
||
| Model | # GPUs | Top 1 | | ||
|:-:|:-:|:-:| | ||
| ResNet50 (`arch=fb`) | 8 | 23.53 (std=0.06) | | ||
| ResNet50 (`arch=fb`) | 32 | 23.56 (std=0.11) | | ||
|
||
|
||
## How to prepare ImageNet Dataset | ||
## How to prepare ImageNet dataset | ||
|
||
This instructions are based on the instruction found [here](https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md#download-the-imagenet-dataset). | ||
|
||
|
@@ -70,3 +105,4 @@ The ImageNet Large Scale Visual Recognition Challenge (ILSVRC) dataset has 1000 | |
2. Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Deep Residual Learning for Image Recognition" CVPR 2016 | ||
3. Jie Hu, Li Shen, Gang Sun. "Squeeze-and-Excitation Networks" CVPR 2018 | ||
4. https://github.com/hujie-frank/SENet | ||
5. Priya Goyal, Piotr Dollár, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, Kaiming He. "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour" https://arxiv.org/abs/1706.02677 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
from __future__ import division | ||
import argparse | ||
import multiprocessing | ||
|
||
import chainer | ||
from chainer.datasets import TransformDataset | ||
from chainer import iterators | ||
from chainer.links import Classifier | ||
from chainer.optimizer import WeightDecay | ||
from chainer.optimizers import CorrectedMomentumSGD | ||
from chainer import training | ||
from chainer.training import extensions | ||
|
||
from chainercv.datasets import directory_parsing_label_names | ||
from chainercv.datasets import DirectoryParsingLabelDataset | ||
from chainercv.transforms import center_crop | ||
from chainercv.transforms import random_flip | ||
from chainercv.transforms import random_sized_crop | ||
from chainercv.transforms import resize | ||
from chainercv.transforms import scale | ||
|
||
from chainercv.chainer_experimental.training.extensions import make_shift | ||
|
||
from chainercv.links.model.resnet import Bottleneck | ||
from chainercv.links import ResNet101 | ||
from chainercv.links import ResNet152 | ||
from chainercv.links import ResNet50 | ||
|
||
import chainermn | ||
|
||
|
||
class TrainTransform(object): | ||
|
||
def __init__(self, mean): | ||
self.mean = mean | ||
|
||
def __call__(self, in_data): | ||
img, label = in_data | ||
img = random_sized_crop(img) | ||
img = resize(img, (224, 224)) | ||
img = random_flip(img, x_random=True) | ||
img -= self.mean | ||
return img, label | ||
|
||
|
||
class ValTransform(object): | ||
|
||
def __init__(self, mean): | ||
self.mean = mean | ||
|
||
def __call__(self, in_data): | ||
img, label = in_data | ||
img = scale(img, 256) | ||
img = center_crop(img, (224, 224)) | ||
img -= self.mean | ||
return img, label | ||
|
||
|
||
def main(): | ||
model_cfgs = { | ||
'resnet50': {'class': ResNet50, 'score_layer_name': 'fc6', | ||
'kwargs': {'arch': 'fb'}}, | ||
'resnet101': {'class': ResNet101, 'score_layer_name': 'fc6', | ||
'kwargs': {'arch': 'fb'}}, | ||
'resnet152': {'class': ResNet152, 'score_layer_name': 'fc6', | ||
'kwargs': {'arch': 'fb'}} | ||
} | ||
parser = argparse.ArgumentParser( | ||
description='Learning convnet from ILSVRC2012 dataset') | ||
parser.add_argument('train', help='Path to root of the train dataset') | ||
parser.add_argument('val', help='Path to root of the validation dataset') | ||
parser.add_argument('--model', | ||
'-m', choices=model_cfgs.keys(), default='resnet50', | ||
help='Convnet models') | ||
parser.add_argument('--communicator', type=str, | ||
default='hierarchical', help='Type of communicator') | ||
parser.add_argument('--loaderjob', type=int, default=4) | ||
parser.add_argument('--batchsize', type=int, default=32, | ||
help='Batch size for each worker') | ||
parser.add_argument('--lr', type=float) | ||
parser.add_argument('--momentum', type=float, default=0.9) | ||
parser.add_argument('--weight_decay', type=float, default=0.0001) | ||
parser.add_argument('--out', type=str, default='result') | ||
parser.add_argument('--epoch', type=int, default=90) | ||
args = parser.parse_args() | ||
|
||
# This fixes a crash caused by a bug with multiprocessing and MPI. | ||
multiprocessing.set_start_method('forkserver') | ||
p = multiprocessing.Process() | ||
p.start() | ||
p.join() | ||
|
||
comm = chainermn.create_communicator(args.communicator) | ||
device = comm.intra_rank | ||
|
||
if args.lr is not None: | ||
lr = args.lr | ||
else: | ||
lr = 0.1 * (args.batchsize * comm.size) / 256 | ||
if comm.rank == 0: | ||
print('lr={}: lr is selected based on the linear ' | ||
'scaling rule'.format(lr)) | ||
|
||
label_names = directory_parsing_label_names(args.train) | ||
|
||
model_cfg = model_cfgs[args.model] | ||
extractor = model_cfg['class']( | ||
n_class=len(label_names), **model_cfg['kwargs']) | ||
extractor.pick = model_cfg['score_layer_name'] | ||
model = Classifier(extractor) | ||
# Following https://arxiv.org/pdf/1706.02677.pdf, | ||
# the gamma of the last BN of each resblock is initialized by zeros. | ||
for l in model.links(): | ||
if isinstance(l, Bottleneck): | ||
l.conv3.bn.gamma.data[:] = 0 | ||
|
||
if comm.rank == 0: | ||
train_data = DirectoryParsingLabelDataset(args.train) | ||
val_data = DirectoryParsingLabelDataset(args.val) | ||
train_data = TransformDataset( | ||
train_data, TrainTransform(extractor.mean)) | ||
val_data = TransformDataset(val_data, ValTransform(extractor.mean)) | ||
print('finished loading dataset') | ||
else: | ||
train_data, val_data = None, None | ||
train_data = chainermn.scatter_dataset(train_data, comm, shuffle=True) | ||
val_data = chainermn.scatter_dataset(val_data, comm, shuffle=True) | ||
train_iter = chainer.iterators.MultiprocessIterator( | ||
train_data, args.batchsize, n_processes=args.loaderjob) | ||
val_iter = iterators.MultiprocessIterator( | ||
val_data, args.batchsize, | ||
repeat=False, shuffle=False, n_processes=args.loaderjob) | ||
|
||
optimizer = chainermn.create_multi_node_optimizer( | ||
CorrectedMomentumSGD(lr=lr, momentum=args.momentum), comm) | ||
optimizer.setup(model) | ||
for param in model.params(): | ||
if param.name not in ('beta', 'gamma'): | ||
param.update_rule.add_hook(WeightDecay(args.weight_decay)) | ||
|
||
if device >= 0: | ||
chainer.cuda.get_device(device).use() | ||
model.to_gpu() | ||
|
||
updater = chainer.training.StandardUpdater( | ||
train_iter, optimizer, device=device) | ||
|
||
trainer = training.Trainer( | ||
updater, (args.epoch, 'epoch'), out=args.out) | ||
|
||
@make_shift('lr') | ||
def warmup_and_exponential_shift(trainer): | ||
epoch = trainer.updater.epoch_detail | ||
warmup_epoch = 5 | ||
if epoch < warmup_epoch: | ||
if lr > 0.1: | ||
warmup_rate = 0.1 / lr | ||
rate = warmup_rate \ | ||
+ (1 - warmup_rate) * epoch / warmup_epoch | ||
else: | ||
rate = 1 | ||
elif epoch < 30: | ||
rate = 1 | ||
elif epoch < 60: | ||
rate = 0.1 | ||
elif epoch < 80: | ||
rate = 0.01 | ||
else: | ||
rate = 0.001 | ||
return rate * lr | ||
|
||
trainer.extend(warmup_and_exponential_shift) | ||
evaluator = chainermn.create_multi_node_evaluator( | ||
extensions.Evaluator(val_iter, model, device=device), comm) | ||
trainer.extend(evaluator, trigger=(1, 'epoch')) | ||
|
||
log_interval = 0.1, 'epoch' | ||
print_interval = 0.1, 'epoch' | ||
|
||
if comm.rank == 0: | ||
trainer.extend(chainer.training.extensions.observe_lr(), | ||
trigger=log_interval) | ||
trainer.extend( | ||
extensions.snapshot_object( | ||
extractor, 'snapshot_model_{.updater.epoch}.npz'), | ||
trigger=(args.epoch, 'epoch')) | ||
trainer.extend(extensions.LogReport(trigger=log_interval)) | ||
trainer.extend(extensions.PrintReport( | ||
['iteration', 'epoch', 'elapsed_time', 'lr', | ||
'main/loss', 'validation/main/loss', | ||
'main/accuracy', 'validation/main/accuracy'] | ||
), trigger=print_interval) | ||
trainer.extend(extensions.ProgressBar(update_interval=10)) | ||
|
||
trainer.run() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we need
'cv2': True
?