Skip to content

Commit

Permalink
Added a second gridsampler testcase to test both align_corners modes (#…
Browse files Browse the repository at this point in the history
…178)

gridsampler has a parameter to select how corners and edges are aligned.
The two options are now implemented and therefore two cases are provided
with this PR.
  • Loading branch information
afalkenberg1 authored Apr 19, 2024
1 parent 1d77dcf commit a6c84bd
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 4 deletions.
11 changes: 7 additions & 4 deletions e2eshark/pytorch/operators/gridsampler/model.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@ def __init__(self):
super().__init__()

def forward(self, x, g):
z = nn.functional.grid_sample(x, g, mode="bilinear", padding_mode="zeros", align_corners=True)
interpolation_mode=0,
padding_mode=0,
align_corners=False,
z = nn.functional.grid_sample(x, g, mode="bilinear", padding_mode="zeros", align_corners=False)
return z

model = op_gridsampler()

X = torch.rand(4, 7, 8, 11)
Y = torch.rand(4, 9, 13, 2)*2-1
# torch.manual_seed(42)
X = torch.rand(7, 8, 12, 4)
Y = torch.rand(7, 11, 13, 2)*2.0-1.0
Z = model(X, Y)

E2ESHARK_CHECK["input"] = [X, Y]
Expand Down
38 changes: 38 additions & 0 deletions e2eshark/pytorch/operators/gridsampler2/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import sys, argparse
import torch
import torch.nn as nn
import torch_mlir

# import from e2eshark/tools to allow running in current dir, for run through
# run.pl, commutils is symbolically linked to allow any rundir to work
sys.path.insert(0, "../../../tools/stubs")
from commonutils import E2ESHARK_CHECK_DEF

# Create an instance of it for this test
E2ESHARK_CHECK = dict(E2ESHARK_CHECK_DEF)

class op_gridsampler(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, g):
z = nn.functional.grid_sample(x, g, mode="bilinear", padding_mode="zeros", align_corners=True)
return z

model = op_gridsampler()
torch.manual_seed(42)
X = torch.rand(4, 7, 8, 11)
Y = torch.rand(4, 9, 13, 2)*2-1
Z = model(X, Y)

E2ESHARK_CHECK["input"] = [X, Y]
E2ESHARK_CHECK["output"] = Z
print("Input:", X)
print("Grid:", Y)
print("Output:", Z)

0 comments on commit a6c84bd

Please sign in to comment.