From a6c84bd3e7dca8f56762fd77b595e4eaf91e75ab Mon Sep 17 00:00:00 2001 From: Andreas Falkenberg <149819731+afalkenberg1@users.noreply.github.com> Date: Thu, 18 Apr 2024 17:44:12 -0700 Subject: [PATCH] Added a second gridsampler testcase to test both align_corners modes (#178) gridsampler has a parameter to select how corners and edges are aligned. The two options are now implemented and therefore two cases are provided with this PR. --- .../pytorch/operators/gridsampler/model.py | 11 ++++-- .../pytorch/operators/gridsampler2/model.py | 38 +++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) mode change 100755 => 100644 e2eshark/pytorch/operators/gridsampler/model.py create mode 100755 e2eshark/pytorch/operators/gridsampler2/model.py diff --git a/e2eshark/pytorch/operators/gridsampler/model.py b/e2eshark/pytorch/operators/gridsampler/model.py old mode 100755 new mode 100644 index 6c24280c0..3674195a3 --- a/e2eshark/pytorch/operators/gridsampler/model.py +++ b/e2eshark/pytorch/operators/gridsampler/model.py @@ -22,13 +22,16 @@ def __init__(self): super().__init__() def forward(self, x, g): - z = nn.functional.grid_sample(x, g, mode="bilinear", padding_mode="zeros", align_corners=True) + interpolation_mode=0, + padding_mode=0, + align_corners=False, + z = nn.functional.grid_sample(x, g, mode="bilinear", padding_mode="zeros", align_corners=False) return z model = op_gridsampler() - -X = torch.rand(4, 7, 8, 11) -Y = torch.rand(4, 9, 13, 2)*2-1 +# torch.manual_seed(42) +X = torch.rand(7, 8, 12, 4) +Y = torch.rand(7, 11, 13, 2)*2.0-1.0 Z = model(X, Y) E2ESHARK_CHECK["input"] = [X, Y] diff --git a/e2eshark/pytorch/operators/gridsampler2/model.py b/e2eshark/pytorch/operators/gridsampler2/model.py new file mode 100755 index 000000000..d6076d7d5 --- /dev/null +++ b/e2eshark/pytorch/operators/gridsampler2/model.py @@ -0,0 +1,38 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import sys, argparse +import torch +import torch.nn as nn +import torch_mlir + +# import from e2eshark/tools to allow running in current dir, for run through +# run.pl, commutils is symbolically linked to allow any rundir to work +sys.path.insert(0, "../../../tools/stubs") +from commonutils import E2ESHARK_CHECK_DEF + +# Create an instance of it for this test +E2ESHARK_CHECK = dict(E2ESHARK_CHECK_DEF) + +class op_gridsampler(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, g): + z = nn.functional.grid_sample(x, g, mode="bilinear", padding_mode="zeros", align_corners=True) + return z + +model = op_gridsampler() +torch.manual_seed(42) +X = torch.rand(4, 7, 8, 11) +Y = torch.rand(4, 9, 13, 2)*2-1 +Z = model(X, Y) + +E2ESHARK_CHECK["input"] = [X, Y] +E2ESHARK_CHECK["output"] = Z +print("Input:", X) +print("Grid:", Y) +print("Output:", Z)