Skip to content

Commit

Permalink
handle arith.select
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Dec 17, 2024
1 parent 3e1dd91 commit 5091eef
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 61 deletions.
119 changes: 60 additions & 59 deletions test/TritonGPU/amd/amd-canonicalize-pointers.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -254,66 +254,67 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
}
}

// -----

//// -----
//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
// // CHECK-LABEL: tt.func @select
// tt.func @select(%arg0 : !tt.ptr<f32>, %i1 : i1) -> tensor<1024xf32, #blocked>{
// %c1024_i32 = arith.constant 1024 : i32
// %c0 = arith.constant 0: index
// %c128 = arith.constant 128: index
// %c1 = arith.constant 1 : index
// %0 = tt.get_program_id x : i32
// %1 = arith.muli %0, %c1024_i32 : i32
// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
// // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32
// // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked>
// // CHECK: %[[baseOffset:.*]] = tt.splat %{{.*}} : i64
// // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]]
// // CHECK: %[[extVariableOffset:.*]] = arith.extsi %[[variableOffset]]
// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
// %5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
// %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
// // CHECK: %[[offset2:.*]] = arith.addi %[[extVariableOffset]], %[[baseOffset]]
// // CHECK: %[[scalarPtr1:.*]] = arith.select %arg1, %arg0, %[[scalarPtr]]
// // CHECK: %[[offset0:.*]] = arith.select %arg1, {{.*}}, %[[offset2]]
// // CHECK: %[[offset1:.*]] = arith.trunci %[[offset0]]
// // CHECK: %[[ptr:.*]] = tt.splat %[[scalarPtr1]]
// // CHECK: tt.addptr %[[ptr]], %[[offset1]]
// %7 = arith.select %i1, %5 , %6 : tensor<1024x!tt.ptr<f32>, #blocked>
// %out = tt.load %7: tensor<1024x!tt.ptr<f32>, #blocked>
// tt.return %out : tensor<1024xf32, #blocked>
// }
//}
//
//// -----
//#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
//module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} {
// // CHECK-LABEL: tt.func @where_kernel
// tt.func @where_kernel(%arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}){
// %c0_i8 = arith.constant 0 : i8
// %c1024_i32 = arith.constant 1024 : i32
// %0 = tt.get_program_id x : i32
// %1 = arith.muli %0, %c1024_i32 : i32
// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
// %9 = arith.cmpi ne, %c0_i8, %c0_i8 : i8
// %10 = arith.select %9, %arg1, %arg2 : !tt.ptr<i64>
// // CHECK: %[[selectPtr:.*]] = arith.select {{.*}} : !tt.ptr<i64>
// %11 = tt.splat %10: !tt.ptr<i64> -> tensor<1024x!tt.ptr<i64>, #blocked>
// %13 = tt.addptr %11, %4 : tensor<1024x!tt.ptr<i64>, #blocked>, tensor<1024xi32, #blocked>
// // CHECK: %[[selectPtr0:.*]] = tt.addptr %[[selectPtr]]
// // CHECK: %[[tensorPtr:.*]] = tt.splat %[[selectPtr0]]
// // CHECK: tt.addptr %[[tensorPtr]]
// %14 = tt.load %13 : tensor<1024x!tt.ptr<i64>, #blocked>
// tt.return
// }
//}
//
//// -----
#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: tt.func @select
tt.func @select(%arg0 : !tt.ptr<f32>, %i1 : i1) -> tensor<1024xf32, #blocked>{
%c1024_i32 = arith.constant 1024 : i32
%c0 = arith.constant 0: index
%c128 = arith.constant 128: index
%c1 = arith.constant 1 : index
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
// CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32
// CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked>
// CHECK: %[[baseOffset:.*]] = tt.splat %{{.*}} : i64
// CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]]
// CHECK: %[[extVariableOffset:.*]] = arith.extsi %[[variableOffset]]
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
// CHECK: %[[offset2:.*]] = arith.addi %[[extVariableOffset]], %[[baseOffset]]
// CHECK: %[[scalarPtr1:.*]] = arith.select %arg1, %arg0, %[[scalarPtr]]
// CHECK: %[[offset0:.*]] = arith.select %arg1, {{.*}}, %[[offset2]]
// CHECK: %[[offset1:.*]] = arith.trunci %[[offset0]]
// CHECK: %[[ptr:.*]] = tt.splat %[[scalarPtr1]]
// CHECK: tt.addptr %[[ptr]], %[[offset1]]
%7 = arith.select %i1, %5 , %6 : tensor<1024x!tt.ptr<f32>, #blocked>
%out = tt.load %7: tensor<1024x!tt.ptr<f32>, #blocked>
tt.return %out : tensor<1024xf32, #blocked>
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: tt.func @where_kernel
tt.func @where_kernel(%arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}){
%c0_i8 = arith.constant 0 : i8
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%9 = arith.cmpi ne, %c0_i8, %c0_i8 : i8
%10 = arith.select %9, %arg1, %arg2 : !tt.ptr<i64>
// CHECK: %[[selectPtr:.*]] = arith.select {{.*}} : !tt.ptr<i64>
%11 = tt.splat %10: !tt.ptr<i64> -> tensor<1024x!tt.ptr<i64>, #blocked>
%13 = tt.addptr %11, %4 : tensor<1024x!tt.ptr<i64>, #blocked>, tensor<1024xi32, #blocked>
// CHECK: %[[selectPtr0:.*]] = tt.addptr %[[selectPtr]]
// CHECK: %[[tensorPtr:.*]] = tt.splat %[[selectPtr0]]
// CHECK: tt.addptr %[[tensorPtr]]
%14 = tt.load %13 : tensor<1024x!tt.ptr<i64>, #blocked>
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1675,12 +1675,67 @@ class ConvertCFBranch : public PointerCanonPattern<cf::BranchOp> {
}
};

