Skip to content

Commit

Permalink
[Stablehlo] support uint8 (#3367)
Browse files Browse the repository at this point in the history
Support lowering unsigned integer type to stablehlo as discussed in
#2184.

The things I do in this PR:
1. create `setupBackendTypeConversionForStablehlo()`,
`createFuncBackendTypeConversionForStablehloPass` and
`createFinalizingBackendTypeConversionForStablehloPass`.
2. remove `InferTypeOpInterface` from `torch_c.to_builtin_tensor`,
because it's different result type between linalg backend and stablehlo
backend:
```
// linalg backend
func.func @forward(%arg0: !torch.vtensor<[3],ui8>) -> tensor<3xf32> {
    %c = torch_c.to_builtin_tensor %arg0 : (!torch.vtensor<[3], ui8> -> tensor<3xi8>
    %0 = tensor.empty() : tensor<3xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<3xi8>) outs(%0 : tensor<3xf32>) {
    ^bb0(%in: i8, %out: f32):
      %2 = arith.uitofp %in : i8 to f32
      linalg.yield %2 : f32
    } -> tensor<3xf32>
    return %1 : tensor<3xf32>
}
// stablehlo backend
func.func @forward(%arg0: !torch.vtensor<[3],ui8>) -> tensor<3xf32> {
    %c = torch_c.to_builtin_tensor %arg0 : (!torch.vtensor<[3], ui8> -> tensor<3xui8>
    %0 = stablehlo.convert %arg0 : (tensor<3xui8> -> tensor<3xf32>
    return %0 : tensor<3xf32>
}
```
3. fix stablehlo and linalg's conversion
  • Loading branch information
qingyunqu authored Jun 4, 2024
1 parent 56d21cb commit 50f7103
Show file tree
Hide file tree
Showing 17 changed files with 245 additions and 54 deletions.
2 changes: 1 addition & 1 deletion externals/stablehlo
Submodule stablehlo updated 39 files
+2 −0 BUILD.bazel
+4 −0 docs/generated/stablehlo_passes.md
+20 −14 stablehlo/integrations/python/tests/stablehlo.py
+0 −1 stablehlo/tests/interpret/dynamic_gather.mlir
+0 −5 stablehlo/tests/ops_speculatability.mlir
+0 −1 stablehlo/tests/ops_stablehlo.mlir
+40 −0 stablehlo/tests/stablehlo_aggressive_folder.mlir
+26 −0 stablehlo/tests/stablehlo_convert_to_signless.mlir
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_19_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_19_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_20_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_20_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_0_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_0_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_1_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_1_0.mlir.bc
+2 −2 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir
+1 −0 stablehlo/transforms/CMakeLists.txt
+4 −0 stablehlo/transforms/Passes.td
+46 −0 stablehlo/transforms/StablehloAggressiveFolder.cpp
+137 −0 stablehlo/transforms/StablehloConvertToSignless.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ class TorchConversion_Op<string mnemonic, list<Trait> traits = []>
// Conversions to backend types.
//===----------------------------------------------------------------------===//

def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor"> {
let summary = "Convert a `!torch.vtensor` to a `tensor`";
let description = [{
This op only operates on ValueTensorType, to avoid conflating conversions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ void getBackendTypeConversionDependentDialects(DialectRegistry &registry);
/// boundary (which currently consist only of builtin types).
void setupBackendTypeConversion(ConversionTarget &target,
TypeConverter &typeConverter);

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
void setupBackendTypeConversionForStablehlo(ConversionTarget &target,
TypeConverter &typeConverter);
#endif
} // namespace TorchConversion
} // namespace torch
} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ struct StablehloBackendPipelineOptions

void createTorchBackendToStablehloBackendPipeline(
OpPassManager &pm, const StablehloBackendPipelineOptions &options);

std::unique_ptr<OperationPass<ModuleOp>>
createFuncBackendTypeConversionForStablehloPass();

std::unique_ptr<InterfacePass<FunctionOpInterface>>
createFinalizingBackendTypeConversionForStablehloPass();

