From 0889d13501b377decc2cc414dd2483dcddc45d33 Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Wed, 9 Oct 2024 01:29:51 +0530 Subject: [PATCH] Add the support to yield multiple results (#18717) In case of multiple results the tileDispatchUsingForall yields a mix of tiled as well as untiled versions. This fix yields multiple shared_outs and hence only yields tiled version. --- .../Common/TileDispatchUsingForall.cpp | 50 +++++++++++++++++-- ...nd_distribute_workgroups_using_forall.mlir | 34 +++++++++++++ 2 files changed, 81 insertions(+), 3 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp index e722fcb8240b..ebbe585bf53e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp @@ -306,6 +306,27 @@ static void fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) { } } +/// Starting from `op` walk all operands backwards to find all +/// potentially fusable operations, i.e. operations that implement +/// the `TilingInterface`. +static void collectTiledAndFusedOps(Operation *rootOp, + llvm::SmallDenseSet &result) { + SmallVector worklist; + worklist.push_back(rootOp); + result.insert(rootOp); + while (!worklist.empty()) { + Operation *current = worklist.pop_back_val(); + for (OpOperand &operand : current->getOpOperands()) { + Operation *producer = operand.get().getDefiningOp(); + if (!producer || !isa(producer) || + result.count(producer)) + continue; + worklist.push_back(producer); + result.insert(producer); + } + } +} + void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() { auto funcOp = getOperation(); auto *context = &getContext(); @@ -322,6 +343,18 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() { // Did not find a tileable op. So do nothing. return; } + mlir::DominanceInfo dominanceInfo(tilableOp); + llvm::SmallDenseSet tiledAndFusedOps; + collectTiledAndFusedOps(tilableOp, tiledAndFusedOps); + + llvm::DenseSet yieldReplacementsFor; + for (auto op : tiledAndFusedOps) { + if (llvm::any_of(op->getUsers(), [&](Operation *user) { + return dominanceInfo.properlyDominates(tilableOp, user); + })) { + yieldReplacementsFor.insert(op); + } + } scf::SCFTilingOptions tilingOptions; tilingOptions.setTileSizes(tilingInfo->tileSizes); @@ -337,9 +370,20 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() { scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.setTilingOptions(tilingOptions); - // TODO: For now use the default tile and fuse control function. That needs - // to be modified to allow for returning the values of the producer when - // needed. + + // The control function that determines whether a tiled producer should yield + // its replacement. + scf::SCFTileAndFuseOptions::ControlFnTy controlFn = + [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, + bool isDestinationOperand) + -> std::optional { + Operation *owner = originalProducer.getOwner(); + bool yieldProducerReplacement = yieldReplacementsFor.contains(owner); + return scf::SCFTileAndFuseOptions::ControlFnResult{ + yieldProducerReplacement}; + return std::nullopt; + }; + tileAndFuseOptions.setFusionControlFn(controlFn); rewriter.setInsertionPoint(tilableOp); // If the `tilableOp` is a `memref` op, then just tile the operation. diff --git a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir index 8f40d251c92d..06987c14e661 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir @@ -485,3 +485,37 @@ func.func @matmul_consumer_fusion_test(%arg0 : tensor, // CHECK: scf.forall.in_parallel // CHECK: tensor.parallel_insert_slice %[[RELU]] // CHECK: return %[[RESULT]] + +// ----- + +func.func @multi_result(%arg0: tensor<64x128xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<256xf32>) -> (tensor<64x256xf32>, tensor<64x256xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<64x256xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64x256xf32>) -> tensor<64x256xf32> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<64x128xf32>, tensor<128x256xf32>) outs(%1 : tensor<64x256xf32>) -> tensor<64x256xf32> + %3 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%2, %arg2 : tensor<64x256xf32>, tensor<256xf32>) + outs(%0 : tensor<64x256xf32>) + attrs = {lowering_config = + #iree_codegen.lowering_config} { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %4 = arith.addf %in, %in_0 : f32 + linalg.yield %4 : f32 + } -> tensor<64x256xf32> + return %2, %3 : tensor<64x256xf32>, tensor<64x256xf32> +} + +// CHECK-LABEL: func @multi_result( +// CHECK: %[[RESULT:.+]]:2 = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) +// CHECK-SAME: shared_outs(%[[OUTS:.+]] = {{.*}}, %[[OUTS:.+]] = {{.*}}) +// CHECK: linalg.matmul +// CHECK: linalg.generic +// CHECK: scf.forall.in_parallel +// CHECK: tensor.parallel_insert_slice +// CHECK: tensor.parallel_insert_slice +// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0