From 9a637d84d90204f9dc971364fb0379518f578d80 Mon Sep 17 00:00:00 2001 From: Yusuke Niitani Date: Thu, 24 May 2018 10:03:26 +0900 Subject: [PATCH] move corrected_momentum_sgd to chainer_experimental --- chainercv/chainer_experimental/__init__.py | 1 + .../optimizers/__init__.py | 1 + .../optimizers/corrected_momentum_sgd.py | 62 +++++++++++++++++++ examples/classification/README.md | 5 -- examples/classification/train_imagenet_mn.py | 4 +- 5 files changed, 65 insertions(+), 8 deletions(-) create mode 100644 chainercv/chainer_experimental/optimizers/__init__.py create mode 100644 chainercv/chainer_experimental/optimizers/corrected_momentum_sgd.py diff --git a/chainercv/chainer_experimental/__init__.py b/chainercv/chainer_experimental/__init__.py index b91a45fb64..f0227014bd 100644 --- a/chainercv/chainer_experimental/__init__.py +++ b/chainercv/chainer_experimental/__init__.py @@ -1 +1,2 @@ from chainercv.chainer_experimental import datasets # NOQA +from chainercv.chainer_experimental import optimizers # NOQA diff --git a/chainercv/chainer_experimental/optimizers/__init__.py b/chainercv/chainer_experimental/optimizers/__init__.py new file mode 100644 index 0000000000..0ce7d43a78 --- /dev/null +++ b/chainercv/chainer_experimental/optimizers/__init__.py @@ -0,0 +1 @@ +from chainercv.chainer_experimental.optimizers.corrected_momentum_sgd import CorrectedMomentumSGD # NOQA diff --git a/chainercv/chainer_experimental/optimizers/corrected_momentum_sgd.py b/chainercv/chainer_experimental/optimizers/corrected_momentum_sgd.py new file mode 100644 index 0000000000..a2c1a66d2c --- /dev/null +++ b/chainercv/chainer_experimental/optimizers/corrected_momentum_sgd.py @@ -0,0 +1,62 @@ +from chainer import cuda +from chainer import optimizer + + +_default_hyperparam = optimizer.Hyperparameter() +_default_hyperparam.lr = 0.01 +_default_hyperparam.momentum = 0.9 + + +class CorrectedMomentumSGDRule(optimizer.UpdateRule): + + # use update rule used in frameworks like Torch. + + def __init__(self, parent_hyperparam=None, lr=None, momentum=None): + super(CorrectedMomentumSGDRule, self).__init__( + parent_hyperparam or _default_hyperparam) + if lr is not None: + self.hyperparam.lr = lr + if momentum is not None: + self.hyperparam.momentum = momentum + + def init_state(self, param): + xp = cuda.get_array_module(param.data) + with cuda.get_device_from_array(param.data): + self.state['v'] = xp.zeros_like(param.data) + + def update_core_cpu(self, param): + grad = param.grad + if grad is None: + return + v = self.state['v'] + v *= self.hyperparam.momentum + v -= self.hyperparam.lr * grad + param.data += v + + def update_core_gpu(self, param): + grad = param.grad + if grad is None: + return + cuda.elementwise( + 'T grad, T lr, T momentum', + 'T param, T v', + '''v = momentum * v - grad; + param += lr * v;''', + 'momentum_sgd')( + grad, self.hyperparam.lr, self.hyperparam.momentum, + param.data, self.state['v']) + + +class CorrectedMomentumSGD(optimizer.GradientMethod): + + def __init__(self, lr=_default_hyperparam.lr, + momentum=_default_hyperparam.momentum): + super(CorrectedMomentumSGD, self).__init__() + self.hyperparam.lr = lr + self.hyperparam.momentum = momentum + + lr = optimizer.HyperparameterProxy('lr') + momentum = optimizer.HyperparameterProxy('momentum') + + def create_update_rule(self): + return CorrectedMomentumSGDRule(self.hyperparam) diff --git a/examples/classification/README.md b/examples/classification/README.md index 67d6f6227b..fa0350dddd 100644 --- a/examples/classification/README.md +++ b/examples/classification/README.md @@ -31,11 +31,6 @@ $ python eval_imagenet.py [--model vgg16|resnet50|resnet10 ## Training Models -Training with single GPU. -``` -$ python train_imagenet.py [--gpu ] -``` - Training with multiple GPUs. Please install ChainerMN to use this feature. ``` $ mpiexec -n N python train_imagenet_mn.py diff --git a/examples/classification/train_imagenet_mn.py b/examples/classification/train_imagenet_mn.py index a8408d6f3d..6ae388bfdb 100644 --- a/examples/classification/train_imagenet_mn.py +++ b/examples/classification/train_imagenet_mn.py @@ -21,14 +21,13 @@ from chainercv.datasets import directory_parsing_label_names +from chainercv.chainer_experimental.optimizers import CorrectedMomentumSGD from chainercv.links import ResNet101 from chainercv.links import ResNet152 from chainercv.links import ResNet50 import chainermn -from corrected_momentum_sgd import CorrectedMomentumSGD - class TrainTransform(object): @@ -78,7 +77,6 @@ def main(): parser.add_argument('--communicator', type=str, default='hierarchical', help='Type of communicator') parser.add_argument('--pretrained_model') - # parser.add_argument('--gpu', type=int, default=-1) parser.add_argument('--loaderjob', type=int, default=4) parser.add_argument('--batchsize', type=int, default=32, help='Batch size for each worker')