Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Target Independent Inference Optimizations #2525

Open
Wheest opened this issue Sep 4, 2024 · 0 comments
Open

Target Independent Inference Optimizations #2525

Wheest opened this issue Sep 4, 2024 · 0 comments
Assignees

Comments

@Wheest
Copy link

Wheest commented Sep 4, 2024

I've been working with StableHLO to import graphs for inference from PyTorch, and I'm trying to understand what target independent graph-level optimizations are available (and currently missing) and how I can apply them.
Essentially something similar to Onnx-simplifier, or TVM's PassManager with o3.

For example, one I'm looking at is that in most cases for inference, batch normalization can be folded into the weights of the previous layer.

However, when I export a graph from PyTorch, I see that the batch normalization layer is still present in the graph. I'm wondering if there's a way to apply this optimization in StableHLO, or if it's something that's planned for the future.

For example, exporting this one-layer PyTorch model (using the forward stablehlo text):

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.export import export
from torch_xla.stablehlo import exported_program_to_stablehlo


class WeeNetBN(nn.Module):
    def __init__(self):
        """Takes 4D input, applies a convolution,
        a batch normalization layer, and ReLU."""
        super(WeeNetBN, self).__init__()
        self.conv = nn.Conv2d(3, 6, 5)
        self.bn = nn.BatchNorm2d(6)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x)
        return x

model = WeeNetBN()

dummy_input = torch.randn(1, 3, 32, 32)
output = model(dummy_input)
sample_input = (dummy_input,)
exported = export(model, sample_input)

stablehlo_program = exported_program_to_stablehlo(exported)
with open("examples/stablehlo_weenet_bn.mlir", "w") as f:
    f.write(stablehlo_program.get_stablehlo_text("forward"))

The MLIR still has a batch norm in it:

%output, %batch_mean, %batch_var = "stablehlo.batch_norm_training"(%2, %arg3, %arg2) <{epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64}> : (tensor<1x6x28x28xf32>, tensor<6xf32>, tensor<6xf32>) -> (tensor<1x6x28x28xf32>, tensor<6xf32>, tensor<6xf32>)

Using existing passes doesn't seem to remove this op.

@abhigunj abhigunj self-assigned this Sep 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants