diff --git a/test_fused_rms_norm.py b/test_fused_rms_norm.py new file mode 100644 index 00000000..39f89b5a --- /dev/null +++ b/test_fused_rms_norm.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn + +from torch.distributed._tensor import ( + distribute_tensor, + init_device_mesh, + Replicate, + Shard, +) + +from torchtitan.models.norms import create_norm, fused_rms_norm_fn + +def get_device_type(): + return ( + "cuda" + if torch.cuda.is_available() and torch.cuda.device_count() >= 4 + else "cpu" + ) + + +world_size = 4 +device_type = get_device_type() +device = torch.device(device_type) +mesh = init_device_mesh(device_type, (4,)) +x = torch.randn(4, 4, 4, device=device) # Shard(1) +w = torch.randn(4, device=device, requires_grad=True) # Replicate + +dx = distribute_tensor(x, mesh, [Shard(1)]) +dw = distribute_tensor(w, mesh, [Replicate()]) + +# fused rmsnorm +out = fused_rms_norm_fn(dx, dw) +grad_out = torch.ones_like(out) +out.backward(grad_out) +print(grad_out) diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index e29338d9..ad73d60a 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -12,6 +12,11 @@ import triton import triton.language as tl +from functools import partial + +from torch.distributed._tensor.experimental import local_map +from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard +from torch.distributed._functional_collectives import AsyncCollectiveTensor def create_norm(norm_type: str, dim: int, eps: float = 1e-6): """ @@ -29,6 +34,7 @@ def create_norm(norm_type: str, dim: int, eps: float = 1e-6): Raises: NotImplementedError: If an unknown norm_type is provided. """ + print(f"create_norm: {norm_type}; dim={dim}") norm_type = norm_type.lower() # Normalize to lowercase if norm_type == "layernorm": @@ -214,8 +220,11 @@ def _rms_norm_bwd_kernel_sm( class TritonFusedRMSNorm(torch.autograd.Function): + @partial(local_map, out_placements=[Shard(1)], in_placements=(None, [Shard(1)], [Replicate()], None)) @staticmethod def forward(ctx, x, weight, eps): + if isinstance(x, AsyncCollectiveTensor): + x = x.wait() x_shape_start = x.shape # Flatten input @@ -256,8 +265,12 @@ def forward(ctx, x, weight, eps): y = y.reshape(x_shape_start) return y + @partial(local_map, out_placements=([Shard(1)], [_Partial()], None), in_placements=(None, [Shard(1)])) @staticmethod def backward(ctx, dy): + if isinstance(dy, AsyncCollectiveTensor): + dy = dy.wait() + x, weight, rstd = ctx.saved_tensors eps = ctx.eps x_shape_start = ctx.x_shape_start diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 9c8d0a29..3ce9d779 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -140,11 +140,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): raise NotImplementedError("PP not implemented yet.") if parallel_dims.tp_enabled: - if job_config.model.norm_type == "fused_rmsnorm": - raise NotImplementedError( - "fused_rmsnorm not yet compatible with TP. Please use layernorm or rmsnorm." - ) - tp_mesh = world_mesh["tp"] row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy( job_config diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 4541fec7..c5494b37 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -35,7 +35,7 @@ warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 -tensor_parallel_degree = 1 +tensor_parallel_degree = 2 pipeline_parallel_degree = 1 fp8_linear = "" compile = false