Skip to content

Commit

Permalink
Update numerical checks on
Browse files Browse the repository at this point in the history
* test_half_misaligned_vector_load
* test_half_broadcast
  • Loading branch information
csullivan committed Feb 12, 2024
1 parent 38693aa commit e01330c
Showing 1 changed file with 39 additions and 11 deletions.
50 changes: 39 additions & 11 deletions tests/python/tir-base/test_native_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,18 @@ def add(
@tvm.testing.requires_cuda_compute_version(9)
def test_e4m3_packing():
native_dtype, packed_dtype = ("e4m3_float8x2", "uint32")
vector_length = 64
length = 64

@T.prim_func
def add(
A: T.Buffer((vector_length,), native_dtype),
B: T.Buffer((vector_length,), packed_dtype),
A: T.Buffer((length,), native_dtype),
B: T.Buffer((length,), packed_dtype),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i in range(vector_length):
for i in range(length):
with T.block("C"):
v_i = T.axis.spatial(vector_length, i)
v_i = T.axis.spatial(length, i)
T.reads(A[v_i])
T.writes(B[v_i])
B[v_i] = T.reinterpret(packed_dtype, A[v_i])
Expand Down Expand Up @@ -191,8 +191,18 @@ def vector_broadcast(a: T.Buffer[(), dtype], vec: T.Buffer[(bcast_length,), dtyp
sch.bind(tx, "threadIdx.x")

target = "cuda"
tvm.build(sch.mod, target=target)
# TODO(csullivan): numerical check
func = tvm.build(sch.mod, target=target)
dev = tvm.device(target, 0)

a_np = np.random.uniform(low=0, high=4, size=()).astype(dtype)
a = tvm.nd.array(a_np, device=dev)
b = tvm.nd.empty((bcast_length,), dtype=dtype, device=dev)

func(a, b)

b_np = np.full((bcast_length,), a_np)

tvm.testing.assert_allclose(b.numpy(), b_np)


vector_length = tvm.testing.parameter(2, 4)
Expand All @@ -202,17 +212,35 @@ def vector_broadcast(a: T.Buffer[(), dtype], vec: T.Buffer[(bcast_length,), dtyp
def test_half_misaligned_vector_load(vector_length):
dtype = "float16"
vec_dtype = dtype + "x" + str(vector_length)
length = 256

@T.prim_func
def vector_load(A: T.Buffer[(128,), dtype], B: T.Buffer[(32,), vec_dtype]):
def vector_load(
A: T.Buffer[(length,), dtype], B: T.Buffer[(length // vector_length,), vec_dtype]
):
for b in T.thread_binding(1, thread="blockIdx.x"):
for i in T.thread_binding(32, thread="threadIdx.x"):
for i in T.thread_binding(length // vector_length, thread="threadIdx.x"):
vec_index = T.ramp((i + 1) * vector_length - 1, -1, vector_length)
B[i] = A[vec_index]

target = "cuda"
tvm.build(vector_load, target=target)
# TODO(csullivan): numerical check
f = tvm.build(vector_load, target=target)

dev = tvm.device(target, 0)
a_np = np.random.uniform(low=0, high=1, size=(length,)).astype(dtype)
a = tvm.nd.array(a_np, device=dev)

b = tvm.nd.empty((length // vector_length,), dtype=vec_dtype, device=dev)

f(a, b)

b_np = np.empty((length // vector_length, vector_length), dtype=dtype)

for i in range(length // vector_length):
start_index = (i + 1) * vector_length - 1
b_np[i, :] = a_np[start_index - vector_length + 1 : start_index + 1][::-1]

tvm.testing.assert_allclose(b.numpy(), b_np)


@tvm.testing.requires_cuda_compute_version(8)
Expand Down

0 comments on commit e01330c

Please sign in to comment.