Skip to content

Commit

Permalink
Bump IREE to 20240226.813 and adapt to API breaks. (nod-ai#482)
Browse files Browse the repository at this point in the history
This release pulls in iree-org/iree#16486 which
makes
substantial changes to torch imports:

* Always generates async code (with a special `$async` suffixed
entrypoint that the default entrypoint delegates to).
* Internal structure of the generated code is different, invalidating
some tests.
  • Loading branch information
stellaraccident authored Feb 26, 2024
1 parent 18262d4 commit 971231c
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 63 deletions.
4 changes: 2 additions & 2 deletions core/iree-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
iree-compiler==20240215.802
iree-runtime==20240215.802
iree-compiler==20240226.813
iree-runtime==20240226.813
7 changes: 1 addition & 6 deletions core/shark_turbine/dynamo/backends/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,7 @@
from torch._dynamo.backends.common import aot_autograd
from ..passes import turbine_cpu_pass_pipeline

DEFAULT_COMPILER_FLAGS = (
# Enable asynchronous calling convention.
# TODO: Enable async execution mode.
# "--iree-execution-model=async-external",
"--iree-input-type=tm_tensor",
)
DEFAULT_COMPILER_FLAGS = ("--iree-input-type=torch",)


def _base_backend(gm: torch.fx.GraphModule, example_inputs):
Expand Down
2 changes: 1 addition & 1 deletion core/shark_turbine/dynamo/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(
self,
user_module: VmModule,
device_state: DeviceState,
entry_name: str = "main",
entry_name: str = "main$async",
):
self.user_module = user_module
self.vm_context = VmContext(
Expand Down
6 changes: 1 addition & 5 deletions core/shark_turbine/dynamo/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,7 @@

from ..importers.fx_importer import FxImporter

DEFAULT_COMPILER_FLAGS = (
# Enable asynchronous calling convention.
"--iree-execution-model=async-external",
"--iree-input-type=torch",
)
DEFAULT_COMPILER_FLAGS = ("--iree-input-type=torch",)

###############################################################################
# Factories and device enablement
Expand Down
28 changes: 8 additions & 20 deletions core/tests/aot/args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,24 @@ def foobar(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)):
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn(
"func.func @foobar(%arg0: tensor<3x2xf32>, %arg1: tensor<1x1xf32>) -> (tensor<1x1xf32>, tensor<3x2xf32>)",
"util.func public @foobar$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.fence, %arg3: !hal.fence) -> (!hal.buffer_view, !hal.buffer_view)",
module_str,
)
self.assertIn("return %arg1, %arg0", module_str)

def testProcToJitArgs(self):
class ProcArgsModule(CompiledModule):
class testProcToJitArgs(CompiledModule):
def foobar(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)):
return self.compute(a, b)

@jittable
def compute(a, b):
return a + b

inst = ProcArgsModule(context=Context())
inst = testProcToJitArgs(context=Context())
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn(
"func.func @foobar(%arg0: tensor<3x2xf32>, %arg1: tensor<1x1xf32>) -> tensor<3x2xf32>",
module_str,
)
self.assertIn(
"func.func private @compute(%arg0: tensor<3x2xf32>, %arg1: tensor<1x1xf32>) -> tensor<3x2xf32>",
module_str,
)
self.assertIn(
"%0 = call @compute(%arg0, %arg1)",
"linalg.generic",
module_str,
)

Expand All @@ -68,13 +59,10 @@ def compute(a, b):
inst = ProcArgsModule(context=Context())
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn(
"%0 = call @compute(%arg0, %arg1)",
module_str,
)
self.assertIn(
"%1 = call @compute$1(%0, %arg0)",
module_str,
self.assertEqual(
2,
module_str.count("linalg.generic"),
msg=f"Did not find two linalg.generics in module: module_str",
)


Expand Down
46 changes: 17 additions & 29 deletions core/tests/aot/globals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def forward(self, x):
return self.classifier(x)


class ArgsTest(unittest.TestCase):
class GlobalsTest(unittest.TestCase):
def testGlobalParameters(self):
m = SimpleParams()

Expand Down Expand Up @@ -63,10 +63,6 @@ def read_params(self):
"%_params.classifier.bias = util.global.load @_params.classifier.bias",
module_str,
)
self.assertIn(
"return %_params.classifier.weight, %_params.classifier.bias",
module_str,
)

def testGlobalLoadFromPyLeaf(self):
m = SimpleParams()
Expand All @@ -84,7 +80,6 @@ def read_weight(self):
"%_params.classifier.weight = util.global.load @_params.classifier.weight",
module_str,
)
self.assertIn("return %_params.classifier.weight", module_str)

