Skip to content

Commit

Permalink
[Dlight] LowBatchGemv rule only apply to function with spatial symbo…
Browse files Browse the repository at this point in the history
…lic var (#16678)

* squash

* fix
  • Loading branch information
jinhongyii authored Mar 9, 2024
1 parent 48992a4 commit 5bbe1ab
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
12 changes: 10 additions & 2 deletions python/tvm/dlight/gpu/low_batch_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,14 @@ def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffe
for iter_var in block_stmt.iter_vars
if isinstance(iter_var.dom.extent, tir.IntImm)
)
if len(const_iter_vars) == len(block_stmt.iter_vars):
if len(block_stmt.iter_vars) - len(const_iter_vars) != 1:
return None
symbolic_iter_var = list(
iter_var
for iter_var in block_stmt.iter_vars
if not isinstance(iter_var.dom.extent, tir.IntImm)
)[0]
if symbolic_iter_var.iter_type != tir.stmt.IterVar.DataPar:
return None
ret = [
read.buffer
Expand Down Expand Up @@ -220,7 +227,8 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-
return None
sch = tir.Schedule(func)
block_infos = normalize_prim_func(sch)

if block_infos is None:
return None
reduction_block_infos = [
block_info for block_info in block_infos if block_info.is_reduction()
]
Expand Down
24 changes: 24 additions & 0 deletions tests/python/dlight/test_gpu_low_batch_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,5 +251,29 @@ def expected(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), "float
tvm.ir.assert_structural_equal(mod["main"], expected)


def test_reduction_symbolic_var():
# fmt: off
@T.prim_func(private=True)
def before(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
kv_seq_len = T.int64()
A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), kv_seq_len))
B = T.match_buffer(var_B, (T.int64(1), T.int64(32), kv_seq_len, T.int64(128)))
# with T.block("root"):
for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), kv_seq_len):
with T.block("matmul"):
v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k])
T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3])
T.writes(matmul[v_i0, v_i1, v_i2, v_i3])
with T.init():
matmul[v_i0, v_i1, v_i2, v_i3] = T.float32(0)
matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3]
# fmt: on
mod = tvm.IRModule({"main": before})
with Target("metal"):
mod = dl.ApplyDefaultSchedule(dl.gpu.LowBatchGEMV(4))(mod)
tvm.ir.assert_structural_equal(mod["main"], before)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 5bbe1ab

Please sign in to comment.