From d1f4d6a4b93935ea180400186c4ff5b8bfb20f64 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 16 Dec 2024 14:58:59 -0500 Subject: [PATCH] handle scf.while (TODO: handle cond arg) --- .../CanonicalizePointers.cpp | 105 +++++++++++++++--- 1 file changed, 88 insertions(+), 17 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index 2ad8ca7de944..d325e438c902 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -1329,11 +1329,6 @@ class ConvertSCFYieldOp : public PointerCanonPattern { matchAndRewrite(scf::YieldOp yieldOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector newYieldedValues = flattenValues(adaptor.getOperands()); - - // Value tensorPtr = - // createTensorPointer(rewriter, fatPtr.basePtr, fatPtr.offset, curLoc, - // fatPtr.canNarrow, fatPtr.attributes); - rewriter.modifyOpInPlace(yieldOp, [&]() { yieldOp.getResultsMutable().clear(); yieldOp.getResultsMutable().append(newYieldedValues); @@ -1400,18 +1395,10 @@ class ConvertSCFForOp : public PointerCanonPattern { LogicalResult matchAndRewrite(scf::ForOp forOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector> fatPtrInits; SmallVector valRangeLens; ArrayRef remappedInits = adaptor.getInitArgs(); - for (ValueRange remappedInit : remappedInits) { - if (remappedInit.size() == 2) { - Value fatPtrBase = remappedInit[0]; - Value fatPtrOffset = remappedInit[1]; - fatPtrInits.emplace_back(fatPtrBase, fatPtrOffset); - } + for (ValueRange remappedInit : remappedInits) valRangeLens.push_back(remappedInit.size()); - } - TypeConverter hackTypeConverter; unsigned inputNo = 0; hackTypeConverter.addConversion( @@ -1436,7 +1423,6 @@ class ConvertSCFForOp : public PointerCanonPattern { forOp.getLoc(), getSingleValue(adaptor.getLowerBound()), getSingleValue(adaptor.getUpperBound()), getSingleValue(adaptor.getStep()), initArgs); - // replaceWithAdditionalYields newForOp->setAttrs(forOp->getAttrs()); rewriter.eraseBlock(newForOp.getBody(0)); @@ -1504,6 +1490,88 @@ class ConvertSCFIfOp : public PointerCanonPattern { } }; +class ConvertSCFWhileOp : public PointerCanonPattern { +public: + using PointerCanonPattern::PointerCanonPattern; + LogicalResult + matchAndRewrite(scf::WhileOp whileOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector valRangeLens; + ArrayRef remappedInits = adaptor.getInits(); + for (ValueRange remappedInit : remappedInits) + valRangeLens.push_back(remappedInit.size()); + TypeConverter hackTypeConverter; + unsigned inputNo = 0; + hackTypeConverter.addConversion( + [&inputNo, &remappedInits = std::as_const(remappedInits)]( + Type inputType, SmallVectorImpl &types) { + SmallVector remappedInitTypes = + llvm::to_vector(remappedInits[inputNo].getTypes()); + types.append(remappedInitTypes); + inputNo++; + return success(); + }); + if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(), + hackTypeConverter))) + return failure(); + if (failed(rewriter.convertRegionTypes(&whileOp.getAfter(), + hackTypeConverter))) + return failure(); + + SmallVector initArgs = flattenValues(remappedInits); + SmallVector resultTypes = + llvm::map_to_vector(initArgs, [](Value v) { return v.getType(); }); + auto newWhileOp = + rewriter.create(whileOp.getLoc(), resultTypes, initArgs); + + newWhileOp->setAttrs(whileOp->getAttrs()); + rewriter.inlineRegionBefore(whileOp.getBefore(), newWhileOp.getBefore(), + newWhileOp.getBefore().end()); + rewriter.inlineRegionBefore(whileOp.getAfter(), newWhileOp.getAfter(), + newWhileOp.getAfter().end()); + + SmallVector packedRets; + for (unsigned i = 0, offset = 0; i < valRangeLens.size(); i++) { + size_t len = valRangeLens[i]; + assert(offset < newWhileOp->getNumResults() && + "expected offset to be within bounds of results"); + ValueRange mappedValue = newWhileOp->getResults().slice(offset, len); + packedRets.push_back(mappedValue); + offset += len; + } + + rewriter.modifyOpInPlace(whileOp, [&] { + whileOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); + }); + rewriter.modifyOpInPlace(newWhileOp, [&] { + newWhileOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); + }); + rewriter.replaceOpWithMultiple(whileOp, packedRets); + + return success(); + } +}; + +class ConvertSCFConditionOp : public PointerCanonPattern { +public: + using PointerCanonPattern::PointerCanonPattern; + LogicalResult + matchAndRewrite(scf::ConditionOp condOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector newArgs = flattenValues(adaptor.getArgs()); + rewriter.modifyOpInPlace(condOp, [&]() { + condOp.getArgsMutable().clear(); + condOp.getArgsMutable().append(newArgs); + }); + + rewriter.modifyOpInPlace(condOp, [&] { + condOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); + }); + + return success(); + } +}; + void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { ModuleOp module = getOperation(); mlir::MLIRContext *context = &getContext(); @@ -1531,14 +1599,17 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { if (auto ifOp = llvm::dyn_cast(op)) return !(ifOp->hasAttr("then_rewritten") and ifOp->hasAttr("else_rewritten")); + if (llvm::isa(op) && !op->hasAttr("legal")) + return false; return isLegal(op); }); FatPointers fatPrs; patterns.add( - patterns.getContext(), fatPrs); + ConvertSCFForOp, ConvertSCFYieldOp, ConvertSCFIfOp, + ConvertSCFConditionOp, ConvertSCFWhileOp>(patterns.getContext(), + fatPrs); ConversionConfig config; config.buildMaterializations = false; if (failed(