Skip to content

Commit

Permalink
Add the support to yield multiple results (iree-org#18717)
Browse files Browse the repository at this point in the history
In case of multiple results the tileDispatchUsingForall yields a mix of
tiled as well as untiled versions. This fix yields multiple shared_outs
and hence only yields tiled version.
  • Loading branch information
pashu123 authored Oct 8, 2024
1 parent 7fb28e0 commit 0889d13
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,27 @@ static void fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {
}
}

/// Starting from `op` walk all operands backwards to find all
/// potentially fusable operations, i.e. operations that implement
/// the `TilingInterface`.
static void collectTiledAndFusedOps(Operation *rootOp,
llvm::SmallDenseSet<Operation *> &result) {
SmallVector<Operation *> worklist;
worklist.push_back(rootOp);
result.insert(rootOp);
while (!worklist.empty()) {
Operation *current = worklist.pop_back_val();
for (OpOperand &operand : current->getOpOperands()) {
Operation *producer = operand.get().getDefiningOp();
if (!producer || !isa<TilingInterface>(producer) ||
result.count(producer))
continue;
worklist.push_back(producer);
result.insert(producer);
}
}
}

void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
auto funcOp = getOperation();
auto *context = &getContext();
Expand All @@ -322,6 +343,18 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
// Did not find a tileable op. So do nothing.
return;
}
mlir::DominanceInfo dominanceInfo(tilableOp);
llvm::SmallDenseSet<Operation *> tiledAndFusedOps;
collectTiledAndFusedOps(tilableOp, tiledAndFusedOps);

llvm::DenseSet<Operation *> yieldReplacementsFor;
for (auto op : tiledAndFusedOps) {
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
return dominanceInfo.properlyDominates(tilableOp, user);
})) {
yieldReplacementsFor.insert(op);
}
}

scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizes(tilingInfo->tileSizes);
Expand All @@ -337,9 +370,20 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {

scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.setTilingOptions(tilingOptions);
// TODO: For now use the default tile and fuse control function. That needs
// to be modified to allow for returning the values of the producer when
// needed.

// The control function that determines whether a tiled producer should yield
// its replacement.
scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
[&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
bool isDestinationOperand)
-> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
Operation *owner = originalProducer.getOwner();
bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
return scf::SCFTileAndFuseOptions::ControlFnResult{
yieldProducerReplacement};
return std::nullopt;
};
tileAndFuseOptions.setFusionControlFn(controlFn);
rewriter.setInsertionPoint(tilableOp);

// If the `tilableOp` is a `memref` op, then just tile the operation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,3 +485,37 @@ func.func @matmul_consumer_fusion_test(%arg0 : tensor<?x?xf16>,
// CHECK: scf.forall.in_parallel
// CHECK: tensor.parallel_insert_slice %[[RELU]]
// CHECK: return %[[RESULT]]

// -----

func.func @multi_result(%arg0: tensor<64x128xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<256xf32>) -> (tensor<64x256xf32>, tensor<64x256xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<64x256xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64x256xf32>) -> tensor<64x256xf32>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<64x128xf32>, tensor<128x256xf32>) outs(%1 : tensor<64x256xf32>) -> tensor<64x256xf32>
%3 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%2, %arg2 : tensor<64x256xf32>, tensor<256xf32>)
outs(%0 : tensor<64x256xf32>)
attrs = {lowering_config =
#iree_codegen.lowering_config<tile_sizes = [[16, 64, 0]]>} {
^bb0(%in: f32, %in_0: f32, %out: f32):
%4 = arith.addf %in, %in_0 : f32
linalg.yield %4 : f32
} -> tensor<64x256xf32>
return %2, %3 : tensor<64x256xf32>, tensor<64x256xf32>
}

// CHECK-LABEL: func @multi_result(
// CHECK: %[[RESULT:.+]]:2 = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]])
// CHECK-SAME: shared_outs(%[[OUTS:.+]] = {{.*}}, %[[OUTS:.+]] = {{.*}})
// CHECK: linalg.matmul
// CHECK: linalg.generic
// CHECK: scf.forall.in_parallel
// CHECK: tensor.parallel_insert_slice
// CHECK: tensor.parallel_insert_slice
// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0

0 comments on commit 0889d13

Please sign in to comment.