diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index b15a8519..dc312f71 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -14,7 +14,6 @@ import triton import triton.language as tl -from torch.distributed._functional_collectives import AsyncCollectiveTensor from torch.distributed._tensor.experimental import local_map from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard @@ -227,9 +226,6 @@ class TritonFusedRMSNorm(torch.autograd.Function): ) @staticmethod def forward(ctx, x, weight, eps): - if isinstance(x, AsyncCollectiveTensor): - x = x.wait() - x_shape_start = x.shape # Flatten input @@ -277,9 +273,6 @@ def forward(ctx, x, weight, eps): ) @staticmethod def backward(ctx, dy): - if isinstance(dy, AsyncCollectiveTensor): - dy = dy.wait() - x, weight, rstd = ctx.saved_tensors eps = ctx.eps x_shape_start = ctx.x_shape_start