Skip to content

Commit

Permalink
[Not for land] Show replicated fp32 norm weights
Browse files Browse the repository at this point in the history
ghstack-source-id: 70d400b56a0c6121b33a9b0567628b2ab7ae23d0
Pull Request resolved: #717
  • Loading branch information
awgu committed Dec 4, 2024
1 parent 3e3909a commit b39911d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
10 changes: 10 additions & 0 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,16 @@ def apply_fsdp(
if cpu_offload:
fsdp_config["offload_policy"] = CPUOffloadPolicy()

from torch.distributed._composable import replicate

def cast_output_to_bf16(module: nn.Module, input: torch.Tensor, output: torch.Tensor):
return output.to(torch.bfloat16)

for module_name, module in model.named_modules():
if "norm" in module_name:
replicate(module, device_mesh=dp_mesh)
module.register_forward_hook(cast_output_to_bf16)

for layer_id, transformer_block in model.layers.items():
if pp_enabled:
# For PP, do not reshard after forward to avoid per-microbatch
Expand Down
17 changes: 10 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch

from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.tensor.experimental import implicit_replication

from torchtitan import utils
from torchtitan.checkpoint import CheckpointManager, TrainState
Expand Down Expand Up @@ -315,19 +316,21 @@ def loss_fn(pred, labels):
loss.backward()

# clip gradients
utils.clip_grad_norm_(
[p for m in model_parts for p in m.parameters()],
job_config.training.max_norm,
foreach=True,
pp_mesh=pp_mesh if parallel_dims.pp_enabled else None,
)
with implicit_replication():
utils.clip_grad_norm_(
[p for m in model_parts for p in m.parameters()],
job_config.training.max_norm,
foreach=True,
pp_mesh=pp_mesh if parallel_dims.pp_enabled else None,
)

# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model_parts)

# optimizer step
checkpoint.maybe_wait_for_staging()
optimizers.step()
with implicit_replication():
optimizers.step()
lr_schedulers.step()

# calculate float8 dynamic amax/scale for all-parameter for FSDP2
Expand Down

0 comments on commit b39911d

Please sign in to comment.