Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Add ResNet training code #436

Merged
merged 74 commits into from
Nov 29, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
d46f680
Merge branch 'random-sized-crop' into resnet-train
yuyu2172 Sep 28, 2017
93e0bf9
add train_imagenet
yuyu2172 Sep 28, 2017
a4a03d7
Merge remote-tracking branch 'yuyu2172/random-sized-crop' into HEAD
yuyu2172 Sep 29, 2017
2808c16
add train_imagenet_mn
yuyu2172 Sep 29, 2017
2c3667a
fix learning rate
yuyu2172 Sep 29, 2017
ea80eee
add color_jitter
yuyu2172 Sep 29, 2017
b5003a1
Merge remote-tracking branch 'yuyu2172/color-jitter' into HEAD
yuyu2172 Sep 29, 2017
63ae346
use color_jitter
yuyu2172 Sep 29, 2017
c0c536b
update readme
yuyu2172 Sep 29, 2017
9efd461
Merge remote-tracking branch 'yuyu2172/random-sized-crop' into resnet…
yuyu2172 Sep 29, 2017
056ec6b
make training code follow Training ImageNet 1 hour paper
yuyu2172 Oct 4, 2017
89d1ba1
Merge remote-tracking branch 'yuyu2172/resnet-link' into resnet-train
yuyu2172 Oct 4, 2017
03d4ac1
set initial gamma to zero for the last bn of block
yuyu2172 Oct 4, 2017
193e47f
add corrected_momentum_sgd
yuyu2172 Oct 4, 2017
1cb7d37
add observe_lr extension
yuyu2172 Oct 4, 2017
007e937
Merge remote-tracking branch 'yuyu2172/resnet-link' into resnet-train
yuyu2172 Oct 4, 2017
61f5936
Merge remote-tracking branch 'yuyu2172/resnet-link' into resnet-train
yuyu2172 Oct 8, 2017
9fdb6ab
Merge branch 'resnet-link' into resnet-train
yuyu2172 Oct 8, 2017
779a4d1
merge
yuyu2172 Mar 7, 2018
d39ebf9
remove ResNet18
yuyu2172 Mar 7, 2018
fb2e062
Merge remote-tracking branch 'origin/master' into HEAD
yuyu2172 May 24, 2018
d357af3
delete color_jitter
yuyu2172 May 24, 2018
c677fcb
move corrected_momentum_sgd to chainer_experimental
yuyu2172 May 24, 2018
d5c1cd3
delete color_jitter from reference
yuyu2172 May 24, 2018
7e53af3
doc
yuyu2172 Jun 1, 2018
43b548b
fix cpu mode of CorrectedMomentumSGD
yuyu2172 Jun 1, 2018
fbf38d6
Merge remote-tracking branch 'origin/master' into HEAD
yuyu2172 Jun 20, 2018
a263992
do not use lambdas
yuyu2172 Jun 20, 2018
16502cf
Merge remote-tracking branch 'yuyu2172/no-lambda' into HEAD
yuyu2172 Jun 20, 2018
60be570
Merge branch 'resnet-train' of https://github.com/yuyu2172/chainercv …
yuyu2172 Jun 20, 2018
a4e14f0
fix
yuyu2172 Jun 20, 2018
7b0b8a7
Merge remote-tracking branch 'yuyu2172/resnet-train' into HEAD
yuyu2172 Jun 20, 2018
898dbbb
fix
yuyu2172 Jun 20, 2018
93c410f
style
yuyu2172 Jun 20, 2018
77cd4cb
remove redundant improt
yuyu2172 Jun 20, 2018
aa8b713
grammar
yuyu2172 Jun 20, 2018
e556613
update README
yuyu2172 Jun 20, 2018
8bb1135
delete unnecessary cmd option
yuyu2172 Jun 20, 2018
5d40516
use original style
yuyu2172 Jun 20, 2018
6de2a17
simplify
yuyu2172 Jun 20, 2018
42ae6f9
initialize the last BN of each BuildingBlock in 1 hour style
yuyu2172 Jun 28, 2018
399a140
fix script
yuyu2172 Jun 28, 2018
02b4f30
flake8
yuyu2172 Jun 28, 2018
c292e09
add warmpup
yuyu2172 Oct 15, 2018
6c6042e
warmup initial lr changed
yuyu2172 Oct 15, 2018
aacd11f
update warmup
yuyu2172 Oct 16, 2018
bc2b34c
try to fix segfault
yuyu2172 Oct 18, 2018
93fbd36
fix
yuyu2172 Oct 19, 2018
86e9ded
delete cv2 setNumThreads
yuyu2172 Oct 19, 2018
ad77235
Merge remote-tracking branch 'origin/master' into resnet-train
yuyu2172 Oct 20, 2018
e299b22
merge
yuyu2172 Oct 20, 2018
078bb4c
delete CorrectedMomentumSGD
yuyu2172 Oct 20, 2018
f591ead
simplify
yuyu2172 Oct 20, 2018
6a9f56b
fix flake8
yuyu2172 Oct 20, 2018
a3d2e3e
update README
yuyu2172 Oct 20, 2018
54a37f2
update README
yuyu2172 Nov 19, 2018
7f289d2
Merge remote-tracking branch 'origin/master' into HEAD
yuyu2172 Nov 19, 2018
6d3cca9
use make_shift
yuyu2172 Nov 19, 2018
cc4c875
fix bug
yuyu2172 Nov 19, 2018
fbd4ae4
change commandline argument: arch -> model
yuyu2172 Nov 19, 2018
a3d85d0
fix an error that is raised when gpu <= 8
yuyu2172 Nov 19, 2018
315a645
update README
yuyu2172 Nov 26, 2018
7c73f05
add url link
yuyu2172 Nov 26, 2018
783ba36
change default arch to fb for eval
yuyu2172 Nov 26, 2018
fbcbe56
update README
yuyu2172 Nov 26, 2018
68479ef
typo
yuyu2172 Nov 26, 2018
175e0fb
update README
yuyu2172 Nov 26, 2018
efdb580
add cv2 option
yuyu2172 Nov 26, 2018
7ca130c
fix doc
yuyu2172 Nov 26, 2018
084fac9
delete unnecessary options for iterators
yuyu2172 Nov 26, 2018
1e19bdf
delete performance related stuff
yuyu2172 Nov 26, 2018
2b0dc98
delete PlotReport
yuyu2172 Nov 28, 2018
dd814dc
fix README
yuyu2172 Nov 29, 2018
80f449e
remove unnecessary
yuyu2172 Nov 29, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions chainercv/links/model/resnet/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ class ResNet(PickableSequentialChain):
loaded from weights distributed on the Internet.
The list of pretrained models supported are as follows:
* :obj:`imagenet`: Loads weights trained with ImageNet and distributed \
* :obj:`imagenet`: Loads weights trained with ImageNet. \
When :obj:`arch=='he'`, the weights distributed \
at `Model Zoo \
<https://github.com/BVLC/caffe/wiki/Model-Zoo>`_.
This is only supported when :obj:`arch=='he'`.
<https://github.com/BVLC/caffe/wiki/Model-Zoo>`_ \
are used.
Args:
n_layer (int): The number of layers.
Expand Down Expand Up @@ -103,9 +104,33 @@ class ResNet(PickableSequentialChain):

