Skip to content

Commit

Permalink
Revert "[FxImporter] Fix sympy_int_to_int utility (#3657)"
Browse files Browse the repository at this point in the history
This reverts commit fa39d91.
  • Loading branch information
patel-vimal committed Sep 10, 2024
1 parent 35be39e commit 64652cc
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 41 deletions.
4 changes: 0 additions & 4 deletions python/TorchMLIRModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,4 @@ PYBIND11_MODULE(_torchMlir, m) {
}
},
py::arg("context"), py::arg("load") = true);

m.def("get_int64_max", []() { return INT64_MAX; });

m.def("get_int64_min", []() { return INT64_MIN; });
}
34 changes: 6 additions & 28 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,6 @@
# conditional.
ml_dtypes = None

try:
from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity
except ModuleNotFoundError:
# This commit on PyTorch repo introduced IntInfinity and NegativeIntInfinity:
# https://github.com/pytorch/pytorch/commit/2229884102ac95c9dda0aeadbded1b04295d892e
# Required module may not be present in the stable version of PyTorch.
int_oo = None
IntInfinity = None
NegativeIntInfinity = None

from torch.fx.node import (
Argument as NodeArgument,
)
Expand Down Expand Up @@ -135,8 +125,6 @@
func as func_dialect,
)

from .._mlir_libs._torchMlir import get_int64_max, get_int64_min

__all__ = [
"FxImporter",
]
Expand Down Expand Up @@ -1178,32 +1166,22 @@ def set_symbolic_guards(
self, prog: torch.export.ExportedProgram
) -> Dict[str, RangeConstraint]:

# Recent PyTorch versions use `int_oo` to represent integer infinity.
# Older PyTorch versions like PyTorch stable version may not have
# `int_oo` defined just yet.
infs = (sympy.oo, int_oo) if int_oo is not None else (sympy.oo,)

def _sympy_int_to_int(val: sympy.Expr, adjust_func: Callable):
# Convert simple sympy Integers into concrete int
if val in infs:
return get_int64_max()
if val in tuple(-inf for inf in infs):
return get_int64_min()
if val == sympy.oo:
return math.inf
if val == -sympy.oo:
return -math.inf
if isinstance(val, sympy.Integer):
return int(val)
# TODO: Remove this adjustment when fractional ranges are removed
return adjust_func(val)

contains_symbolic_ints = False
sym_int_types = (
(sympy.Integer, IntInfinity, NegativeIntInfinity)
if IntInfinity is not None
else sympy.Integer
)
for val in prog.range_constraints.values():
if (
isinstance(val.lower, sym_int_types)
and isinstance(val.upper, sym_int_types)
isinstance(val.lower, sympy.Integer)
and isinstance(val.upper, sympy.Integer)
and not val.is_bool
):
contains_symbolic_ints = True
Expand Down
16 changes: 7 additions & 9 deletions test/python/fx_importer/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,12 @@ def forward(self, x):

@run
# CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,5],f32>) -> !torch.vtensor<[?,?,5],f32>
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32>
# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
# CHECK: %[[S1:.*]] = torch.symbolic_int "s1" {min_val = 2, max_val = {{[0-9]+}}} : !torch.int
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 5)> : !torch.vtensor<[?,?,5],f32>
# CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?,5],f32> -> !torch.vtensor<[?,?,5],f32>
# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 5)> : !torch.vtensor<[?,?,5],f32>
# CHECK: return %[[TANH]] : !torch.vtensor<[?,?,5],f32>
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32>
# CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,4],f32> -> !torch.vtensor<[?,4],f32>
# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32>
# CHECK: return %[[TANH]] : !torch.vtensor<[?,4],f32>
def test_import_frozen_exported_program_with_dynamic_shapes():
class Basic(nn.Module):
def __init__(self):
Expand All @@ -104,11 +103,10 @@ def forward(self, x):
return torch.tanh(x)

batch = Dim("batch", max=10)
channel = Dim("channel", min=2)
dynamic_shapes = {"x": {0: batch, 1: channel}}
dynamic_shapes = {"x": {0: batch}}
m = fx.export_and_import(
Basic(),
torch.randn(3, 4, 5),
torch.randn(3, 4),
dynamic_shapes=dynamic_shapes,
func_name="test_net",
import_symbolic_shape_expressions=True,
Expand Down

0 comments on commit 64652cc

Please sign in to comment.