Skip to content

Commit

Permalink
[STANDARD] Fix inf handling in tl.flip (#5447)
Browse files Browse the repository at this point in the history
Fixes #5439

Currently we end up doing `0 * inf = nan`, the fix is to bitcast to int
first where `x * 0 == 0` holds.
  • Loading branch information
peterbell10 authored Dec 17, 2024
1 parent e57b468 commit a52c88a
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 5 deletions.
23 changes: 23 additions & 0 deletions python/test/unit/language/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,29 @@ def flip_kernel(X, Z, N: tl.constexpr, M: tl.constexpr):
assert (y == z).all(), (y, z)


@pytest.mark.interpreter
def test_flip_inf(device):
# Reproducer for https://github.com/triton-lang/triton/issues/5439

@triton.jit
def triton_flip_kernel(out_ptr, x_ptr, N: tl.constexpr):
pid = tl.program_id(0)
x = tl.load(x_ptr + pid * N + tl.arange(0, N))
shape: tl.constexpr = (N // 2, 2)
y = x.reshape(shape)
y = tl.flip(y, dim=1).reshape(x.shape)
tl.store(out_ptr + pid * N + tl.arange(0, N), y)

x = torch.arange(0, 16, device=device).unsqueeze(0).float()
x[:, -1] = float('inf')

expect = x.reshape(-1, 8, 2).flip(-1).reshape(-1, 16)
actual = torch.empty_like(x)
triton_flip_kernel[(x.shape[0], )](actual, x, x.shape[1])

torch.testing.assert_close(expect, actual)


@pytest.mark.interpreter
@pytest.mark.parametrize("size_i, size_j, size_g", [[5, 7, 3]])
def test_swizzle2d(size_i, size_j, size_g, device):
Expand Down
10 changes: 6 additions & 4 deletions python/triton/language/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,11 +412,13 @@ def flip(x, dim=None):
"""
core.static_assert(_is_power_of_two(x.shape[_get_flip_dim(dim, x.shape)]))
core.static_assert(_is_power_of_two(x.numel))
# # reshape the tensor to have all dimensions be 2.
# # TODO: We shouldn't have to change the dimensions not sorted.
# reshape the tensor to have all dimensions be 2.
# TODO: We shouldn't have to change the dimensions not sorted.
steps: core.constexpr = _log2(x.numel)
start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)])
y = core.reshape(x, [2] * steps)

idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
y = core.reshape(x.to(idtype, bitcast=True), [2] * steps)
y = core.expand_dims(y, start)
flip = (core.arange(0, 2)[:, None] == 1 - core.arange(0, 2))
for i in core.static_range(start, steps):
Expand All @@ -425,7 +427,7 @@ def flip(x, dim=None):
if j != i and j != i + 1:
flip2 = core.expand_dims(flip2, j)
y = sum(y * flip2, i + 1, keep_dims=True)
x = core.reshape(y, x.shape)
x = core.reshape(y, x.shape).to(x.dtype, bitcast=True)
return x


Expand Down
4 changes: 3 additions & 1 deletion python/triton/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,10 +726,12 @@ def check_tensor(self, input):
self.check_axis(arg.shape, self.axis)

def to_tensor(self, ret, dtype):
np_dtype = _get_np_dtype(dtype)
if hasattr(ret, "shape") and ret.shape:
ret = ret.astype(np_dtype)
ret_type = tl.block_type(dtype, list(ret.shape))
else:
ret = np.array([ret]).astype(_get_np_dtype(dtype))
ret = np.array([ret], dtype=np_dtype)
ret_type = dtype
return tl.core.tensor(TensorHandle(ret, dtype.scalar), ret_type)

Expand Down

0 comments on commit a52c88a

Please sign in to comment.