From e821f5ee17916de52cd679482631b9567dd24032 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 18 Sep 2024 12:26:01 -0500 Subject: [PATCH] Enable `stablehlo.concatenate` to `tensor.concat` conversion --- .../conversions/linalg/tests/miscellaneous.mlir | 6 ++++++ .../transforms/StablehloLegalizeToLinalg.cpp | 17 +++++++++++++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/stablehlo/conversions/linalg/tests/miscellaneous.mlir b/stablehlo/conversions/linalg/tests/miscellaneous.mlir index c231d385d02..97c319823b4 100644 --- a/stablehlo/conversions/linalg/tests/miscellaneous.mlir +++ b/stablehlo/conversions/linalg/tests/miscellaneous.mlir @@ -89,6 +89,9 @@ func.func @bitcast_convert_contract(%input: tensor<7x4xi8>) -> tensor<7xi32> { // CHECK-SAME: %[[VAL_0:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[VAL_1:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[VAL_2:[a-zA-Z0-9_]*]] +// CHECK-PRIMTIVE-NOT: linalg.generic +// CHECK-PRIMITIVE: %[[CONCAT:.*]] = tensor.concat dim(1) +// CHECK-PRIMITIVE: return %[[CONCAT:.*]] // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[C0]] : tensor @@ -139,6 +142,9 @@ func.func @concatenate(%a: tensor, %b: tensor, %c: tensor to tensor diff --git a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp index 81654acb64a..76fec51e9e2 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp @@ -1298,7 +1298,10 @@ struct IotaToMapConverter final : OpConversionPattern { /// Converts stablehlo.concatenate operation to a linalg.generic op. struct ConcatenateConverter final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + explicit ConcatenateConverter(TypeConverter &converter, MLIRContext *context, + bool enablePrimitiveOps) + : OpConversionPattern(converter, context), + enablePrimitiveOps(enablePrimitiveOps) {} LogicalResult matchAndRewrite( mlir::stablehlo::ConcatenateOp op, OpAdaptor adaptor, @@ -1327,6 +1330,13 @@ struct ConcatenateConverter final uint64_t dim = op.getDimension(); Location loc = op.getLoc(); + + if (enablePrimitiveOps) { + auto concatOp = + rewriter.create(loc, dim, adaptor.getOperands()); + rewriter.replaceOp(op, concatOp); + return success(); + } Value zero = rewriter.create(loc, 0); // Allocate the output tensor with tensor.empty. @@ -1394,6 +1404,8 @@ struct ConcatenateConverter final linalg::getPrunedAttributeList(op)); return success(); } + + bool enablePrimitiveOps = false; }; /// Converts stablehlo.concatenate operation to a sparse_tensor.concatenate op. @@ -2594,9 +2606,10 @@ static void populateConversionPatterns(MLIRContext *context, bool enablePrimitiveOps, bool enableSparseOps) { // clang-format off + patterns->add(typeConverter, context, + enablePrimitiveOps); patterns->add< BitcastConvertConverter, - ConcatenateConverter, ConstConverterTensor, EinsumToLinalgConverter, GatherConversion,