Skip to content

Commit

Permalink
Update on "enable TritonFusedRMSNorm with local_map annotation"
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
XilunWu committed Jun 4, 2024
1 parent aa5af1b commit 71659de
Showing 1 changed file with 0 additions and 7 deletions.
7 changes: 0 additions & 7 deletions torchtitan/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import triton
import triton.language as tl

from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed._tensor.experimental import local_map
from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard

Expand Down Expand Up @@ -227,9 +226,6 @@ class TritonFusedRMSNorm(torch.autograd.Function):
)
@staticmethod
def forward(ctx, x, weight, eps):
if isinstance(x, AsyncCollectiveTensor):
x = x.wait()

x_shape_start = x.shape

# Flatten input
Expand Down Expand Up @@ -277,9 +273,6 @@ def forward(ctx, x, weight, eps):
)
@staticmethod
def backward(ctx, dy):
if isinstance(dy, AsyncCollectiveTensor):
dy = dy.wait()

x, weight, rstd = ctx.saved_tensors
eps = ctx.eps
x_shape_start = ctx.x_shape_start
Expand Down

0 comments on commit 71659de

Please sign in to comment.