diff --git a/torchglyph/optimizer.py b/torchglyph/optimizer.py index 1dfcd20..95aa541 100644 --- a/torchglyph/optimizer.py +++ b/torchglyph/optimizer.py @@ -1,11 +1,13 @@ from logging import getLogger -from typing import Set, Tuple, Type +from typing import Set, Tuple, Type, Union from torch import nn, optim logger = getLogger(__name__) IGNORES = ( + nn.Embedding, nn.EmbeddingBag, + nn.LayerNorm, nn.GroupNorm, nn.LocalResponseNorm, nn.SyncBatchNorm, @@ -17,59 +19,76 @@ ) -def group_parameters(*modules: nn.Module, ignores: Tuple[Type[nn.Module], ...] = IGNORES): - memory = set() - with_decay = set() - without_decay = set() - - def recur(mod: nn.Module): - if mod in memory: - return +def group_params(modules: Tuple[nn.Module, ...], ignores: Tuple[Type[nn.Module], ...] = IGNORES): + visited, require, without = set(), set(), set() - memory.add(mod) + def recur(module: nn.Module): + if module not in visited: + visited.add(module) - for name, param in mod.named_parameters(recurse=False): - if param.requires_grad: - if isinstance(mod, ignores) or 'bias' in name: - without_decay.add(param) - else: - with_decay.add(param) + for name, param in module.named_parameters(recurse=False): + if param.requires_grad: + if isinstance(module, ignores) or 'bias' in name: + without.add(param) + else: + require.add(param) - for m in mod._modules.values(): - recur(mod=m) + for mod in module._modules.values(): + recur(module=mod) - for module in modules: - recur(mod=module) + for m in modules: + recur(module=m) - return with_decay, without_decay + return require, without -def log_parameters(module: nn.Module, with_decay: Set[nn.Parameter], without_decay: Set[nn.Parameter]): - for name, param in module.named_parameters(): - if not param.requires_grad: - logger.critical(f'{name} {tuple(param.size())} requires no grad') - elif param in with_decay: - logger.info(f'{name} {tuple(param.size())} with decay') - elif param in without_decay: - logger.info(f'{name} {tuple(param.size())} without decay') - else: - logger.error(f'{name} {tuple(param.size())} is not registered') +def log_params(*modules: nn.Module, require: Set[nn.Parameter], without: Set[nn.Parameter]): + for module in modules: + for name, param in module.named_parameters(): + if not param.requires_grad: + logger.critical(f'{name} {tuple(param.size())} requires no grad') + elif param in require: + logger.info(f'{name} {tuple(param.size())} requires weight decay') + elif param in without: + logger.info(f'{name} {tuple(param.size())} requires no weight decay') + else: + logger.error(f'{name} {tuple(param.size())} is not registered') class SGD(optim.SGD): def __init__(self, lr: float = 1e-3, momentum: float = 0.9, dampening: float = 0.0, - weight_decay: float = 1e-4, nesterov: bool = False, *, params, **kwargs) -> None: + weight_decay: float = 1e-4, nesterov: bool = False, *, + modules: Tuple[nn.Module, ...], **kwargs) -> None: + require, without = group_params(modules) + log_params(*modules, require=require, without=without) + super(SGD, self).__init__( - params=params, lr=lr, - momentum=momentum, dampening=dampening, - weight_decay=weight_decay, nesterov=nesterov, + params=[ + {'params': list(require), 'weight_decay': weight_decay}, + {'params': list(without), 'weight_decay': 0}, + ], + lr=lr, momentum=momentum, + dampening=dampening, nesterov=nesterov, ) class Adam(optim.AdamW): def __init__(self, lr: float = 3e-4, beta1: float = 0.9, beta2: float = 0.98, - weight_decay: float = 1e-4, amsgrad: bool = False, *, params, **kwargs) -> None: + weight_decay: float = 1e-4, amsgrad: bool = False, *, + modules: Tuple[nn.Module, ...], **kwargs) -> None: + require, without = group_params(modules) + log_params(*modules, require=require, without=without) + super(Adam, self).__init__( - params=params, lr=lr, betas=(beta1, beta2), - weight_decay=weight_decay, amsgrad=amsgrad, + params=[ + {'params': list(require), 'weight_decay': weight_decay}, + {'params': list(without), 'weight_decay': 0}, + ], + lr=lr, betas=(beta1, beta2), amsgrad=amsgrad, ) + + +Optimizer = Union[ + Type[SGD], + Type[Adam], +]