Skip to content

Commit

Permalink
Refactor: Move to torchglyph.optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
speedcell4 committed Mar 3, 2024
1 parent fa9615f commit 6451fa3
Showing 1 changed file with 57 additions and 38 deletions.
95 changes: 57 additions & 38 deletions torchglyph/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from logging import getLogger
from typing import Set, Tuple, Type
from typing import Set, Tuple, Type, Union

from torch import nn, optim

logger = getLogger(__name__)

IGNORES = (
nn.Embedding, nn.EmbeddingBag,

nn.LayerNorm, nn.GroupNorm, nn.LocalResponseNorm,

nn.SyncBatchNorm,
Expand All @@ -17,59 +19,76 @@
)


def group_parameters(*modules: nn.Module, ignores: Tuple[Type[nn.Module], ...] = IGNORES):
memory = set()
with_decay = set()
without_decay = set()

def recur(mod: nn.Module):
if mod in memory:
return
def group_params(modules: Tuple[nn.Module, ...], ignores: Tuple[Type[nn.Module], ...] = IGNORES):
visited, require, without = set(), set(), set()

memory.add(mod)
def recur(module: nn.Module):
if module not in visited:
visited.add(module)

for name, param in mod.named_parameters(recurse=False):
if param.requires_grad:
if isinstance(mod, ignores) or 'bias' in name:
without_decay.add(param)
else:
with_decay.add(param)
for name, param in module.named_parameters(recurse=False):
if param.requires_grad:
if isinstance(module, ignores) or 'bias' in name:
without.add(param)
else:
require.add(param)

for m in mod._modules.values():
recur(mod=m)
for mod in module._modules.values():
recur(module=mod)

for module in modules:
recur(mod=module)
for m in modules:
recur(module=m)

return with_decay, without_decay
return require, without


def log_parameters(module: nn.Module, with_decay: Set[nn.Parameter], without_decay: Set[nn.Parameter]):
for name, param in module.named_parameters():
if not param.requires_grad:
logger.critical(f'{name} {tuple(param.size())} requires no grad')
elif param in with_decay:
logger.info(f'{name} {tuple(param.size())} with decay')
elif param in without_decay:
logger.info(f'{name} {tuple(param.size())} without decay')
else:
logger.error(f'{name} {tuple(param.size())} is not registered')
def log_params(*modules: nn.Module, require: Set[nn.Parameter], without: Set[nn.Parameter]):
for module in modules:
for name, param in module.named_parameters():
if not param.requires_grad:
logger.critical(f'{name} {tuple(param.size())} requires no grad')
elif param in require:
logger.info(f'{name} {tuple(param.size())} requires weight decay')
elif param in without:
logger.info(f'{name} {tuple(param.size())} requires no weight decay')
else:
logger.error(f'{name} {tuple(param.size())} is not registered')


class SGD(optim.SGD):
def __init__(self, lr: float = 1e-3, momentum: float = 0.9, dampening: float = 0.0,
weight_decay: float = 1e-4, nesterov: bool = False, *, params, **kwargs) -> None:
weight_decay: float = 1e-4, nesterov: bool = False, *,
modules: Tuple[nn.Module, ...], **kwargs) -> None:
require, without = group_params(modules)
log_params(*modules, require=require, without=without)

super(SGD, self).__init__(
params=params, lr=lr,
momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov,
params=[
{'params': list(require), 'weight_decay': weight_decay},
{'params': list(without), 'weight_decay': 0},
],
lr=lr, momentum=momentum,
dampening=dampening, nesterov=nesterov,
)


class Adam(optim.AdamW):
def __init__(self, lr: float = 3e-4, beta1: float = 0.9, beta2: float = 0.98,
weight_decay: float = 1e-4, amsgrad: bool = False, *, params, **kwargs) -> None:
weight_decay: float = 1e-4, amsgrad: bool = False, *,
modules: Tuple[nn.Module, ...], **kwargs) -> None:
require, without = group_params(modules)
log_params(*modules, require=require, without=without)

super(Adam, self).__init__(
params=params, lr=lr, betas=(beta1, beta2),
weight_decay=weight_decay, amsgrad=amsgrad,
params=[
{'params': list(require), 'weight_decay': weight_decay},
{'params': list(without), 'weight_decay': 0},
],
lr=lr, betas=(beta1, beta2), amsgrad=amsgrad,
)


Optimizer = Union[
Type[SGD],
Type[Adam],
]

0 comments on commit 6451fa3

Please sign in to comment.