Skip to content

Commit

Permalink
Fix BroadcastInDimOps linalg lowering with compatible dim=1
Browse files Browse the repository at this point in the history
For the case of simple broadcast check that we don't have to extend dimentions,
so skip convertion to BroadcastOp when only one of them is 1.

Error before:

    test.mlir:5:8: error: 'linalg.broadcast' op input dim 1 should match init dim 2. input: 1, init: 5
      %0 = stablehlo.broadcast_in_dim %arg, dims = [1, 2, 3] : (tensor<3x1x1xf32>) -> tensor<1x3x5x7xf32>
           ^
    test.mlir:5:8: note: see current operation:
    %1 = "linalg.broadcast"(%arg0, %0) <{dimensions = array<i64: 0>}> ({
    ^bb0(%arg1: f32, %arg2: f32):
      "linalg.yield"(%arg1) : (f32) -> ()
    }) : (tensor<3x1x1xf32>, tensor<1x3x5x7xf32>) -> tensor<1x3x5x7xf32>
  • Loading branch information
FruitClover committed Oct 8, 2024
1 parent 8c7d87b commit b919ed0
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 4 deletions.
19 changes: 19 additions & 0 deletions stablehlo/conversions/linalg/tests/miscellaneous.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,25 @@ func.func @broadcast_in_dim_as_broadcast(%arg: tensor<4x3x16xf32>) -> tensor<4x2

// -----

// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, 0)>
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: func @broadcast_in_dim_as_broadcast_with_compatible_ones
func.func @broadcast_in_dim_as_broadcast_with_compatible_ones(%arg: tensor<3x1x1xf32>) -> tensor<1x3x5x7xf32> {
%0 = stablehlo.broadcast_in_dim %arg, dims = [1, 2, 3] : (tensor<3x1x1xf32>) -> tensor<1x3x5x7xf32>
func.return %0: tensor<1x3x5x7xf32>
}
// CHECK: %{{.*}} = tensor.empty() : tensor<1x3x5x7xf32>
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32

// CHECK-PRIMITIVE-LABEL: func @broadcast
// CHECK-PRIMITIVE: %{{.*}} = tensor.empty() : tensor<1x3x5x7xf32>
// CHECK-PRIMITIVE: linalg.broadcast
// CHECK-PRIMITIVE: dimensions = [0, 3]

// -----

// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: func @iota_f32
func.func @iota_f32() -> tensor<7x10xf32> {
Expand Down
33 changes: 29 additions & 4 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1517,11 +1517,36 @@ DenseI64ArrayAttr getBroadcastDimensionsFromBroadcastSizes(
getBroadcastDimensionsFromBroadcast(broadcastSizesSize, operandRank));
}

namespace {
bool haveSimpleCompatibleDimentions(RankedTensorType operand,
RankedTensorType result) {
auto operandTy = cast<ShapedType>(operand);
auto resultTy = cast<ShapedType>(result);
ArrayRef<int64_t> operandShape = operandTy.getShape();
ArrayRef<int64_t> resultShape = resultTy.getShape();
// Two dimensions are compatible when:
// - they are equal, or
// - one of them is 1.
// For the simple broadcast check that we don't have to extend dimentions -
// reject the case when only one of them is 1.
bool isCompatible = true;
for (auto [operandDim, resultDim] : llvm::zip(operandShape, resultShape))
isCompatible &= operandDim == resultDim;
return isCompatible;
}
} // namespace

bool BroadcastInDimOp::isSimpleBroadcast() {
auto operandRank = getOperand().getType().getRank();
auto broadcastSizesSize = getType().getRank() - operandRank;
return llvm::to_vector(getBroadcastDimensions()) ==
getBroadcastDimensionsFromBroadcast(broadcastSizesSize, operandRank);
auto operandTy = getOperand().getType();
auto resultTy = getType();
auto operandRank = operandTy.getRank();
auto broadcastSizesSize = resultTy.getRank() - operandRank;
bool haveCompatibleDimentions =
haveSimpleCompatibleDimentions(operandTy, resultTy);
return haveCompatibleDimentions &&
llvm::to_vector(getBroadcastDimensions()) ==
getBroadcastDimensionsFromBroadcast(broadcastSizesSize,
operandRank);
}

//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit b919ed0

Please sign in to comment.