diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 81b07f2e7d86..fff8421af5a3 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -250,8 +250,8 @@ def make_ttgir(mod, metadata, options): if amd.has_matrix_core_feature(options.arch): amd.passes.ttgpuir.add_reorder_instructions(pm) + amd.passes.ttgpuir.add_canonicalize_pointers(pm) if use_buffer_ops: - amd.passes.ttgpuir.add_canonicalize_pointers(pm) passes.common.add_canonicalizer(pm) amd.passes.ttgpuir.add_convert_to_buffer_ops(pm) passes.common.add_canonicalizer(pm) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index fa7d54e0fbee..20ab4246b4aa 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -611,28 +611,71 @@ LogicalResult PointerCanonicalizer::rewriteForOp(scf::ForOp forOp, Location curLoc, OpOperand *curOperand, Value &nextPtr) { - size_t operandNum = curOperand->getOperandNumber(); FatPtr fatPtr = pointers[curOperand->get()]; Value offset = fatPtr.offset; Value basePtr = fatPtr.basePtr; - // Replace the forOp with two additional argument (i.e., the curOperand's - // scalar pointer and the offset) - Value tensorPtr = createTensorPointer(fatPtr, curLoc); - auto newForOp = - replaceForOpWithNewSignature(rewriter, forOp, {basePtr, offset}); + size_t operandNum = curOperand->getOperandNumber(); + auto numLeadingInitsIncludingCurOp = + operandNum + 1 - forOp.getNumControlOperands(); + + SmallVector newInitArgs( + forOp.getInitArgs().take_front(numLeadingInitsIncludingCurOp)); + newInitArgs.append({basePtr, offset}); + auto trailingInits = + forOp.getInitArgs().drop_front(numLeadingInitsIncludingCurOp); + newInitArgs.append(trailingInits.begin(), trailingInits.end()); + + scf::ForOp newForOp; + { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(forOp); + newForOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), newInitArgs); + newForOp->setAttrs(forOp->getAttrs()); + newForOp.getBody()->erase(); + newForOp.getRegion().getBlocks().splice( + newForOp.getRegion().getBlocks().begin(), + forOp.getRegion().getBlocks()); + + newForOp.getBody()->insertArgument(numLeadingInitsIncludingCurOp + 1, + offset.getType(), offset.getLoc()); + newForOp.getBody()->insertArgument(numLeadingInitsIncludingCurOp + 1, + basePtr.getType(), basePtr.getLoc()); + + for (auto [src, tgt] : llvm::zip( + forOp.getResults().take_front(numLeadingInitsIncludingCurOp), + newForOp.getResults().take_front(numLeadingInitsIncludingCurOp))) + rewriter.replaceAllUsesWith(src, tgt); + for (auto [src, tgt] : + llvm::zip(forOp.getResults().drop_front(numLeadingInitsIncludingCurOp), + newForOp.getResults().drop_front( + numLeadingInitsIncludingCurOp + 2))) + rewriter.replaceAllUsesWith(src, tgt); + + newForOp.getBody()->getTerminator()->insertOperands( + numLeadingInitsIncludingCurOp, + newForOp.getRegionIterArg(numLeadingInitsIncludingCurOp)); + newForOp.getBody()->getTerminator()->insertOperands( + numLeadingInitsIncludingCurOp + 1, + newForOp.getRegionIterArg(numLeadingInitsIncludingCurOp + 1)); + } + rewriteOpMap[forOp] = newForOp; - newForOp->setOperand(operandNum, tensorPtr); + { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(newForOp); + Value tensorPtr = createTensorPointer(fatPtr, curLoc); + newForOp->setOperand(operandNum, tensorPtr); + } + OpOperand *forOperand = &newForOp->getOpOperand(operandNum); // This is making sure we propagate the visit from the forOp result - nextPtr = newForOp.getTiedLoopResult(forOperand); - // This is making sure we visit the uses within the forOp region Value arg = newForOp.getTiedLoopRegionIterArg(forOperand); - size_t numIterArgs = newForOp.getNumRegionIterArgs(); - pointers[arg] = fatPtr.copy(newForOp.getRegionIterArg(numIterArgs - 2), - newForOp.getRegionIterArg(numIterArgs - 1)); + pointers[arg] = fatPtr.copy(basePtr, offset); // Collect attributes before continuing the visit collectFatPointerAttributes(newForOp, arg); @@ -642,9 +685,9 @@ LogicalResult PointerCanonicalizer::rewriteForOp(scf::ForOp forOp, // This is setting the fat pointer for the users of the loop // and then propagate the result - size_t numResults = newForOp->getNumResults(); - pointers[nextPtr] = fatPtr.copy(newForOp->getResult(numResults - 2), - newForOp.getResult(numResults - 1)); + nextPtr = newForOp.getTiedLoopResult(forOperand); + pointers[nextPtr] = fatPtr.copy(basePtr, offset); + opToDelete.insert(forOp); return success(); } @@ -659,9 +702,6 @@ LogicalResult PointerCanonicalizer::rewriteYieldOp(scf::YieldOp yieldOp, // IfOp size_t operandNum = curOperand->getOperandNumber(); FatPtr fatPtr = pointers[curOperand->get()]; - yieldOp.getResultsMutable().append(fatPtr.basePtr); - yieldOp.getResultsMutable().append(fatPtr.offset); - if (auto forOp = dyn_cast(yieldOp->getParentOp())) { yieldOp->setOperand(operandNum, forOp.getRegionIterArg(operandNum)); } else if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { @@ -683,6 +723,8 @@ LogicalResult PointerCanonicalizer::rewriteYieldOp(scf::YieldOp yieldOp, } else if (auto whileOp = resolveOp(yieldOp->getParentOp(), rewriteOpMap)) { + yieldOp.getResultsMutable().append(fatPtr.basePtr); + yieldOp.getResultsMutable().append(fatPtr.offset); // Case 2: the yieldOp is contained within the AfterRegion of a // WhileOp. In this case, we know that the before region should have // already been replaced (when we met the WhileOp), hence we can