def testGlobalStoreFromPyTree(self):
m = SimpleParams()
Expand All @@ -100,8 +95,10 @@ def update_params(me, updates=abstractify(params)):
inst = GlobalModule(context=Context())
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn("util.global.store %arg0, @_params.classifier.weight", module_str)
self.assertIn("util.global.store %arg1, @_params.classifier.bias", module_str)
self.assertRegex(
module_str, "util.global.store %.*, @_params.classifier.weight"
)
self.assertRegex(module_str, "util.global.store %.*, @_params.classifier.bias")

def testGlobalStoreFromLeaf(self):
m = SimpleParams()
Expand All @@ -115,7 +112,7 @@ def update_bias(self, new_bias=abstractify(params["classifier.bias"])):
inst = GlobalModule(context=Context())
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn("util.global.store %arg0, @_params.classifier.bias", module_str)
self.assertRegex(module_str, "util.global.store %.*, @_params.classifier.bias")

def testExportSingleGlobalTensor(self):
state_example = torch.randn(3, 11)
Expand All @@ -131,7 +128,6 @@ def read_state(self):
print(module_str)
self.assertIn("util.global private @_state0.global", module_str)
self.assertIn("%_state0.global = util.global.load @_state0.global", module_str)
self.assertIn("return %_state0.global", module_str)

def testExportTreeGlobalTensors(self):
state_example = {
Expand Down Expand Up @@ -160,10 +156,6 @@ def read_state(self):
self.assertIn("%_state0.seq.0 = util.global.load @_state0.seq.0", module_str)
self.assertIn("%_state0.seq.1 = util.global.load @_state0.seq.1", module_str)
self.assertIn("%_state0.seq.2 = util.global.load @_state0.seq.2", module_str)
self.assertIn(
"return %_state0.data, %_state0.seq.0, %_state0.seq.1, %_state0.seq.2",
module_str,
)

def testExportGlobalScalars(self):
class ScalarState(CompiledModule):
Expand Down Expand Up @@ -210,9 +202,6 @@ class DerivedState(BaseState):
print(module_str)
self.assertIn("@_state_index.global {noinline} = 0 : index", module_str)
self.assertIn("@_state_f32.global {noinline} = 0.000000e+00 : f32", module_str)
self.assertIn(
"return %_state_index.global, %_state_f32.global : index, f32", module_str
)

def testInheritOverrideBase(self):
class BaseState(CompiledModule):
Expand Down Expand Up @@ -252,8 +241,10 @@ class DerivedModule(BaseModule):
inst = DerivedModule(context=Context())
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn("util.global.store %arg0, @_params.classifier.weight", module_str)
self.assertIn("util.global.store %arg1, @_params.classifier.bias", module_str)
self.assertRegex(
module_str, "util.global.store %.*, @_params.classifier.weight"
)
self.assertRegex(module_str, "util.global.store %.*, @_params.classifier.bias")

def testUpdateGlobalStateTree(self):
state_example = {
Expand Down Expand Up @@ -287,10 +278,10 @@ def read_state(self, updates=abstractify(state_example)):
module_str,
)
self.assertIn("util.global private mutable @_state0.data", module_str)
self.assertIn("util.global.store %arg0, @_state0.data", module_str)
self.assertIn("util.global.store %arg1, @_state0.seq.0", module_str)
self.assertIn("util.global.store %arg2, @_state0.seq.1", module_str)
self.assertIn("util.global.store %arg3, @_state0.seq.2", module_str)
self.assertRegex(module_str, "util.global.store %.*, @_state0.data")
self.assertRegex(module_str, "util.global.store %.*, @_state0.seq.0")
self.assertRegex(module_str, "util.global.store %.*, @_state0.seq.1")
self.assertRegex(module_str, "util.global.store %.*, @_state0.seq.2")

def testTensorUpdateGlobal(self):
state_example = torch.randn(5, 20)
Expand All @@ -305,9 +296,9 @@ def tensor_update_state(self, update=abstractify(update_example)):
inst = UpdateState(context=Context())
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn(
"flow.tensor.update %arg0, %_state0.global[%c0, %c0] : tensor<1x20xf32> -> %_state0.global as tensor<5x20xf32>",
self.assertRegex(
module_str,
"flow.tensor.update %.*, %_state0.global\\[%c0, %c0\\] : tensor<1x20xf32> -> %_state0.global as tensor<5x20xf32>",
)

def testTensorUpdateGlobalReturnNone(self):
Expand All @@ -325,10 +316,7 @@ def tensor_update_state(self, update=abstractify(update_example)):
inst = UpdateState(context=Context())
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn(
"flow.tensor.update %arg0, %_state0.global[%c4, %c0, %c0] : tensor<1x1x4xf32> -> %_state0.global as tensor<5x20x4xf32>",
module_str,
)
self.assertIn("flow.tensor.update", module_str)

def testExternalGlobalParametersDefaults(self):
m = SimpleParams()
Expand Down

0 comments on commit 971231c

Please sign in to comment.