diff --git a/stablehlo/dialect/Base.cpp b/stablehlo/dialect/Base.cpp index 1f1811a058..bafcf5b606 100644 --- a/stablehlo/dialect/Base.cpp +++ b/stablehlo/dialect/Base.cpp @@ -751,23 +751,6 @@ bool isValidStablehloQuantizedElementType(Type elementType) { quantizedPerAxisElementType.getScales().end()); } - // quantized_type_c6 - auto maxPosFiniteNum = - APFloat::getLargest( - cast(quantizedElementType.getExpressedType()) - .getFloatSemantics()) - .convertToDouble(); - auto minPosFiniteNum = - APFloat::getSmallest( - cast(quantizedElementType.getExpressedType()) - .getFloatSemantics()) - .convertToDouble(); - if (llvm::any_of(scales, [&](double scale) { - return scale < minPosFiniteNum || scale > maxPosFiniteNum; - })) { - return false; - } - // quantized_type_c7, quantized_type_c8 if (llvm::any_of(zeroPoints, [&](int64_t zeroPoint) { return storageTypeMin > zeroPoint || zeroPoint > storageTypeMax; @@ -788,11 +771,11 @@ bool isValidQuantizedDimension(Type type) { if (!quantizedPerAxisElementType) return true; - // quantized_type_c11, quantized_type_c12, quantized_type_c13 + // quantized_type_c12, quantized_type_c13 int64_t quantDim = quantizedPerAxisElementType.getQuantizedDimension(); int64_t numScales = static_cast(quantizedPerAxisElementType.getScales().size()); - return quantDim >= 0 && quantDim < rankedType.getRank() && + return quantDim < rankedType.getRank() && (!rankedType.isDynamicDim(quantDim) && numScales == rankedType.getDimSize(quantDim)); }