diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index 8b83231e85f6..8ab4780d2cba 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -473,8 +473,9 @@ Value createTensorPointer( if (canNarrow) offset = createNarrow64bitOffsetTo32bits(rewriter, loc, offset); - Value tensorPtr = - rewriter.create(loc, tensorPtrType, basePtr); + Value tensorPtr = rewriter.create( + loc, TypeRange{tensorPtrType}, ValueRange{basePtr}, + SmallVector{rewriter.getNamedAttr("legal", rewriter.getUnitAttr())}); auto addPtrOp = rewriter.create(loc, tensorPtrType, tensorPtr, offset); @@ -1049,9 +1050,91 @@ class TritonAMDGPUCanonicalizePointersPass void runOnOperationmine(); }; -class ConvertAddPtrOp : public OpConversionPattern { +struct FatPointers { + struct FatPtr { + bool canNarrow = false; + llvm::SmallDenseMap attributes; + }; + using KeyT = std::pair; + using ValueT = FatPtr; + using DenseMapT = DenseMap; + DenseMapT pointers; + ValueT &operator[](const KeyT &k) { return pointers[k]; } + ValueT &operator[](KeyT &&k) { return pointers[k]; } + template + using const_arg_type_t = typename llvm::const_pointer_or_const_ref::type; + const ValueT &at(const_arg_type_t k) const { return pointers.at(k); } +}; + +std::optional getFatPtrCastOp(Value base, + Value offset) { + std::optional maybeCastOp; + for (Operation *user : base.getUsers()) { + if (auto castOp = llvm::dyn_cast(user)) { + if (castOp.getNumOperands() == 2 && castOp.getOperand(0) == base && + castOp.getOperand(1) == offset) { + maybeCastOp = castOp; + } + } + } +#ifndef NDEBUG + for (Operation *user : offset.getUsers()) { + if (auto castOp = llvm::dyn_cast(user)) { + if (castOp.getNumOperands() == 2 && castOp.getOperand(0) == base && + castOp.getOperand(1) == offset) { + assert( + castOp == *maybeCastOp && + "expected castop through base and castop through offset to match"); + } + } + } +#endif + return maybeCastOp; +} + +std::optional getFatPtrCastOp(OpOperand &operand) { + Value operandVal = operand.get(); + for (Operation *user : operandVal.getUsers()) { + if (auto castOp = llvm::dyn_cast(user)) { + if (castOp.getNumOperands() == 2 && + (castOp.getOperand(0) == operandVal || + castOp.getOperand(1) == operandVal) && + castOp.getNumResults() == 1 && + std::distance(castOp->getUsers().begin(), castOp->getUsers().end()) == + 1 && + *castOp->getUsers().begin() == operand.getOwner()) { + return castOp; + } + } + } + return {}; +} + +/// Flatten the given value ranges into a single vector of values. +static SmallVector flattenValues(ArrayRef values) { + SmallVector result; + for (const ValueRange &vals : values) + llvm::append_range(result, vals); + return result; +} + +/// Assert that the given value range contains a single value and return it. +static Value getSingleValue(ValueRange values) { + assert(values.size() == 1 && "expected single value"); + return values.front(); +} + +template +struct PointerCanonPattern : OpConversionPattern { + PointerCanonPattern(MLIRContext *context, FatPointers &fatPtrs, + PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), fatPtrs(fatPtrs) {} + FatPointers &fatPtrs; +}; + +class ConvertAddPtrOp : public PointerCanonPattern { public: - using OpConversionPattern::OpConversionPattern; + using PointerCanonPattern::PointerCanonPattern; LogicalResult matchAndRewrite(triton::AddPtrOp addPtrOp, OneToNOpAdaptor adaptor, @@ -1061,11 +1144,11 @@ class ConvertAddPtrOp : public OpConversionPattern { ArrayRef remappedOperands = adaptor.getOperands(); assert(remappedOperands.size() == 2 && remappedOperands[0].size() == 2 && - "expected adaptor to have 2,1 remapped values"); + "expected adaptor to have 2 remapped values"); Value fatPtrBase = remappedOperands[0][0]; Value fatPtrOffset = remappedOperands[0][1]; Value origOffset = remappedOperands[1][0]; - auto curLoc = addPtrOp.getLoc(); + Location curLoc = addPtrOp.getLoc(); // If it is a scalar pointer update, simply bump the base pointer if (!isa(addPtrOp.getPtr().getType())) { @@ -1094,10 +1177,8 @@ class ConvertAddPtrOp : public OpConversionPattern { rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, fatPtrOffset}}); // If we are updating the tensor pointer with a uniform value, we can // propagate the attributes of the tensor pointer to the fat pointer. - // TODO(max): re-enable - // for (auto attribute : fatPtr.attributes) - // pointers[nextPtr].setAttr(attribute.getFirst(), attribute.getSecond()); - // opToDelete.insert(addPtrOp); + fatPtrs[{newAddPtrOp.getResult(), fatPtrOffset}].attributes = + fatPtrs[{fatPtrBase, fatPtrOffset}].attributes; return success(); } @@ -1107,10 +1188,8 @@ class ConvertAddPtrOp : public OpConversionPattern { decomposeOffsetFromExpr(rewriter, curLoc, origOffset, bitness); // Vector offset update (if any): bump the tensor offset - // TODO(max): stash somewhere bool canNarrow = false; - bool propagateAtrs = false; - + bool propagateAtrs = true; Value newOffset = fatPtrOffset; if (!isZeroConst(nonUniformOffset)) { Type addPtrOffsetType = getElementTypeOrSelf(nonUniformOffset); @@ -1139,22 +1218,19 @@ class ConvertAddPtrOp : public OpConversionPattern { addPtrOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); }); rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, newOffset}}); - // - // // If we are updating the tensor pointer with a uniform value, we can - // // propagate the attributes of the tensor pointer to the fat pointer. - // TODO(max): re-enable - // if (propagateAtrs) - // for (auto attribute : fatPtr.attributes) - // pointers[nextPtr].setAttr(attribute.getFirst(), - // attribute.getSecond()); + auto nextFatPtr = std::pair{newAddPtrOp.getResult(), newOffset}; + fatPtrs[nextFatPtr].canNarrow = canNarrow; + if (propagateAtrs) + fatPtrs[nextFatPtr].attributes = + fatPtrs.at({fatPtrBase, fatPtrOffset}).attributes; return success(); } }; -class ConvertSplatOp : public OpConversionPattern { +class ConvertSplatOp : public PointerCanonPattern { public: - using OpConversionPattern::OpConversionPattern; + using PointerCanonPattern::PointerCanonPattern; LogicalResult matchAndRewrite(triton::SplatOp splatOp, OneToNOpAdaptor adaptor, @@ -1171,8 +1247,8 @@ class ConvertSplatOp : public OpConversionPattern { assert(llvm::isa(fatPtrOffset.getType()) && "expected fatPtrOffset to be an integer type"); - auto outType = splatOp.getResult().getType(); - auto ptrShape = outType.getShape(); + RankedTensorType outType = splatOp.getResult().getType(); + llvm::ArrayRef ptrShape = outType.getShape(); auto newOffsetType = RankedTensorType::get(ptrShape, fatPtrOffset.getType(), outType.getEncoding()); Value offset = rewriter.create( @@ -1186,38 +1262,37 @@ class ConvertSplatOp : public OpConversionPattern { } }; -class ConvertLoadOp : public OpConversionPattern { +class ConvertLoadOp : public PointerCanonPattern { public: - using OpConversionPattern::OpConversionPattern; + using PointerCanonPattern::PointerCanonPattern; LogicalResult matchAndRewrite(triton::LoadOp loadOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto fatPtr = *adaptor.getOperands().begin(); + ValueRange fatPtr = *adaptor.getOperands().begin(); Value fatPtrBase = fatPtr.front(); Value fatPtrOffset = fatPtr.back(); Location curLoc = loadOp.getLoc(); - llvm::SmallDenseMap attributes{}; - auto newPtr = - createTensorPointer(rewriter, fatPtrBase, fatPtrOffset, curLoc, - // TODO(max): - /*canNarrow*/ true, attributes); + llvm::SmallDenseMap attributes{ + {rewriter.getStringAttr("legal"), rewriter.getUnitAttr()}}; + Value newPtr = createTensorPointer( + rewriter, fatPtrBase, fatPtrOffset, curLoc, + fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow, attributes); SmallVector operands = loadOp.getOperands().take_back(loadOp.getNumOperands() - 1); operands.insert(operands.begin(), newPtr); - auto newLoadPtrOp = rewriter.replaceOpWithNewOp( - loadOp, operands, loadOp->getAttrs()); - rewriter.modifyOpInPlace(newLoadPtrOp, [&] { - newLoadPtrOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); - }); + SmallVector attrs = llvm::to_vector(loadOp->getAttrs()); + attrs.append({rewriter.getNamedAttr("legal", rewriter.getUnitAttr())}); + auto newLoadPtrOp = + rewriter.replaceOpWithNewOp(loadOp, operands, attrs); return success(); } }; -class ConvertFuncOp : public OpConversionPattern { +class ConvertFuncOp : public PointerCanonPattern { public: - using OpConversionPattern::OpConversionPattern; + using PointerCanonPattern::PointerCanonPattern; LogicalResult matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, @@ -1246,15 +1321,35 @@ class ConvertFuncOp : public OpConversionPattern { } }; +class ConvertSCFYieldOp : public PointerCanonPattern { +public: + using PointerCanonPattern::PointerCanonPattern; + + LogicalResult + matchAndRewrite(scf::YieldOp yieldOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.modifyOpInPlace(yieldOp, [&] { + yieldOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); + }); + auto newYieldOp = rewriter.replaceOpWithNewOp( + yieldOp, flattenValues(adaptor.getOperands())); + rewriter.modifyOpInPlace(newYieldOp, [&] { + newYieldOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); + }); + return success(); + } +}; + class ConvertUnrealizedConversionCastOp - : public OpConversionPattern { + : public PointerCanonPattern { public: - using OpConversionPattern::OpConversionPattern; + using PointerCanonPattern::PointerCanonPattern; LogicalResult matchAndRewrite(UnrealizedConversionCastOp castOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(castOp->hasOneUse() && "expected at least 1 use of unrealized_cast"); + // dunno why but i get -Wdangling here... ArrayRef remappedOperands = adaptor.getOperands(); assert(remappedOperands.size() == 1 && remappedOperands[0].size() == 2 && "expected adaptor to have 2 remapped values"); @@ -1277,30 +1372,105 @@ class ConvertUnrealizedConversionCastOp } }; +class ConvertSCFForOp : public PointerCanonPattern { + using PointerCanonPattern::PointerCanonPattern; + +public: + LogicalResult + matchAndRewrite(scf::ForOp forOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector> fatPtrInits; + SmallVector valRangeLens; + ArrayRef remappedInits = adaptor.getInitArgs(); + for (ValueRange remappedInit : remappedInits) { + if (remappedInit.size() == 2) { + Value fatPtrBase = remappedInit[0]; + Value fatPtrOffset = remappedInit[1]; + fatPtrInits.emplace_back(fatPtrBase, fatPtrOffset); + } + valRangeLens.push_back(remappedInit.size()); + } + + TypeConverter hackTypeConverter; + unsigned inputNo = 0; + hackTypeConverter.addConversion( + [&inputNo, &remappedInits = std::as_const(remappedInits)]( + Type inputType, SmallVectorImpl &types) { + // handle the 0th iv + if (inputNo == 0) { + types.append({inputType}); + } else { + SmallVector remappedInitTypes = + llvm::to_vector(remappedInits[inputNo - 1].getTypes()); + types.append(remappedInitTypes); + } + inputNo++; + return success(); + }); + if (failed( + rewriter.convertRegionTypes(&forOp.getRegion(), hackTypeConverter))) + return failure(); + SmallVector initArgs = flattenValues(adaptor.getInitArgs()); + auto newForOp = rewriter.create( + forOp.getLoc(), getSingleValue(adaptor.getLowerBound()), + getSingleValue(adaptor.getUpperBound()), + getSingleValue(adaptor.getStep()), initArgs); + // replaceWithAdditionalYields + + newForOp->setAttrs(forOp->getAttrs()); + rewriter.eraseBlock(newForOp.getBody(0)); + rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(), + newForOp.getRegion().end()); + + SmallVector packedRets; + for (unsigned i = 0, offset = 0; i < valRangeLens.size(); i++) { + size_t len = valRangeLens[i]; + assert(offset < newForOp->getNumResults() && + "expected offset to be within bounds of results"); + ValueRange mappedValue = newForOp->getResults().slice(offset, len); + packedRets.push_back(mappedValue); + offset += len; + } + + rewriter.modifyOpInPlace(forOp, [&] { + forOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); + }); + rewriter.modifyOpInPlace(newForOp, [&] { + newForOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); + }); + rewriter.replaceOpWithMultiple(forOp, packedRets); + + return success(); + } +}; + void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { ModuleOp module = getOperation(); - auto *context = &getContext(); + mlir::MLIRContext *context = &getContext(); ConversionTarget target(*context); RewritePatternSet patterns(context); target.addLegalDialect(); - target.addDynamicallyLegalDialect([](Operation *op) { + auto isLegal = [](Operation *op) { if (llvm::isa(op) && !op->hasAttr("rewritten")) return false; - if (op->hasAttr("rewritten") || op->hasAttr("legal")) return true; - for (auto operand : op->getOperands()) { - if (llvm::isa(operand)) + for (OpOperand &operand : op->getOpOperands()) { + if (llvm::isa(operand.get())) return false; - if (operand.getDefiningOp()->hasAttr("rewritten")) + if (operand.get().getDefiningOp()->hasAttr("rewritten")) return false; } - return true; - }); + }; + target.addDynamicallyLegalDialect(isLegal); + target.addDynamicallyLegalDialect(isLegal); + + FatPointers fatPrs; - patterns.add( - patterns.getContext()); + patterns.add(patterns.getContext(), + fatPrs); ConversionConfig config; config.buildMaterializations = false; if (failed( @@ -1309,7 +1479,8 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { patterns.clear(); target.addIllegalOp(); - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext(), + fatPrs); if (failed( applyPartialConversion(module, target, std::move(patterns), config))) return signalPassFailure();