Skip to content

Commit

Permalink
[torch-frontend] use min(target_version, current_version) as target_v… (
Browse files Browse the repository at this point in the history
#460)

…ersion
  • Loading branch information
qingyunqu authored Oct 12, 2024
1 parent 5b9a49e commit c0e688b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ def serialize_helper(module, inputs):
stablehlo_bytecode = compile(module, inputs, "stablehlo+0.16.2")
deserialize_str = deserialize_portable_artifact(stablehlo_bytecode)
print(deserialize_str)
stablehlo_bytecode = compile(module, inputs, "stablehlo+10.0.0")
deserialize_str = deserialize_portable_artifact(stablehlo_bytecode)
print(deserialize_str)

# ==============================================================================
class SoftmaxModule(torch.nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from . import ir
from .passmanager import PassManager
from ._mlir_libs._stablehlo import serialize_portable_artifact
from ._mlir_libs._stablehlo import serialize_portable_artifact, get_current_version

from .extra_shape_fn import byteir_extra_library

Expand Down Expand Up @@ -259,8 +259,13 @@ def compile(
############################################
# serialize stablehlo to target version
############################################
from packaging import version
target_version = version.Version(output_type.split("+")[1])
current_version = version.Version(get_current_version())
if target_version > current_version:
target_version = current_version
return serialize_portable_artifact(
module.operation.get_asm(), output_type.split("+")[1]
module.operation.get_asm(), str(target_version)
)


Expand Down Expand Up @@ -380,6 +385,11 @@ def compile_dynamo_model(
############################################
# serialize stablehlo to target version
############################################
from packaging import version
target_version = version.Version(output_type.split("+")[1])
current_version = version.Version(get_current_version())
if target_version > current_version:
target_version = current_version
return serialize_portable_artifact(
module.operation.get_asm(), output_type.split("+")[1]
module.operation.get_asm(), str(target_version)
)

0 comments on commit c0e688b

Please sign in to comment.