Skip to content

Commit

Permalink
[Relax] CUDA graph rewrite treating StringImm as static
Browse files Browse the repository at this point in the history
The RewriteCUDAGraph pass missed to consider StringImm as a static
expression, causing some loss of CUDA graph rewrite opportunities.
This PR fixes the issue.
  • Loading branch information
MasterJH5574 committed Mar 9, 2024
1 parent 7b7677f commit 89a73f0
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/relax/transform/rewrite_cuda_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
}

bool IsStatic(const Expr& expr, std::vector<const VarNode*>* vars_collector = nullptr) {
if (expr->IsInstance<ConstantNode>() || expr->IsInstance<DataTypeImmNode>()) {
if (expr->IsInstance<ConstantNode>() || expr->IsInstance<DataTypeImmNode>() ||
expr->IsInstance<StringImmNode>()) {
return true;
}
if (const auto* prim_value = expr.as<PrimValueNode>()) {
Expand Down
57 changes: 55 additions & 2 deletions tests/python/relax/test_transform_rewrite_cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
import pytest

import tvm
from tvm import relax
from tvm.script import tir as T, relax as R, ir as I
import tvm.testing
from tvm import relax
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T


class BaseCompare(tvm.testing.CompareBeforeAfter):
Expand Down Expand Up @@ -704,5 +706,56 @@ def main():
tvm.ir.assert_structural_equal(Before, AfterWhenDisabled)


def test_static_args():
@I.ir_module
class Before:
@R.function
def main():
storage0 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float32")
alloc0 = R.memory.alloc_tensor(storage0, 0, R.shape([8]), "float32")
_ = R.call_packed("dummy_func", alloc0, R.dtype("float32"), R.str("string"))
return R.tuple()

@I.ir_module
class Expected:
@R.function(private=True)
def cuda_graph_alloc() -> R.Tuple(R.Object):
R.func_attr({"relax.force_pure": True})
storage0: R.Object = R.memory.alloc_storage(
R.shape([8]), R.prim_value(0), R.str("global"), R.dtype("float32")
)
gv: R.Tuple(R.Object) = (storage0,)
return gv

@R.function(private=True)
def cuda_graph_capture(alloc0: R.Tensor((8,), dtype="float32")) -> R.Tuple:
R.func_attr({"relax.force_pure": True})
_: R.Object = R.call_packed("dummy_func", alloc0, R.dtype("float32"), R.str("string"))
gv: R.Tuple = R.tuple()
return gv

@R.function
def main() -> R.Tuple:
cls = Expected
gv: R.Tuple(R.Object) = R.call_builtin_with_ctx(
"vm.builtin.cuda_graph.get_cached_alloc",
(cls.cuda_graph_alloc, R.prim_value(0)),
sinfo_args=(R.Tuple(R.Object),),
)
storage0: R.Object = gv[0]
alloc0: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(
storage0, R.prim_value(0), R.shape([8]), R.dtype("float32")
)
gv1: R.Tuple = R.call_builtin_with_ctx(
"vm.builtin.cuda_graph.run_or_capture",
(cls.cuda_graph_capture, (alloc0,), R.prim_value(0)),
sinfo_args=(R.Tuple,),
)
return R.tuple()

mod = relax.transform.RewriteCUDAGraph()(Before)
tvm.ir.assert_structural_equal(mod, Expected)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 89a73f0

Please sign in to comment.