Skip to content

Commit

Permalink
propagate canNarrow through scf.if, scf.forop, scf.while, cf.br
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Dec 18, 2024
1 parent dc314dd commit 1b6bd72
Showing 1 changed file with 192 additions and 60 deletions.
252 changes: 192 additions & 60 deletions third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,15 @@ std::pair<Value, Value> createDecomposeOffsetFromExpr(RewriterBase &rewriter,
return offsets;
}

static const char kLegalAttr[] = "__amd-pointer-canonicalize-legal__";
static const char kRewrittenAttr[] = "__amd-pointer-canonicalize-rewritten__";
static const char kSCFThenRewrittenAttr[] =
"__amd-pointer-canonicalize-scf-then-rewritten__";
static const char kSCFElseRewrittenAttr[] =
"__amd-pointer-canonicalize-scf-else-rewritten__";
static const std::string kPtrCanonPrefix = "__amdpointercanonicalize.";
static const std::string kLegalAttr = kPtrCanonPrefix + "legal__";
static const std::string kRewrittenAttr = kPtrCanonPrefix + "rewritten__";
static const std::string kSCFThenRewrittenAttr =
kPtrCanonPrefix + "scf-then-rewritten__";
static const std::string kSCFElseRewrittenAttr =
kPtrCanonPrefix + "scf-else-rewritten__";
static const std::string kSCFIfOpYieldFatPtrOffsets =
kPtrCanonPrefix + "scf-if-yield-fatptr-offsets__";

