Skip to content

Commit

Permalink
Refactor: Update optimizer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
speedcell4 committed Mar 3, 2024
1 parent 70e432d commit fa9615f
Showing 1 changed file with 12 additions and 31 deletions.
43 changes: 12 additions & 31 deletions torchglyph/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from logging import getLogger
from typing import Set, Tuple, Type, Union
from typing import Set, Tuple, Type

from torch import nn, optim

Expand Down Expand Up @@ -47,48 +47,29 @@ def recur(mod: nn.Module):
def log_parameters(module: nn.Module, with_decay: Set[nn.Parameter], without_decay: Set[nn.Parameter]):
for name, param in module.named_parameters():
if not param.requires_grad:
logger.critical(f'{name} requires no grad')
logger.critical(f'{name} {tuple(param.size())} requires no grad')
elif param in with_decay:
logger.info(f'{name} with decay')
logger.info(f'{name} {tuple(param.size())} with decay')
elif param in without_decay:
logger.info(f'{name} without decay')
logger.info(f'{name} {tuple(param.size())} without decay')
else:
logger.error(f'{name} is not registered')
logger.error(f'{name} {tuple(param.size())} is not registered')


class SGD(optim.SGD):
def __init__(self, lr: float = 1e-3, momentum: float = 0.9, dampening: float = 0.0,
weight_decay: float = 1e-4, nesterov: bool = False, *modules: nn.Module, **kwargs) -> None:
with_decay, without_decay = group_parameters(*modules)
for module in modules:
log_parameters(module, with_decay=with_decay, without_decay=without_decay)

weight_decay: float = 1e-4, nesterov: bool = False, *, params, **kwargs) -> None:
super(SGD, self).__init__(
lr=lr, momentum=momentum, dampening=dampening, nesterov=nesterov,
params=[
{'params': list(with_decay), 'weight_decay': weight_decay},
{'params': list(without_decay), 'weight_decay': 0.0},
],
params=params, lr=lr,
momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov,
)


class Adam(optim.AdamW):
def __init__(self, lr: float = 3e-4, beta1: float = 0.9, beta2: float = 0.98,
weight_decay: float = 1e-4, amsgrad: bool = False, *modules: nn.Module, **kwargs) -> None:
with_decay, without_decay = group_parameters(*modules)
for module in modules:
log_parameters(module, with_decay=with_decay, without_decay=without_decay)

weight_decay: float = 1e-4, amsgrad: bool = False, *, params, **kwargs) -> None:
super(Adam, self).__init__(
lr=lr, betas=(beta1, beta2), amsgrad=amsgrad,
params=[
{'params': list(with_decay), 'weight_decay': weight_decay},
{'params': list(without_decay), 'weight_decay': 0.0},
],
params=params, lr=lr, betas=(beta1, beta2),
weight_decay=weight_decay, amsgrad=amsgrad,
)


Optimizer = Union[
Type[SGD],
Type[Adam],
]

0 comments on commit fa9615f

Please sign in to comment.