diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp index 09160d69f27b..3b6236bc83f1 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp @@ -276,6 +276,8 @@ struct FuseTilableForallConsumers final } tensor::ParallelInsertSliceOp producerSlice; + scf::ForallOp sliceOwner; + Value fusionOperand; for (auto operand : dpsOp.getDpsInputs()) { auto forallProducer = operand.getDefiningOp(); if (!forallProducer) { @@ -288,6 +290,8 @@ struct FuseTilableForallConsumers final auto sliceOp = dyn_cast(user); if (sliceOp && sliceOp.getDest() == iterArg) { producerSlice = sliceOp; + sliceOwner = forallProducer; + fusionOperand = operand; break; } } @@ -297,7 +301,16 @@ struct FuseTilableForallConsumers final } if (!producerSlice) { - return failure(); + return rewriter.notifyMatchFailure(tilableOp, + "no scf.forall producer to fuse into"); + } + + for (auto operand : tilableOp->getOperands()) { + if (operand != fusionOperand && operand.getDefiningOp() == sliceOwner) { + return rewriter.notifyMatchFailure(tilableOp, + "unimplemented: Cannot fuse op with " + "multiple uses of producer loop"); + } } FailureOr fuseConsumerResults = diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir index 6b66bf718e7a..0b97a1880ea7 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir @@ -486,3 +486,35 @@ func.func @forall_hoist_unit_loop_with_fill(%3: tensor<1x128xf16>, %4: tensor<12 // CHECK: scf.forall.in_parallel // CHECK-NEXT: tensor.parallel_insert_slice %[[LOOP]] into %[[ITER]] // CHECK: return %[[OUTER_PARALLEL]] + +// ----- + +func.func @no_fuse_multi_use(%2: tensor<128x128xf16>, %3: tensor<128x128xf16>) -> tensor<128x128xf16> { + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + %empty = tensor.empty() : tensor<128x128xf16> + %10:2 = scf.forall (%arg5, %arg6) in (32, 32) shared_outs(%arg7 = %empty, %arg8 = %empty) -> (tensor<128x128xf16>, tensor<128x128xf16>) { + %extracted_slice_1 = tensor.extract_slice %2[%arg5, %arg6] [2, 2] [1, 1] : tensor<128x128xf16> to tensor<2x2xf16> + %extracted_slice_2 = tensor.extract_slice %arg7[%arg5, %arg6] [2, 2] [1, 1] : tensor<128x128xf16> to tensor<2x2xf16> + %extracted_slice_3 = tensor.extract_slice %arg8[%arg6, %arg5] [2, 2] [1, 1] : tensor<128x128xf16> to tensor<2x2xf16> + %16 = linalg.copy ins(%extracted_slice_1 : tensor<2x2xf16>) outs(%extracted_slice_2 : tensor<2x2xf16>) -> tensor<2x2xf16> + %17 = linalg.transpose ins(%extracted_slice_1 : tensor<2x2xf16>) outs(%extracted_slice_3 : tensor<2x2xf16>) permutation = [1, 0] + scf.forall.in_parallel { + tensor.parallel_insert_slice %16 into %arg7[%arg5, %arg6] [2, 2] [1, 1] : tensor<2x2xf16> into tensor<128x128xf16> + tensor.parallel_insert_slice %17 into %arg8[%arg6, %arg5] [2, 2] [1, 1] : tensor<2x2xf16> into tensor<128x128xf16> + } + } {mapping = [#gpu.thread, #gpu.thread]} + %add = linalg.add + ins(%10#0, %10#1 : tensor<128x128xf16>, tensor<128x128xf16>) + outs(%empty: tensor<128x128xf16>) -> tensor<128x128xf16> + return %add : tensor<128x128xf16> +} + +// CHECK-LABEL: func @no_fuse_multi_use +// CHECK: scf.forall +// CHECK: linalg.copy +// CHECK: linalg.transpose +// CHECK: scf.forall.in_parallel +// CHECK: linalg.add +// CHECK: return