Value createTensorPointer(
RewriterBase &rewriter, Value basePtr, Value offset, Location loc,
Expand Down Expand Up @@ -340,6 +343,7 @@ struct FatPointers {
template <typename T>
using const_arg_type_t = typename llvm::const_pointer_or_const_ref<T>::type;
const ValueT &at(const_arg_type_t<KeyT> k) const { return pointers.at(k); }
const bool contains(const KeyT &k) { return pointers.contains(k); }

private:
DenseMapT pointers;
Expand Down Expand Up @@ -636,8 +640,22 @@ class ConvertSCFForOp : public PointerCanonicalizationPattern<scf::ForOp> {
hackTypeConverter.convertBlockSignature(forOp.getBody());
if (!conversion)
return failure();
rewriter.applySignatureConversion(forOp.getBody(), *conversion,
&hackTypeConverter);
auto newBodyBlock = rewriter.applySignatureConversion(
forOp.getBody(), *conversion, &hackTypeConverter);

// propagate canNarrow to bb arg fatPtrs in for body bb
// skip iv at index 0
int offset = 1;
for (auto operands : remappedInits) {
if (operands.size() == 2) {
assert(fatPtrs.contains({operands[0], operands[1]}) &&
"expected fatPtrs to contain remapped fat pointer");
fatPtrs[{newBodyBlock->getArgument(offset),
newBodyBlock->getArgument(offset + 1)}]
.canNarrow = fatPtrs[{operands[0], operands[1]}].canNarrow;
}
offset += operands.size();
}

SmallVector<Value> initArgs = flattenValues(adaptor.getInitArgs());
auto newForOp = rewriter.create<scf::ForOp>(
Expand Down Expand Up @@ -676,7 +694,8 @@ class ConvertSCFYieldOp : public PointerCanonicalizationPattern<scf::YieldOp> {
LogicalResult
matchAndRewrite(scf::YieldOp yieldOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> newYieldedValues = flattenValues(adaptor.getOperands());
ArrayRef<ValueRange> remappedYields = adaptor.getOperands();
SmallVector<Value> newYieldedValues = flattenValues(remappedYields);
// have to mutate here because otherwise scf.if, scf.for, and scf.while will
// get confused about which yield is the "correct" yield (since there will
// be two of them before the rewriter DCEs)
Expand All @@ -689,16 +708,33 @@ class ConvertSCFYieldOp : public PointerCanonicalizationPattern<scf::YieldOp> {
// other to indicate to the parent IfOp that the result type can now be
// rewritten and not before.
if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
if (ifOp.thenBlock() == yieldOp->getBlock())
if (ifOp.thenBlock() == yieldOp->getBlock()) {
rewriter.modifyOpInPlace(ifOp, [&] {
ifOp->setDiscardableAttr(kSCFThenRewrittenAttr,
rewriter.getUnitAttr());
});
else
} else {
rewriter.modifyOpInPlace(ifOp, [&] {
ifOp->setDiscardableAttr(kSCFElseRewrittenAttr,
rewriter.getUnitAttr());
});
}
// set indices of fatPtrs so that IfOp can propagate canNarrow to
// result users
int offset = 0;
SmallVector<int64_t> fatPtrOffsets;
for (auto operands : remappedYields) {
if (operands.size() == 2) {
assert(fatPtrs.contains({operands[0], operands[1]}) &&
"expected fatPtrs to contain remapped fat pointer");
fatPtrOffsets.push_back(offset);
}
offset += operands.size();
}
if (!fatPtrOffsets.empty())
yieldOp->setDiscardableAttr(
kSCFIfOpYieldFatPtrOffsets,
rewriter.getDenseI64ArrayAttr(fatPtrOffsets));
}

setLegalAttr(rewriter, yieldOp);
Expand All @@ -717,26 +753,52 @@ class ConvertSCFWhileOp : public PointerCanonicalizationPattern<scf::WhileOp> {
ArrayRef<ValueRange> remappedInits = adaptor.getInits();
for (ValueRange remappedInit : remappedInits)
valRangeLens.push_back(remappedInit.size());
// rewrite the "before" region (bb args)
TypeConverter hackTypeConverter;

// rewrite the "before" block bb args
unsigned inputNo = 0;
TypeConverter hackTypeConverter;
hackTypeConverter.addConversion(
[&inputNo, &remappedInits = std::as_const(remappedInits)](
Type inputType, SmallVectorImpl<Type> &types) {
[&inputNo, remappedInits](Type _inputType,
SmallVectorImpl<Type> &types) {
SmallVector<Type> remappedInitTypes =
llvm::to_vector(remappedInits[inputNo].getTypes());
types.append(remappedInitTypes);
inputNo++;
return success();
});
if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(),
hackTypeConverter)))
std::optional<TypeConverter::SignatureConversion> conversion =
hackTypeConverter.convertBlockSignature(whileOp.getBeforeBody());
if (!conversion)
return failure();
// rewrite the "after" region (bb args)
inputNo = 0;
if (failed(rewriter.convertRegionTypes(&whileOp.getAfter(),
hackTypeConverter)))
auto newBeforeBodyBlock = rewriter.applySignatureConversion(
whileOp.getBeforeBody(), *conversion, &hackTypeConverter);

auto propagateCanNarrowToBlock = [remappedInits, this](Block *block) {
int offset = 0;
for (auto operands : remappedInits) {
if (operands.size() == 2) {
assert(fatPtrs.contains({operands[0], operands[1]}) &&
"expected fatPtrs to contain remapped fat pointer");
fatPtrs[{block->getArgument(offset), block->getArgument(offset + 1)}]
.canNarrow = fatPtrs[{operands[0], operands[1]}].canNarrow;
}
offset += operands.size();
}
};

// propagate canNarrow to bb arg fatPtrs in before bb
propagateCanNarrowToBlock(newBeforeBodyBlock);

// rewrite the "after" block bb args
conversion =
hackTypeConverter.convertBlockSignature(whileOp.getAfterBody());
if (!conversion)
return failure();
auto newAfterBodyBlock = rewriter.applySignatureConversion(
whileOp.getAfterBody(), *conversion, &hackTypeConverter);

// propagate canNarrow to bb arg fatPtrs in after bb
propagateCanNarrowToBlock(newAfterBodyBlock);

SmallVector<Value> initArgs = flattenValues(remappedInits);
SmallVector<Type> resultTypes =
Expand Down Expand Up @@ -796,54 +858,75 @@ class ConvertCFCondBranch
LogicalResult
matchAndRewrite(cf::CondBranchOp branchOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> trueOperands =
flattenValues(adaptor.getTrueDestOperands());
SmallVector<Value> falseOperands =
flattenValues(adaptor.getFalseDestOperands());
ArrayRef<ValueRange> remappedTrueOperands = adaptor.getTrueDestOperands();
ArrayRef<ValueRange> remappedFalseOperands = adaptor.getFalseDestOperands();
SmallVector<Value> trueOperands = flattenValues(remappedTrueOperands);
SmallVector<Value> falseOperands = flattenValues(remappedFalseOperands);

setRewrittenAttr(rewriter, branchOp);
auto newBrancOp = rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
branchOp, branchOp.getCondition(), branchOp.getTrueDest(), trueOperands,
branchOp.getFalseDest(), falseOperands);

// can't put inputNo inside because of limited lifetime (it'll be popped
// from stack memory after lambda returns...)
auto makeTypeConv = [](unsigned &inputNo,
ArrayRef<ValueRange> remappedOperands) {
return [&inputNo, remappedOperands](Type inputType,
SmallVectorImpl<Type> &types) {
SmallVector<Type> remappedInitTypes =
llvm::to_vector(remappedOperands[inputNo].getTypes());
types.append(remappedInitTypes);
inputNo++;
return success();
};
};

auto propagateCanNarrowToBlock = [this](Block *block,
ArrayRef<ValueRange>
remappedOperands) {
int offset = 0;
for (auto operands : remappedOperands) {
if (operands.size() == 2) {
assert(fatPtrs.contains({operands[0], operands[1]}) &&
"expected fatPtrs to contain remapped fat pointer");
fatPtrs[{block->getArgument(offset), block->getArgument(offset + 1)}]
.canNarrow = fatPtrs[{operands[0], operands[1]}].canNarrow;
}
offset += operands.size();
}
};

// convert the type signature of the true dest bb
unsigned inputNo = 0;
TypeConverter hackTypeConverterTrueDest;
unsigned inputNo = 0;
hackTypeConverterTrueDest.addConversion(
[&inputNo, remappedOperands = adaptor.getTrueDestOperands()](
Type inputType, SmallVectorImpl<Type> &types) {
SmallVector<Type> remappedInitTypes =
llvm::to_vector(remappedOperands[inputNo].getTypes());
types.append(remappedInitTypes);
inputNo++;
return success();
});
makeTypeConv(inputNo, remappedTrueOperands));
std::optional<TypeConverter::SignatureConversion> conversion =
hackTypeConverterTrueDest.convertBlockSignature(branchOp.getTrueDest());
if (!conversion)
return failure();
rewriter.applySignatureConversion(branchOp.getTrueDest(), *conversion,
&hackTypeConverterTrueDest);
auto newTrueBlock = rewriter.applySignatureConversion(
branchOp.getTrueDest(), *conversion, &hackTypeConverterTrueDest);

// propagate canNarrow to bb arg fatPtrs in true bb
propagateCanNarrowToBlock(newTrueBlock, remappedTrueOperands);

// convert the type signature of the false dest bb
inputNo = 0;
TypeConverter hackTypeConverterFalseDest;
hackTypeConverterFalseDest.addConversion(
[&inputNo, remappedOperands = adaptor.getFalseDestOperands()](
Type inputType, SmallVectorImpl<Type> &types) {
SmallVector<Type> remappedInitTypes =
llvm::to_vector(remappedOperands[inputNo].getTypes());
types.append(remappedInitTypes);
inputNo++;
return success();
});

makeTypeConv(inputNo, remappedFalseOperands));
conversion = hackTypeConverterFalseDest.convertBlockSignature(
branchOp.getFalseDest());
if (!conversion)
return failure();
rewriter.applySignatureConversion(branchOp.getFalseDest(), *conversion,
&hackTypeConverterFalseDest);
auto newFalseBlock = rewriter.applySignatureConversion(
branchOp.getFalseDest(), *conversion, &hackTypeConverterFalseDest);

// propagate canNarrow to bb arg fatPtrs in false bb
propagateCanNarrowToBlock(newFalseBlock, remappedFalseOperands);

setLegalAttr(rewriter, newBrancOp);
return success();
}
Expand Down Expand Up @@ -912,10 +995,8 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern<scf::IfOp> {
LogicalResult
matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (ifOp.getNumResults() != 1 ||
ifOp.thenYield().getOperandTypes().size() != 2)
return rewriter.notifyMatchFailure(
ifOp, "only 1 -> 2 supported for scf::IfOp rewrite");
assert(ifOp.thenYield()->hasAttr(kSCFIfOpYieldFatPtrOffsets) &&
"expected then yield to report fat ptr indices");

bool withElseRegion = ifOp.getNumRegions() > 1;

Expand All @@ -924,6 +1005,32 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern<scf::IfOp> {
assert(ifOp.thenYield().getOperandTypes() ==
ifOp.elseYield().getOperandTypes() &&
"ifOp types must match in both arms");
assert(ifOp.elseYield()->hasAttr(kSCFIfOpYieldFatPtrOffsets) &&
"expected then yield to report fat ptr indices");
if (auto thenFatPtrIndxs = ifOp.thenYield()->getDiscardableAttr(
kSCFIfOpYieldFatPtrOffsets)) {
auto elseFatPtrIndx =
ifOp.elseYield()->getDiscardableAttr(kSCFIfOpYieldFatPtrOffsets);
assert(elseFatPtrIndx &&
"expected else fat ptr indices as well as then fat ptr indices");
for (auto [i, j] : llvm::zip(
llvm::cast<DenseI64ArrayAttr>(thenFatPtrIndxs).asArrayRef(),
llvm::cast<DenseI64ArrayAttr>(elseFatPtrIndx).asArrayRef())) {
assert(i == j &&
"expected thenFatPtrIndxs and elseFatPtrIndxs to agree");
assert(i < ifOp.thenYield().getNumOperands() &&
i + 1 < ifOp.thenYield().getNumOperands() &&
"expected idx to be within bounds of IfOp's results");
Value thenFatPtrBase = ifOp.thenYield().getOperand(i);
Value thenFatPtrOffset = ifOp.thenYield().getOperand(i + 1);
Value elseFatPtrBase = ifOp.elseYield().getOperand(i);
Value elseFatPtrOffset = ifOp.elseYield().getOperand(i + 1);
assert((fatPtrs[{thenFatPtrBase, thenFatPtrOffset}].canNarrow ==
fatPtrs[{elseFatPtrBase, elseFatPtrOffset}].canNarrow) &&
"expected then fat ptr canNarrow and else fat ptr canNarrow "
"to be equal");
}
}
}
#endif

Expand All @@ -939,6 +1046,16 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern<scf::IfOp> {
setRewrittenLegalAttrs(rewriter, ifOp, newIfOp);
rewriter.replaceOpWithMultiple(ifOp, {newIfOp.getResults()});

for (int64_t idx :
llvm::cast<DenseI64ArrayAttr>(newIfOp.thenYield()->getDiscardableAttr(
kSCFIfOpYieldFatPtrOffsets))
.asArrayRef()) {
Value thenFatPtrBase = newIfOp.thenYield().getOperand(idx);
Value thenFatPtrOffset = newIfOp.thenYield().getOperand(idx + 1);
fatPtrs[{newIfOp.getResult(idx), newIfOp.getResult(idx + 1)}].canNarrow =
fatPtrs[{thenFatPtrBase, thenFatPtrOffset}].canNarrow;
}

return success();
}
};
Expand All @@ -950,7 +1067,8 @@ class ConvertCFBranch : public PointerCanonicalizationPattern<cf::BranchOp> {
LogicalResult
matchAndRewrite(cf::BranchOp branchOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> trueOperands = flattenValues(adaptor.getDestOperands());
ArrayRef<ValueRange> remappedDestOperands = adaptor.getDestOperands();
SmallVector<Value> trueOperands = flattenValues(remappedDestOperands);

setRewrittenAttr(rewriter, branchOp);
auto newBrancOp = rewriter.replaceOpWithNewOp<cf::BranchOp>(
Expand All @@ -959,21 +1077,33 @@ class ConvertCFBranch : public PointerCanonicalizationPattern<cf::BranchOp> {
unsigned inputNo = 0;
TypeConverter hackTypeConverterTrueDest;
hackTypeConverterTrueDest.addConversion(
[&inputNo, remappedOperands = adaptor.getDestOperands()](
Type _inputType, SmallVectorImpl<Type> &types) {
[&inputNo, remappedDestOperands](Type _inputType,
SmallVectorImpl<Type> &types) {
SmallVector<Type> remappedInitTypes =
llvm::to_vector(remappedOperands[inputNo].getTypes());
llvm::to_vector(remappedDestOperands[inputNo].getTypes());
types.append(remappedInitTypes);
inputNo++;
return success();
});

std::optional<TypeConverter::SignatureConversion> conversion =
hackTypeConverterTrueDest.convertBlockSignature(branchOp.getDest());
if (!conversion)
return failure();
rewriter.applySignatureConversion(branchOp.getDest(), *conversion,
&hackTypeConverterTrueDest);
auto newDestBlock = rewriter.applySignatureConversion(
branchOp.getDest(), *conversion, &hackTypeConverterTrueDest);

int offset = 0;
for (auto operands : remappedDestOperands) {
if (operands.size() == 2) {
assert(fatPtrs.contains({operands[0], operands[1]}) &&
"expected fatPtrs to contain remapped fat pointer");
fatPtrs[{newDestBlock->getArgument(offset),
newDestBlock->getArgument(offset + 1)}]
.canNarrow = fatPtrs[{operands[0], operands[1]}].canNarrow;
}
offset += operands.size();
}

setLegalAttr(rewriter, newBrancOp);
return success();
}
Expand Down Expand Up @@ -1181,8 +1311,10 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() {
return signalPassFailure();

module.walk<WalkOrder::PreOrder>([](Operation *op) {
op->removeDiscardableAttr(kRewrittenAttr);
op->removeDiscardableAttr(kLegalAttr);
for (auto attr : op->getDiscardableAttrs()) {
if (attr.getName().strref().starts_with(kPtrCanonPrefix))
op->removeDiscardableAttr(attr.getName());
}
});
}

Expand Down

0 comments on commit 1b6bd72

Please sign in to comment.