Skip to content

Commit

Permalink
Cleanup: Remove duplicate quantization checks (#2566)
Browse files Browse the repository at this point in the history
The recent upstream
[change](llvm/llvm-project#100667) have
introduced quantization checks that are
already present in the StableHLO core library. This commit removes these
duplicate
checks to avoid redundancy and potential inconsistencies.


|Checks proposed to be removed| StableHLO Code | Upstream MLIR | 
|-|-|-|
| `channel-axis >= 0`|
[cs](https://github.com/openxla/stablehlo/blob/1c0547f391dff5ac71d36dc20a916260afa78c61/stablehlo/dialect/Base.cpp#L795)
|
[cs](https://github.com/llvm/llvm-project/blob/96f37ae45310885e09195be09d9c05e1c1dff86b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp#L399)
|
| scale within smallest and largest finite numbers determined by
`expressed_type`|
[cs](https://github.com/openxla/stablehlo/blob/1c0547f391dff5ac71d36dc20a916260afa78c61/stablehlo/dialect/Base.cpp#L765)
|
[cs1](https://github.com/llvm/llvm-project/blob/96f37ae45310885e09195be09d9c05e1c1dff86b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp#L327)
[cs2](https://github.com/llvm/llvm-project/blob/96f37ae45310885e09195be09d9c05e1c1dff86b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp#L393C9-L393C45)
|


Note that StableHLO has checks like `quantization_dimension <
rank(self)` and
`dim(self, quantization_dimension) = size(scales)` implemented at
[cs](https://github.com/openxla/stablehlo/blob/1c0547f391dff5ac71d36dc20a916260afa78c61/stablehlo/dialect/Base.cpp#L795).
In upstream MLIR similar checks
[cs](https://github.com/llvm/llvm-project/blob/96f37ae45310885e09195be09d9c05e1c1dff86b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp#L51)
are encoded as part of
[dcast](https://github.com/llvm/llvm-project/blob/96f37ae45310885e09195be09d9c05e1c1dff86b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp#L110)
and
[qcast](https://github.com/llvm/llvm-project/blob/96f37ae45310885e09195be09d9c05e1c1dff86b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp#L139)
ops and hence cannot be claimed as duplicate.

related upstream clean-up
llvm/llvm-project#110604
  • Loading branch information
sdasgup3 authored Oct 3, 2024
1 parent 8dd667a commit c495957
Showing 1 changed file with 2 additions and 19 deletions.
21 changes: 2 additions & 19 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,23 +751,6 @@ bool isValidStablehloQuantizedElementType(Type elementType) {
quantizedPerAxisElementType.getScales().end());
}

// quantized_type_c6
auto maxPosFiniteNum =
APFloat::getLargest(
cast<FloatType>(quantizedElementType.getExpressedType())
.getFloatSemantics())
.convertToDouble();
auto minPosFiniteNum =
APFloat::getSmallest(
cast<FloatType>(quantizedElementType.getExpressedType())
.getFloatSemantics())
.convertToDouble();
if (llvm::any_of(scales, [&](double scale) {
return scale < minPosFiniteNum || scale > maxPosFiniteNum;
})) {
return false;
}

// quantized_type_c7, quantized_type_c8
if (llvm::any_of(zeroPoints, [&](int64_t zeroPoint) {
return storageTypeMin > zeroPoint || zeroPoint > storageTypeMax;
Expand All @@ -788,11 +771,11 @@ bool isValidQuantizedDimension(Type type) {

if (!quantizedPerAxisElementType) return true;

// quantized_type_c11, quantized_type_c12, quantized_type_c13
// quantized_type_c12, quantized_type_c13
int64_t quantDim = quantizedPerAxisElementType.getQuantizedDimension();
int64_t numScales =
static_cast<int64_t>(quantizedPerAxisElementType.getScales().size());
return quantDim >= 0 && quantDim < rankedType.getRank() &&
return quantDim < rankedType.getRank() &&
(!rankedType.isDynamicDim(quantDim) &&
numScales == rankedType.getDimSize(quantDim));
}
Expand Down

0 comments on commit c495957

Please sign in to comment.