std::unique_ptr<OperationPass<ModuleOp>>
createVerifyStablehloBackendContractPass();
#endif
Expand Down
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "Modu
}];
}

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
def FuncBackendTypeConversionForStablehlo : Pass<"torch-func-backend-type-conversion-for-stablehlo", "ModuleOp"> {
let summary = "Convert functions to operate on builtin tensors for stablehlo backend";
let constructor = "mlir::torch::TorchConversion::createFuncBackendTypeConversionForStablehloPass()";
let description = [{
Partial type conversion pass analogous in scope to the upstream
`func-bufferize` pass. See details there.
}];
}
#endif // TORCH_MLIR_ENABLE_STABLEHLO

def FinalizingBackendTypeConversion
: InterfacePass<"torch-finalizing-backend-type-conversion", "mlir::FunctionOpInterface"> {
let summary = "Finalizes a partial conversion to builtin tensors";
Expand All @@ -32,6 +43,19 @@ def FinalizingBackendTypeConversion
}];
}

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
def FinalizingBackendTypeConversionForStablehlo
: InterfacePass<"torch-finalizing-backend-type-conversion-for-stablehlo", "mlir::FunctionOpInterface"> {
let summary = "Finalizes a partial conversion to builtin tensors for stablehlo";
let constructor =
"mlir::torch::TorchConversion::createFinalizingBackendTypeConversionForStablehloPass()";
let description = [{
Analogous in scope to the upstream `finalizing-bufferize` pass.
See details there.
}];
}
#endif // TORCH_MLIR_ENABLE_STABLEHLO

def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the linalg-on-tensors backend contract";
let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()";
Expand Down
4 changes: 3 additions & 1 deletion lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1197,6 +1197,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
Value input = payloadArgs[0];
Type inputElementType =
cast<BaseTensorType>(atenToDtype.getSelf().getType()).getDtype();
Type dtype =
cast<RankedTensorType>(converter->convertType(atenToDtype.getType()))
.getElementType();
Expand All @@ -1215,7 +1217,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
resultElementType = *maybeResultElementType;
Value result = convertScalarToDtype(b, loc, input, dtype,
/*srcOriginalDtype=*/std::nullopt,
/*srcOriginalDtype=*/inputElementType,
/*dstOriginalDtype=*/resultElementType);
return result;
}
Expand Down
16 changes: 13 additions & 3 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,8 @@ class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern<AtenOpT> {
ConversionPatternRewriter &rewriter) const override {
auto inputType = dyn_cast<RankedTensorType>(adaptor.getA().getType());
if (!inputType)

op.emitError("only Tensor types supported in StableHLO");

Location loc = op.getLoc();
Value input = adaptor.getA();
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
Expand All @@ -290,14 +290,24 @@ class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern<AtenOpT> {
for (int64_t i = 0; i < inputRank; i++)
checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne);

// handle unsigned interger
if (inputType.getElementType().isUnsignedInteger()) {
input = rewriter.create<stablehlo::ConvertOp>(
loc, input,
rewriter.getIntegerType(
inputType.getElementType().getIntOrFloatBitWidth()));
}

Value constantZero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
SmallVector<Value> indices(inputRank, constantZero);
Value result = rewriter.create<tensor::ExtractOp>(loc, input, indices);
Type resultType =
this->getTypeConverter()->convertType(op->getResult(0).getType());
rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result,
resultType, inputDtype));
rewriter.replaceOp(
op,
convertScalarToDtype(rewriter, loc, result, resultType, inputDtype,
/*srcOriginalDtype=*/inputType.getElementType()));
return success();
}
};
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToStablehlo/GatherScatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
for (int64_t i = maxIndexRank; i < inputRank; ++i) {
updateWindowDims.push_back(i);
}
llvm::outs() << "maxIndexRank: " << maxIndexRank << "\n";

auto scatterDimensionNumbers = stablehlo::ScatterDimensionNumbersAttr::get(
rewriter.getContext(),
/*updateWindowDims=*/updateWindowDims,
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ class ConvertTorchToStablehlo

TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);
TorchConversion::setupBackendTypeConversionForStablehlo(target,
typeConverter);

RewritePatternSet patterns(context);

