Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Not for land] Show replicated fp32 norm weights #717

Draft
wants to merge 1 commit into
base: gh/awgu/23/base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,16 @@ def apply_fsdp(
if cpu_offload:
fsdp_config["offload_policy"] = CPUOffloadPolicy()

from torch.distributed._composable import replicate

def cast_output_to_bf16(module: nn.Module, input: torch.Tensor, output: torch.Tensor):
return output.to(torch.bfloat16)

for module_name, module in model.named_modules():
if "norm" in module_name:
replicate(module, device_mesh=dp_mesh)
module.register_forward_hook(cast_output_to_bf16)

for layer_id, transformer_block in model.layers.items():
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
Expand Down
17 changes: 10 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch

from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.tensor.experimental import implicit_replication

from torchtitan import utils
from torchtitan.checkpoint import CheckpointManager, TrainState
Expand Down Expand Up @@ -315,19 +316,21 @@ def loss_fn(pred, labels):
loss.backward()

# clip gradients
utils.clip_grad_norm_(
[p for m in model_parts for p in m.parameters()],
job_config.training.max_norm,
foreach=True,
pp_mesh=pp_mesh if parallel_dims.pp_enabled else None,
)
with implicit_replication():
utils.clip_grad_norm_(
[p for m in model_parts for p in m.parameters()],
job_config.training.max_norm,
foreach=True,
pp_mesh=pp_mesh if parallel_dims.pp_enabled else None,
)

# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model_parts)

# optimizer step
checkpoint.maybe_wait_for_staging()
optimizers.step()
with implicit_replication():
optimizers.step()
lr_schedulers.step()

# calculate float8 dynamic amax/scale for all-parameter for FSDP2
Expand Down
Loading