Skip to content

Commit

Permalink
Fix quant type handling in stablehlo-legalize-quant-to-int pass (#2385)
Browse files Browse the repository at this point in the history
Originally brought up at
#2383 (comment)

The `stablehlo-legalize-quant-to-int` pass propagated from mho (#2383 )
uses std::variant to handle per-tensor and per-axis quantized type.
However, the base class quant::QuantizedType provides a better
polymorphic way to accomplish the same.

## Direction to reviewers

Please review the change only in the following commit:
ee8b499

Reason: This PR is based on
#2383 which is not merged yet
and hence has many files which are already reviewed as part of #2383.
  • Loading branch information
sdasgup3 authored Jun 10, 2024
1 parent 6ebc59d commit cca7073
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 161 deletions.
8 changes: 4 additions & 4 deletions stablehlo/tests/stablehlo_legalize_quant_to_int.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ func.func @add_per_channel_i8(
%arg0: tensor<?x3x4x2x!quant.uniform<i8:f32:3, {2.9455460163317514E-5,5.8952903030815205E-5}>>,
%arg1: tensor<?x3x4x2x!quant.uniform<i8:f32:3, {2.9455460163317514E-5,5.8952903030815205E-5}>>
) -> tensor<?x3x4x2x!quant.uniform<i8:f32:3, {2.9455460163317514E-5,5.8952903030815205E-5}>> {
// expected-error@+2 {{Per-channel quantized AddOp requires i32 storage type}}
// expected-error@+2 {{Per-axis quantized AddOp requires i32 storage type}}
// expected-error@+1 {{failed to legalize operation 'stablehlo.add' that was explicitly marked illegal}}
%11 = stablehlo.add %arg0, %arg1 : tensor<?x3x4x2x!quant.uniform<i8:f32:3, {2.9455460163317514E-5,5.8952903030815205E-5}>>
return %11 : tensor<?x3x4x2x!quant.uniform<i8:f32:3, {2.9455460163317514E-5,5.8952903030815205E-5}>>
Expand All @@ -350,7 +350,7 @@ func.func @add_per_channel_different_quant_types(
%arg0: tensor<?x3x4x2x!quant.uniform<i32:f32:3, {2.9455460163317514E-5,5.8952903030815205E-5}>>,
%arg1: tensor<?x3x4x2x!quant.uniform<i32:f32:3, {1.1:2,0.4:-3}>>
) -> tensor<?x3x4x2x!quant.uniform<i32:f32:3, {2.9455460163317514E-5,5.8952903030815205E-5}>> {
// expected-error@+2 {{Per-channel quantized AddOp requires the same quantized element type for all operands and results}}
// expected-error@+2 {{Per-axis quantized AddOp requires the same quantized element type for all operands and results}}
// expected-error@+1 {{failed to legalize operation 'stablehlo.add' that was explicitly marked illegal}}
%11 = stablehlo.add %arg0, %arg1 : (
tensor<?x3x4x2x!quant.uniform<i32:f32:3, {2.9455460163317514E-5,5.8952903030815205E-5}>>,
Expand All @@ -365,7 +365,7 @@ func.func @add_per_channel_per_tensor_mix(
%arg0: tensor<?x3x4x2x!quant.uniform<i32:f32:3, {2.9455460163317514E-5,5.8952903030815205E-5}>>,
%arg1: tensor<?x3x4x2x!quant.uniform<i32:f32, 1.1:2>>
) -> tensor<?x3x4x2x!quant.uniform<i32:f32:3, {2.9455460163317514E-5,5.8952903030815205E-5}>> {
// expected-error@+2 {{Per-channel quantized AddOp requires the same quantized element type for all operands and results}}
// expected-error@+2 {{Per-axis quantized AddOp requires the same quantized element type for all operands and results}}
// expected-error@+1 {{failed to legalize operation 'stablehlo.add' that was explicitly marked illegal}}
%11 = stablehlo.add %arg0, %arg1 : (
tensor<?x3x4x2x!quant.uniform<i32:f32:3, {2.9455460163317514E-5,5.8952903030815205E-5}>>,
Expand Down Expand Up @@ -1624,7 +1624,7 @@ func.func @conv2d_per_channel_rhs_result_scale_ratio_different(
%arg0: tensor<128x28x28x1x!quant.uniform<i8:f32, 2.000000e+00:4>>,
%arg1: tensor<3x3x1x2x!quant.uniform<i8:f32:3, {2.000000e+00:0, 1.000000e+00:0}>>
) -> tensor<128x26x26x2x!quant.uniform<i32:f32:3, {4.000000e+00:0, 2.200000e+00:0}>> {
// expected-error@+2 {{Per-channel quantizated Conv must have same RHS/Result scale ratio for each channel}}
// expected-error@+2 {{Per-axis quantizated Conv must have same RHS/Result scale ratio for each channel}}
// expected-error@+1 {{failed to legalize operation 'stablehlo.convolution' that was explicitly marked illegal}}
%0 = stablehlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
Expand Down
Loading

0 comments on commit cca7073

Please sign in to comment.