Expand Down
25 changes: 12 additions & 13 deletions lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,18 @@ static bool haveSameSizeAndElementType(TensorType lhs, TensorType rhs) {
if (lhs.hasRank() != rhs.hasRank())
return false;
bool sameSize = lhs.hasRank() ? lhs.getShape().equals(rhs.getShape()) : true;
bool sameElementType = lhs.getElementType() == rhs.getElementType();
bool sameElementType = false;
// Namely, it is worth mentioning that the backends can have different
// expectations for signedness when converting from and to the builtin MLIR
// types. Therefore, the verifier cannot expect the input and output types to
// match in their signedness.
if (isa<IntegerType>(lhs.getElementType()) &&
isa<IntegerType>(rhs.getElementType())) {
sameElementType = lhs.getElementType().getIntOrFloatBitWidth() ==
rhs.getElementType().getIntOrFloatBitWidth();
} else {
sameElementType = lhs.getElementType() == rhs.getElementType();
}
return sameElementType && sameSize;
}

Expand All @@ -42,18 +53,6 @@ LogicalResult ToBuiltinTensorOp::verify() {
return success();
}

LogicalResult ToBuiltinTensorOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto resultType =
cast<Torch::ValueTensorType>(operands[0].getType()).toBuiltinTensor();
if (!resultType)
return failure();
inferredReturnTypes.push_back(resultType);
return success();
}

//===----------------------------------------------------------------------===//
// FromBuiltinTensorOp
//===----------------------------------------------------------------------===//
Expand Down
43 changes: 34 additions & 9 deletions lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,22 @@ void mlir::torch::TorchConversion::getBackendTypeConversionDependentDialects(
// Type conversion setup.
//===----------------------------------------------------------------------===//

static void
setupValueTensorToBuiltinTensorConversion(ConversionTarget &target,
TypeConverter &typeConverter) {
using ValueTensorTypeConversionFn =
std::function<std::optional<Type>(Torch::ValueTensorType)>;

static void setupValueTensorToBuiltinTensorConversion(
ConversionTarget &target, TypeConverter &typeConverter,
const ValueTensorTypeConversionFn &conversionFn) {
target.addLegalOp<TorchConversion::ToBuiltinTensorOp,
TorchConversion::FromBuiltinTensorOp>();
typeConverter.addConversion(
[](Torch::ValueTensorType type) -> std::optional<Type> {
return type.toBuiltinTensor();
});
typeConverter.addConversion(conversionFn);
typeConverter.addTargetMaterialization([](OpBuilder &builder, TensorType type,
ValueRange inputs,
Location loc) -> Value {
assert(inputs.size() == 1);
if (!isa<Torch::BaseTensorType>(inputs[0].getType()))
return {};
return builder.create<ToBuiltinTensorOp>(loc, inputs[0]);
return builder.create<ToBuiltinTensorOp>(loc, type, inputs[0]);
});
auto sourceMaterialization = [](OpBuilder &builder,
Torch::ValueTensorType type,
Expand Down Expand Up @@ -162,9 +162,34 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target,

void mlir::torch::TorchConversion::setupBackendTypeConversion(
ConversionTarget &target, TypeConverter &typeConverter) {
setupValueTensorToBuiltinTensorConversion(target, typeConverter);
auto valueTensorTypeConversion =
[](Torch::ValueTensorType type) -> std::optional<Type> {
return type.toBuiltinTensor();
};
setupValueTensorToBuiltinTensorConversion(target, typeConverter,
valueTensorTypeConversion);
setupTorchBoolToI1Conversion(target, typeConverter);
setupTorchIntToI64Conversion(target, typeConverter);
setupTorchFloatToF64Conversion(target, typeConverter);
setupTorchGeneratorToI64Conversion(target, typeConverter);
}

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
void mlir::torch::TorchConversion::setupBackendTypeConversionForStablehlo(
ConversionTarget &target, TypeConverter &typeConverter) {
auto valueTensorTypeConversion =
[](Torch::ValueTensorType type) -> std::optional<Type> {
auto builtinType = type.toBuiltinTensor();
if (type.getDtype().isUnsignedInteger()) {
return builtinType.clone(type.getDtype());
}
return builtinType;
};
setupValueTensorToBuiltinTensorConversion(target, typeConverter,
valueTensorTypeConversion);
setupTorchBoolToI1Conversion(target, typeConverter);
setupTorchIntToI64Conversion(target, typeConverter);
setupTorchFloatToF64Conversion(target, typeConverter);
setupTorchGeneratorToI64Conversion(target, typeConverter);
}
#endif
Loading

0 comments on commit 50f7103

Please sign in to comment.