class ConvertArithSelectOp : public PointerCanonPattern<arith::SelectOp> {
public:
using PointerCanonPattern::PointerCanonPattern;
LogicalResult
matchAndRewrite(arith::SelectOp selectOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ArrayRef<ValueRange> remappedOperands = adaptor.getOperands();
assert(remappedOperands.size() == 3 && remappedOperands[1].size() == 2 &&
remappedOperands[2].size() == 2 &&
"expected adaptor to have 3 remapped values, 1 for cond, 2 for each "
"arm");
// If both have been traversed, then we can rewrite select of pointers as a
// select of base and offset
ValueRange fatPtrT = remappedOperands[1];
ValueRange fatPtrF = remappedOperands[2];
// Simple case of a scalar select: update the base pointer
if (!isa<RankedTensorType>(selectOp.getType())) {
auto newSelectOp = rewriter.create<arith::SelectOp>(
selectOp.getLoc(), selectOp.getType(),
// TODO(max): why fatPtrTrue here?
selectOp.getCondition(), fatPtrT[0], selectOp.getFalseValue());
rewriter.modifyOpInPlace(selectOp, [&] {
selectOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr());
});
rewriter.modifyOpInPlace(newSelectOp, [&] {
newSelectOp->setDiscardableAttr("legal", rewriter.getUnitAttr());
});
rewriter.replaceOpWithMultiple(selectOp, {{newSelectOp, fatPtrT[1]}});
return success();
}

// Rewrite `select` for base and offset
auto newBase = rewriter.create<arith::SelectOp>(
selectOp.getLoc(), selectOp.getCondition(), fatPtrT[0], fatPtrF[0]);
auto newOffset = rewriter.create<arith::SelectOp>(
selectOp.getLoc(), selectOp.getCondition(), fatPtrT[1], fatPtrF[1]);

assert((fatPtrs[{fatPtrT[0], fatPtrT[1]}].canNarrow ==
fatPtrs[{fatPtrF[0], fatPtrF[1]}].canNarrow));

rewriter.modifyOpInPlace(selectOp, [&] {
selectOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr());
});
rewriter.modifyOpInPlace(newBase, [&] {
newBase->setDiscardableAttr("legal", rewriter.getUnitAttr());
});
rewriter.modifyOpInPlace(newOffset, [&] {
newOffset->setDiscardableAttr("legal", rewriter.getUnitAttr());
});

rewriter.replaceOpWithMultiple(selectOp, {{newBase, newOffset}});

return success();
}
};

void TritonAMDGPUCanonicalizePointersPass::runOnOperation() {
ModuleOp module = getOperation();
mlir::MLIRContext *context = &getContext();
ConversionTarget target(*context);
RewritePatternSet patterns(context);
target.addLegalDialect<arith::ArithDialect>();
auto isLegal = [](Operation *op) {
if (op->hasAttr("rewritten") || op->hasAttr("legal"))
return true;
Expand Down Expand Up @@ -1711,13 +1766,20 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() {
});
target.addDynamicallyLegalDialect<cf::ControlFlowDialect>(
[&isLegal](Operation *op) { return isLegal(op); });
target.addDynamicallyLegalDialect<arith::ArithDialect>(
[&isLegal](Operation *op) {
if (llvm::isa<arith::SelectOp>(op))
return isLegal(op);
return true;
});

FatPointers fatPrs;

patterns.add<ConvertFuncOp, ConvertSplatOp, ConvertAddPtrOp, ConvertLoadOp,
ConvertSCFForOp, ConvertSCFYieldOp, ConvertSCFIfOp,
ConvertSCFConditionOp, ConvertSCFWhileOp, ConvertCFCondBranch,
ConvertCFBranch>(patterns.getContext(), fatPrs);
ConvertCFBranch, ConvertArithSelectOp>(patterns.getContext(),
fatPrs);
ConversionConfig config;
config.buildMaterializations = false;
if (failed(
Expand Down

0 comments on commit 5091eef

Please sign in to comment.