Skip to content

Commit

Permalink
[CodeGen] Introduce EraseDeadAllocAndStores pass. (iree-org#14365)
Browse files Browse the repository at this point in the history
It also refactors the implementation to Utils/.
  • Loading branch information
hanhanW authored Jul 13, 2023
1 parent 93e492b commit 554d40a
Show file tree
Hide file tree
Showing 12 changed files with 95 additions and 72 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ iree_compiler_cc_library(
"DecomposeConvolutionToLowerDimOps.cpp",
"DecomposeLinalgGeneric.cpp",
"DecomposePackUnPackOps.cpp",
"EraseDeadAllocAndStores.cpp",
"EraseHALDescriptorTypeFromMemRef.cpp",
"ExtractAddressComputation.cpp",
"FlattenMemRefSubspanPass.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ iree_cc_library(
"DecomposeConvolutionToLowerDimOps.cpp"
"DecomposeLinalgGeneric.cpp"
"DecomposePackUnPackOps.cpp"
"EraseDeadAllocAndStores.cpp"
"EraseHALDescriptorTypeFromMemRef.cpp"
"ExtractAddressComputation.cpp"
"FlattenMemRefSubspanPass.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/PassDetail.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
namespace iree_compiler {
namespace {
class EraseDeadAllocAndStoresPass
: public EraseDeadAllocAndStoresBase<EraseDeadAllocAndStoresPass> {
public:
using EraseDeadAllocAndStoresBase::EraseDeadAllocAndStoresBase;

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<scf::SCFDialect, vector::VectorDialect>();
}
void runOnOperation() override;
};

void EraseDeadAllocAndStoresPass::runOnOperation() {
auto funcOp = getOperation();
eraseDeadAllocAndStores(funcOp);
}
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
createEraseDeadAllocAndStoresPass() {
return std::make_unique<EraseDeadAllocAndStoresPass>();
}

} // namespace iree_compiler
} // namespace mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree/compiler/Codegen/Common/PassDetail.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
Expand All @@ -19,40 +20,6 @@

namespace mlir {
namespace iree_compiler {

// Return true if all the uses of op are either Store/transfer_write.
// There can be SubviewOp users as long as all its users are also
// StoreOp/transfer_write. If return true it also fills out the uses, if it
// returns false uses is unchanged.
static bool allUsesAreStores(Operation *op, std::vector<Operation *> &uses) {
std::vector<Operation *> opUses;
for (OpOperand &use : op->getUses()) {
Operation *useOp = use.getOwner();
if (isa<vector::TransferWriteOp, memref::StoreOp>(useOp) ||
(isa<memref::SubViewOp>(useOp) && allUsesAreStores(useOp, opUses))) {
opUses.push_back(useOp);
continue;
}
return false;
}
uses.insert(uses.end(), opUses.begin(), opUses.end());
return true;
}

// Track temporary allocations that are never read from. If this is the case
// it means both the allocations and associated stores can be removed.
static void eraseDeadAllocAndStores(func::FuncOp funcOp) {
std::vector<Operation *> opToErase;
funcOp.walk([&](memref::AllocOp op) {
if (allUsesAreStores(op, opToErase)) {
opToErase.push_back(op.getOperation());
}
});
for (Operation *op : opToErase) {
op->erase();
}
}

namespace {

// Pattern to canonialize tranpose where only one dimension is not unit
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ createDecomposePackUnPackOpsPass(bool tileOuterToOne = false);
/// during bufferization.
std::unique_ptr<OperationPass<ModuleOp>> createEliminateEmptyTensorsPass();

/// Creates a pass to erase dead alloc ops where all uses are just store ops.
std::unique_ptr<OperationPass<func::FuncOp>>
createEraseDeadAllocAndStoresPass();

std::unique_ptr<OperationPass<func::FuncOp>>
createEraseHALDescriptorTypeFromMemRefPass();

Expand Down
7 changes: 7 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ def EliminateEmptyTensors :
let constructor = "mlir::iree_compiler::createEliminateEmptyTensorsPass()";
}

def EraseDeadAllocAndStores :
Pass<"iree-codegen-erase-dead-alloc-and-stores", "func::FuncOp"> {
let summary = "Erase alloc ops if all the uses are just stores";
let constructor =
"mlir::iree_compiler::createEraseDeadAllocAndStoresPass()";
}

def EraseHALDescriptorTypeFromMemRef :
Pass<"iree-codegen-erase-hal-descriptor-type-from-memref", "func::FuncOp"> {
let summary = "Erase #hal.descriptor_type from MemRef memory space";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,40 +67,6 @@ void mlir::iree_compiler::registerTransformDialectCommonExtension(
mlir::iree_compiler::IREE::transform_dialect::CommonExtensions>();
}

// Return true if all the uses of op are either Store/transfer_write.
// There can be SubviewOp users as long as all its users are also
// StoreOp/transfer_write. If return true it also fills out the uses, if it
// returns false uses is unchanged.
static bool allUsesAreStores(Operation *op, std::vector<Operation *> &uses) {
std::vector<Operation *> opUses;
for (OpOperand &use : op->getUses()) {
Operation *useOp = use.getOwner();
if (isa<memref::DeallocOp, vector::TransferWriteOp, memref::StoreOp>(
useOp) ||
(isa<memref::SubViewOp>(useOp) && allUsesAreStores(useOp, opUses))) {
opUses.push_back(useOp);
continue;
}
return false;
}
uses.insert(uses.end(), opUses.begin(), opUses.end());
return true;
}

// Track temporary allocations that are never read from. If this is the case
// it means both the allocations and associated stores can be removed.
static void eraseDeadAllocAndStores(RewriterBase &rewriter,
Operation *parentOp) {
std::vector<Operation *> opToErase;
parentOp->walk([&](memref::AllocOp op) {
if (allUsesAreStores(op, opToErase)) {
opToErase.push_back(op.getOperation());
}
});
for (Operation *op : opToErase)
rewriter.eraseOp(op);
}

//===---------------------------------------------------------------------===//
// ApplyBufferOptimizationsOp
//===---------------------------------------------------------------------===//
Expand All @@ -112,7 +78,7 @@ transform_dialect::ApplyBufferOptimizationsOp::applyToOne(
transform::TransformState &state) {
// Apply store to load forwarding and dead store elimination.
vector::transferOpflowOpt(rewriter, target);
eraseDeadAllocAndStores(rewriter, target);
eraseDeadAllocAndStores(target);
return DiagnosedSilenceableFailure::success();
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ iree_lit_test_suite(
"convert_bf16_arith_to_f32.mlir",
"convert_to_destination_passing_style.mlir",
"convolutions.mlir",
"dead_alloc.mlir",
"erase_dead_alloc_and_stores.mlir",
"decompose_affine_ops.mlir",
"decompose_linalg_generic.mlir",
"decompose_pack_unpack_ops.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ iree_lit_test_suite(
"convert_bf16_to_uint16_buffers.mlir"
"convert_to_destination_passing_style.mlir"
"convolutions.mlir"
"dead_alloc.mlir"
"decompose_affine_ops.mlir"
"decompose_linalg_generic.mlir"
"decompose_pack_unpack_ops.mlir"
"eliminate_empty_tensors.mlir"
"erase_dead_alloc_and_stores.mlir"
"erase_hal_descriptor_type.mlir"
"extract_address_computation.mlir"
"flatten_memref_subspan.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --iree-codegen-optimize-vector-transfer %s | FileCheck %s
// RUN: iree-opt --iree-codegen-erase-dead-alloc-and-stores %s | FileCheck %s

module {
func.func @dead_alloc() {
Expand Down
33 changes: 33 additions & 0 deletions compiler/src/iree/compiler/Codegen/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/SymbolTable.h"
Expand Down Expand Up @@ -879,5 +880,37 @@ SmallVector<int64_t> getStaticNumWorkgroups(func::FuncOp funcOp) {
return result;
}

// Return true if all the uses of op are either Store/transfer_write.
// There can be SubviewOp users as long as all its users are also
// StoreOp/transfer_write. If return true it also fills out the uses, if it
// returns false uses is unchanged.
static bool allUsesAreStores(Operation *op, std::vector<Operation *> &uses) {
std::vector<Operation *> opUses;
for (OpOperand &use : op->getUses()) {
Operation *useOp = use.getOwner();
if (isa<memref::DeallocOp, vector::TransferWriteOp, memref::StoreOp>(
useOp) ||
(isa<memref::SubViewOp>(useOp) && allUsesAreStores(useOp, opUses))) {
opUses.push_back(useOp);
continue;
}
return false;
}
uses.insert(uses.end(), opUses.begin(), opUses.end());
return true;
}

void eraseDeadAllocAndStores(Operation *parentOp) {
std::vector<Operation *> opToErase;
parentOp->walk([&](memref::AllocOp op) {
if (allUsesAreStores(op, opToErase)) {
opToErase.push_back(op.getOperation());
}
});
for (Operation *op : opToErase) {
op->erase();
}
}

} // namespace iree_compiler
} // namespace mlir
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ void replaceMemrefUsesAndPropagateType(RewriterBase &rewriter, Location loc,
void sinkOpsInCFG(const SmallVector<Operation *> &allocs,
DominanceInfo &dominators);

// Track temporary allocations that are never read from. If this is the case
// it means both the allocations and associated stores can be removed.
void eraseDeadAllocAndStores(Operation *parentOp);

} // namespace iree_compiler
} // namespace mlir

Expand Down

0 comments on commit 554d40a

Please sign in to comment.