Skip to content

Commit

Permalink
Refactor: Move to torchglyph.scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
speedcell4 committed Mar 3, 2024
1 parent ec54db8 commit 048b46b
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions torchglyph/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,29 @@


class LambdaLR(lr_scheduler.LambdaLR):
def __init__(self, num_warmup_steps: int = 20 if DEBUG else 4000,
num_training_steps: int = 20 if DEBUG else 200000, *,
def __init__(self, num_training_steps: int = 20 if DEBUG else 5_0000,
num_warmup_steps: int = 20 if DEBUG else 3000, *,
optimizer: Optimizer, last_epoch: int = -1, **kwargs) -> None:
self.num_warmup_steps = num_warmup_steps
self.num_training_steps = num_training_steps
self.num_warmup_steps = num_warmup_steps

super(LambdaLR, self).__init__(
optimizer=optimizer, last_epoch=last_epoch,
lr_lambda=self.lr_lambda,
optimizer=optimizer, lr_lambda=self.lr_lambda,
last_epoch=last_epoch,
)

def lr_lambda(self, step: int) -> float:
raise NotImplementedError

def __repr__(self) -> str:
return f'{self.__class__.__name__}({self.extra_repr()})'

def extra_repr(self) -> str:
return ', '.join([
f'num_warmup_steps={self.num_warmup_steps}',
f'num_training_steps={self.num_training_steps}',
f'num_warmup_steps={self.num_warmup_steps}',
])

def lr_lambda(self, step: int) -> float:
raise NotImplementedError


class ConstantScheduler(LambdaLR):
def lr_lambda(self, step: int) -> float:
Expand All @@ -46,15 +46,15 @@ def lr_lambda(self, step: int) -> float:
if step < self.num_warmup_steps:
return float(step / max(1.0, self.num_warmup_steps))

return max(0., self.num_training_steps - step) / max(1.0, self.num_training_steps - self.num_warmup_steps)
return max(0., (self.num_training_steps - step) / max(1.0, self.num_training_steps - self.num_warmup_steps))


class InverseSquareRootScheduler(LambdaLR):
def lr_lambda(self, step: int) -> float:
if step < self.num_warmup_steps:
return float(step / max(1.0, self.num_warmup_steps))

return max(0., (self.num_warmup_steps / step) ** 0.5)
return max(0., (self.num_warmup_steps / max(1.0, step)) ** 0.5)


Scheduler = Union[
Expand Down

0 comments on commit 048b46b

Please sign in to comment.