_models = {
'fb': {
50: {},
101: {},
152: {}
50: {
'imagenet': {
'param': {'n_class': 1000, 'mean': _imagenet_mean},
'overwritable': {'mean'},
'url': 'https://chainercv-models.preferred.jp/'
'resnet152_imagenet_trained_2018_11_26.npz',
'cv2': True,
},
},
101: {
'imagenet': {
'param': {'n_class': 1000, 'mean': _imagenet_mean},
'overwritable': {'mean'},
'url': 'https://chainercv-models.preferred.jp/'
'resnet101_imagenet_trained_2018_11_26.npz',
'cv2': True,
},
},
152: {
'imagenet': {
'param': {'n_class': 1000, 'mean': _imagenet_mean},
'overwritable': {'mean'},
'url': 'https://chainercv-models.preferred.jp/'
'resnet152_imagenet_trained_2018_11_26.npz',
'cv2': True,
},
},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need 'cv2': True?

},
'he': {
50: {
Expand Down Expand Up @@ -140,10 +165,6 @@ def __init__(self, n_layer,
pretrained_model=None,
mean=None, initialW=None, fc_kwargs={}, arch='fb'):
if arch == 'fb':
if pretrained_model == 'imagenet':
raise ValueError(
'Pretrained weights for Facebook ResNet models '
'are not supported. Please set arch to \'he\'.')
stride_first = False
conv1_no_bias = True
elif arch == 'he':
Expand Down
56 changes: 46 additions & 10 deletions examples/classification/README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# Classification

## Performance
## ImageNet
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

he and fb should be distinguished from each other in score board.


Single crop error rate.
### Weight conversion

Single crop error rates of the models with the weights converted from Caffe weights.

| Model | Top 1 | Reference Top 1 |
|:-:|:-:|:-:|
| VGG16 | 29.0 % | 28.5 % [1] |
| ResNet50 | 24.8 % | 24.7 % [2] |
| ResNet101 | 23.6 % | 23.6 % [2] |
| ResNet152 | 23.2 % | 23.0 % [2] |
| ResNet50 (`arch=he`) | 24.8 % | 24.7 % [2] |
| ResNet101 (`arch=he`) | 23.6 % | 23.6 % [2] |
| ResNet152 (`arch=he`) | 23.2 % | 23.0 % [2] |
| SE-ResNet50 | 22.7 % | 22.4 % [3,4] |
| SE-ResNet101 | 21.8 % | 21.8 % [3,4] |
| SE-ResNet152 | 21.4 % | 21.3 % [3,4] |
Expand All @@ -21,9 +23,9 @@ Ten crop error rate.
| Model | Top 1 | Reference Top 1 |
|:-:|:-:|:-:|
| VGG16 | 27.1 % | |
| ResNet50 | 23.0 % | 22.9 % [2] |
| ResNet101 | 21.8 % | 21.8 % [2] |
| ResNet152 | 21.4 % | 21.4 % [2] |
| ResNet50 (`arch=he`) | 23.0 % | 22.9 % [2] |
| ResNet101 (`arch=he`) | 21.8 % | 21.8 % [2] |
| ResNet152 (`arch=he`) | 21.4 % | 21.4 % [2] |
| SE-ResNet50 | 20.8 % | |
| SE-ResNet101 | 20.1 % | |
| SE-ResNet152 | 19.7 % | |
Expand All @@ -32,15 +34,48 @@ Ten crop error rate.


The results can be reproduced by the following command.
The score is reported using a weight converted from a weight trained by Caffe.
These scores are obtained using OpenCV backend. If Pillow is used, scores would differ.

```
$ python eval_imagenet.py <path_to_val_dataset> [--model vgg16|resnet50|resnet101|resnet152|se-resnet50|se-resnet101|se-resnet152] [--pretrained-model <model_path>] [--batchsize <batchsize>] [--gpu <gpu>] [--crop center|10]
```

### Trained model

Single crop error rates of the models trained with the ChainerCV's training script.

| Model | Top 1 | Reference Top 1 |
|:-:|:-:|:-:|
| ResNet50 (`arch=fb`) | 23.51 % | 23.60% [5] |
| ResNet101 (`arch=fb`) | 22.07 % | 22.08% [5] |
| ResNet152 (`arch=fb`) | 21.67 % | |


The scores of the models trained with `train_imagenet_multi.py`, which can be executed like below.
Please consult the full list of arguments for the training script with `python train_imagenet_multi.py -h`.
```
$ mpiexec -n N python train_imagenet_multi.py <path_to_train_dataset> <path_to_val_dataset>
```

The training procedure carefully follows the "ResNet in 1 hour" paper [5].

#### Performance tip
When training over multiple nodes, set the communicator to `pure_nccl` (requires NCCL2).
The default communicator (`hierarchical`) uses MPI to communicate between nodes, which is slower than the pure NCCL communicator.
Also, cuDNN convolution functions can be optimized with extra commands (see https://docs.chainer.org/en/stable/performance.html#optimize-cudnn-convolution).

#### Detailed training results

Here, we investigate the effect of the number of GPUs on the final performance.
For more statistically reliable results, we obtained results from five different random seeds.

| Model | # GPUs | Top 1 |
|:-:|:-:|:-:|
| ResNet50 (`arch=fb`) | 8 | 23.53 (std=0.06) |
| ResNet50 (`arch=fb`) | 32 | 23.56 (std=0.11) |


## How to prepare ImageNet Dataset
## How to prepare ImageNet dataset

This instructions are based on the instruction found [here](https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md#download-the-imagenet-dataset).

Expand Down Expand Up @@ -70,3 +105,4 @@ The ImageNet Large Scale Visual Recognition Challenge (ILSVRC) dataset has 1000
2. Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Deep Residual Learning for Image Recognition" CVPR 2016
3. Jie Hu, Li Shen, Gang Sun. "Squeeze-and-Excitation Networks" CVPR 2018
4. https://github.com/hujie-frank/SENet
5. Priya Goyal, Piotr Dollár, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, Kaiming He. "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour" https://arxiv.org/abs/1706.02677
2 changes: 1 addition & 1 deletion examples/classification/eval_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def main():
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--batchsize', type=int, default=32)
parser.add_argument('--crop', choices=('center', '10'), default='center')
parser.add_argument('--resnet-arch', default='he')
parser.add_argument('--resnet-arch', default='fb')
args = parser.parse_args()

dataset = DirectoryParsingLabelDataset(args.val)
Expand Down
199 changes: 199 additions & 0 deletions examples/classification/train_imagenet_multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
from __future__ import division
import argparse
import multiprocessing

import chainer
from chainer.datasets import TransformDataset
from chainer import iterators
from chainer.links import Classifier
from chainer.optimizer import WeightDecay
from chainer.optimizers import CorrectedMomentumSGD
from chainer import training
from chainer.training import extensions

from chainercv.datasets import directory_parsing_label_names
from chainercv.datasets import DirectoryParsingLabelDataset
from chainercv.transforms import center_crop
from chainercv.transforms import random_flip
from chainercv.transforms import random_sized_crop
from chainercv.transforms import resize
from chainercv.transforms import scale

from chainercv.chainer_experimental.training.extensions import make_shift

from chainercv.links.model.resnet import Bottleneck
from chainercv.links import ResNet101
from chainercv.links import ResNet152
from chainercv.links import ResNet50

import chainermn


class TrainTransform(object):

def __init__(self, mean):
self.mean = mean

def __call__(self, in_data):
img, label = in_data
img = random_sized_crop(img)
img = resize(img, (224, 224))
img = random_flip(img, x_random=True)
img -= self.mean
return img, label


class ValTransform(object):

def __init__(self, mean):
self.mean = mean

def __call__(self, in_data):
img, label = in_data
img = scale(img, 256)
img = center_crop(img, (224, 224))
img -= self.mean
return img, label


def main():
model_cfgs = {
'resnet50': {'class': ResNet50, 'score_layer_name': 'fc6',
'kwargs': {'arch': 'fb'}},
'resnet101': {'class': ResNet101, 'score_layer_name': 'fc6',
'kwargs': {'arch': 'fb'}},
'resnet152': {'class': ResNet152, 'score_layer_name': 'fc6',
'kwargs': {'arch': 'fb'}}
}
parser = argparse.ArgumentParser(
description='Learning convnet from ILSVRC2012 dataset')
parser.add_argument('train', help='Path to root of the train dataset')
parser.add_argument('val', help='Path to root of the validation dataset')
parser.add_argument('--model',
'-m', choices=model_cfgs.keys(), default='resnet50',
help='Convnet models')
parser.add_argument('--communicator', type=str,
default='hierarchical', help='Type of communicator')
parser.add_argument('--loaderjob', type=int, default=4)
parser.add_argument('--batchsize', type=int, default=32,
help='Batch size for each worker')
parser.add_argument('--lr', type=float)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--weight_decay', type=float, default=0.0001)
parser.add_argument('--out', type=str, default='result')
parser.add_argument('--epoch', type=int, default=90)
args = parser.parse_args()

# This fixes a crash caused by a bug with multiprocessing and MPI.
multiprocessing.set_start_method('forkserver')
p = multiprocessing.Process()
p.start()
p.join()

comm = chainermn.create_communicator(args.communicator)
device = comm.intra_rank

if args.lr is not None:
lr = args.lr
else:
lr = 0.1 * (args.batchsize * comm.size) / 256
if comm.rank == 0:
print('lr={}: lr is selected based on the linear '
'scaling rule'.format(lr))

label_names = directory_parsing_label_names(args.train)

model_cfg = model_cfgs[args.model]
extractor = model_cfg['class'](
n_class=len(label_names), **model_cfg['kwargs'])
extractor.pick = model_cfg['score_layer_name']
model = Classifier(extractor)
# Following https://arxiv.org/pdf/1706.02677.pdf,
# the gamma of the last BN of each resblock is initialized by zeros.
for l in model.links():
if isinstance(l, Bottleneck):
l.conv3.bn.gamma.data[:] = 0

if comm.rank == 0:
train_data = DirectoryParsingLabelDataset(args.train)
val_data = DirectoryParsingLabelDataset(args.val)
train_data = TransformDataset(
train_data, TrainTransform(extractor.mean))
val_data = TransformDataset(val_data, ValTransform(extractor.mean))
print('finished loading dataset')
else:
train_data, val_data = None, None
train_data = chainermn.scatter_dataset(train_data, comm, shuffle=True)
val_data = chainermn.scatter_dataset(val_data, comm, shuffle=True)
train_iter = chainer.iterators.MultiprocessIterator(
train_data, args.batchsize, n_processes=args.loaderjob)
val_iter = iterators.MultiprocessIterator(
val_data, args.batchsize,
repeat=False, shuffle=False, n_processes=args.loaderjob)

optimizer = chainermn.create_multi_node_optimizer(
CorrectedMomentumSGD(lr=lr, momentum=args.momentum), comm)
optimizer.setup(model)
for param in model.params():
if param.name not in ('beta', 'gamma'):
param.update_rule.add_hook(WeightDecay(args.weight_decay))

if device >= 0:
chainer.cuda.get_device(device).use()
model.to_gpu()

updater = chainer.training.StandardUpdater(
train_iter, optimizer, device=device)

trainer = training.Trainer(
updater, (args.epoch, 'epoch'), out=args.out)

@make_shift('lr')
def warmup_and_exponential_shift(trainer):
epoch = trainer.updater.epoch_detail
warmup_epoch = 5
if epoch < warmup_epoch:
if lr > 0.1:
warmup_rate = 0.1 / lr
rate = warmup_rate \
+ (1 - warmup_rate) * epoch / warmup_epoch
else:
rate = 1
elif epoch < 30:
rate = 1
elif epoch < 60:
rate = 0.1
elif epoch < 80:
rate = 0.01
else:
rate = 0.001
return rate * lr

trainer.extend(warmup_and_exponential_shift)
evaluator = chainermn.create_multi_node_evaluator(
extensions.Evaluator(val_iter, model, device=device), comm)
trainer.extend(evaluator, trigger=(1, 'epoch'))

log_interval = 0.1, 'epoch'
print_interval = 0.1, 'epoch'

if comm.rank == 0:
trainer.extend(chainer.training.extensions.observe_lr(),
trigger=log_interval)
trainer.extend(
extensions.snapshot_object(
extractor, 'snapshot_model_{.updater.epoch}.npz'),
trigger=(args.epoch, 'epoch'))
trainer.extend(extensions.LogReport(trigger=log_interval))
trainer.extend(extensions.PrintReport(
['iteration', 'epoch', 'elapsed_time', 'lr',
'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy']
), trigger=print_interval)
trainer.extend(extensions.ProgressBar(update_interval=10))

trainer.run()


if __name__ == '__main__':
main()