Skip to content

Commit

Permalink
change config name
Browse files Browse the repository at this point in the history
  • Loading branch information
mori360 committed Dec 4, 2024
1 parent 8445547 commit 9d5d113
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def __init__(
model_parts,
optimizers,
)
if not job_config.training.enable_optimizer_in_backward
if not job_config.optimizer.backward
else OptimizerInBackwardWrapper(
model_parts,
optimizers,
Expand Down
16 changes: 8 additions & 8 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,14 @@ def __init__(self):
action="store_true",
help="Whether the fused implementation(CUDA only) is used.",
)
self.parser.add_argument(
"--optimizer.backward",
type=bool,
default=False,
help="""
Whether to apply optimizer in the backward. Caution, optimizer_in_backward
is not compatible with gradients clipping.""",
)

# training configs
self.parser.add_argument(
Expand Down Expand Up @@ -270,14 +278,6 @@ def __init__(self):
action="store_true",
help="Whether to apply loss parallel when sequence parallel is enabled",
)
self.parser.add_argument(
"--training.enable_optimizer_in_backward",
type=bool,
default=False,
help="""
Whether to apply optimizer in the backward. Caution, optimizer_in_backward
cannot compile with gradients clipping.""",
)
self.parser.add_argument(
"--experimental.enable_async_tensor_parallel",
default=False,
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def build_optimizers(model_parts, job_config: JobConfig):
"""Wrap one optimizer per model part in an OptimizersContainer which provides a single
step() and zero_grad() method for all the child optimizers.
"""
optim_in_bwd = job_config.training.enable_optimizer_in_backward
optim_in_bwd = job_config.optimizer.backward

def _build_optimizer(model):
name = job_config.optimizer.name
Expand Down Expand Up @@ -135,7 +135,7 @@ def linear_warmup_linear_decay(


def build_lr_schedulers(optimizers, job_config: JobConfig):
optim_in_bwd = job_config.training.enable_optimizer_in_backward
optim_in_bwd = job_config.optimizer.backward

def _build_lr_scheduler(optimizer):
"""Build a linear warmup and linear decay scheduler"""
Expand Down

0 comments on commit 9d5d113

Please sign in to comment.