Skip to content

Commit

Permalink
flatten optimizerwrapper in checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
mori360 committed Dec 4, 2024
1 parent 9d5d113 commit 2851629
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 40 deletions.
2 changes: 1 addition & 1 deletion test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def build_test_list():
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--training.enable_cpu_offload True",
"--training.enable_optimizer_in_backward True",
"--optimizer.backward True",
],
],
"Enable CPU Offload with PP",
Expand Down
46 changes: 11 additions & 35 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,15 @@ def __init__(
self,
model: Union[nn.Module, List[nn.Module]],
optim: Union[torch.optim.Optimizer, List[torch.optim.Optimizer]],
optim_in_bwd: bool = False,
) -> None:
self.model = [model] if isinstance(model, nn.Module) else model
self.optim = [optim] if isinstance(optim, torch.optim.Optimizer) else optim
if not optim_in_bwd:
self.optim = [optim] if isinstance(optim, torch.optim.Optimizer) else optim
else:
self.optim = [
sub_optim for optim_group in optim for sub_optim in optim_group
]

def state_dict(self) -> Dict[str, Any]:
func = functools.partial(
Expand All @@ -123,30 +129,6 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
list(map(func, self.model, self.optim))


class OptimizerInBackwardWrapper(OptimizerWrapper):
def state_dict(self) -> Dict[str, Any]:
func = functools.partial(
get_optimizer_state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
state_dict = {}
for optim in self.optim:
for sub_opt in optim:
for sd in map(func, self.model, (sub_opt,)):
state_dict.update(sd)
return state_dict

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
func = functools.partial(
set_optimizer_state_dict,
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
for optim in self.optim:
for sub_opt in optim:
list(map(func, self.model, (sub_opt,)))


class Terminate:
pass

Expand Down Expand Up @@ -240,16 +222,10 @@ def __init__(
self.states.update(
{
"model": ModelWrapper(model_parts),
"optimizer": (
OptimizerWrapper(
model_parts,
optimizers,
)
if not job_config.optimizer.backward
else OptimizerInBackwardWrapper(
model_parts,
optimizers,
)
"optimizer": OptimizerWrapper(
model_parts,
optimizers,
job_config.optimizer.backward,
),
"dataloader": dataloader,
}
Expand Down
5 changes: 1 addition & 4 deletions torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,7 @@ def step(self):
schedulers.step()

class SchedulersInBackwardContainer(SchedulersContainer):
"""Util for calling step on multiple learning rate schedulers needed for virtual pipeline stages"""

def __init__(self, schedulers):
self.schedulers = schedulers
"""Util for calling step on multiple learning rate schedulers when optimizers are in backward"""

def step(self):
for schedulers in self.schedulers:
Expand Down

0 comments on commit 2851629

Please sign in to comment.