Skip to content

Commit

Permalink
handle scf.while (TODO: handle cond arg)
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Dec 16, 2024
1 parent ce211fd commit e4b27bd
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 52 deletions.
70 changes: 35 additions & 35 deletions test/TritonGPU/amd/amd-canonicalize-pointers.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -329,41 +329,41 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32}
}

// -----
//
//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
// // CHECK-LABEL: tt.func @whileOp
// tt.func @whileOp(%arg0: !tt.ptr<f32>, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{
// %c1024_i32 = arith.constant 1024 : i32
// %c0 = arith.constant 0: index
// %c128 = arith.constant 128: index
// %c1 = arith.constant 1 : index
// %0 = tt.get_program_id x : i32
// %1 = arith.muli %0, %c1024_i32 : i32
// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
// // CHECK: %[[base_offset:.*]] = tt.splat %{{.*}} : i64
// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
// %5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
// // CHECK: %[[whileOut:.*]]:3 = scf.while ({{.*}}, %[[loopPtr:.*]] = %arg0, %[[loopOffset:.*]] = %[[base_offset]])
// %6 = scf.while (%arg1 = %5, %arg2 = %cond) : (tensor<1024x!tt.ptr<f32>, #blocked>, i1) -> (tensor<1024x!tt.ptr<f32>, #blocked>) {
// // CHECK: scf.condition({{.*}}) %{{.*}}, %[[loopPtr]], %[[loopOffset]]
// scf.condition(%arg2) %arg1 : tensor<1024x!tt.ptr<f32>, #blocked>
// } do {
// // CHECK: ^bb{{.*}}(%{{.*}}, %[[blockPtr:.*]]: !tt.ptr<f32>, %[[blockOffset:.*]]: tensor<1024xi64, #blocked>):
// ^bb0(%arg1: tensor<1024x!tt.ptr<f32>, #blocked>):
// // CHECK: scf.yield {{.*}}, %[[blockPtr]], %[[blockOffset]]
// scf.yield %arg1, %cond : tensor<1024x!tt.ptr<f32>, #blocked>, i1
// }
// // CHECK: %[[trunc_offset:.*]] = arith.trunci %[[whileOut]]#2
// // CHECK: %[[base_ptr:.*]] = tt.splat %[[whileOut]]#1
// // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[trunc_offset]]
// // CHECK: tt.load %[[newPtr]]
// %11 = tt.load %6 : tensor<1024x!tt.ptr<f32>, #blocked>
// tt.return %11 : tensor<1024xf32, #blocked>
// }
//}
//

