diff --git a/stablehlo/conversions/linalg/tests/miscellaneous.mlir b/stablehlo/conversions/linalg/tests/miscellaneous.mlir index 97c319823b..3cc398a6e9 100644 --- a/stablehlo/conversions/linalg/tests/miscellaneous.mlir +++ b/stablehlo/conversions/linalg/tests/miscellaneous.mlir @@ -545,6 +545,25 @@ func.func @broadcast_in_dim_as_broadcast(%arg: tensor<4x3x16xf32>) -> tensor<4x2 // ----- +// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, 0)> +// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: func @broadcast_in_dim_as_broadcast_with_compatible_ones +func.func @broadcast_in_dim_as_broadcast_with_compatible_ones(%arg: tensor<3x1x1xf32>) -> tensor<1x3x5x7xf32> { + %0 = stablehlo.broadcast_in_dim %arg, dims = [1, 2, 3] : (tensor<3x1x1xf32>) -> tensor<1x3x5x7xf32> + func.return %0: tensor<1x3x5x7xf32> +} +// CHECK: %{{.*}} = tensor.empty() : tensor<1x3x5x7xf32> +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// CHECK-PRIMITIVE-LABEL: func @broadcast +// CHECK-PRIMITIVE: %{{.*}} = tensor.empty() : tensor<1x3x5x7xf32> +// CHECK-PRIMITIVE: linalg.broadcast +// CHECK-PRIMITIVE: dimensions = [0, 2, 3] + +// ----- + // CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK: func @iota_f32 func.func @iota_f32() -> tensor<7x10xf32> { diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index f36dfeb1c7..9af289f4ce 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -1517,11 +1517,36 @@ DenseI64ArrayAttr getBroadcastDimensionsFromBroadcastSizes( getBroadcastDimensionsFromBroadcast(broadcastSizesSize, operandRank)); } +namespace { +bool haveSimpleCompatibleDimentions(RankedTensorType operand, + RankedTensorType result) { + auto operandTy = cast(operand); + auto resultTy = cast(result); + ArrayRef operandShape = operandTy.getShape(); + ArrayRef resultShape = resultTy.getShape(); + // Two dimensions are compatible when: + // - they are equal, or + // - one of them is 1. + // For the simple broadcast check that we don't have to extend dimentions - + // reject the case when only one of them is 1. + bool isCompatible = true; + for (auto [operandDim, resultDim] : llvm::zip(operandShape, resultShape)) + isCompatible &= operandDim == resultDim; + return isCompatible; +} +} // namespace + bool BroadcastInDimOp::isSimpleBroadcast() { - auto operandRank = getOperand().getType().getRank(); - auto broadcastSizesSize = getType().getRank() - operandRank; - return llvm::to_vector(getBroadcastDimensions()) == - getBroadcastDimensionsFromBroadcast(broadcastSizesSize, operandRank); + auto operandTy = getOperand().getType(); + auto resultTy = getType(); + auto operandRank = operandTy.getRank(); + auto broadcastSizesSize = resultTy.getRank() - operandRank; + bool haveCompatibleDimentions = + haveSimpleCompatibleDimentions(operandTy, resultTy); + return haveCompatibleDimentions && + llvm::to_vector(getBroadcastDimensions()) == + getBroadcastDimensionsFromBroadcast(broadcastSizesSize, + operandRank); } //===----------------------------------------------------------------------===//