Skip to content

Commit

Permalink
Add additional default decompositions for upsample operators
Browse files Browse the repository at this point in the history
Summary:
There are several core ATen ops that are not yet supported on ExecuTorch, including upsample_bilinear2d.vec and upsample_nearest2d.vec. These ops are currently not decomposed by default with PyTorch export default decompositions, but should be. Existing ET consumers rely on this behavior, so we need to preserve it until we have upsample kernels ready.

This change allows ET to opt-into decomposing these ops, regardless of the PyTorch default export decomposition table. This will unblock updating PyTorch with the correct behavior (see pytorch/pytorch#116684).

Once the upsample kernels land in ET, we can remove these decompositions. This is currently blocked by pin bumps, which may take a while to resolve.

Differential Revision: D67443180
  • Loading branch information
GregoryComer committed Dec 19, 2024
1 parent 2ed5ce3 commit e61c9b5
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 15 deletions.
57 changes: 43 additions & 14 deletions exir/program/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import copy
import unittest
from collections.abc import Iterable
from typing import Any, Dict

import torch
Expand All @@ -21,6 +22,7 @@
from executorch.exir.lowered_backend_module import get_lowered_submodules
from executorch.exir.pass_base import ExportPass
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.replace_aten_with_edge_pass import aten_to_edge
from executorch.exir.program._program import (
EdgeProgramManager,
ExecutorchProgramManager,
Expand All @@ -41,6 +43,16 @@
from torch.nn import functional as F


def count_nodes(graph_module, target):
targets = target if isinstance(target, Iterable) else [target]

count = 0
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target in targets:
count += 1
return count


class TestLinear(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -662,13 +674,6 @@ def _get_random_inputs(cls):
partitioner=[NonDecompTestPartitioner()],
)

def count_nodes(graph_module, target):
count = 0
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target == target:
count += 1
return count

# There should be 1 call_delegate node and 1 node for aten.mm.default for the
# linear that doesn't have a bias which was decomposed as the partitioner
# said this node wasn't supported.
Expand Down Expand Up @@ -723,13 +728,6 @@ def _test_to_edge_with_preserved_ops(
):
edge = to_edge_with_preserved_ops(program, preserve_ops=preserved_ops)

def count_nodes(graph_module, target):
count = 0
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target in target:
count += 1
return count

aten_ops_non_decomposed = count_nodes(
program.graph_module,
preserved_ops,
Expand Down Expand Up @@ -811,3 +809,34 @@ def test_save_fails(self):
et = edge.to_executorch()
with self.assertRaises(ValueError):
_ = et.save("/tmp/test_save.pt")

def test_additional_decomposed_ops(self):
"""
Validate that EXECUTORCH_ADDITIONAL_DECOMPOSED_OPS are decomposed.
"""

class TestModel(torch.nn.Module):
def forward(self, x):
y = torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
y = torch.nn.functional.interpolate(y, scale_factor=2, mode="bilinear")
return y

test_ops = [
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.upsample_nearest2d.vec,
]
inputs = (torch.randn(1, 1, 4, 4),)
program = torch.export.export(TestModel(), inputs)

for op in test_ops:
self.assertEqual(1, count_nodes(program.graph_module, op))

edge1 = to_edge(program)
edge2 = to_edge_transform_and_lower(program)

for edge in [edge1, edge2]:
for op in test_ops:
edge_op = aten_to_edge(op)
self.assertEqual(
0, count_nodes(edge.exported_program().graph_module, edge_op)
)
16 changes: 15 additions & 1 deletion exir/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@

torchdynamo_enabled = False

"""
Additional decompositions to apply by during to_edge or
to to_edge_transform_and_lower in addition to the default decompositions from
PyTorch export.
"""
EXECUTORCH_ADDITIONAL_DECOMPOSITIONS = [
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.upsample_nearest2d.vec,
]


def get_stacktrace() -> List[Dict[str, str]]:
"""
Expand Down Expand Up @@ -631,8 +641,12 @@ def _default_decomposition_table(
]
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e...
return get_decompositions(decomp_opset)

# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir....
return default_decompositions()
table = default_decompositions()
additional_decompositions = get_decompositions(EXECUTORCH_ADDITIONAL_DECOMPOSITIONS)
table.decomp_table.update(additional_decompositions)
return table


def dynamo_trace(
Expand Down

0 comments on commit e61c9b5

Please sign in to comment.