This repository has been archived by the owner on Jul 2, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 303
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
move corrected_momentum_sgd to chainer_experimental
- Loading branch information
Showing
5 changed files
with
65 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from chainercv.chainer_experimental import datasets # NOQA | ||
from chainercv.chainer_experimental import optimizers # NOQA |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from chainercv.chainer_experimental.optimizers.corrected_momentum_sgd import CorrectedMomentumSGD # NOQA |
62 changes: 62 additions & 0 deletions
62
chainercv/chainer_experimental/optimizers/corrected_momentum_sgd.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from chainer import cuda | ||
from chainer import optimizer | ||
|
||
|
||
_default_hyperparam = optimizer.Hyperparameter() | ||
_default_hyperparam.lr = 0.01 | ||
_default_hyperparam.momentum = 0.9 | ||
|
||
|
||
class CorrectedMomentumSGDRule(optimizer.UpdateRule): | ||
|
||
# use update rule used in frameworks like Torch. | ||
|
||
def __init__(self, parent_hyperparam=None, lr=None, momentum=None): | ||
super(CorrectedMomentumSGDRule, self).__init__( | ||
parent_hyperparam or _default_hyperparam) | ||
if lr is not None: | ||
self.hyperparam.lr = lr | ||
if momentum is not None: | ||
self.hyperparam.momentum = momentum | ||
|
||
def init_state(self, param): | ||
xp = cuda.get_array_module(param.data) | ||
with cuda.get_device_from_array(param.data): | ||
self.state['v'] = xp.zeros_like(param.data) | ||
|
||
def update_core_cpu(self, param): | ||
grad = param.grad | ||
if grad is None: | ||
return | ||
v = self.state['v'] | ||
v *= self.hyperparam.momentum | ||
v -= self.hyperparam.lr * grad | ||
param.data += v | ||
|
||
def update_core_gpu(self, param): | ||
grad = param.grad | ||
if grad is None: | ||
return | ||
cuda.elementwise( | ||
'T grad, T lr, T momentum', | ||
'T param, T v', | ||
'''v = momentum * v - grad; | ||
param += lr * v;''', | ||
'momentum_sgd')( | ||
grad, self.hyperparam.lr, self.hyperparam.momentum, | ||
param.data, self.state['v']) | ||
|
||
|
||
class CorrectedMomentumSGD(optimizer.GradientMethod): | ||
|
||
def __init__(self, lr=_default_hyperparam.lr, | ||
momentum=_default_hyperparam.momentum): | ||
super(CorrectedMomentumSGD, self).__init__() | ||
self.hyperparam.lr = lr | ||
self.hyperparam.momentum = momentum | ||
|
||
lr = optimizer.HyperparameterProxy('lr') | ||
momentum = optimizer.HyperparameterProxy('momentum') | ||
|
||
def create_update_rule(self): | ||
return CorrectedMomentumSGDRule(self.hyperparam) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters