From cca707378f03c3646d075907371cdb6a2806258f Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Mon, 10 Jun 2024 23:31:39 +0000 Subject: [PATCH] Fix quant type handling in stablehlo-legalize-quant-to-int pass (#2385) Originally brought up at https://github.com/openxla/stablehlo/pull/2383#discussion_r1630510134 The `stablehlo-legalize-quant-to-int` pass propagated from mho (#2383 ) uses std::variant to handle per-tensor and per-axis quantized type. However, the base class quant::QuantizedType provides a better polymorphic way to accomplish the same. ## Direction to reviewers Please review the change only in the following commit: https://github.com/openxla/stablehlo/pull/2385/commits/ee8b499fa691bbb7b6b3948bdc581dce05abaa39 Reason: This PR is based on https://github.com/openxla/stablehlo/pull/2383 which is not merged yet and hence has many files which are already reviewed as part of #2383. --- .../stablehlo_legalize_quant_to_int.mlir | 8 +- .../StablehloLegalizeQuantToInt.cpp | 265 +++++++----------- 2 files changed, 112 insertions(+), 161 deletions(-) diff --git a/stablehlo/tests/stablehlo_legalize_quant_to_int.mlir b/stablehlo/tests/stablehlo_legalize_quant_to_int.mlir index fd303b578ec..b865f4bf64d 100644 --- a/stablehlo/tests/stablehlo_legalize_quant_to_int.mlir +++ b/stablehlo/tests/stablehlo_legalize_quant_to_int.mlir @@ -338,7 +338,7 @@ func.func @add_per_channel_i8( %arg0: tensor>, %arg1: tensor> ) -> tensor> { - // expected-error@+2 {{Per-channel quantized AddOp requires i32 storage type}} + // expected-error@+2 {{Per-axis quantized AddOp requires i32 storage type}} // expected-error@+1 {{failed to legalize operation 'stablehlo.add' that was explicitly marked illegal}} %11 = stablehlo.add %arg0, %arg1 : tensor> return %11 : tensor> @@ -350,7 +350,7 @@ func.func @add_per_channel_different_quant_types( %arg0: tensor>, %arg1: tensor> ) -> tensor> { - // expected-error@+2 {{Per-channel quantized AddOp requires the same quantized element type for all operands and results}} + // expected-error@+2 {{Per-axis quantized AddOp requires the same quantized element type for all operands and results}} // expected-error@+1 {{failed to legalize operation 'stablehlo.add' that was explicitly marked illegal}} %11 = stablehlo.add %arg0, %arg1 : ( tensor>, @@ -365,7 +365,7 @@ func.func @add_per_channel_per_tensor_mix( %arg0: tensor>, %arg1: tensor> ) -> tensor> { - // expected-error@+2 {{Per-channel quantized AddOp requires the same quantized element type for all operands and results}} + // expected-error@+2 {{Per-axis quantized AddOp requires the same quantized element type for all operands and results}} // expected-error@+1 {{failed to legalize operation 'stablehlo.add' that was explicitly marked illegal}} %11 = stablehlo.add %arg0, %arg1 : ( tensor>, @@ -1624,7 +1624,7 @@ func.func @conv2d_per_channel_rhs_result_scale_ratio_different( %arg0: tensor<128x28x28x1x!quant.uniform>, %arg1: tensor<3x3x1x2x!quant.uniform> ) -> tensor<128x26x26x2x!quant.uniform> { - // expected-error@+2 {{Per-channel quantizated Conv must have same RHS/Result scale ratio for each channel}} + // expected-error@+2 {{Per-axis quantizated Conv must have same RHS/Result scale ratio for each channel}} // expected-error@+1 {{failed to legalize operation 'stablehlo.convolution' that was explicitly marked illegal}} %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], diff --git a/stablehlo/transforms/StablehloLegalizeQuantToInt.cpp b/stablehlo/transforms/StablehloLegalizeQuantToInt.cpp index 0888c0556a4..fd9cec03de7 100644 --- a/stablehlo/transforms/StablehloLegalizeQuantToInt.cpp +++ b/stablehlo/transforms/StablehloLegalizeQuantToInt.cpp @@ -48,36 +48,28 @@ limitations under the License. namespace mlir::stablehlo { namespace { -// TODO: b/311218165 - consider extract this to common utils and better ways to -// handle polymorphism. -using QuantType = std::variant; +using QuantType = quant::QuantizedType; FailureOr getQuantType(Type type) { if (auto quantType = - dyn_cast(getElementTypeOrSelf(type))) { - return QuantType(quantType); - } - if (auto quantType = dyn_cast( - getElementTypeOrSelf(type))) { - return QuantType(quantType); - } + dyn_cast(getElementTypeOrSelf(type))) + return quantType; return failure(); } -bool isPerTensorType(QuantType quantType) { - return std::holds_alternative(quantType); +bool isPerTensorType(Type type) { + return isa(getElementTypeOrSelf(type)); } -bool isPerChannelType(QuantType quantType) { - return std::holds_alternative(quantType); +bool isPerAxisType(Type type) { + return isa(getElementTypeOrSelf(type)); } -quant::UniformQuantizedType getPerTensorType(QuantType quantType) { - return std::get(quantType); +quant::UniformQuantizedType getPerTensorType(Type type) { + return cast(getElementTypeOrSelf(type)); } -quant::UniformQuantizedPerAxisType getPerChannelType(QuantType quantType) { - return std::get(quantType); +quant::UniformQuantizedPerAxisType getPerAxisType(Type type) { + return cast(getElementTypeOrSelf(type)); } // Extracts scale and zero point info from input quant type info. @@ -86,58 +78,56 @@ void getQuantizationParams(OpBuilder &builder, Location loc, Value &zeroPoints, bool outputZeroPointInFp, DenseI64ArrayAttr &broadcastDims) { // Get scales/zero points for per-tensor and per-axis quantization cases. - if (auto *quantPerTensorType = - std::get_if(&quantType)) { + if (auto quantPerTensorType = + dyn_cast(quantType)) { scales = builder.create( - loc, builder.getF32FloatAttr(quantPerTensorType->getScale())); + loc, builder.getF32FloatAttr(quantPerTensorType.getScale())); if (outputZeroPointInFp) { zeroPoints = builder.create( loc, builder.getF32FloatAttr( - static_cast(quantPerTensorType->getZeroPoint()))); + static_cast(quantPerTensorType.getZeroPoint()))); } else { zeroPoints = builder.create( loc, builder.getI32IntegerAttr( - static_cast(quantPerTensorType->getZeroPoint()))); + static_cast(quantPerTensorType.getZeroPoint()))); } } else { - auto &quantPerChannelType = - std::get(quantType); + auto quantPerAxisType = getPerAxisType(quantType); SmallVector scalesVec; - for (auto scale : quantPerChannelType.getScales()) - scalesVec.push_back(scale); + for (auto scale : quantPerAxisType.getScales()) scalesVec.push_back(scale); scales = builder.create( loc, DenseFPElementsAttr::get( RankedTensorType::get( - {static_cast(quantPerChannelType.getScales().size())}, + {static_cast(quantPerAxisType.getScales().size())}, builder.getF32Type()), scalesVec)); if (outputZeroPointInFp) { SmallVector zeroPointsVec; - for (auto zeroPoint : quantPerChannelType.getZeroPoints()) + for (auto zeroPoint : quantPerAxisType.getZeroPoints()) zeroPointsVec.push_back(zeroPoint); zeroPoints = builder.create( loc, DenseFPElementsAttr::get( RankedTensorType::get( {static_cast( - quantPerChannelType.getZeroPoints().size())}, + quantPerAxisType.getZeroPoints().size())}, builder.getF32Type()), zeroPointsVec)); } else { SmallVector zeroPointsVec; - for (auto zeroPoint : quantPerChannelType.getZeroPoints()) + for (auto zeroPoint : quantPerAxisType.getZeroPoints()) zeroPointsVec.push_back(zeroPoint); zeroPoints = builder.create( loc, DenseIntElementsAttr::get( RankedTensorType::get( {static_cast( - quantPerChannelType.getZeroPoints().size())}, + quantPerAxisType.getZeroPoints().size())}, builder.getI32Type()), zeroPointsVec)); } broadcastDims = DenseI64ArrayAttr::get( builder.getContext(), - {static_cast(quantPerChannelType.getQuantizedDimension())}); + {static_cast(quantPerAxisType.getQuantizedDimension())}); } } @@ -145,50 +135,29 @@ void getQuantizationParams(OpBuilder &builder, Location loc, void getQuantizationStorageInfo(OpBuilder &builder, Location loc, QuantType quantType, Value &storageMin, Value &storageMax) { - if (auto *quantPerTensorType = - std::get_if(&quantType)) { - storageMin = builder.create( - loc, builder.getF32FloatAttr( - static_cast(quantPerTensorType->getStorageTypeMin()))); - storageMax = builder.create( - loc, builder.getF32FloatAttr( - static_cast(quantPerTensorType->getStorageTypeMax()))); - } else { - auto &quantPerChannelType = - std::get(quantType); - storageMin = builder.create( - loc, builder.getF32FloatAttr( - static_cast(quantPerChannelType.getStorageTypeMin()))); - storageMax = builder.create( - loc, builder.getF32FloatAttr( - static_cast(quantPerChannelType.getStorageTypeMax()))); - } + storageMin = builder.create( + loc, builder.getF32FloatAttr( + static_cast(quantType.getStorageTypeMin()))); + storageMax = builder.create( + loc, builder.getF32FloatAttr( + static_cast(quantType.getStorageTypeMax()))); } +Type getQuantStorageType(QuantType type) { return type.getStorageType(); } + // Extracts storage type of a UQ type. Return original type if it is no UQ type. Type getQuantStorageType(Type type) { if (auto shaped = dyn_cast(type)) { return shaped.clone(getQuantStorageType(shaped.getElementType())); } - if (auto elementType = - dyn_cast(getElementTypeOrSelf(type))) { - return elementType.getStorageType(); - } - if (auto elementType = dyn_cast( - getElementTypeOrSelf(type))) { - return elementType.getStorageType(); + auto quantizedType = getQuantType(type); + if (succeeded(quantizedType)) { + return getQuantStorageType(*quantizedType); } return type; } -Type getQuantStorageType(QuantType type) { - if (isPerTensorType(type)) { - return getPerTensorType(type).getStorageType(); - } - return getPerChannelType(type).getStorageType(); -} - Value applyMergedScalesAndZps(OpBuilder &builder, Location loc, QuantType inputQuantType, QuantType outputQuantType, @@ -218,30 +187,30 @@ Value applyMergedScalesAndZps(OpBuilder &builder, Location loc, } } else { int64_t channelSize = - isPerChannelType(outputQuantType) - ? getPerChannelType(outputQuantType).getScales().size() - : getPerChannelType(inputQuantType).getScales().size(); + isPerAxisType(outputQuantType) + ? getPerAxisType(outputQuantType).getScales().size() + : getPerAxisType(inputQuantType).getScales().size(); int64_t quantizedDimension = - isPerChannelType(outputQuantType) - ? getPerChannelType(outputQuantType).getQuantizedDimension() - : getPerChannelType(inputQuantType).getQuantizedDimension(); + isPerAxisType(outputQuantType) + ? getPerAxisType(outputQuantType).getQuantizedDimension() + : getPerAxisType(inputQuantType).getQuantizedDimension(); SmallVector mergedScaleDouble, mergedZpDouble; mergedScaleDouble.resize(channelSize); mergedZpDouble.resize(channelSize); for (int i = 0; i < channelSize; ++i) { mergedScaleDouble[i] = - (isPerChannelType(inputQuantType) - ? getPerChannelType(inputQuantType).getScales()[i] + (isPerAxisType(inputQuantType) + ? getPerAxisType(inputQuantType).getScales()[i] : getPerTensorType(inputQuantType).getScale()) / - (isPerChannelType(outputQuantType) - ? getPerChannelType(outputQuantType).getScales()[i] + (isPerAxisType(outputQuantType) + ? getPerAxisType(outputQuantType).getScales()[i] : getPerTensorType(outputQuantType).getScale()); mergedZpDouble[i] = - (isPerChannelType(outputQuantType) - ? getPerChannelType(outputQuantType).getZeroPoints()[i] + (isPerAxisType(outputQuantType) + ? getPerAxisType(outputQuantType).getZeroPoints()[i] : getPerTensorType(outputQuantType).getZeroPoint()) - - (isPerChannelType(inputQuantType) - ? getPerChannelType(inputQuantType).getZeroPoints()[i] + (isPerAxisType(inputQuantType) + ? getPerAxisType(inputQuantType).getZeroPoints()[i] : getPerTensorType(inputQuantType).getZeroPoint()) * mergedScaleDouble[i]; } @@ -271,7 +240,7 @@ Value applyMergedScalesAndZps(OpBuilder &builder, Location loc, // This helper function create ops to requantize `input` tensor and returns the // output tensor. Clamping is done if output integer bit-width < i32. It assumes -// that if both input and output tensor are per-channel quantized, they have the +// that if both input and output tensor are per-axis quantized, they have the // same quantization axis. // // Requantization is essentially dequantize --> quantize. @@ -339,10 +308,9 @@ class ConvertUniformQuantizeOp auto inputQuantType = getQuantType(inputElementType); auto outputQuantType = getQuantType(op.getResult().getType()); if (succeeded(inputQuantType) && succeeded(outputQuantType)) { - if (isPerChannelType(*inputQuantType) && - isPerChannelType(*outputQuantType) && - getPerChannelType(*inputQuantType).getQuantizedDimension() != - getPerChannelType(*outputQuantType).getQuantizedDimension()) { + if (isPerAxisType(*inputQuantType) && isPerAxisType(*outputQuantType) && + getPerAxisType(*inputQuantType).getQuantizedDimension() != + getPerAxisType(*outputQuantType).getQuantizedDimension()) { op->emitError("Cannot requantize while changing quantization_axis"); return failure(); } @@ -464,30 +432,27 @@ class ConvertUniformQuantizedAddOp return failure(); } - if (isPerChannelType(*lhsQuantType) || isPerChannelType(*rhsQuantType) || - isPerChannelType(*resQuantType)) { - // Handle Per-Channel Quantized Types. We only support lhs/rhs/result with - // exact same per-channel quantized types with I32 storage type. - if (!isPerChannelType(*lhsQuantType) || - !isPerChannelType(*rhsQuantType) || - !isPerChannelType(*resQuantType) || - getPerChannelType(*lhsQuantType) != - getPerChannelType(*rhsQuantType) || - getPerChannelType(*lhsQuantType) != - getPerChannelType(*resQuantType)) { + if (isPerAxisType(*lhsQuantType) || isPerAxisType(*rhsQuantType) || + isPerAxisType(*resQuantType)) { + // Handle Per-Axis Quantized Types. We only support lhs/rhs/result with + // exact same per-axis quantized types with I32 storage type. + if (!isPerAxisType(*lhsQuantType) || !isPerAxisType(*rhsQuantType) || + !isPerAxisType(*resQuantType) || + getPerAxisType(*lhsQuantType) != getPerAxisType(*rhsQuantType) || + getPerAxisType(*lhsQuantType) != getPerAxisType(*resQuantType)) { op->emitError( - "Per-channel quantized AddOp requires the same quantized element " + "Per-axis quantized AddOp requires the same quantized element " "type for all operands and results"); return failure(); } - if (!getPerChannelType(*lhsQuantType).getStorageType().isInteger(32)) { + if (!getPerAxisType(*lhsQuantType).getStorageType().isInteger(32)) { // For server-side StableHLO Quantization, add is quantized only when // fused with conv/dot ops, whose output must be i32. - op->emitError("Per-channel quantized AddOp requires i32 storage type"); + op->emitError("Per-axis quantized AddOp requires i32 storage type"); return failure(); } - return matchAndRewritePerChannel(op, adaptor, rewriter, - getPerChannelType(*lhsQuantType)); + return matchAndRewritePerAxis(op, adaptor, rewriter, + getPerAxisType(*lhsQuantType)); } // TODO: b/260280919 - Consider avoiding conversion to int32. @@ -551,14 +516,14 @@ class ConvertUniformQuantizedAddOp return success(); } - LogicalResult matchAndRewritePerChannel( + LogicalResult matchAndRewritePerAxis( stablehlo::AddOp op, stablehlo::AddOpAdaptor adaptor, ConversionPatternRewriter &rewriter, quant::UniformQuantizedPerAxisType quantType) const { // We assume lhs/rhs/result have the same quantized type with i32 storage. Value addResult = rewriter.create( op->getLoc(), adaptor.getLhs(), adaptor.getRhs()); - // Add zp contribution if it is non-zero for any channel. + // Add zp contribution if it is non-zero for any axis. if (llvm::any_of(quantType.getZeroPoints(), [](int64_t zp) { return zp != 0; })) { SmallVector zpsVec(quantType.getZeroPoints().begin(), @@ -600,8 +565,8 @@ bool isZeroPointZero(QuantType type) { if (isPerTensorType(type)) { return getPerTensorType(type).getZeroPoint() == 0; } - if (isPerChannelType(type)) { - ArrayRef zeroPoints = getPerChannelType(type).getZeroPoints(); + if (isPerAxisType(type)) { + ArrayRef zeroPoints = getPerAxisType(type).getZeroPoints(); return llvm::all_of(zeroPoints, [](int64_t zp) { return zp == 0; }); } return false; @@ -894,9 +859,7 @@ Value createDotLikeKernel( loc, DenseIntElementsAttr::get( RankedTensorType::get({}, builder.getI8Type()), {static_cast( - cast( - getElementTypeOrSelf(op.getLhs().getType())) - .getZeroPoint())})); + getPerTensorType(op.getLhs().getType()).getZeroPoint())})); // Convert Padding attributes from stablehlo::Convolution to stablehlo::Pad. // Note that Padding is applied for spatial dimensions [1...rank-1) only for // stablehlo::Convolution. But stablehlo::Pad require those for all @@ -950,36 +913,35 @@ LogicalResult matchAndRewriteDotLikeOp(DotLikeOp op, DotLikeOpAdaptor adaptor, Value resI32 = createDotLikeKernel(rewriter, op->getLoc(), op, resInt32TensorType, lhs, rhs, attrs); - auto lhsElementQuantType = cast( - getElementTypeOrSelf(op.getLhs().getType())); + auto lhsElementQuantType = getPerTensorType(op.getLhs().getType()); auto rhsElementQuantType = dyn_cast( getElementTypeOrSelf(op.getRhs().getType())); - auto rhsElementQuantPerChannelType = + auto rhsElementQuantPerAxisType = dyn_cast( getElementTypeOrSelf(op.getRhs().getType())); auto resElementQuantType = dyn_cast( getElementTypeOrSelf(op.getResult())); - auto resElementQuantPerChannelType = + auto resElementQuantPerAxisType = dyn_cast( getElementTypeOrSelf(op.getResult())); // Here we assume LHS must be per-tensor quantized. - // If RHS is per-channel quantized, it must has 0 zp. + // If RHS is per-axis quantized, it must has 0 zp. Value zpOffset = calculateZeroPointOffset( rewriter, op->getLoc(), lhs, rhs, resI32, lhsElementQuantType.getZeroPoint(), (rhsElementQuantType ? rhsElementQuantType.getZeroPoint() : 0), resInt32TensorType, dims); - // For per-channel quantization, we assume that result scales are proportional - // to rhs scales for each channels. - double combinedScaleFp = - rhsElementQuantType - ? lhsElementQuantType.getScale() * rhsElementQuantType.getScale() / - resElementQuantType.getScale() - : lhsElementQuantType.getScale() * - rhsElementQuantPerChannelType.getScales()[0] / - resElementQuantPerChannelType.getScales()[0]; + // For per-axis quantization, we assume that result scales are proportional + // to rhs scales for each channel axis. + double combinedScaleFp = rhsElementQuantType + ? lhsElementQuantType.getScale() * + rhsElementQuantType.getScale() / + resElementQuantType.getScale() + : lhsElementQuantType.getScale() * + rhsElementQuantPerAxisType.getScales()[0] / + resElementQuantPerAxisType.getScales()[0]; // Multiply dot result and zp_offset by combined_scale only if it is not 1.0. if (std::abs(combinedScaleFp - 1.0) > 0.001) { @@ -1010,7 +972,7 @@ LogicalResult matchAndRewriteDotLikeOp(DotLikeOp op, DotLikeOpAdaptor adaptor, } } - // If result is per-channel quantized, it must has 0 zp. + // If result is per-axis quantized, it must has 0 zp. Value combinedZp = rewriter.create( op->getLoc(), rewriter.getI32IntegerAttr( @@ -1030,26 +992,25 @@ FailureOr isDotLikeOpHybrid(DotLikeOp op) { // Returns failure() when the type is not supported. bool isLhsQuant = isa( getElementTypeOrSelf(op.getLhs().getType())); - bool isLhsQuantPerChannel = isa( + bool isLhsQuantPerAxis = isa( getElementTypeOrSelf(op.getLhs().getType())); bool isRhsQuant = isa( getElementTypeOrSelf(op.getRhs().getType())); - bool isRhsQuantPerChannel = isa( + bool isRhsQuantPerAxis = isa( getElementTypeOrSelf(op.getRhs().getType())); bool isResQuant = isa(getElementTypeOrSelf(op.getResult())); - bool isResQuantPerChannel = isa( + bool isResQuantPerAxis = isa( getElementTypeOrSelf(op.getResult())); if (isLhsQuant && ((isRhsQuant && isResQuant) || - (isRhsQuantPerChannel && isResQuantPerChannel))) { - // For quantized ops, RHS and result must be both per-channel quantized or + (isRhsQuantPerAxis && isResQuantPerAxis))) { + // For quantized ops, RHS and result must be both per-axis quantized or // both per-tensor quantized. return false; } - if (!isLhsQuant && !isLhsQuantPerChannel && - (isRhsQuant || isRhsQuantPerChannel) && !isResQuant && - !isResQuantPerChannel) { + if (!isLhsQuant && !isLhsQuantPerAxis && (isRhsQuant || isRhsQuantPerAxis) && + !isResQuant && !isResQuantPerAxis) { return true; } op->emitError("Invalid input/output type for Dot/Convolution op"); @@ -1164,56 +1125,46 @@ bool isConvNDHWC(const stablehlo::ConvDimensionNumbersAttr &dims) { FailureOr verifyAndConstructDims( stablehlo::ConvolutionOp op) { // RHS (weight) must have zero zp. - // Here assumes RHS/result must be both per-tensor or both per-channel + // Here assumes RHS/result must be both per-tensor or both per-axis // quantized. auto failedOr = getQuantType(op.getRhs().getType()); if (failed(failedOr)) { return failure(); } QuantType rhsElementQuantType = *failedOr; - bool isRhsQuantPerTensor = - std::get_if(&rhsElementQuantType); + bool isRhsQuantPerTensor = isPerTensorType(rhsElementQuantType); if (isRhsQuantPerTensor - ? (std::get(rhsElementQuantType) - .getZeroPoint() != 0) - : llvm::any_of(llvm::concat( - std::get( - rhsElementQuantType) - .getZeroPoints(), - cast( - getElementTypeOrSelf(op.getResult())) - .getZeroPoints()), - [](int64_t zp) { return zp != 0; })) { + ? getPerTensorType(rhsElementQuantType).getZeroPoint() != 0 + : llvm::any_of( + llvm::concat( + getPerAxisType(rhsElementQuantType).getZeroPoints(), + getPerAxisType(op.getType()).getZeroPoints()), + [](int64_t zp) { return zp != 0; })) { op->emitError("RHS/result UQ type must have zero zp."); return failure(); } - // For per-channel quantization, RHS quantized axis must be out channel axis. + // For per-axis quantization, RHS quantized axis must be out channel axis. if (!isRhsQuantPerTensor && - (std::get(rhsElementQuantType) - .getQuantizedDimension() != + (getPerAxisType(rhsElementQuantType).getQuantizedDimension() != cast(op.getRhs().getType()).getRank() - 1)) { op->emitError("Conv quantized axis must be out channel axis"); return failure(); } - // For per-channel quantization, ratio between RHS and Result scales must be + // For per-axis quantization, ratio between RHS and Result scales must be // the same for each channel. if (!isRhsQuantPerTensor) { - auto resElementQuantPerChannelType = - cast( - getElementTypeOrSelf(op.getResult())); + auto resElementQuantPerAxisType = getPerAxisType(op.getType()); SmallVector scaleRatios( - resElementQuantPerChannelType.getScales().size()); + resElementQuantPerAxisType.getScales().size()); for (size_t i = 0; i < scaleRatios.size(); ++i) { - scaleRatios[i] = - resElementQuantPerChannelType.getScales()[i] / - std::get(rhsElementQuantType) - .getScales()[i]; + scaleRatios[i] = resElementQuantPerAxisType.getScales()[i] / + getPerAxisType(rhsElementQuantType).getScales()[i]; auto diff = (scaleRatios[i] - scaleRatios[0]) / scaleRatios[0]; // Check all ratios within a threshold. if (std::abs(diff) > 0.001) { op->emitError( - "Per-channel quantizated Conv must have same RHS/Result scale " + "Per-axis quantizated Conv must have same RHS/Result scale " "ratio for each channel"); return failure(); }