Skip to content

Commit

Permalink
[AMD] re-enable canonicalize pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Dec 4, 2024
1 parent b7e0601 commit 9390a46
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 19 deletions.
2 changes: 1 addition & 1 deletion third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ def make_ttgir(mod, metadata, options):
if amd.has_matrix_core_feature(options.arch):
amd.passes.ttgpuir.add_reorder_instructions(pm)

amd.passes.ttgpuir.add_canonicalize_pointers(pm)
if use_buffer_ops:
amd.passes.ttgpuir.add_canonicalize_pointers(pm)
passes.common.add_canonicalizer(pm)
amd.passes.ttgpuir.add_convert_to_buffer_ops(pm)
passes.common.add_canonicalizer(pm)
Expand Down
78 changes: 60 additions & 18 deletions third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,28 +611,71 @@ LogicalResult PointerCanonicalizer::rewriteForOp(scf::ForOp forOp,
Location curLoc,
OpOperand *curOperand,
Value &nextPtr) {
size_t operandNum = curOperand->getOperandNumber();
FatPtr fatPtr = pointers[curOperand->get()];
Value offset = fatPtr.offset;
Value basePtr = fatPtr.basePtr;

// Replace the forOp with two additional argument (i.e., the curOperand's
// scalar pointer and the offset)
Value tensorPtr = createTensorPointer(fatPtr, curLoc);
auto newForOp =
replaceForOpWithNewSignature(rewriter, forOp, {basePtr, offset});
size_t operandNum = curOperand->getOperandNumber();
auto numLeadingInitsIncludingCurOp =
operandNum + 1 - forOp.getNumControlOperands();

SmallVector<Value> newInitArgs(
forOp.getInitArgs().take_front(numLeadingInitsIncludingCurOp));
newInitArgs.append({basePtr, offset});
auto trailingInits =
forOp.getInitArgs().drop_front(numLeadingInitsIncludingCurOp);
newInitArgs.append(trailingInits.begin(), trailingInits.end());

scf::ForOp newForOp;
{
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(forOp);
newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newInitArgs);
newForOp->setAttrs(forOp->getAttrs());
newForOp.getBody()->erase();
newForOp.getRegion().getBlocks().splice(
newForOp.getRegion().getBlocks().begin(),
forOp.getRegion().getBlocks());

newForOp.getBody()->insertArgument(numLeadingInitsIncludingCurOp + 1,
offset.getType(), offset.getLoc());
newForOp.getBody()->insertArgument(numLeadingInitsIncludingCurOp + 1,
basePtr.getType(), basePtr.getLoc());

for (auto [src, tgt] : llvm::zip(
forOp.getResults().take_front(numLeadingInitsIncludingCurOp),
newForOp.getResults().take_front(numLeadingInitsIncludingCurOp)))
rewriter.replaceAllUsesWith(src, tgt);
for (auto [src, tgt] :
llvm::zip(forOp.getResults().drop_front(numLeadingInitsIncludingCurOp),
newForOp.getResults().drop_front(
numLeadingInitsIncludingCurOp + 2)))
rewriter.replaceAllUsesWith(src, tgt);

newForOp.getBody()->getTerminator()->insertOperands(
numLeadingInitsIncludingCurOp,
newForOp.getRegionIterArg(numLeadingInitsIncludingCurOp));
newForOp.getBody()->getTerminator()->insertOperands(
numLeadingInitsIncludingCurOp + 1,
newForOp.getRegionIterArg(numLeadingInitsIncludingCurOp + 1));
}

rewriteOpMap[forOp] = newForOp;

newForOp->setOperand(operandNum, tensorPtr);
{
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(newForOp);
Value tensorPtr = createTensorPointer(fatPtr, curLoc);
newForOp->setOperand(operandNum, tensorPtr);
}

OpOperand *forOperand = &newForOp->getOpOperand(operandNum);
// This is making sure we propagate the visit from the forOp result
nextPtr = newForOp.getTiedLoopResult(forOperand);

// This is making sure we visit the uses within the forOp region
Value arg = newForOp.getTiedLoopRegionIterArg(forOperand);
size_t numIterArgs = newForOp.getNumRegionIterArgs();
pointers[arg] = fatPtr.copy(newForOp.getRegionIterArg(numIterArgs - 2),
newForOp.getRegionIterArg(numIterArgs - 1));
pointers[arg] = fatPtr.copy(basePtr, offset);

// Collect attributes before continuing the visit
collectFatPointerAttributes(newForOp, arg);
Expand All @@ -642,9 +685,9 @@ LogicalResult PointerCanonicalizer::rewriteForOp(scf::ForOp forOp,

// This is setting the fat pointer for the users of the loop
// and then propagate the result
size_t numResults = newForOp->getNumResults();
pointers[nextPtr] = fatPtr.copy(newForOp->getResult(numResults - 2),
newForOp.getResult(numResults - 1));
nextPtr = newForOp.getTiedLoopResult(forOperand);
pointers[nextPtr] = fatPtr.copy(basePtr, offset);

opToDelete.insert(forOp);
return success();
}
Expand All @@ -659,9 +702,6 @@ LogicalResult PointerCanonicalizer::rewriteYieldOp(scf::YieldOp yieldOp,
// IfOp
size_t operandNum = curOperand->getOperandNumber();
FatPtr fatPtr = pointers[curOperand->get()];
yieldOp.getResultsMutable().append(fatPtr.basePtr);
yieldOp.getResultsMutable().append(fatPtr.offset);

if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
yieldOp->setOperand(operandNum, forOp.getRegionIterArg(operandNum));
} else if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
Expand All @@ -683,6 +723,8 @@ LogicalResult PointerCanonicalizer::rewriteYieldOp(scf::YieldOp yieldOp,

} else if (auto whileOp = resolveOp<scf::WhileOp>(yieldOp->getParentOp(),
rewriteOpMap)) {
yieldOp.getResultsMutable().append(fatPtr.basePtr);
yieldOp.getResultsMutable().append(fatPtr.offset);
// Case 2: the yieldOp is contained within the AfterRegion of a
// WhileOp. In this case, we know that the before region should have
// already been replaced (when we met the WhileOp), hence we can
Expand Down

0 comments on commit 9390a46

Please sign in to comment.