From 971231c6ab8f3721573692621bb77d3ebf1f04eb Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 26 Feb 2024 14:45:15 -0800 Subject: [PATCH] Bump IREE to 20240226.813 and adapt to API breaks. (#482) This release pulls in https://github.com/openxla/iree/pull/16486 which makes substantial changes to torch imports: * Always generates async code (with a special `$async` suffixed entrypoint that the default entrypoint delegates to). * Internal structure of the generated code is different, invalidating some tests. --- core/iree-requirements.txt | 4 +- core/shark_turbine/dynamo/backends/cpu.py | 7 +--- core/shark_turbine/dynamo/executor.py | 2 +- core/shark_turbine/dynamo/tensor.py | 6 +-- core/tests/aot/args_test.py | 28 ++++---------- core/tests/aot/globals_test.py | 46 +++++++++-------------- 6 files changed, 30 insertions(+), 63 deletions(-) diff --git a/core/iree-requirements.txt b/core/iree-requirements.txt index 30a3b268d..95630b1db 100644 --- a/core/iree-requirements.txt +++ b/core/iree-requirements.txt @@ -1,2 +1,2 @@ -iree-compiler==20240215.802 -iree-runtime==20240215.802 +iree-compiler==20240226.813 +iree-runtime==20240226.813 diff --git a/core/shark_turbine/dynamo/backends/cpu.py b/core/shark_turbine/dynamo/backends/cpu.py index f92299c2e..34f9d2085 100644 --- a/core/shark_turbine/dynamo/backends/cpu.py +++ b/core/shark_turbine/dynamo/backends/cpu.py @@ -39,12 +39,7 @@ from torch._dynamo.backends.common import aot_autograd from ..passes import turbine_cpu_pass_pipeline -DEFAULT_COMPILER_FLAGS = ( - # Enable asynchronous calling convention. - # TODO: Enable async execution mode. - # "--iree-execution-model=async-external", - "--iree-input-type=tm_tensor", -) +DEFAULT_COMPILER_FLAGS = ("--iree-input-type=torch",) def _base_backend(gm: torch.fx.GraphModule, example_inputs): diff --git a/core/shark_turbine/dynamo/executor.py b/core/shark_turbine/dynamo/executor.py index 5208210a6..a18c61749 100644 --- a/core/shark_turbine/dynamo/executor.py +++ b/core/shark_turbine/dynamo/executor.py @@ -153,7 +153,7 @@ def __init__( self, user_module: VmModule, device_state: DeviceState, - entry_name: str = "main", + entry_name: str = "main$async", ): self.user_module = user_module self.vm_context = VmContext( diff --git a/core/shark_turbine/dynamo/tensor.py b/core/shark_turbine/dynamo/tensor.py index c85515c5a..1a0a683d8 100644 --- a/core/shark_turbine/dynamo/tensor.py +++ b/core/shark_turbine/dynamo/tensor.py @@ -53,11 +53,7 @@ from ..importers.fx_importer import FxImporter -DEFAULT_COMPILER_FLAGS = ( - # Enable asynchronous calling convention. - "--iree-execution-model=async-external", - "--iree-input-type=torch", -) +DEFAULT_COMPILER_FLAGS = ("--iree-input-type=torch",) ############################################################################### # Factories and device enablement diff --git a/core/tests/aot/args_test.py b/core/tests/aot/args_test.py index c14c92463..b03ac4273 100644 --- a/core/tests/aot/args_test.py +++ b/core/tests/aot/args_test.py @@ -24,13 +24,12 @@ def foobar(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)): module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) self.assertIn( - "func.func @foobar(%arg0: tensor<3x2xf32>, %arg1: tensor<1x1xf32>) -> (tensor<1x1xf32>, tensor<3x2xf32>)", + "util.func public @foobar$async(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.fence, %arg3: !hal.fence) -> (!hal.buffer_view, !hal.buffer_view)", module_str, ) - self.assertIn("return %arg1, %arg0", module_str) def testProcToJitArgs(self): - class ProcArgsModule(CompiledModule): + class testProcToJitArgs(CompiledModule): def foobar(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)): return self.compute(a, b) @@ -38,19 +37,11 @@ def foobar(self, a=AbstractTensor(3, 2), b=AbstractTensor(1, 1)): def compute(a, b): return a + b - inst = ProcArgsModule(context=Context()) + inst = testProcToJitArgs(context=Context()) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) self.assertIn( - "func.func @foobar(%arg0: tensor<3x2xf32>, %arg1: tensor<1x1xf32>) -> tensor<3x2xf32>", - module_str, - ) - self.assertIn( - "func.func private @compute(%arg0: tensor<3x2xf32>, %arg1: tensor<1x1xf32>) -> tensor<3x2xf32>", - module_str, - ) - self.assertIn( - "%0 = call @compute(%arg0, %arg1)", + "linalg.generic", module_str, ) @@ -68,13 +59,10 @@ def compute(a, b): inst = ProcArgsModule(context=Context()) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) - self.assertIn( - "%0 = call @compute(%arg0, %arg1)", - module_str, - ) - self.assertIn( - "%1 = call @compute$1(%0, %arg0)", - module_str, + self.assertEqual( + 2, + module_str.count("linalg.generic"), + msg=f"Did not find two linalg.generics in module: module_str", ) diff --git a/core/tests/aot/globals_test.py b/core/tests/aot/globals_test.py index 11618155f..657f3cd89 100644 --- a/core/tests/aot/globals_test.py +++ b/core/tests/aot/globals_test.py @@ -26,7 +26,7 @@ def forward(self, x): return self.classifier(x) -class ArgsTest(unittest.TestCase): +class GlobalsTest(unittest.TestCase): def testGlobalParameters(self): m = SimpleParams() @@ -63,10 +63,6 @@ def read_params(self): "%_params.classifier.bias = util.global.load @_params.classifier.bias", module_str, ) - self.assertIn( - "return %_params.classifier.weight, %_params.classifier.bias", - module_str, - ) def testGlobalLoadFromPyLeaf(self): m = SimpleParams() @@ -84,7 +80,6 @@ def read_weight(self): "%_params.classifier.weight = util.global.load @_params.classifier.weight", module_str, ) - self.assertIn("return %_params.classifier.weight", module_str) def testGlobalStoreFromPyTree(self): m = SimpleParams() @@ -100,8 +95,10 @@ def update_params(me, updates=abstractify(params)): inst = GlobalModule(context=Context()) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) - self.assertIn("util.global.store %arg0, @_params.classifier.weight", module_str) - self.assertIn("util.global.store %arg1, @_params.classifier.bias", module_str) + self.assertRegex( + module_str, "util.global.store %.*, @_params.classifier.weight" + ) + self.assertRegex(module_str, "util.global.store %.*, @_params.classifier.bias") def testGlobalStoreFromLeaf(self): m = SimpleParams() @@ -115,7 +112,7 @@ def update_bias(self, new_bias=abstractify(params["classifier.bias"])): inst = GlobalModule(context=Context()) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) - self.assertIn("util.global.store %arg0, @_params.classifier.bias", module_str) + self.assertRegex(module_str, "util.global.store %.*, @_params.classifier.bias") def testExportSingleGlobalTensor(self): state_example = torch.randn(3, 11) @@ -131,7 +128,6 @@ def read_state(self): print(module_str) self.assertIn("util.global private @_state0.global", module_str) self.assertIn("%_state0.global = util.global.load @_state0.global", module_str) - self.assertIn("return %_state0.global", module_str) def testExportTreeGlobalTensors(self): state_example = { @@ -160,10 +156,6 @@ def read_state(self): self.assertIn("%_state0.seq.0 = util.global.load @_state0.seq.0", module_str) self.assertIn("%_state0.seq.1 = util.global.load @_state0.seq.1", module_str) self.assertIn("%_state0.seq.2 = util.global.load @_state0.seq.2", module_str) - self.assertIn( - "return %_state0.data, %_state0.seq.0, %_state0.seq.1, %_state0.seq.2", - module_str, - ) def testExportGlobalScalars(self): class ScalarState(CompiledModule): @@ -210,9 +202,6 @@ class DerivedState(BaseState): print(module_str) self.assertIn("@_state_index.global {noinline} = 0 : index", module_str) self.assertIn("@_state_f32.global {noinline} = 0.000000e+00 : f32", module_str) - self.assertIn( - "return %_state_index.global, %_state_f32.global : index, f32", module_str - ) def testInheritOverrideBase(self): class BaseState(CompiledModule): @@ -252,8 +241,10 @@ class DerivedModule(BaseModule): inst = DerivedModule(context=Context()) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) - self.assertIn("util.global.store %arg0, @_params.classifier.weight", module_str) - self.assertIn("util.global.store %arg1, @_params.classifier.bias", module_str) + self.assertRegex( + module_str, "util.global.store %.*, @_params.classifier.weight" + ) + self.assertRegex(module_str, "util.global.store %.*, @_params.classifier.bias") def testUpdateGlobalStateTree(self): state_example = { @@ -287,10 +278,10 @@ def read_state(self, updates=abstractify(state_example)): module_str, ) self.assertIn("util.global private mutable @_state0.data", module_str) - self.assertIn("util.global.store %arg0, @_state0.data", module_str) - self.assertIn("util.global.store %arg1, @_state0.seq.0", module_str) - self.assertIn("util.global.store %arg2, @_state0.seq.1", module_str) - self.assertIn("util.global.store %arg3, @_state0.seq.2", module_str) + self.assertRegex(module_str, "util.global.store %.*, @_state0.data") + self.assertRegex(module_str, "util.global.store %.*, @_state0.seq.0") + self.assertRegex(module_str, "util.global.store %.*, @_state0.seq.1") + self.assertRegex(module_str, "util.global.store %.*, @_state0.seq.2") def testTensorUpdateGlobal(self): state_example = torch.randn(5, 20) @@ -305,9 +296,9 @@ def tensor_update_state(self, update=abstractify(update_example)): inst = UpdateState(context=Context()) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) - self.assertIn( - "flow.tensor.update %arg0, %_state0.global[%c0, %c0] : tensor<1x20xf32> -> %_state0.global as tensor<5x20xf32>", + self.assertRegex( module_str, + "flow.tensor.update %.*, %_state0.global\\[%c0, %c0\\] : tensor<1x20xf32> -> %_state0.global as tensor<5x20xf32>", ) def testTensorUpdateGlobalReturnNone(self): @@ -325,10 +316,7 @@ def tensor_update_state(self, update=abstractify(update_example)): inst = UpdateState(context=Context()) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) - self.assertIn( - "flow.tensor.update %arg0, %_state0.global[%c4, %c0, %c0] : tensor<1x1x4xf32> -> %_state0.global as tensor<5x20x4xf32>", - module_str, - ) + self.assertIn("flow.tensor.update", module_str) def testExternalGlobalParametersDefaults(self): m = SimpleParams()