From 0e16a89f7dbb89cc550254b0336abbea8b8c8f4b Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Tue, 8 Oct 2024 18:24:08 -0400 Subject: [PATCH] [Codegen][GPU] Disable consumer fusion for multi use cases (#18723) The upstream patterns for doing consumer fusion currently don't support cases where multiple operands of the consumer come from the producer loop. This disables fusion of such cases and sends it down the fallback path. --- .../Transforms/FuseAndHoistParallelLoops.cpp | 15 ++++++++- .../test/fuse_and_hoist_forall.mlir | 32 +++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp index 09160d69f27b..3b6236bc83f1 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp @@ -276,6 +276,8 @@ struct FuseTilableForallConsumers final } tensor::ParallelInsertSliceOp producerSlice; + scf::ForallOp sliceOwner; + Value fusionOperand; for (auto operand : dpsOp.getDpsInputs()) { auto forallProducer = operand.getDefiningOp(); if (!forallProducer) { @@ -288,6 +290,8 @@ struct FuseTilableForallConsumers final auto sliceOp = dyn_cast(user); if (sliceOp && sliceOp.getDest() == iterArg) { producerSlice = sliceOp; + sliceOwner = forallProducer; + fusionOperand = operand; break; } } @@ -297,7 +301,16 @@ struct FuseTilableForallConsumers final } if (!producerSlice) { - return failure(); + return rewriter.notifyMatchFailure(tilableOp, + "no scf.forall producer to fuse into"); + } + + for (auto operand : tilableOp->getOperands()) { + if (operand != fusionOperand && operand.getDefiningOp() == sliceOwner) { + return rewriter.notifyMatchFailure(tilableOp, + "unimplemented: Cannot fuse op with " + "multiple uses of producer loop"); + } } FailureOr fuseConsumerResults = diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir index 6b66bf718e7a..0b97a1880ea7 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir @@ -486,3 +486,35 @@ func.func @forall_hoist_unit_loop_with_fill(%3: tensor<1x128xf16>, %4: tensor<12 // CHECK: scf.forall.in_parallel // CHECK-NEXT: tensor.parallel_insert_slice %[[LOOP]] into %[[ITER]] // CHECK: return %[[OUTER_PARALLEL]] + +// ----- + +func.func @no_fuse_multi_use(%2: tensor<128x128xf16>, %3: tensor<128x128xf16>) -> tensor<128x128xf16> { + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + %empty = tensor.empty() : tensor<128x128xf16> + %10:2 = scf.forall (%arg5, %arg6) in (32, 32) shared_outs(%arg7 = %empty, %arg8 = %empty) -> (tensor<128x128xf16>, tensor<128x128xf16>) { + %extracted_slice_1 = tensor.extract_slice %2[%arg5, %arg6] [2, 2] [1, 1] : tensor<128x128xf16> to tensor<2x2xf16> + %extracted_slice_2 = tensor.extract_slice %arg7[%arg5, %arg6] [2, 2] [1, 1] : tensor<128x128xf16> to tensor<2x2xf16> + %extracted_slice_3 = tensor.extract_slice %arg8[%arg6, %arg5] [2, 2] [1, 1] : tensor<128x128xf16> to tensor<2x2xf16> + %16 = linalg.copy ins(%extracted_slice_1 : tensor<2x2xf16>) outs(%extracted_slice_2 : tensor<2x2xf16>) -> tensor<2x2xf16> + %17 = linalg.transpose ins(%extracted_slice_1 : tensor<2x2xf16>) outs(%extracted_slice_3 : tensor<2x2xf16>) permutation = [1, 0] + scf.forall.in_parallel { + tensor.parallel_insert_slice %16 into %arg7[%arg5, %arg6] [2, 2] [1, 1] : tensor<2x2xf16> into tensor<128x128xf16> + tensor.parallel_insert_slice %17 into %arg8[%arg6, %arg5] [2, 2] [1, 1] : tensor<2x2xf16> into tensor<128x128xf16> + } + } {mapping = [#gpu.thread, #gpu.thread]} + %add = linalg.add + ins(%10#0, %10#1 : tensor<128x128xf16>, tensor<128x128xf16>) + outs(%empty: tensor<128x128xf16>) -> tensor<128x128xf16> + return %add : tensor<128x128xf16> +} + +// CHECK-LABEL: func @no_fuse_multi_use +// CHECK: scf.forall +// CHECK: linalg.copy +// CHECK: linalg.transpose +// CHECK: scf.forall.in_parallel +// CHECK: linalg.add +// CHECK: return