Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
move corrected_momentum_sgd to chainer_experimental
Browse files Browse the repository at this point in the history
  • Loading branch information
yuyu2172 committed May 24, 2018
1 parent d357af3 commit 9a637d8
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 8 deletions.
1 change: 1 addition & 0 deletions chainercv/chainer_experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from chainercv.chainer_experimental import datasets # NOQA
from chainercv.chainer_experimental import optimizers # NOQA
1 change: 1 addition & 0 deletions chainercv/chainer_experimental/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from chainercv.chainer_experimental.optimizers.corrected_momentum_sgd import CorrectedMomentumSGD # NOQA
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from chainer import cuda
from chainer import optimizer


_default_hyperparam = optimizer.Hyperparameter()
_default_hyperparam.lr = 0.01
_default_hyperparam.momentum = 0.9


class CorrectedMomentumSGDRule(optimizer.UpdateRule):

# use update rule used in frameworks like Torch.

def __init__(self, parent_hyperparam=None, lr=None, momentum=None):
super(CorrectedMomentumSGDRule, self).__init__(
parent_hyperparam or _default_hyperparam)
if lr is not None:
self.hyperparam.lr = lr
if momentum is not None:
self.hyperparam.momentum = momentum

def init_state(self, param):
xp = cuda.get_array_module(param.data)
with cuda.get_device_from_array(param.data):
self.state['v'] = xp.zeros_like(param.data)

def update_core_cpu(self, param):
grad = param.grad
if grad is None:
return
v = self.state['v']
v *= self.hyperparam.momentum
v -= self.hyperparam.lr * grad
param.data += v

def update_core_gpu(self, param):
grad = param.grad
if grad is None:
return
cuda.elementwise(
'T grad, T lr, T momentum',
'T param, T v',
'''v = momentum * v - grad;
param += lr * v;''',
'momentum_sgd')(
grad, self.hyperparam.lr, self.hyperparam.momentum,
param.data, self.state['v'])


class CorrectedMomentumSGD(optimizer.GradientMethod):

def __init__(self, lr=_default_hyperparam.lr,
momentum=_default_hyperparam.momentum):
super(CorrectedMomentumSGD, self).__init__()
self.hyperparam.lr = lr
self.hyperparam.momentum = momentum

lr = optimizer.HyperparameterProxy('lr')
momentum = optimizer.HyperparameterProxy('momentum')

def create_update_rule(self):
return CorrectedMomentumSGDRule(self.hyperparam)
5 changes: 0 additions & 5 deletions examples/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@ $ python eval_imagenet.py <path_to_val_dataset> [--model vgg16|resnet50|resnet10

## Training Models

Training with single GPU.
```
$ python train_imagenet.py <path_to_train_dataset> <path_to_val_dataset> [--gpu <gpu>]
```

Training with multiple GPUs. Please install ChainerMN to use this feature.
```
$ mpiexec -n N python train_imagenet_mn.py <path_to_train_dataset> <path_to_val_dataset>
Expand Down
4 changes: 1 addition & 3 deletions examples/classification/train_imagenet_mn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@

from chainercv.datasets import directory_parsing_label_names

from chainercv.chainer_experimental.optimizers import CorrectedMomentumSGD
from chainercv.links import ResNet101
from chainercv.links import ResNet152
from chainercv.links import ResNet50

import chainermn

from corrected_momentum_sgd import CorrectedMomentumSGD


class TrainTransform(object):

Expand Down Expand Up @@ -78,7 +77,6 @@ def main():
parser.add_argument('--communicator', type=str,
default='hierarchical', help='Type of communicator')
parser.add_argument('--pretrained_model')
# parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--loaderjob', type=int, default=4)
parser.add_argument('--batchsize', type=int, default=32,
help='Batch size for each worker')
Expand Down

0 comments on commit 9a637d8

Please sign in to comment.