Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
Merge pull request #436 from yuyu2172/resnet-train
Browse files Browse the repository at this point in the history
Add ResNet training code
  • Loading branch information
Hakuyume authored Nov 29, 2018
2 parents ca9b94c + 80f449e commit 9d8a68e
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 21 deletions.
41 changes: 31 additions & 10 deletions chainercv/links/model/resnet/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ class ResNet(PickableSequentialChain):
loaded from weights distributed on the Internet.
The list of pretrained models supported are as follows:
* :obj:`imagenet`: Loads weights trained with ImageNet and distributed \
* :obj:`imagenet`: Loads weights trained with ImageNet. \
When :obj:`arch=='he'`, the weights distributed \
at `Model Zoo \
<https://github.com/BVLC/caffe/wiki/Model-Zoo>`_.
This is only supported when :obj:`arch=='he'`.
<https://github.com/BVLC/caffe/wiki/Model-Zoo>`_ \
are used.
Args:
n_layer (int): The number of layers.
Expand Down Expand Up @@ -103,9 +104,33 @@ class ResNet(PickableSequentialChain):

_models = {
'fb': {
50: {},
101: {},
152: {}
50: {
'imagenet': {
'param': {'n_class': 1000, 'mean': _imagenet_mean},
'overwritable': {'mean'},
'url': 'https://chainercv-models.preferred.jp/'
'resnet152_imagenet_trained_2018_11_26.npz',
'cv2': True,
},
},
101: {
'imagenet': {
'param': {'n_class': 1000, 'mean': _imagenet_mean},
'overwritable': {'mean'},
'url': 'https://chainercv-models.preferred.jp/'
'resnet101_imagenet_trained_2018_11_26.npz',
'cv2': True,
},
},
152: {
'imagenet': {
'param': {'n_class': 1000, 'mean': _imagenet_mean},
'overwritable': {'mean'},
'url': 'https://chainercv-models.preferred.jp/'
'resnet152_imagenet_trained_2018_11_26.npz',
'cv2': True,
},
},
},
'he': {
50: {
Expand Down Expand Up @@ -140,10 +165,6 @@ def __init__(self, n_layer,
pretrained_model=None,
mean=None, initialW=None, fc_kwargs={}, arch='fb'):
if arch == 'fb':
if pretrained_model == 'imagenet':
raise ValueError(
'Pretrained weights for Facebook ResNet models '
'are not supported. Please set arch to \'he\'.')
stride_first = False
conv1_no_bias = True
elif arch == 'he':
Expand Down
56 changes: 46 additions & 10 deletions examples/classification/README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# Classification

## Performance
## ImageNet

Single crop error rate.
### Weight conversion

Single crop error rates of the models with the weights converted from Caffe weights.

| Model | Top 1 | Reference Top 1 |
|:-:|:-:|:-:|
| VGG16 | 29.0 % | 28.5 % [1] |
| ResNet50 | 24.8 % | 24.7 % [2] |
| ResNet101 | 23.6 % | 23.6 % [2] |
| ResNet152 | 23.2 % | 23.0 % [2] |
| ResNet50 (`arch=he`) | 24.8 % | 24.7 % [2] |
| ResNet101 (`arch=he`) | 23.6 % | 23.6 % [2] |
| ResNet152 (`arch=he`) | 23.2 % | 23.0 % [2] |
| SE-ResNet50 | 22.7 % | 22.4 % [3,4] |
| SE-ResNet101 | 21.8 % | 21.8 % [3,4] |
| SE-ResNet152 | 21.4 % | 21.3 % [3,4] |
Expand All @@ -21,9 +23,9 @@ Ten crop error rate.
| Model | Top 1 | Reference Top 1 |
|:-:|:-:|:-:|
| VGG16 | 27.1 % | |
| ResNet50 | 23.0 % | 22.9 % [2] |
| ResNet101 | 21.8 % | 21.8 % [2] |
| ResNet152 | 21.4 % | 21.4 % [2] |
| ResNet50 (`arch=he`) | 23.0 % | 22.9 % [2] |
| ResNet101 (`arch=he`) | 21.8 % | 21.8 % [2] |
| ResNet152 (`arch=he`) | 21.4 % | 21.4 % [2] |
| SE-ResNet50 | 20.8 % | |
| SE-ResNet101 | 20.1 % | |
| SE-ResNet152 | 19.7 % | |
Expand All @@ -32,15 +34,48 @@ Ten crop error rate.


The results can be reproduced by the following command.
The score is reported using a weight converted from a weight trained by Caffe.
These scores are obtained using OpenCV backend. If Pillow is used, scores would differ.

```
$ python eval_imagenet.py <path_to_val_dataset> [--model vgg16|resnet50|resnet101|resnet152|se-resnet50|se-resnet101|se-resnet152] [--pretrained-model <model_path>] [--batchsize <batchsize>] [--gpu <gpu>] [--crop center|10]
```

### Trained model

Single crop error rates of the models trained with the ChainerCV's training script.

| Model | Top 1 | Reference Top 1 |
|:-:|:-:|:-:|
| ResNet50 (`arch=fb`) | 23.51 % | 23.60% [5] |
| ResNet101 (`arch=fb`) | 22.07 % | 22.08% [5] |
| ResNet152 (`arch=fb`) | 21.67 % | |


The scores of the models trained with `train_imagenet_multi.py`, which can be executed like below.
Please consult the full list of arguments for the training script with `python train_imagenet_multi.py -h`.
```
$ mpiexec -n N python train_imagenet_multi.py <path_to_train_dataset> <path_to_val_dataset>
```

The training procedure carefully follows the "ResNet in 1 hour" paper [5].

#### Performance tip
When training over multiple nodes, set the communicator to `pure_nccl` (requires NCCL2).
The default communicator (`hierarchical`) uses MPI to communicate between nodes, which is slower than the pure NCCL communicator.
Also, cuDNN convolution functions can be optimized with extra commands (see https://docs.chainer.org/en/stable/performance.html#optimize-cudnn-convolution).

#### Detailed training results

Here, we investigate the effect of the number of GPUs on the final performance.
For more statistically reliable results, we obtained results from five different random seeds.

| Model | # GPUs | Top 1 |
|:-:|:-:|:-:|
| ResNet50 (`arch=fb`) | 8 | 23.53 (std=0.06) |
| ResNet50 (`arch=fb`) | 32 | 23.56 (std=0.11) |


## How to prepare ImageNet Dataset
## How to prepare ImageNet dataset

This instructions are based on the instruction found [here](https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md#download-the-imagenet-dataset).

Expand Down Expand Up @@ -70,3 +105,4 @@ The ImageNet Large Scale Visual Recognition Challenge (ILSVRC) dataset has 1000
2. Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Deep Residual Learning for Image Recognition" CVPR 2016
3. Jie Hu, Li Shen, Gang Sun. "Squeeze-and-Excitation Networks" CVPR 2018
4. https://github.com/hujie-frank/SENet
5. Priya Goyal, Piotr Dollár, Ross Girshick, Pieter Noordhuis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch, Yangqing Jia, Kaiming He. "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour" https://arxiv.org/abs/1706.02677
2 changes: 1 addition & 1 deletion examples/classification/eval_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def main():
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--batchsize', type=int, default=32)
parser.add_argument('--crop', choices=('center', '10'), default='center')
parser.add_argument('--resnet-arch', default='he')
parser.add_argument('--resnet-arch', default='fb')
args = parser.parse_args()

dataset = DirectoryParsingLabelDataset(args.val)
Expand Down
199 changes: 199 additions & 0 deletions examples/classification/train_imagenet_multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
from __future__ import division
import argparse
import multiprocessing

import chainer
from chainer.datasets import TransformDataset
from chainer import iterators
from chainer.links import Classifier
from chainer.optimizer import WeightDecay
from chainer.optimizers import CorrectedMomentumSGD
from chainer import training
from chainer.training import extensions

from chainercv.datasets import directory_parsing_label_names
from chainercv.datasets import DirectoryParsingLabelDataset
from chainercv.transforms import center_crop
from chainercv.transforms import random_flip
from chainercv.transforms import random_sized_crop
from chainercv.transforms import resize
from chainercv.transforms import scale

from chainercv.chainer_experimental.training.extensions import make_shift

from chainercv.links.model.resnet import Bottleneck
from chainercv.links import ResNet101
from chainercv.links import ResNet152
from chainercv.links import ResNet50

import chainermn


class TrainTransform(object):

def __init__(self, mean):
self.mean = mean

def __call__(self, in_data):
img, label = in_data
img = random_sized_crop(img)
img = resize(img, (224, 224))
img = random_flip(img, x_random=True)
img -= self.mean
return img, label


class ValTransform(object):

def __init__(self, mean):
self.mean = mean

def __call__(self, in_data):
img, label = in_data
img = scale(img, 256)
img = center_crop(img, (224, 224))
img -= self.mean
return img, label


def main():
model_cfgs = {
'resnet50': {'class': ResNet50, 'score_layer_name': 'fc6',
'kwargs': {'arch': 'fb'}},
'resnet101': {'class': ResNet101, 'score_layer_name': 'fc6',
'kwargs': {'arch': 'fb'}},
'resnet152': {'class': ResNet152, 'score_layer_name': 'fc6',
'kwargs': {'arch': 'fb'}}
}
parser = argparse.ArgumentParser(
description='Learning convnet from ILSVRC2012 dataset')
parser.add_argument('train', help='Path to root of the train dataset')
parser.add_argument('val', help='Path to root of the validation dataset')
parser.add_argument('--model',
'-m', choices=model_cfgs.keys(), default='resnet50',
help='Convnet models')
parser.add_argument('--communicator', type=str,
default='hierarchical', help='Type of communicator')
parser.add_argument('--loaderjob', type=int, default=4)
parser.add_argument('--batchsize', type=int, default=32,
help='Batch size for each worker')
parser.add_argument('--lr', type=float)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--weight_decay', type=float, default=0.0001)
parser.add_argument('--out', type=str, default='result')
parser.add_argument('--epoch', type=int, default=90)
args = parser.parse_args()

# This fixes a crash caused by a bug with multiprocessing and MPI.
multiprocessing.set_start_method('forkserver')
p = multiprocessing.Process()
p.start()
p.join()

comm = chainermn.create_communicator(args.communicator)
device = comm.intra_rank

if args.lr is not None:
lr = args.lr
else:
lr = 0.1 * (args.batchsize * comm.size) / 256
if comm.rank == 0:
print('lr={}: lr is selected based on the linear '
'scaling rule'.format(lr))

label_names = directory_parsing_label_names(args.train)

model_cfg = model_cfgs[args.model]
extractor = model_cfg['class'](
n_class=len(label_names), **model_cfg['kwargs'])
extractor.pick = model_cfg['score_layer_name']
model = Classifier(extractor)
# Following https://arxiv.org/pdf/1706.02677.pdf,
# the gamma of the last BN of each resblock is initialized by zeros.
for l in model.links():
if isinstance(l, Bottleneck):
l.conv3.bn.gamma.data[:] = 0

if comm.rank == 0:
train_data = DirectoryParsingLabelDataset(args.train)
val_data = DirectoryParsingLabelDataset(args.val)
train_data = TransformDataset(
train_data, TrainTransform(extractor.mean))
val_data = TransformDataset(val_data, ValTransform(extractor.mean))
print('finished loading dataset')
else:
train_data, val_data = None, None
train_data = chainermn.scatter_dataset(train_data, comm, shuffle=True)
val_data = chainermn.scatter_dataset(val_data, comm, shuffle=True)
train_iter = chainer.iterators.MultiprocessIterator(
train_data, args.batchsize, n_processes=args.loaderjob)
val_iter = iterators.MultiprocessIterator(
val_data, args.batchsize,
repeat=False, shuffle=False, n_processes=args.loaderjob)

optimizer = chainermn.create_multi_node_optimizer(
CorrectedMomentumSGD(lr=lr, momentum=args.momentum), comm)
optimizer.setup(model)
for param in model.params():
if param.name not in ('beta', 'gamma'):
param.update_rule.add_hook(WeightDecay(args.weight_decay))

if device >= 0:
chainer.cuda.get_device(device).use()
model.to_gpu()

updater = chainer.training.StandardUpdater(
train_iter, optimizer, device=device)

trainer = training.Trainer(
updater, (args.epoch, 'epoch'), out=args.out)

@make_shift('lr')
def warmup_and_exponential_shift(trainer):
epoch = trainer.updater.epoch_detail
warmup_epoch = 5
if epoch < warmup_epoch:
if lr > 0.1:
warmup_rate = 0.1 / lr
rate = warmup_rate \
+ (1 - warmup_rate) * epoch / warmup_epoch
else:
rate = 1
elif epoch < 30:
rate = 1
elif epoch < 60:
rate = 0.1
elif epoch < 80:
rate = 0.01
else:
rate = 0.001
return rate * lr

trainer.extend(warmup_and_exponential_shift)
evaluator = chainermn.create_multi_node_evaluator(
extensions.Evaluator(val_iter, model, device=device), comm)
trainer.extend(evaluator, trigger=(1, 'epoch'))

log_interval = 0.1, 'epoch'
print_interval = 0.1, 'epoch'

if comm.rank == 0:
trainer.extend(chainer.training.extensions.observe_lr(),
trigger=log_interval)
trainer.extend(
extensions.snapshot_object(
extractor, 'snapshot_model_{.updater.epoch}.npz'),
trigger=(args.epoch, 'epoch'))
trainer.extend(extensions.LogReport(trigger=log_interval))
trainer.extend(extensions.PrintReport(
['iteration', 'epoch', 'elapsed_time', 'lr',
'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy']
), trigger=print_interval)
trainer.extend(extensions.ProgressBar(update_interval=10))

trainer.run()


if __name__ == '__main__':
main()

0 comments on commit 9d8a68e

Please sign in to comment.