Skip to content

Commit

Permalink
handle scf.while (TODO: handle cond arg)
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Dec 16, 2024
1 parent ce211fd commit d1f4d6a
Showing 1 changed file with 88 additions and 17 deletions.
105 changes: 88 additions & 17 deletions third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1329,11 +1329,6 @@ class ConvertSCFYieldOp : public PointerCanonPattern<scf::YieldOp> {
matchAndRewrite(scf::YieldOp yieldOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> newYieldedValues = flattenValues(adaptor.getOperands());

// Value tensorPtr =
// createTensorPointer(rewriter, fatPtr.basePtr, fatPtr.offset, curLoc,
// fatPtr.canNarrow, fatPtr.attributes);

rewriter.modifyOpInPlace(yieldOp, [&]() {
yieldOp.getResultsMutable().clear();
yieldOp.getResultsMutable().append(newYieldedValues);
Expand Down Expand Up @@ -1400,18 +1395,10 @@ class ConvertSCFForOp : public PointerCanonPattern<scf::ForOp> {
LogicalResult
matchAndRewrite(scf::ForOp forOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<std::pair<Value, Value>> fatPtrInits;
SmallVector<size_t> valRangeLens;
ArrayRef<ValueRange> remappedInits = adaptor.getInitArgs();
for (ValueRange remappedInit : remappedInits) {
if (remappedInit.size() == 2) {
Value fatPtrBase = remappedInit[0];
Value fatPtrOffset = remappedInit[1];
fatPtrInits.emplace_back(fatPtrBase, fatPtrOffset);
}
for (ValueRange remappedInit : remappedInits)
valRangeLens.push_back(remappedInit.size());
}

TypeConverter hackTypeConverter;
unsigned inputNo = 0;
hackTypeConverter.addConversion(
Expand All @@ -1436,7 +1423,6 @@ class ConvertSCFForOp : public PointerCanonPattern<scf::ForOp> {
forOp.getLoc(), getSingleValue(adaptor.getLowerBound()),
getSingleValue(adaptor.getUpperBound()),
getSingleValue(adaptor.getStep()), initArgs);
// replaceWithAdditionalYields

newForOp->setAttrs(forOp->getAttrs());
rewriter.eraseBlock(newForOp.getBody(0));
Expand Down Expand Up @@ -1504,6 +1490,88 @@ class ConvertSCFIfOp : public PointerCanonPattern<scf::IfOp> {
}
};

class ConvertSCFWhileOp : public PointerCanonPattern<scf::WhileOp> {
public:
using PointerCanonPattern::PointerCanonPattern;
LogicalResult
matchAndRewrite(scf::WhileOp whileOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<size_t> valRangeLens;
ArrayRef<ValueRange> remappedInits = adaptor.getInits();
for (ValueRange remappedInit : remappedInits)
valRangeLens.push_back(remappedInit.size());
TypeConverter hackTypeConverter;
unsigned inputNo = 0;
hackTypeConverter.addConversion(
[&inputNo, &remappedInits = std::as_const(remappedInits)](
Type inputType, SmallVectorImpl<Type> &types) {
SmallVector<Type> remappedInitTypes =
llvm::to_vector(remappedInits[inputNo].getTypes());
types.append(remappedInitTypes);
inputNo++;
return success();
});
if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(),
hackTypeConverter)))
return failure();
if (failed(rewriter.convertRegionTypes(&whileOp.getAfter(),
hackTypeConverter)))
return failure();

SmallVector<Value> initArgs = flattenValues(remappedInits);
SmallVector<Type> resultTypes =
llvm::map_to_vector(initArgs, [](Value v) { return v.getType(); });
auto newWhileOp =
rewriter.create<scf::WhileOp>(whileOp.getLoc(), resultTypes, initArgs);

newWhileOp->setAttrs(whileOp->getAttrs());
rewriter.inlineRegionBefore(whileOp.getBefore(), newWhileOp.getBefore(),
newWhileOp.getBefore().end());
rewriter.inlineRegionBefore(whileOp.getAfter(), newWhileOp.getAfter(),
newWhileOp.getAfter().end());

SmallVector<ValueRange> packedRets;
for (unsigned i = 0, offset = 0; i < valRangeLens.size(); i++) {
size_t len = valRangeLens[i];
assert(offset < newWhileOp->getNumResults() &&
"expected offset to be within bounds of results");
ValueRange mappedValue = newWhileOp->getResults().slice(offset, len);
packedRets.push_back(mappedValue);
offset += len;
}

rewriter.modifyOpInPlace(whileOp, [&] {
whileOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr());
});
rewriter.modifyOpInPlace(newWhileOp, [&] {
newWhileOp->setDiscardableAttr("legal", rewriter.getUnitAttr());
});
rewriter.replaceOpWithMultiple(whileOp, packedRets);

return success();
}
};

class ConvertSCFConditionOp : public PointerCanonPattern<scf::ConditionOp> {
public:
using PointerCanonPattern::PointerCanonPattern;
LogicalResult
matchAndRewrite(scf::ConditionOp condOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> newArgs = flattenValues(adaptor.getArgs());
rewriter.modifyOpInPlace(condOp, [&]() {
condOp.getArgsMutable().clear();
condOp.getArgsMutable().append(newArgs);
});

rewriter.modifyOpInPlace(condOp, [&] {
condOp->setDiscardableAttr("legal", rewriter.getUnitAttr());
});

return success();
}
};

void TritonAMDGPUCanonicalizePointersPass::runOnOperation() {
ModuleOp module = getOperation();
mlir::MLIRContext *context = &getContext();
Expand Down Expand Up @@ -1531,14 +1599,17 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() {
if (auto ifOp = llvm::dyn_cast<scf::IfOp>(op))
return !(ifOp->hasAttr("then_rewritten") and
ifOp->hasAttr("else_rewritten"));
if (llvm::isa<scf::ConditionOp>(op) && !op->hasAttr("legal"))
return false;
return isLegal(op);
});

FatPointers fatPrs;

patterns.add<ConvertFuncOp, ConvertSplatOp, ConvertAddPtrOp, ConvertLoadOp,
ConvertSCFForOp, ConvertSCFYieldOp, ConvertSCFIfOp>(
patterns.getContext(), fatPrs);
ConvertSCFForOp, ConvertSCFYieldOp, ConvertSCFIfOp,
ConvertSCFConditionOp, ConvertSCFWhileOp>(patterns.getContext(),
fatPrs);
ConversionConfig config;
config.buildMaterializations = false;
if (failed(
Expand Down

0 comments on commit d1f4d6a

Please sign in to comment.