Skip to content

Commit

Permalink
Update numerical checks on
Browse files Browse the repository at this point in the history
* test_e4m3_packing
* test_half4_vector_add
  • Loading branch information
csullivan committed Feb 12, 2024
1 parent e01330c commit ef2be51
Showing 1 changed file with 53 additions and 16 deletions.
69 changes: 53 additions & 16 deletions tests/python/tir-base/test_native_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,33 +71,56 @@ def add(

@tvm.testing.requires_cuda_compute_version(9)
def test_e4m3_packing():
native_dtype, packed_dtype = ("e4m3_float8x2", "uint32")
length = 64
vector_length = 4
native_dtype, packed_dtype = ("e4m3_float8x4", "uint32")

@T.prim_func
def add(
A: T.Buffer((length,), native_dtype),
B: T.Buffer((length,), packed_dtype),
R: T.Buffer((length,), packed_dtype),
B: T.Buffer((length,), native_dtype),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i in range(length):
with T.block("C"):
with T.block("R"):
v_i = T.axis.spatial(length, i)
T.reads(A[v_i])
T.writes(R[v_i])
R[v_i] = T.reinterpret(packed_dtype, A[v_i])
for i in range(length):
with T.block("B"):
v_i = T.axis.spatial(length, i)
T.reads(R[v_i])
T.writes(B[v_i])
B[v_i] = T.reinterpret(packed_dtype, A[v_i])
B[v_i] = T.reinterpret(native_dtype, R[v_i])

sch = tvm.tir.Schedule(add)
block = sch.get_block("C")
block = sch.get_block("R")
b = sch.get_loops(block)
bx, tx = sch.split(b[0], factors=[None, 32])
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")
block = sch.get_block("B")
b = sch.get_loops(block)
bx, tx = sch.split(b[0], factors=[None, 32])
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")

target = "cuda"
tvm.build(sch.mod, target=target)
# TODO(csullivan): numerical check
f = tvm.build(sch.mod, target=target)
dev = tvm.device(target, 0)

numpytype = "float8_e4m3fn"
np_shape = (length, vector_length)
a_np = np.random.uniform(low=0, high=5, size=np_shape).astype(numpytype)
a = tvm.nd.empty(shape=(length,), dtype=native_dtype, device=dev)
r = tvm.nd.empty(shape=(length,), dtype=packed_dtype, device=dev)
b = tvm.nd.empty(shape=(length,), dtype=native_dtype, device=dev)
a.copyfrom(a_np)
f(a, r, b)
tvm.testing.assert_allclose(a.numpy().astype("float16"), b.numpy().astype("float16"))


native_dtype, promoted_dtype = tvm.testing.parameters(
Expand Down Expand Up @@ -244,21 +267,23 @@ def vector_load(


@tvm.testing.requires_cuda_compute_version(8)
def test_half_vector_add():
dtype = "float16x4"
vector_length = 64
def test_half4_vector_add():
dtype = "float16"
length = 64
vector_length = 4
vec_dtype = dtype + "x" + str(vector_length)

@T.prim_func
def add(
A: T.Buffer((vector_length,), dtype),
B: T.Buffer((vector_length,), dtype),
C: T.Buffer((vector_length,), dtype),
A: T.Buffer((length,), vec_dtype),
B: T.Buffer((length,), vec_dtype),
C: T.Buffer((length,), vec_dtype),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i in range(vector_length):
for i in range(length):
with T.block("C"):
v_i = T.axis.spatial(vector_length, i)
v_i = T.axis.spatial(length, i)
T.reads(A[v_i], B[v_i])
T.writes(C[v_i])
C[v_i] = A[v_i] + B[v_i]
Expand All @@ -272,7 +297,19 @@ def add(

target = "cuda"
fadd = tvm.build(sch.mod, target=target)
# TODO(csullivan): numerical check
dev = tvm.device(target, 0)

a_np = np.random.uniform(-1, 1, (length, vector_length)).astype(dtype)
a = tvm.nd.empty(shape=(length,), dtype=vec_dtype, device=dev)
a.copyfrom(a_np)
b_np = np.random.uniform(-1, 1, (length, vector_length)).astype(dtype)
b = tvm.nd.empty(shape=(length,), dtype=vec_dtype, device=dev)
b.copyfrom(b_np)
c = tvm.nd.empty(shape=(length,), dtype=vec_dtype, device=dev)

fadd(a, b, c)
c_expected = a_np + b_np
tvm.testing.assert_allclose(c.numpy(), c_expected, atol=1e-5, rtol=1e-5)


if __name__ == "__main__":
Expand Down

0 comments on commit ef2be51

Please sign in to comment.