Skip to content

Commit

Permalink
[Bugfix][SLM] Produce well-formed Relax for nn.modules.KVCache (apach…
Browse files Browse the repository at this point in the history
…e#16684)

* [Bugfix][SLM] Produce well-formed Relax for nn.modules.KVCache

Prior to this commit, the `nn.modules.KVCache` implementations used
`R.call_packed(...)` to call the `"vm.builtin.attention_*"` functions.
Since `nn.Module` emits all relax functions within a
`relax.DataflowBlock`, where impure expressions are forbidden, this is
ill-formed.

This commit updates the implementations in `nn.modules.KVCache` to use
`R.call_pure_packed` instead of `R.call_packed`.  This assertation
that the callee is pure allows the call to occur within a
`relax.DataflowBlock`.

* Correct import for relax

* Fix unit test
  • Loading branch information
Lunderberg authored and thaisacs committed Apr 3, 2024
1 parent 9484836 commit 484c6c1
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 13 deletions.
4 changes: 3 additions & 1 deletion python/tvm/relax/frontend/nn/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm import tir
from tvm.ir import IRModule

from ... import expr as rx
from .... import relax as rx
from ...block_builder import BlockBuilder
from ...struct_info import ObjectStructInfo, ShapeStructInfo, TupleStructInfo
from . import core, extern
Expand Down Expand Up @@ -136,6 +136,8 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]:
outputs, inputs = _emit_method(self.builder, method_spec, params, effects)
self.builder.emit_func_output(outputs, inputs)
mod = self.builder.finalize()
assert rx.analysis.well_formed(mod)

return mod, params, ext_mods


Expand Down
27 changes: 20 additions & 7 deletions python/tvm/relax/frontend/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import List, Optional, Sequence, Union

from tvm import relax as rx
from tvm import tir
from tvm import tir, ir

from . import op
from .core import Effect, Module, ModuleList, Parameter, Tensor, get_default_dtype
Expand Down Expand Up @@ -600,8 +600,13 @@ def emit_init(self, name_hint: str, bb: rx.BlockBuilder): # pylint: disable=arg
return [
bb.emit(
rx.Call(
rx.extern("vm.builtin.attention_kv_cache_create"),
args=[rx.op.zeros(init_shape, self.dtype), init_shape, rx.PrimValue(0)],
ir.Op.get("relax.call_pure_packed"),
args=[
rx.extern("vm.builtin.attention_kv_cache_create"),
rx.op.zeros(init_shape, self.dtype),
init_shape,
rx.PrimValue(0),
],
sinfo_args=[rx.ObjectStructInfo()],
),
name_hint=name_hint,
Expand Down Expand Up @@ -671,8 +676,12 @@ def view(self, seq_len: tir.Var) -> Tensor:
return Tensor(
_expr=rx.BlockBuilder.current().emit(
rx.Call(
rx.extern("vm.builtin.attention_kv_cache_view"),
args=[self.cache, shape],
ir.Op.get("relax.call_pure_packed"),
args=[
rx.extern("vm.builtin.attention_kv_cache_view"),
self.cache,
shape,
],
sinfo_args=[rx.TensorStructInfo(shape, self.dtype)],
)
)
Expand All @@ -694,8 +703,12 @@ def append(self, new_element: Tensor) -> None:
)
self.cache = rx.BlockBuilder.current().emit(
rx.Call(
rx.extern("vm.builtin.attention_kv_cache_append"),
args=[self.cache, new_element._expr],
ir.Op.get("relax.call_pure_packed"),
args=[
rx.extern("vm.builtin.attention_kv_cache_append"),
self.cache,
new_element._expr,
],
sinfo_args=[rx.ObjectStructInfo()],
)
)
Expand Down
10 changes: 5 additions & 5 deletions tests/python/relax/test_frontend_nn_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,15 +484,15 @@ def _initialize_effect() -> R.Tuple(R.Object, R.Object):
lv: R.Tensor((8, 2, 4), dtype="float32") = R.zeros(
R.shape([8, 2, 4]), dtype="float32"
)
cache: R.Object = R.call_packed(
cache: R.Object = R.call_pure_packed(
"vm.builtin.attention_kv_cache_create",
lv,
R.shape([8, 2, 4]),
R.prim_value(0),
sinfo_args=(R.Object,),
)
lv1: R.Tuple(R.Object, R.Object) = _io, cache
gv: R.Tuple(R.Object, R.Object) = lv1
lv1 = _io, cache
gv = lv1
R.output(gv)
return gv

Expand All @@ -502,10 +502,10 @@ def forward(
) -> R.Tuple(R.Tensor((4, 2, 4), dtype="float32"), R.Tuple(R.Object, R.Object)):
R.func_attr({"num_input": 3})
with R.dataflow():
lv2: R.Object = R.call_packed(
lv2: R.Object = R.call_pure_packed(
"vm.builtin.attention_kv_cache_append", cache, x, sinfo_args=(R.Object,)
)
lv3: R.Tensor((4, 2, 4), dtype="float32") = R.call_packed(
lv3: R.Tensor((4, 2, 4), dtype="float32") = R.call_pure_packed(
"vm.builtin.attention_kv_cache_view",
lv2,
R.shape([4, 2, 4]),
Expand Down

0 comments on commit 484c6c1

Please sign in to comment.