#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: tt.func @whileOp
tt.func @whileOp(%arg0: !tt.ptr<f32>, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{
%c1024_i32 = arith.constant 1024 : i32
%c0 = arith.constant 0: index
%c128 = arith.constant 128: index
%c1 = arith.constant 1 : index
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
// CHECK: %[[base_offset:.*]] = tt.splat %{{.*}} : i64
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
// CHECK: %[[whileOut:.*]]:3 = scf.while ({{.*}}, %[[loopPtr:.*]] = %arg0, %[[loopOffset:.*]] = %[[base_offset]])
%6 = scf.while (%arg1 = %5) : (tensor<1024x!tt.ptr<f32>, #blocked>) -> (tensor<1024x!tt.ptr<f32>, #blocked>) {
// CHECK: scf.condition({{.*}}) %{{.*}}, %[[loopPtr]], %[[loopOffset]]
scf.condition(%cond) %arg1 : tensor<1024x!tt.ptr<f32>, #blocked>
} do {
// CHECK: ^bb{{.*}}(%{{.*}}, %[[blockPtr:.*]]: !tt.ptr<f32>, %[[blockOffset:.*]]: tensor<1024xi64, #blocked>):
^bb0(%arg1: tensor<1024x!tt.ptr<f32>, #blocked>):
// CHECK: scf.yield {{.*}}, %[[blockPtr]], %[[blockOffset]]
scf.yield %arg1 : tensor<1024x!tt.ptr<f32>, #blocked>
}
// CHECK: %[[trunc_offset:.*]] = arith.trunci %[[whileOut]]#2
// CHECK: %[[base_ptr:.*]] = tt.splat %[[whileOut]]#1
// CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[trunc_offset]]
// CHECK: tt.load %[[newPtr]]
%11 = tt.load %6 : tensor<1024x!tt.ptr<f32>, #blocked>
tt.return %11 : tensor<1024xf32, #blocked>
}
}

// -----

//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
Expand Down
105 changes: 88 additions & 17 deletions third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1329,11 +1329,6 @@ class ConvertSCFYieldOp : public PointerCanonPattern<scf::YieldOp> {
matchAndRewrite(scf::YieldOp yieldOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> newYieldedValues = flattenValues(adaptor.getOperands());

// Value tensorPtr =
// createTensorPointer(rewriter, fatPtr.basePtr, fatPtr.offset, curLoc,
// fatPtr.canNarrow, fatPtr.attributes);

rewriter.modifyOpInPlace(yieldOp, [&]() {
yieldOp.getResultsMutable().clear();
yieldOp.getResultsMutable().append(newYieldedValues);
Expand Down Expand Up @@ -1400,18 +1395,10 @@ class ConvertSCFForOp : public PointerCanonPattern<scf::ForOp> {
LogicalResult
matchAndRewrite(scf::ForOp forOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<std::pair<Value, Value>> fatPtrInits;
SmallVector<size_t> valRangeLens;
ArrayRef<ValueRange> remappedInits = adaptor.getInitArgs();
for (ValueRange remappedInit : remappedInits) {
if (remappedInit.size() == 2) {
Value fatPtrBase = remappedInit[0];
Value fatPtrOffset = remappedInit[1];
fatPtrInits.emplace_back(fatPtrBase, fatPtrOffset);
}
for (ValueRange remappedInit : remappedInits)
valRangeLens.push_back(remappedInit.size());
}

TypeConverter hackTypeConverter;
unsigned inputNo = 0;
hackTypeConverter.addConversion(
Expand All @@ -1436,7 +1423,6 @@ class ConvertSCFForOp : public PointerCanonPattern<scf::ForOp> {
forOp.getLoc(), getSingleValue(adaptor.getLowerBound()),
getSingleValue(adaptor.getUpperBound()),
getSingleValue(adaptor.getStep()), initArgs);
// replaceWithAdditionalYields

newForOp->setAttrs(forOp->getAttrs());
rewriter.eraseBlock(newForOp.getBody(0));
Expand Down Expand Up @@ -1504,6 +1490,88 @@ class ConvertSCFIfOp : public PointerCanonPattern<scf::IfOp> {
}
};

class ConvertSCFWhileOp : public PointerCanonPattern<scf::WhileOp> {
public:
using PointerCanonPattern::PointerCanonPattern;
LogicalResult
matchAndRewrite(scf::WhileOp whileOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<size_t> valRangeLens;
ArrayRef<ValueRange> remappedInits = adaptor.getInits();
for (ValueRange remappedInit : remappedInits)
valRangeLens.push_back(remappedInit.size());
TypeConverter hackTypeConverter;
unsigned inputNo = 0;
hackTypeConverter.addConversion(
[&inputNo, &remappedInits = std::as_const(remappedInits)](
Type inputType, SmallVectorImpl<Type> &types) {
SmallVector<Type> remappedInitTypes =
llvm::to_vector(remappedInits[inputNo].getTypes());
types.append(remappedInitTypes);
inputNo++;
return success();
});
if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(),
hackTypeConverter)))
return failure();
if (failed(rewriter.convertRegionTypes(&whileOp.getAfter(),
hackTypeConverter)))
return failure();

SmallVector<Value> initArgs = flattenValues(remappedInits);
SmallVector<Type> resultTypes =
llvm::map_to_vector(initArgs, [](Value v) { return v.getType(); });
auto newWhileOp =
rewriter.create<scf::WhileOp>(whileOp.getLoc(), resultTypes, initArgs);

newWhileOp->setAttrs(whileOp->getAttrs());
rewriter.inlineRegionBefore(whileOp.getBefore(), newWhileOp.getBefore(),
newWhileOp.getBefore().end());
rewriter.inlineRegionBefore(whileOp.getAfter(), newWhileOp.getAfter(),
newWhileOp.getAfter().end());

SmallVector<ValueRange> packedRets;
for (unsigned i = 0, offset = 0; i < valRangeLens.size(); i++) {
size_t len = valRangeLens[i];
assert(offset < newWhileOp->getNumResults() &&
"expected offset to be within bounds of results");
ValueRange mappedValue = newWhileOp->getResults().slice(offset, len);
packedRets.push_back(mappedValue);
offset += len;
}

rewriter.modifyOpInPlace(whileOp, [&] {
whileOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr());
});
rewriter.modifyOpInPlace(newWhileOp, [&] {
newWhileOp->setDiscardableAttr("legal", rewriter.getUnitAttr());
});
rewriter.replaceOpWithMultiple(whileOp, packedRets);

return success();
}
};

class ConvertSCFConditionOp : public PointerCanonPattern<scf::ConditionOp> {
public:
using PointerCanonPattern::PointerCanonPattern;
LogicalResult
matchAndRewrite(scf::ConditionOp condOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> newArgs = flattenValues(adaptor.getArgs());
rewriter.modifyOpInPlace(condOp, [&]() {
condOp.getArgsMutable().clear();
condOp.getArgsMutable().append(newArgs);
});

rewriter.modifyOpInPlace(condOp, [&] {
condOp->setDiscardableAttr("legal", rewriter.getUnitAttr());
});

return success();
}
};

void TritonAMDGPUCanonicalizePointersPass::runOnOperation() {
ModuleOp module = getOperation();
mlir::MLIRContext *context = &getContext();
Expand Down Expand Up @@ -1531,14 +1599,17 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() {
if (auto ifOp = llvm::dyn_cast<scf::IfOp>(op))
return !(ifOp->hasAttr("then_rewritten") and
ifOp->hasAttr("else_rewritten"));
if (llvm::isa<scf::ConditionOp>(op) && !op->hasAttr("legal"))
return false;
return isLegal(op);
});

FatPointers fatPrs;

patterns.add<ConvertFuncOp, ConvertSplatOp, ConvertAddPtrOp, ConvertLoadOp,
ConvertSCFForOp, ConvertSCFYieldOp, ConvertSCFIfOp>(
patterns.getContext(), fatPrs);
ConvertSCFForOp, ConvertSCFYieldOp, ConvertSCFIfOp,
ConvertSCFConditionOp, ConvertSCFWhileOp>(patterns.getContext(),
fatPrs);
ConversionConfig config;
config.buildMaterializations = false;
if (failed(
Expand Down

0 comments on commit e4b27bd

Please sign in to comment.