From fa9615f6a92bbeedd1f68c61b287337043e17a8d Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Mar 2024 15:30:53 +0900 Subject: [PATCH] Refactor: Update optimizer.py --- torchglyph/optimizer.py | 43 ++++++++++++----------------------------- 1 file changed, 12 insertions(+), 31 deletions(-) diff --git a/torchglyph/optimizer.py b/torchglyph/optimizer.py index 2e74b8a..1dfcd20 100644 --- a/torchglyph/optimizer.py +++ b/torchglyph/optimizer.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Set, Tuple, Type, Union +from typing import Set, Tuple, Type from torch import nn, optim @@ -47,48 +47,29 @@ def recur(mod: nn.Module): def log_parameters(module: nn.Module, with_decay: Set[nn.Parameter], without_decay: Set[nn.Parameter]): for name, param in module.named_parameters(): if not param.requires_grad: - logger.critical(f'{name} requires no grad') + logger.critical(f'{name} {tuple(param.size())} requires no grad') elif param in with_decay: - logger.info(f'{name} with decay') + logger.info(f'{name} {tuple(param.size())} with decay') elif param in without_decay: - logger.info(f'{name} without decay') + logger.info(f'{name} {tuple(param.size())} without decay') else: - logger.error(f'{name} is not registered') + logger.error(f'{name} {tuple(param.size())} is not registered') class SGD(optim.SGD): def __init__(self, lr: float = 1e-3, momentum: float = 0.9, dampening: float = 0.0, - weight_decay: float = 1e-4, nesterov: bool = False, *modules: nn.Module, **kwargs) -> None: - with_decay, without_decay = group_parameters(*modules) - for module in modules: - log_parameters(module, with_decay=with_decay, without_decay=without_decay) - + weight_decay: float = 1e-4, nesterov: bool = False, *, params, **kwargs) -> None: super(SGD, self).__init__( - lr=lr, momentum=momentum, dampening=dampening, nesterov=nesterov, - params=[ - {'params': list(with_decay), 'weight_decay': weight_decay}, - {'params': list(without_decay), 'weight_decay': 0.0}, - ], + params=params, lr=lr, + momentum=momentum, dampening=dampening, + weight_decay=weight_decay, nesterov=nesterov, ) class Adam(optim.AdamW): def __init__(self, lr: float = 3e-4, beta1: float = 0.9, beta2: float = 0.98, - weight_decay: float = 1e-4, amsgrad: bool = False, *modules: nn.Module, **kwargs) -> None: - with_decay, without_decay = group_parameters(*modules) - for module in modules: - log_parameters(module, with_decay=with_decay, without_decay=without_decay) - + weight_decay: float = 1e-4, amsgrad: bool = False, *, params, **kwargs) -> None: super(Adam, self).__init__( - lr=lr, betas=(beta1, beta2), amsgrad=amsgrad, - params=[ - {'params': list(with_decay), 'weight_decay': weight_decay}, - {'params': list(without_decay), 'weight_decay': 0.0}, - ], + params=params, lr=lr, betas=(beta1, beta2), + weight_decay=weight_decay, amsgrad=amsgrad, ) - - -Optimizer = Union[ - Type[SGD], - Type[Adam], -]