diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 4d4c60bc..4692cbfa 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -343,6 +343,16 @@ def apply_fsdp( if cpu_offload: fsdp_config["offload_policy"] = CPUOffloadPolicy() + from torch.distributed._composable import replicate + + def cast_output_to_bf16(module: nn.Module, input: torch.Tensor, output: torch.Tensor): + return output.to(torch.bfloat16) + + for module_name, module in model.named_modules(): + if "norm" in module_name: + replicate(module, device_mesh=dp_mesh) + module.register_forward_hook(cast_output_to_bf16) + for layer_id, transformer_block in model.layers.items(): if pp_enabled: # For PP, do not reshard after forward to avoid per-microbatch diff --git a/train.py b/train.py index 9e8b1fa8..3d6d5405 100644 --- a/train.py +++ b/train.py @@ -11,6 +11,7 @@ import torch from torch.distributed.elastic.multiprocessing.errors import record +from torch.distributed.tensor.experimental import implicit_replication from torchtitan import utils from torchtitan.checkpoint import CheckpointManager, TrainState @@ -315,19 +316,21 @@ def loss_fn(pred, labels): loss.backward() # clip gradients - utils.clip_grad_norm_( - [p for m in model_parts for p in m.parameters()], - job_config.training.max_norm, - foreach=True, - pp_mesh=pp_mesh if parallel_dims.pp_enabled else None, - ) + with implicit_replication(): + utils.clip_grad_norm_( + [p for m in model_parts for p in m.parameters()], + job_config.training.max_norm, + foreach=True, + pp_mesh=pp_mesh if parallel_dims.pp_enabled else None, + ) # sync float8 amaxes and scales float8_handler.sync_float8_amax_and_scale_history(model_parts) # optimizer step checkpoint.maybe_wait_for_staging() - optimizers.step() + with implicit_replication(): + optimizers.step() lr_schedulers.step() # calculate float8 dynamic amax/scale for all-parameter for FSDP2