Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[do NOT land][experiment] use local_map to annotate TritonFusedRMSNorm #363

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions test_fused_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
import torch.nn as nn

from torch.distributed._tensor import (
distribute_tensor,
init_device_mesh,
Replicate,
Shard,
)

from torchtitan.models.norms import create_norm, fused_rms_norm_fn

def get_device_type():
return (
"cuda"
if torch.cuda.is_available() and torch.cuda.device_count() >= 4
else "cpu"
)


world_size = 4
device_type = get_device_type()
device = torch.device(device_type)
mesh = init_device_mesh(device_type, (4,))
x = torch.randn(4, 4, 4, device=device) # Shard(1)
w = torch.randn(4, device=device, requires_grad=True) # Replicate

dx = distribute_tensor(x, mesh, [Shard(1)])
dw = distribute_tensor(w, mesh, [Replicate()])

# fused rmsnorm
out = fused_rms_norm_fn(dx, dw)
grad_out = torch.ones_like(out)
out.backward(grad_out)
print(grad_out)
13 changes: 13 additions & 0 deletions torchtitan/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
import triton
import triton.language as tl

from functools import partial

from torch.distributed._tensor.experimental import local_map
from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard
from torch.distributed._functional_collectives import AsyncCollectiveTensor

def create_norm(norm_type: str, dim: int, eps: float = 1e-6):
"""
Expand All @@ -29,6 +34,7 @@ def create_norm(norm_type: str, dim: int, eps: float = 1e-6):
Raises:
NotImplementedError: If an unknown norm_type is provided.
"""
print(f"create_norm: {norm_type}; dim={dim}")
norm_type = norm_type.lower() # Normalize to lowercase

if norm_type == "layernorm":
Expand Down Expand Up @@ -214,8 +220,11 @@ def _rms_norm_bwd_kernel_sm(


class TritonFusedRMSNorm(torch.autograd.Function):
@partial(local_map, out_placements=[Shard(1)], in_placements=(None, [Shard(1)], [Replicate()], None))
@staticmethod
def forward(ctx, x, weight, eps):
if isinstance(x, AsyncCollectiveTensor):
x = x.wait()
x_shape_start = x.shape

# Flatten input
Expand Down Expand Up @@ -256,8 +265,12 @@ def forward(ctx, x, weight, eps):
y = y.reshape(x_shape_start)
return y

@partial(local_map, out_placements=([Shard(1)], [_Partial()], None), in_placements=(None, [Shard(1)]))
@staticmethod
def backward(ctx, dy):
if isinstance(dy, AsyncCollectiveTensor):
dy = dy.wait()

x, weight, rstd = ctx.saved_tensors
eps = ctx.eps
x_shape_start = ctx.x_shape_start
Expand Down
5 changes: 0 additions & 5 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
raise NotImplementedError("PP not implemented yet.")

if parallel_dims.tp_enabled:
if job_config.model.norm_type == "fused_rmsnorm":
raise NotImplementedError(
"fused_rmsnorm not yet compatible with TP. Please use layernorm or rmsnorm."
)

tp_mesh = world_mesh["tp"]
row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy(
job_config
Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_degree = -1
tensor_parallel_degree = 1
tensor_parallel_degree = 2
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
Expand Down
Loading