From c0e688bdd2aa9c811dffbcb4a85bbda15c6b9fcf Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Sat, 12 Oct 2024 15:58:59 +0800 Subject: [PATCH] =?UTF-8?q?[torch-frontend]=20use=20min(target=5Fversion,?= =?UTF-8?q?=20current=5Fversion)=20as=20target=5Fv=E2=80=A6=20(#460)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ersion --- .../python/test/test_stablehlo_bytecode.py | 3 +++ .../python/torch_frontend/compile.py | 16 +++++++++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/frontends/torch-frontend/torch-frontend/python/test/test_stablehlo_bytecode.py b/frontends/torch-frontend/torch-frontend/python/test/test_stablehlo_bytecode.py index 6b82aad9b..628592087 100644 --- a/frontends/torch-frontend/torch-frontend/python/test/test_stablehlo_bytecode.py +++ b/frontends/torch-frontend/torch-frontend/python/test/test_stablehlo_bytecode.py @@ -9,6 +9,9 @@ def serialize_helper(module, inputs): stablehlo_bytecode = compile(module, inputs, "stablehlo+0.16.2") deserialize_str = deserialize_portable_artifact(stablehlo_bytecode) print(deserialize_str) + stablehlo_bytecode = compile(module, inputs, "stablehlo+10.0.0") + deserialize_str = deserialize_portable_artifact(stablehlo_bytecode) + print(deserialize_str) # ============================================================================== class SoftmaxModule(torch.nn.Module): diff --git a/frontends/torch-frontend/torch-frontend/python/torch_frontend/compile.py b/frontends/torch-frontend/torch-frontend/python/torch_frontend/compile.py index 487312b04..28056e457 100644 --- a/frontends/torch-frontend/torch-frontend/python/torch_frontend/compile.py +++ b/frontends/torch-frontend/torch-frontend/python/torch_frontend/compile.py @@ -11,7 +11,7 @@ from . import ir from .passmanager import PassManager -from ._mlir_libs._stablehlo import serialize_portable_artifact +from ._mlir_libs._stablehlo import serialize_portable_artifact, get_current_version from .extra_shape_fn import byteir_extra_library @@ -259,8 +259,13 @@ def compile( ############################################ # serialize stablehlo to target version ############################################ + from packaging import version + target_version = version.Version(output_type.split("+")[1]) + current_version = version.Version(get_current_version()) + if target_version > current_version: + target_version = current_version return serialize_portable_artifact( - module.operation.get_asm(), output_type.split("+")[1] + module.operation.get_asm(), str(target_version) ) @@ -380,6 +385,11 @@ def compile_dynamo_model( ############################################ # serialize stablehlo to target version ############################################ + from packaging import version + target_version = version.Version(output_type.split("+")[1]) + current_version = version.Version(get_current_version()) + if target_version > current_version: + target_version = current_version return serialize_portable_artifact( - module.operation.get_asm(), output_type.split("+")[1] + module.operation.get_asm(), str(target_version) )