From bc6ba644c45759f8e7ad8b410a2679e08f1f9c93 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 4 Dec 2024 02:15:41 -0500 Subject: [PATCH 01/17] [AMD] re-enable canonicalize pointers --- third_party/amd/backend/compiler.py | 2 +- .../CanonicalizePointers.cpp | 80 ++++++++++++++----- 2 files changed, 63 insertions(+), 19 deletions(-) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 81b07f2e7d86..fff8421af5a3 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -250,8 +250,8 @@ def make_ttgir(mod, metadata, options): if amd.has_matrix_core_feature(options.arch): amd.passes.ttgpuir.add_reorder_instructions(pm) + amd.passes.ttgpuir.add_canonicalize_pointers(pm) if use_buffer_ops: - amd.passes.ttgpuir.add_canonicalize_pointers(pm) passes.common.add_canonicalizer(pm) amd.passes.ttgpuir.add_convert_to_buffer_ops(pm) passes.common.add_canonicalizer(pm) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index fa7d54e0fbee..f798edca1676 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -611,28 +611,71 @@ LogicalResult PointerCanonicalizer::rewriteForOp(scf::ForOp forOp, Location curLoc, OpOperand *curOperand, Value &nextPtr) { - size_t operandNum = curOperand->getOperandNumber(); FatPtr fatPtr = pointers[curOperand->get()]; Value offset = fatPtr.offset; Value basePtr = fatPtr.basePtr; - // Replace the forOp with two additional argument (i.e., the curOperand's - // scalar pointer and the offset) - Value tensorPtr = createTensorPointer(fatPtr, curLoc); - auto newForOp = - replaceForOpWithNewSignature(rewriter, forOp, {basePtr, offset}); + size_t operandNum = curOperand->getOperandNumber(); + auto numLeadingInitsIncludingCurOp = + operandNum + 1 - forOp.getNumControlOperands(); + + SmallVector newInitArgs( + forOp.getInitArgs().take_front(numLeadingInitsIncludingCurOp)); + newInitArgs.append({basePtr, offset}); + auto trailingInits = + forOp.getInitArgs().drop_front(numLeadingInitsIncludingCurOp); + newInitArgs.append(trailingInits.begin(), trailingInits.end()); + + scf::ForOp newForOp; + { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(forOp); + newForOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), newInitArgs); + newForOp->setAttrs(forOp->getAttrs()); + newForOp.getBody()->erase(); + newForOp.getRegion().getBlocks().splice( + newForOp.getRegion().getBlocks().begin(), + forOp.getRegion().getBlocks()); + + newForOp.getBody()->insertArgument(numLeadingInitsIncludingCurOp + 1, + offset.getType(), offset.getLoc()); + newForOp.getBody()->insertArgument(numLeadingInitsIncludingCurOp + 1, + basePtr.getType(), basePtr.getLoc()); + + for (auto [src, tgt] : llvm::zip( + forOp.getResults().take_front(numLeadingInitsIncludingCurOp), + newForOp.getResults().take_front(numLeadingInitsIncludingCurOp))) + rewriter.replaceAllUsesWith(src, tgt); + for (auto [src, tgt] : + llvm::zip(forOp.getResults().drop_front(numLeadingInitsIncludingCurOp), + newForOp.getResults().drop_front( + numLeadingInitsIncludingCurOp + 2))) + rewriter.replaceAllUsesWith(src, tgt); + + newForOp.getBody()->getTerminator()->insertOperands( + numLeadingInitsIncludingCurOp, + newForOp.getRegionIterArg(numLeadingInitsIncludingCurOp)); + newForOp.getBody()->getTerminator()->insertOperands( + numLeadingInitsIncludingCurOp + 1, + newForOp.getRegionIterArg(numLeadingInitsIncludingCurOp + 1)); + } + rewriteOpMap[forOp] = newForOp; - newForOp->setOperand(operandNum, tensorPtr); + { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(newForOp); + Value tensorPtr = createTensorPointer(fatPtr, curLoc); + newForOp->setOperand(operandNum, tensorPtr); + } + OpOperand *forOperand = &newForOp->getOpOperand(operandNum); // This is making sure we propagate the visit from the forOp result - nextPtr = newForOp.getTiedLoopResult(forOperand); - // This is making sure we visit the uses within the forOp region Value arg = newForOp.getTiedLoopRegionIterArg(forOperand); - size_t numIterArgs = newForOp.getNumRegionIterArgs(); - pointers[arg] = fatPtr.copy(newForOp.getRegionIterArg(numIterArgs - 2), - newForOp.getRegionIterArg(numIterArgs - 1)); + pointers[arg] = fatPtr.copy(basePtr, offset); // Collect attributes before continuing the visit collectFatPointerAttributes(newForOp, arg); @@ -642,9 +685,9 @@ LogicalResult PointerCanonicalizer::rewriteForOp(scf::ForOp forOp, // This is setting the fat pointer for the users of the loop // and then propagate the result - size_t numResults = newForOp->getNumResults(); - pointers[nextPtr] = fatPtr.copy(newForOp->getResult(numResults - 2), - newForOp.getResult(numResults - 1)); + nextPtr = newForOp.getTiedLoopResult(forOperand); + pointers[nextPtr] = fatPtr.copy(basePtr, offset); + opToDelete.insert(forOp); return success(); } @@ -659,12 +702,11 @@ LogicalResult PointerCanonicalizer::rewriteYieldOp(scf::YieldOp yieldOp, // IfOp size_t operandNum = curOperand->getOperandNumber(); FatPtr fatPtr = pointers[curOperand->get()]; - yieldOp.getResultsMutable().append(fatPtr.basePtr); - yieldOp.getResultsMutable().append(fatPtr.offset); - if (auto forOp = dyn_cast(yieldOp->getParentOp())) { yieldOp->setOperand(operandNum, forOp.getRegionIterArg(operandNum)); } else if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { + yieldOp.getResultsMutable().append(fatPtr.basePtr); + yieldOp.getResultsMutable().append(fatPtr.offset); // Case 1: the yieldOp is contained within an IfOp. One of the // two branches is responsible to rewrite the operation. The other // branch only update the yieldOp with the right parameters @@ -683,6 +725,8 @@ LogicalResult PointerCanonicalizer::rewriteYieldOp(scf::YieldOp yieldOp, } else if (auto whileOp = resolveOp(yieldOp->getParentOp(), rewriteOpMap)) { + yieldOp.getResultsMutable().append(fatPtr.basePtr); + yieldOp.getResultsMutable().append(fatPtr.offset); // Case 2: the yieldOp is contained within the AfterRegion of a // WhileOp. In this case, we know that the before region should have // already been replaced (when we met the WhileOp), hence we can From e1d1025f59bded48c2db948f804fe5cbf6b735e2 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Thu, 12 Dec 2024 23:23:58 -0500 Subject: [PATCH 02/17] [wip] dialect conversion --- .../amd/amd-canonicalize-pointers.mlir | 1443 +++++++++-------- .../CanonicalizePointers.cpp | 500 ++++-- 2 files changed, 1105 insertions(+), 838 deletions(-) diff --git a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir index ed47e1512da9..0c3b3b9f977c 100644 --- a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir +++ b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir @@ -1,4 +1,25 @@ // RUN: triton-opt %s -split-input-file -tritonamdgpu-canonicalize-pointers | FileCheck %s + +module { + tt.func public @add_kernel( + %in_ptr0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} , + %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} , + %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} ) -> tensor<1024xf32> attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %pid = tt.get_program_id x : i32 + %block_start = arith.muli %pid, %c1024_i32 : i32 + %make_range = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %block_start_splat = tt.splat %block_start : i32 -> tensor<1024xi32> + %offsets = arith.addi %block_start_splat, %make_range : tensor<1024xi32> + %in_ptr0_splat = tt.splat %in_ptr0 : !tt.ptr -> tensor<1024x!tt.ptr> + %addr = tt.addptr %in_ptr0_splat, %offsets : tensor<1024x!tt.ptr>, tensor<1024xi32> + %val = tt.load %addr : tensor<1024x!tt.ptr> + tt.return %val : tensor<1024xf32> + } +} + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @conversion1 @@ -22,714 +43,714 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} // ----- -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: tt.func @conversion2 - tt.func @conversion2(%arg0: !tt.ptr)-> tensor<1024xf32, #blocked>{ - %c1024_i32 = arith.constant 1024 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> - %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 - // CHECK: %[[baseOffset064bit:.*]] = tt.splat {{.*}} : i64 - // CHECK: %[[newScalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] - // CHECK: %[[offset064bit:.*]] = arith.extsi {{.*}} - // CHECK: %[[offset164bit:.*]] = arith.addi %[[offset064bit]], %[[baseOffset064bit]] - // CHECK: %[[offset132bit:.*]] = arith.trunci %[[offset164bit]] : tensor<1024xi64, #blocked> to tensor<1024xi32, #blocked> - // CHECK: %[[basePtr:.*]] = tt.splat %[[newScalarPtr]] - // CHECK: %[[newPtr:.*]] = tt.addptr %[[basePtr]], %[[offset132bit]] - // CHECK: tt.load %[[newPtr]] - %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %7 = tt.load %6 : tensor<1024x!tt.ptr, #blocked> - tt.return %7 : tensor<1024xf32, #blocked> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: tt.func @conversion3 - tt.func @conversion3(%arg0: !tt.ptr)-> tensor<1024xf32, #blocked>{ - %c1024_i32 = arith.constant 1024 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> - %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> - - //CHECK: %0 = tt.get_program_id x : i32 - //CHECK: %[[pid:.*]] = arith.muli %0, {{.*}} : i32 - //CHECK: %[[makerange:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - //CHECK: %[[uniformOffset1:.*]] = arith.addi %[[pid]], {{.*}} : i32 - //CHECK: %[[tensorOffset1:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> - //CHECK: %[[uniformOffset0:.*]] = arith.addi %[[pid:.*]], %{{.*}} : i32 - //CHECK: %[[tensorOffset3:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> - //CHECK: %[[zero:.*]] = tt.splat %{{.*}} : i64 -> tensor<1024xi64, #blocked> - //CHECK: %[[uniformPtr0:.*]] = tt.addptr %arg0, %[[uniformOffset0:.*]] : !tt.ptr, i32 - //CHECK: %[[tensorOffset3ext:.*]] = arith.extsi %[[tensorOffset3]] : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> - //CHECK: %[[tensorOffset0:.*]]= arith.addi %[[tensorOffset3ext]], %[[zero]] : tensor<1024xi64, #blocked> - //CHECK: %[[uniformPtr1:.*]] = tt.addptr %[[uniformPtr0]], %[[uniformOffset1]] : !tt.ptr, i32 - //CHECK: %[[tensorOffset1ext:.*]] = arith.extsi %[[tensorOffset1]] : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> - //CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset1ext]], %[[tensorOffset0]]: tensor<1024xi64, #blocked> - //CHECK: %[[scalarPtr:.*]] = tt.splat %[[uniformPtr1]] : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - //CHECK: %[[newPtr:.*]] = tt.addptr %[[scalarPtr]], %[[tensorOffset2]] : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi64, #blocked> - //CHECK: tt.load %[[newPtr]] - - %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %8 = tt.load %7 : tensor<1024x!tt.ptr, #blocked> - tt.return %8 : tensor<1024xf32, #blocked> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // - // This is the same as conversion3, but now the `arith.extsi` operations - // disappeared and all the offsets are 32 bits. - // - // CHECK-LABEL: tt.func @conversion4 - tt.func @conversion4(%arg0: !tt.ptr{tt.pointer_range = 32 : i32})-> tensor<1024xf32, #blocked>{ - %c1024_i32 = arith.constant 1024 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> - %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> - - //CHECK: %0 = tt.get_program_id x : i32 - //CHECK: %[[pid:.*]] = arith.muli %0, {{.*}} : i32 - //CHECK: %[[makerange:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - //CHECK: %[[uniformOffset1:.*]] = arith.addi %[[pid]], {{.*}} : i32 - //CHECK: %[[tensorOffset1:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> - //CHECK: %[[uniformOffset0:.*]] = arith.addi %[[pid:.*]], %{{.*}} : i32 - //CHECK: %[[tensorOffset3:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> - //CHECK: %[[zero:.*]] = tt.splat %{{.*}} : i32 -> tensor<1024xi32, #blocked> - //CHECK: %[[uniformPtr0:.*]] = tt.addptr %arg0, %[[uniformOffset0:.*]] : !tt.ptr, i32 - //CHECK: %[[tensorOffset0:.*]]= arith.addi %[[tensorOffset3]], %[[zero]] : tensor<1024xi32, #blocked> - //CHECK: %[[uniformPtr1:.*]] = tt.addptr %[[uniformPtr0]], %[[uniformOffset1]] : !tt.ptr, i32 - //CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset1]], %[[tensorOffset0]]: tensor<1024xi32, #blocked> - //CHECK: %[[scalarPtr:.*]] = tt.splat %[[uniformPtr1]] : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - //CHECK: %[[newPtr:.*]] = tt.addptr %[[scalarPtr]], %[[tensorOffset2]] : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - //CHECK: tt.load %[[newPtr]] - %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %8 = tt.load %7 : tensor<1024x!tt.ptr, #blocked> - tt.return %8 : tensor<1024xf32, #blocked> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: tt.func @forOp - tt.func @forOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ - %c1024_i32 = arith.constant 1024 : i32 - %c0 = arith.constant 0: index - %c128 = arith.constant 128: index - %c1 = arith.constant 1 : index - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - // CHECK: %[[scalarOffsetLoop:.*]] = arith.addi {{.*}}, {{.*}} : i32 - // CHECK: %[[variableOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor - // CHECK: %[[scalarOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 - // CHECK: %[[scalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 - // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor - // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %arg0, %[[scalarOffset]] - // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] - // CHECK: %[[offset1:.*]] = arith.addi %[[ext_offset0]], %{{.*}} : tensor<1024xi64, #blocked> - %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> - %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - // CHECK: %[[loop:.*]]:4 = scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[loopScalarPtr:.*]] = %{{.*}}, %[[loopOffset:.*]] = %[[offset1]]) -> {{.*}} { - %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %6, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ - // CHECK: %[[scalarPtrUpdateLoop:.*]] = tt.addptr %[[loopScalarPtr]], %[[scalarOffsetLoop]] - // CHECK: %[[ext_offset0i:.*]] = arith.extsi %[[variableOffset1]] - // CHECK: %[[offset_i:.*]] = arith.addi %[[ext_offset0i]], %[[loopOffset]] - // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdateLoop]] - // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[offset_i]] - // CHECK: tt.load %[[newPtr]] - // CHECK: scf.yield {{.*}}, {{.*}}, %[[scalarPtrUpdateLoop]], %[[offset_i]] - %11 = tt.addptr %arg1, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> - %10 = arith.addf %9, %arg2 : tensor<1024xf32, #blocked> - scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> - } - // CHECK: tt.addptr %[[loop]]#2, %[[scalarOffset1]] : !tt.ptr, i32 - %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> - tt.return %11 : tensor<1024xf32, #blocked> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: tt.func @forOp2 - tt.func @forOp2(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ - %c1024_i32 = arith.constant 1024 : i32 - %c0 = arith.constant 0: index - %c128 = arith.constant 128: index - %c1 = arith.constant 1 : index - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - // CHECK: %[[scalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 - // CHECK: %[[variableOffset0:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> - // CHECK: %[[finalScalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 - // CHECK: %[[variableOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> - // CHECK: %[[base_offset:.*]] = tt.splat {{.*}} : i64 - %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> - %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - // CHECK: %[[forOut:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr:.*]] = %arg0, %[[loopOffset:.*]] = %[[base_offset]]) - %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %5, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ - // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %[[scalarPtr]], %[[scalarOffset]] - // CHECK: %[[ext_offset0i:.*]] = arith.extsi %[[variableOffset0]] - // CHECK: %[[ext_offset_i:.*]] = arith.addi %[[ext_offset0i]], %[[loopOffset]] - // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdate]] - // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[ext_offset_i]] - // CHECK: tt.load %[[newPtr]] - %11 = tt.addptr %arg1, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> - %10 = arith.addf %9, %arg2 : tensor<1024xf32, #blocked> - scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> - } - // CHECK: %[[scalarPtrFinalUpdate:.*]] = tt.addptr %[[forOut]]#2, %[[finalScalarOffset]] - // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset1]] - // CHECK: %[[tailOffset:.*]] = arith.addi %[[ext_offset0]], %[[forOut]]#3 - // CHECK: %[[tail_base_ptr:.*]] = tt.splat %[[scalarPtrFinalUpdate]] - // CHECK: %[[tailPtr:.*]] = tt.addptr %[[tail_base_ptr]], %[[tailOffset]] - // CHECK: tt.load %[[tailPtr]] - %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> - tt.return %11 : tensor<1024xf32, #blocked> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: tt.func @forNested - tt.func @forNested(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ - %c1024_i32 = arith.constant 1024 : i32 - %c0 = arith.constant 0: index - %c128 = arith.constant 128: index - %c1 = arith.constant 1 : index - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 - // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> - // CHECK: %[[base_offset:.*]] = tt.splat {{.*}} : i64 - %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> - %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - - // CHECK: %[[forOut0:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr0:.*]] = %arg0, %[[loopOffset0:.*]] = %[[base_offset]]){{.*}}{ - // CHECK: %[[forOut1:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr1:.*]] = %[[scalarPtr0]], %[[loopOffset1:.*]] = %[[loopOffset0]]){{.*}}{ - // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %[[scalarPtr1]], %{{.*}} - // CHECK: %[[ext_loop_offset1:.*]] = arith.extsi %[[variableOffset]] - // CHECK: %[[offset_i:.*]] = arith.addi %[[ext_loop_offset1]], %[[loopOffset1]] - // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdate]] - // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[offset_i]] - // CHECK: tt.load %[[newPtr]] - // CHECK: scf.yield %{{.*}}, {{.*}}, %[[scalarPtrUpdate]], %[[offset_i]] - // CHECK: scf.yield %{{.*}}, {{.*}}, %[[forOut1]]#2, %[[forOut1]]#3 - - %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %5, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ - %53:2 = scf.for %arg10 = %c0 to %c128 step %c1 iter_args(%arg3 = %arg1, %arg4 = %arg2) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ - %11 = tt.addptr %arg3, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> - %10 = arith.addf %9, %arg4 : tensor<1024xf32, #blocked> - scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> - } - scf.yield %53#0, %53#1: tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> - } - %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> - tt.return %11 : tensor<1024xf32, #blocked> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: tt.func @ifOp - tt.func @ifOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{ - %c1024_i32 = arith.constant 1024 : i32 - %c0 = arith.constant 0: index - %c128 = arith.constant 128: index - %c1 = arith.constant 1 : index - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 - // CHECK: %[[variableOffset:.*]] = arith.addi - %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> - %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - // CHECK: %[[baseOffsetVariable:.*]] = tt.splat {{.*}} : i64 -> tensor<1024xi64, #blocked> - // CHECK: %[[ifOut:.*]]:3 = scf.if {{.*}} -> (tensor<1024x!tt.ptr, #blocked>, !tt.ptr, tensor<1024xi64, #blocked>) - %6 = scf.if %cond -> (tensor<1024x!tt.ptr, #blocked>){ - // CHECK: %[[scalarOffsetUpdate:.*]] = tt.addptr %arg0, %[[scalarOffset]] - // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] - // CHECK: %[[if_offset:.*]] = arith.addi %[[ext_offset0]], %[[baseOffsetVariable]] - // CHECK: scf.yield %{{.*}}, %[[scalarOffsetUpdate]], %[[if_offset]] - %true = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - scf.yield %true : tensor<1024x!tt.ptr, #blocked> - } else { - // CHECK: %[[new_scalar_ptr:.*]] = tt.addptr %arg0, {{.*}} - // CHECK: scf.yield %{{.*}}, %[[new_scalar_ptr]], %[[baseOffsetVariable]] - %false = tt.addptr %5, %3 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - scf.yield %false : tensor<1024x!tt.ptr, #blocked> - } - // CHECK: %[[trunc_offset:.*]] = arith.trunci %[[ifOut]]#2 - // CHECK: %[[base_ptr:.*]] = tt.splat %[[ifOut]]#1 - // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[trunc_offset]] - // CHECK: tt.load %[[newPtr]] - %11 = tt.load %6 : tensor<1024x!tt.ptr, #blocked> - tt.return %11 : tensor<1024xf32, #blocked> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: tt.func @whileOp - tt.func @whileOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{ - %c1024_i32 = arith.constant 1024 : i32 - %c0 = arith.constant 0: index - %c128 = arith.constant 128: index - %c1 = arith.constant 1 : index - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - // CHECK: %[[base_offset:.*]] = tt.splat %{{.*}} : i64 - %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> - %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - // CHECK: %[[whileOut:.*]]:3 = scf.while ({{.*}}, %[[loopPtr:.*]] = %arg0, %[[loopOffset:.*]] = %[[base_offset]]) - %6 = scf.while (%arg1 = %5, %arg2 = %cond) : (tensor<1024x!tt.ptr, #blocked>, i1) -> (tensor<1024x!tt.ptr, #blocked>) { - // CHECK: scf.condition({{.*}}) %{{.*}}, %[[loopPtr]], %[[loopOffset]] - scf.condition(%arg2) %arg1 : tensor<1024x!tt.ptr, #blocked> - } do { - // CHECK: ^bb{{.*}}(%{{.*}}, %[[blockPtr:.*]]: !tt.ptr, %[[blockOffset:.*]]: tensor<1024xi64, #blocked>): - ^bb0(%arg1: tensor<1024x!tt.ptr, #blocked>): - // CHECK: scf.yield {{.*}}, %[[blockPtr]], %[[blockOffset]] - scf.yield %arg1, %cond : tensor<1024x!tt.ptr, #blocked>, i1 - } - // CHECK: %[[trunc_offset:.*]] = arith.trunci %[[whileOut]]#2 - // CHECK: %[[base_ptr:.*]] = tt.splat %[[whileOut]]#1 - // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[trunc_offset]] - // CHECK: tt.load %[[newPtr]] - %11 = tt.load %6 : tensor<1024x!tt.ptr, #blocked> - tt.return %11 : tensor<1024xf32, #blocked> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: tt.func @condBranch - tt.func @condBranch(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ - %c1024_i32 = arith.constant 1024 : i32 - %c0 = arith.constant 0: index - %c128 = arith.constant 128: index - %c1 = arith.constant 1 : index - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 - // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> - // CHECK: %[[base_offset:.*]] = tt.splat %{{.*}} : i64 - // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] - // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] - %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> - %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - // CHECK: %[[offset1:.*]] = arith.addi %[[ext_offset0]], %[[base_offset]] - %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - // CHECK: cf.cond_br {{.*}}, ^bb1(%{{.*}}, %arg0, %[[base_offset]] : {{.*}}), ^bb2(%{{.*}}, %[[scalarPtr]], %[[offset1]] : {{.*}}) - cf.cond_br %i1, ^bb1(%5 : tensor<1024x!tt.ptr, #blocked>), ^bb2(%6 : tensor<1024x!tt.ptr, #blocked>) - // CHECK: ^bb1({{.*}}, %[[block1ScalarPtr:.*]]: !tt.ptr, %[[block1Offset:.*]]: tensor<1024xi64, #blocked>) - ^bb1(%arg1 : tensor<1024x!tt.ptr, #blocked>): - // CHECK: %[[trunc_offset_1:.*]] = arith.trunci %[[block1Offset]] - // CHECK: %[[basePtr1:.*]] = tt.splat %[[block1ScalarPtr]] - // CHECK: %[[newPtr1:.*]] = tt.addptr %[[basePtr1]], %[[trunc_offset_1]] - // CHECK: tt.load %[[newPtr1]] - %out1 = tt.load %arg1 : tensor<1024x!tt.ptr, #blocked> - tt.return %out1 : tensor<1024xf32, #blocked> - // CHECK: ^bb2({{.*}}, %[[block2ScalarPtr:.*]]: !tt.ptr, %[[block2Offset:.*]]: tensor<1024xi64, #blocked>) - ^bb2(%arg2 : tensor<1024x!tt.ptr, #blocked>): // 2 preds: ^bb0, ^bb1 - // CHECK: %[[trunc_offset_2:.*]] = arith.trunci %[[block2Offset]] - // CHECK: %[[basePtr2:.*]] = tt.splat %[[block2ScalarPtr]] - // CHECK: %[[newPtr2:.*]] = tt.addptr %[[basePtr2]], %[[trunc_offset_2]] - // CHECK: tt.load %[[newPtr2]] - %out2 = tt.load %arg2 : tensor<1024x!tt.ptr, #blocked> - tt.return %out2 : tensor<1024xf32, #blocked> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: tt.func @branch - tt.func @branch(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ - %c1024_i32 = arith.constant 1024 : i32 - %c0 = arith.constant 0: index - %c128 = arith.constant 128: index - %c1 = arith.constant 1 : index - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 - // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> - // CHECK: %[[base_offset:.*]] = tt.splat %{{.*}} : i64 - // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] - // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] - %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> - %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - // CHECK: %[[offset1:.*]] = arith.addi %[[ext_offset0]], %[[base_offset]] - %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - // CHECK: cf.br ^bb1(%{{.*}}, %[[scalarPtr]], %[[offset1]] : {{.*}}) - // CHECK: ^bb1({{.*}}, %[[block1ScalarPtr:.*]]: {{.*}}, %[[block1Offset:.*]]: {{.*}}) - cf.br ^bb1(%6 : tensor<1024x!tt.ptr, #blocked>) - ^bb1(%arg1 : tensor<1024x!tt.ptr, #blocked>): - // CHECK: %[[trunc_offset_1:.*]] = arith.trunci %[[block1Offset]] - // CHECK: %[[basePtr1:.*]] = tt.splat %[[block1ScalarPtr]] - // CHECK: %[[newPtr1:.*]] = tt.addptr %[[basePtr1]], %[[trunc_offset_1]] - // CHECK: tt.load %[[newPtr1]] - %out1 = tt.load %arg1 : tensor<1024x!tt.ptr, #blocked> - tt.return %out1 : tensor<1024xf32, #blocked> - } -} - -// ----- - -// The following is a simple case of a tile offset like: (A*B + C + D) where B,C are Uniform and A,D are not. So -// we expect that the Uniform offset (which can be added to the scalar pointer) will be simply C and the NonUniform -// offset will be A*B+D -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: tt.func @tile_offset - tt.func @tile_offset(%arg1: !tt.ptr, %arg5: i32 , %arg7: i32 ) { - %c128_i32 = arith.constant 128 : i32 - %c256_i32 = arith.constant 256 : i32 - %1 = tt.get_program_id x : i32 - %20 = arith.muli %1, %c256_i32 : i32 - %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %24 = tt.splat %20 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %26 = arith.addi %24, %22 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %36 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %37 = tt.expand_dims %36 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> - %38 = tt.splat %arg7 : i32 -> tensor<16x1xi32, #blocked> - %39 = arith.muli %37, %38 : tensor<16x1xi32, #blocked> - %41 = tt.expand_dims %26 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - %42 = tt.broadcast %39 : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> - %43 = tt.broadcast %41 : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> - %44 = arith.addi %42, %43 : tensor<16x256xi32, #blocked> - %45 = tt.splat %arg1 : !tt.ptr -> tensor<16x256x!tt.ptr, #blocked> - %46 = tt.addptr %45, %44 : tensor<16x256x!tt.ptr, #blocked>, tensor<16x256xi32, #blocked> - // CHECK: %[[uniformOffset1:.*]] = arith.muli %c0_i32_0, %arg2 : i32 - // CHECK: {{.*}} = tt.expand_dims %{{.*}} {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - // CHECK: %[[tensorOffset6:.*]] = tt.expand_dims %{{.*}} {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> - // CHECK: {{.*}} = tt.broadcast %{{.*}} : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> - // CHECK: %[[tensorOffset3:.*]] = tt.broadcast %{{.*}} : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> - // CHECK: %[[tensorOffset4:.*]] = tt.broadcast %{{.*}} : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> - // CHECK: %[[tensorOffset5:.*]] = tt.broadcast %[[tensorOffset6]] : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> - // CHECK: %[[uniformOffset:.*]] = arith.addi %[[uniformOffset1]], %{{.*}}: i32 - // CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset3]], %[[tensorOffset5]] : tensor<16x256xi32, #blocked> - // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[uniformOffset]] : !tt.ptr, i32 - // CHECK: %[[tensorOffset2ext:.*]] = arith.extsi %[[tensorOffset2]] : tensor<16x256xi32, #blocked> to tensor<16x256xi64, #blocked> - // CHECK: %[[tensorOffset1:.*]] = arith.addi %[[tensorOffset2ext]], %{{.*}} : tensor<16x256xi64, #blocked> - // CHECK: %[[tensorOffset:.*]] = arith.trunci %[[tensorOffset1:.*]] : tensor<16x256xi64, #blocked> to tensor<16x256xi32, #blocked> - // CHECK: %[[ptr:.*]] = tt.splat %[[scalarPtr]] : !tt.ptr -> tensor<16x256x!tt.ptr, #blocked> - // CHECK: tt.addptr %[[ptr]], %[[tensorOffset]] : tensor<16x256x!tt.ptr, #block - %61 = tt.load %46 : tensor<16x256x!tt.ptr, #blocked> - tt.return - } -} - -// ----- - -// The following is a more complex case where also a multiplication is involved. It's useful to walk through the case. -// We have that the offset to the pointer is the following: -// %12 = %10 + 11 -// This can be transformed in: -// = %7 + %9 -// = %5*%6 + %8 -// = %4*%arg1 + %8 -// = (%3+%2)*%arg1 + %8 -// = (%1 + %2) * %arg1 + %8 -// = (U + N)*U + N -// Where U means uniform (e.g., a splat) and N means NonUniform (e.g., a make_range) -// The scalar offset we want is (%1*%arg1), while the variable offset should be (%2*%arg1 + %8) -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: tt.func public @matmul_kernel - tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}) { - %c128_i32 = arith.constant 128 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c128_i32 : i32 - %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %3 = tt.splat %1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %4 = arith.addi %3, %2 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> - %6 = tt.splat %arg1 : i32 -> tensor<128x1xi32, #blocked> - %7 = arith.muli %5, %6 : tensor<128x1xi32, #blocked> - %8 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> - %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> - %10 = tt.broadcast %7 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> - %11 = tt.broadcast %9 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> - %12 = arith.addi %10, %11 : tensor<128x16xi32, #blocked> - %13 = tt.splat %arg0 : !tt.ptr -> tensor<128x16x!tt.ptr, #blocked> - %14 = tt.addptr %13, %12 : tensor<128x16x!tt.ptr, #blocked>, tensor<128x16xi32, #blocked> - %15 = tt.load %14 : tensor<128x16x!tt.ptr, #blocked> - // CHECK: %[[pid:.*]] = tt.get_program_id x : i32 - // CHECK: %[[uniformOffset3:.*]] = arith.muli %[[pid]], %{{.*}} : i32 - // CHECK: %[[uniformOffset2:.*]] = arith.addi %[[uniformOffset3]], %{{.*}} : i32 - // CHECK: %[[uniformOffset1:.*]] = arith.muli %[[uniformOffset2]], %arg1 : i32 - // CHECK: %[[makerange:.*]] = tt.make_range - // CHECK: %{{.*}} = tt.expand_dims %[[makerange]] {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> - // CHECK: %[[tensorOffset6:.*]] = tt.expand_dims %[[makerange]] {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> - // CHECK: %{{.*}} = tt.broadcast %{{.*}} : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> - // CHECK: %[[tensorOffset3:.*]] = tt.broadcast %{{.*}} : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> - // CHECK: %{{.*}} = tt.broadcast %{{.*}} : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> - // CHECK: %[[tensorOffset4:.*]] = tt.broadcast %[[tensorOffset6]] : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> - // CHECK: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : tensor<128x16xi32, #blocked> - // CHECK: %[[uniformOffset:.*]] = arith.addi %[[uniformOffset1]], %{{.*}} : i32 - // CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset3]], %[[tensorOffset4]] : tensor<128x16xi32, #blocked> - // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[uniformOffset]] : !tt.ptr, i32 - // CHECK: %[[tensorOffset1Ext:.*]] = arith.extsi %[[tensorOffset2]] : tensor<128x16xi32, #blocked> to tensor<128x16xi64, #blocked> - // CHECK: %[[tensorOffset:.*]] = arith.addi %[[tensorOffset1Ext]], %{{.*}} : tensor<128x16xi64, #blocked> - // CHECK: %[[tensorOffsetTrunc:.*]] = arith.trunci %[[tensorOffset]] : tensor<128x16xi64, #blocked> to tensor<128x16xi32, #blocked> - // CHECK: %[[ptr:.*]] = tt.splat %[[scalarPtr]] : !tt.ptr -> tensor<128x16x!tt.ptr, #blocked> - // CHECK: tt.addptr %[[ptr]], %[[tensorOffsetTrunc]] : tensor<128x16x!tt.ptr, #blocked>, tensor<128x16xi32, #blocked> - tt.return - } -} - - -// ----- -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: tt.func @select - tt.func @select(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ - %c1024_i32 = arith.constant 1024 : i32 - %c0 = arith.constant 0: index - %c128 = arith.constant 128: index - %c1 = arith.constant 1 : index - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 - // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> - // CHECK: %[[baseOffset:.*]] = tt.splat %{{.*}} : i64 - // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] - // CHECK: %[[extVariableOffset:.*]] = arith.extsi %[[variableOffset]] - %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> - %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - // CHECK: %[[offset2:.*]] = arith.addi %[[extVariableOffset]], %[[baseOffset]] - // CHECK: %[[scalarPtr1:.*]] = arith.select %arg1, %arg0, %[[scalarPtr]] - // CHECK: %[[offset0:.*]] = arith.select %arg1, {{.*}}, %[[offset2]] - // CHECK: %[[offset1:.*]] = arith.trunci %[[offset0]] - // CHECK: %[[ptr:.*]] = tt.splat %[[scalarPtr1]] - // CHECK: tt.addptr %[[ptr]], %[[offset1]] - %7 = arith.select %i1, %5 , %6 : tensor<1024x!tt.ptr, #blocked> - %out = tt.load %7: tensor<1024x!tt.ptr, #blocked> - tt.return %out : tensor<1024xf32, #blocked> - } -} - -// ----- -#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: tt.func @where_kernel - tt.func @where_kernel(%arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}){ - %c0_i8 = arith.constant 0 : i8 - %c1024_i32 = arith.constant 1024 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> - %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> - %9 = arith.cmpi ne, %c0_i8, %c0_i8 : i8 - %10 = arith.select %9, %arg1, %arg2 : !tt.ptr - // CHECK: %[[selectPtr:.*]] = arith.select {{.*}} : !tt.ptr - %11 = tt.splat %10: !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - %13 = tt.addptr %11, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - // CHECK: %[[selectPtr0:.*]] = tt.addptr %[[selectPtr]] - // CHECK: %[[tensorPtr:.*]] = tt.splat %[[selectPtr0]] - // CHECK: tt.addptr %[[tensorPtr]] - %14 = tt.load %13 : tensor<1024x!tt.ptr, #blocked> - tt.return - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: tt.func @forOpWithHints - tt.func @forOpWithHints(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ - %c0 = arith.constant 0: index - %c1 = arith.constant 1 : index - %c128 = arith.constant 128: index - %0 = tt.get_program_id x : i32 - %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - %3 = tt.splat %0 : i32 -> tensor<1024xi32, #blocked> - %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %6, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ - %9 = tt.load %arg1: tensor<1024x!tt.ptr, #blocked> - // CHECK: tt.addptr {{.*}}, {{.*}} {tt.divisibility = dense<16> : tensor<1xi32>} - %11 = tt.addptr %arg1, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %12 = tt.addptr %11, %3 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %10 = arith.addf %9, %arg2 : tensor<1024xf32, #blocked> - scf.yield %12, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> - } {"tt.divisibility_arg1"=dense<[16]> : tensor<1xi32>} - // CHECK: tt.divisibility_arg1 - // CHECK-SAME: tt.divisibility_arg4 - %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> - tt.return %11 : tensor<1024xf32, #blocked> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: scalar_pointers - tt.func public @scalar_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %0 = tt.get_program_id x : i32 - %c1_i32 = arith.constant 1 : i32 - %c0_i64 = arith.constant 0 : i64 - %c10_i64 = arith.constant 10 : i64 - %c100_i32 = arith.constant 100 : i32 - %5 = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 - // CHECK: arith.constant 0 : i64 - // CHECK: arith.constant 0 : i64 - // CHECK: %[[offset0:.*]] = arith.constant 0 : i64 - // CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 - // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[ptr1:.*]] = %[[ptr0]], %[[offset1:.*]] = %[[offset0]]) - %10:1 = scf.for %arg3 = %c1_i32 to %c100_i32 step %c1_i32 iter_args(%arg4 = %5) -> (!tt.ptr) : i32 { - // CHECK: tt.store %[[ptr1]] - tt.store %arg4, %c0_i64 : !tt.ptr - // CHECK: tt.addptr %[[ptr1]] - %11 = tt.addptr %arg4, %c1_i32 : !tt.ptr, i32 - scf.yield %11 : !tt.ptr - } - tt.return - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: @scalar_if - tt.func @scalar_if(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)->f32{ - %0 = tt.get_program_id x : i32 - %c1_i32 = arith.constant 1 : i32 - %c0_i64 = arith.constant 0 : i64 - %c10_i64 = arith.constant 10 : i64 - %c100_i32 = arith.constant 100 : i32 - %5 = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 - // CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %{{.*}} - // CHECK: scf.if {{.*}} -> ({{.*}}, !tt.ptr, i64) - %6 = scf.if %cond -> (!tt.ptr){ - %true = tt.addptr %5, %c1_i32 : !tt.ptr, i32 - // CHECK: %[[ptr1:.*]] = tt.addptr %[[ptr0]] - // CHECK: scf.yield {{.*}}, %[[ptr1]] - scf.yield %true : !tt.ptr - } else { - %false = tt.addptr %5, %c100_i32 : !tt.ptr, i32 - // CHECK: %[[ptr2:.*]] = tt.addptr %[[ptr0]] - // CHECK: scf.yield {{.*}}, %[[ptr2]] - scf.yield %false : !tt.ptr - } - %11 = tt.load %6 : !tt.ptr - tt.return %11 : f32 - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: tt.func @scalar_while - tt.func @scalar_while(%arg0: !tt.ptr, %init : f32, %cond : i1)->f32{ - %c1024_i32 = arith.constant 1024 : i32 - %c0 = arith.constant 0: index - %c128 = arith.constant 128: index - %c1 = arith.constant 1 : index - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - // CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %{{.*}} - // CHECK: scf.while ({{.*}}, {{.*}} = %arg2, %[[ptr1:.*]] = %[[ptr0]], {{.*}}) - %2 = tt.addptr %arg0, %0 : !tt.ptr, i32 - %6 = scf.while (%arg1 = %2, %arg2 = %cond) : (!tt.ptr, i1) -> (!tt.ptr) { - // CHECK: scf.condition({{.*}}) {{.*}}, %[[ptr1]] - scf.condition(%arg2) %arg1 : !tt.ptr - } do { - // CHECK: ^bb0({{.*}}: !tt.ptr, %[[ptr2:.*]]: !tt.ptr, {{.*}}) - // CHECK: scf.yield %{{.*}}, {{.*}} %[[ptr2]], {{.*}}, {{.*}} - ^bb0(%arg1: !tt.ptr): - scf.yield %arg1, %cond : !tt.ptr, i1 - } - %11 = tt.load %6 : !tt.ptr - tt.return %11 : f32 - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: tt.func @scalar_cond_branch - tt.func @scalar_cond_branch(%arg0 : !tt.ptr, %i1 : i1) -> f32{ - %c1024_i32 = arith.constant 1024 : i32 - %c0 = arith.constant 0: index - %c128 = arith.constant 128: index - %c1 = arith.constant 1 : index - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %6 = tt.addptr %arg0, %0 : !tt.ptr, i32 - // CHECK: %[[ptr0:.*]] = tt.addptr %arg0 - // CHECK: cf.cond_br %arg1, ^bb1(%{{.*}}, %[[ptr0]], {{.*}}), ^bb2(%{{.*}}, %arg0, {{.*}}) - cf.cond_br %i1, ^bb1(%6 : !tt.ptr), ^bb2(%arg0 : !tt.ptr) - // CHECK: ^bb1({{.*}}, %[[ptr1:.*]]: !tt.ptr, {{.*}}): - ^bb1(%arg1 : !tt.ptr): - // CHECK: tt.load %[[ptr1]] - %out1 = tt.load %arg1 : !tt.ptr - tt.return %out1 : f32 - // CHECK: ^bb2({{.*}}, %[[ptr2:.*]]: !tt.ptr, {{.*}}): - ^bb2(%arg2 : !tt.ptr): // 2 preds: ^bb0, ^bb1 - // CHECK: tt.load %[[ptr2]] - %out2 = tt.load %arg2 : !tt.ptr - tt.return %out2 : f32 - } -} +//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { +// // CHECK-LABEL: tt.func @conversion2 +// tt.func @conversion2(%arg0: !tt.ptr)-> tensor<1024xf32, #blocked>{ +// %c1024_i32 = arith.constant 1024 : i32 +// %0 = tt.get_program_id x : i32 +// %1 = arith.muli %0, %c1024_i32 : i32 +// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> +// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> +// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> +// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> +// // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 +// // CHECK: %[[baseOffset064bit:.*]] = tt.splat {{.*}} : i64 +// // CHECK: %[[newScalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] +// // CHECK: %[[offset064bit:.*]] = arith.extsi {{.*}} +// // CHECK: %[[offset164bit:.*]] = arith.addi %[[offset064bit]], %[[baseOffset064bit]] +// // CHECK: %[[offset132bit:.*]] = arith.trunci %[[offset164bit]] : tensor<1024xi64, #blocked> to tensor<1024xi32, #blocked> +// // CHECK: %[[basePtr:.*]] = tt.splat %[[newScalarPtr]] +// // CHECK: %[[newPtr:.*]] = tt.addptr %[[basePtr]], %[[offset132bit]] +// // CHECK: tt.load %[[newPtr]] +// %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// %7 = tt.load %6 : tensor<1024x!tt.ptr, #blocked> +// tt.return %7 : tensor<1024xf32, #blocked> +// } +//} +// +//// ----- +// +//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { +// // CHECK-LABEL: tt.func @conversion3 +// tt.func @conversion3(%arg0: !tt.ptr)-> tensor<1024xf32, #blocked>{ +// %c1024_i32 = arith.constant 1024 : i32 +// %0 = tt.get_program_id x : i32 +// %1 = arith.muli %0, %c1024_i32 : i32 +// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> +// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> +// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> +// +// //CHECK: %0 = tt.get_program_id x : i32 +// //CHECK: %[[pid:.*]] = arith.muli %0, {{.*}} : i32 +// //CHECK: %[[makerange:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> +// //CHECK: %[[uniformOffset1:.*]] = arith.addi %[[pid]], {{.*}} : i32 +// //CHECK: %[[tensorOffset1:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> +// //CHECK: %[[uniformOffset0:.*]] = arith.addi %[[pid:.*]], %{{.*}} : i32 +// //CHECK: %[[tensorOffset3:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> +// //CHECK: %[[zero:.*]] = tt.splat %{{.*}} : i64 -> tensor<1024xi64, #blocked> +// //CHECK: %[[uniformPtr0:.*]] = tt.addptr %arg0, %[[uniformOffset0:.*]] : !tt.ptr, i32 +// //CHECK: %[[tensorOffset3ext:.*]] = arith.extsi %[[tensorOffset3]] : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> +// //CHECK: %[[tensorOffset0:.*]]= arith.addi %[[tensorOffset3ext]], %[[zero]] : tensor<1024xi64, #blocked> +// //CHECK: %[[uniformPtr1:.*]] = tt.addptr %[[uniformPtr0]], %[[uniformOffset1]] : !tt.ptr, i32 +// //CHECK: %[[tensorOffset1ext:.*]] = arith.extsi %[[tensorOffset1]] : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> +// //CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset1ext]], %[[tensorOffset0]]: tensor<1024xi64, #blocked> +// //CHECK: %[[scalarPtr:.*]] = tt.splat %[[uniformPtr1]] : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> +// //CHECK: %[[newPtr:.*]] = tt.addptr %[[scalarPtr]], %[[tensorOffset2]] : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi64, #blocked> +// //CHECK: tt.load %[[newPtr]] +// +// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> +// %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// %8 = tt.load %7 : tensor<1024x!tt.ptr, #blocked> +// tt.return %8 : tensor<1024xf32, #blocked> +// } +//} +// +//// ----- +// +//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { +// // +// // This is the same as conversion3, but now the `arith.extsi` operations +// // disappeared and all the offsets are 32 bits. +// // +// // CHECK-LABEL: tt.func @conversion4 +// tt.func @conversion4(%arg0: !tt.ptr{tt.pointer_range = 32 : i32})-> tensor<1024xf32, #blocked>{ +// %c1024_i32 = arith.constant 1024 : i32 +// %0 = tt.get_program_id x : i32 +// %1 = arith.muli %0, %c1024_i32 : i32 +// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> +// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> +// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> +// +// //CHECK: %0 = tt.get_program_id x : i32 +// //CHECK: %[[pid:.*]] = arith.muli %0, {{.*}} : i32 +// //CHECK: %[[makerange:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> +// //CHECK: %[[uniformOffset1:.*]] = arith.addi %[[pid]], {{.*}} : i32 +// //CHECK: %[[tensorOffset1:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> +// //CHECK: %[[uniformOffset0:.*]] = arith.addi %[[pid:.*]], %{{.*}} : i32 +// //CHECK: %[[tensorOffset3:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> +// //CHECK: %[[zero:.*]] = tt.splat %{{.*}} : i32 -> tensor<1024xi32, #blocked> +// //CHECK: %[[uniformPtr0:.*]] = tt.addptr %arg0, %[[uniformOffset0:.*]] : !tt.ptr, i32 +// //CHECK: %[[tensorOffset0:.*]]= arith.addi %[[tensorOffset3]], %[[zero]] : tensor<1024xi32, #blocked> +// //CHECK: %[[uniformPtr1:.*]] = tt.addptr %[[uniformPtr0]], %[[uniformOffset1]] : !tt.ptr, i32 +// //CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset1]], %[[tensorOffset0]]: tensor<1024xi32, #blocked> +// //CHECK: %[[scalarPtr:.*]] = tt.splat %[[uniformPtr1]] : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> +// //CHECK: %[[newPtr:.*]] = tt.addptr %[[scalarPtr]], %[[tensorOffset2]] : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// //CHECK: tt.load %[[newPtr]] +// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> +// %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// %8 = tt.load %7 : tensor<1024x!tt.ptr, #blocked> +// tt.return %8 : tensor<1024xf32, #blocked> +// } +//} +// +//// ----- +// +//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { +// // CHECK-LABEL: tt.func @forOp +// tt.func @forOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ +// %c1024_i32 = arith.constant 1024 : i32 +// %c0 = arith.constant 0: index +// %c128 = arith.constant 128: index +// %c1 = arith.constant 1 : index +// %0 = tt.get_program_id x : i32 +// %1 = arith.muli %0, %c1024_i32 : i32 +// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> +// // CHECK: %[[scalarOffsetLoop:.*]] = arith.addi {{.*}}, {{.*}} : i32 +// // CHECK: %[[variableOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor +// // CHECK: %[[scalarOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 +// // CHECK: %[[scalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 +// // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor +// // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %arg0, %[[scalarOffset]] +// // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] +// // CHECK: %[[offset1:.*]] = arith.addi %[[ext_offset0]], %{{.*}} : tensor<1024xi64, #blocked> +// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> +// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> +// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> +// %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// // CHECK: %[[loop:.*]]:4 = scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[loopScalarPtr:.*]] = %{{.*}}, %[[loopOffset:.*]] = %[[offset1]]) -> {{.*}} { +// %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %6, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ +// // CHECK: %[[scalarPtrUpdateLoop:.*]] = tt.addptr %[[loopScalarPtr]], %[[scalarOffsetLoop]] +// // CHECK: %[[ext_offset0i:.*]] = arith.extsi %[[variableOffset1]] +// // CHECK: %[[offset_i:.*]] = arith.addi %[[ext_offset0i]], %[[loopOffset]] +// // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdateLoop]] +// // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[offset_i]] +// // CHECK: tt.load %[[newPtr]] +// // CHECK: scf.yield {{.*}}, {{.*}}, %[[scalarPtrUpdateLoop]], %[[offset_i]] +// %11 = tt.addptr %arg1, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> +// %10 = arith.addf %9, %arg2 : tensor<1024xf32, #blocked> +// scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> +// } +// // CHECK: tt.addptr %[[loop]]#2, %[[scalarOffset1]] : !tt.ptr, i32 +// %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> +// tt.return %11 : tensor<1024xf32, #blocked> +// } +//} +// +//// ----- +// +//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { +// // CHECK-LABEL: tt.func @forOp2 +// tt.func @forOp2(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ +// %c1024_i32 = arith.constant 1024 : i32 +// %c0 = arith.constant 0: index +// %c128 = arith.constant 128: index +// %c1 = arith.constant 1 : index +// %0 = tt.get_program_id x : i32 +// %1 = arith.muli %0, %c1024_i32 : i32 +// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> +// // CHECK: %[[scalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 +// // CHECK: %[[variableOffset0:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> +// // CHECK: %[[finalScalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 +// // CHECK: %[[variableOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> +// // CHECK: %[[base_offset:.*]] = tt.splat {{.*}} : i64 +// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> +// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> +// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> +// // CHECK: %[[forOut:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr:.*]] = %arg0, %[[loopOffset:.*]] = %[[base_offset]]) +// %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %5, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ +// // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %[[scalarPtr]], %[[scalarOffset]] +// // CHECK: %[[ext_offset0i:.*]] = arith.extsi %[[variableOffset0]] +// // CHECK: %[[ext_offset_i:.*]] = arith.addi %[[ext_offset0i]], %[[loopOffset]] +// // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdate]] +// // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[ext_offset_i]] +// // CHECK: tt.load %[[newPtr]] +// %11 = tt.addptr %arg1, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> +// %10 = arith.addf %9, %arg2 : tensor<1024xf32, #blocked> +// scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> +// } +// // CHECK: %[[scalarPtrFinalUpdate:.*]] = tt.addptr %[[forOut]]#2, %[[finalScalarOffset]] +// // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset1]] +// // CHECK: %[[tailOffset:.*]] = arith.addi %[[ext_offset0]], %[[forOut]]#3 +// // CHECK: %[[tail_base_ptr:.*]] = tt.splat %[[scalarPtrFinalUpdate]] +// // CHECK: %[[tailPtr:.*]] = tt.addptr %[[tail_base_ptr]], %[[tailOffset]] +// // CHECK: tt.load %[[tailPtr]] +// %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> +// tt.return %11 : tensor<1024xf32, #blocked> +// } +//} +// +//// ----- +// +//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { +// // CHECK-LABEL: tt.func @forNested +// tt.func @forNested(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ +// %c1024_i32 = arith.constant 1024 : i32 +// %c0 = arith.constant 0: index +// %c128 = arith.constant 128: index +// %c1 = arith.constant 1 : index +// %0 = tt.get_program_id x : i32 +// %1 = arith.muli %0, %c1024_i32 : i32 +// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> +// // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 +// // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> +// // CHECK: %[[base_offset:.*]] = tt.splat {{.*}} : i64 +// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> +// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> +// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> +// +// // CHECK: %[[forOut0:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr0:.*]] = %arg0, %[[loopOffset0:.*]] = %[[base_offset]]){{.*}}{ +// // CHECK: %[[forOut1:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr1:.*]] = %[[scalarPtr0]], %[[loopOffset1:.*]] = %[[loopOffset0]]){{.*}}{ +// // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %[[scalarPtr1]], %{{.*}} +// // CHECK: %[[ext_loop_offset1:.*]] = arith.extsi %[[variableOffset]] +// // CHECK: %[[offset_i:.*]] = arith.addi %[[ext_loop_offset1]], %[[loopOffset1]] +// // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdate]] +// // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[offset_i]] +// // CHECK: tt.load %[[newPtr]] +// // CHECK: scf.yield %{{.*}}, {{.*}}, %[[scalarPtrUpdate]], %[[offset_i]] +// // CHECK: scf.yield %{{.*}}, {{.*}}, %[[forOut1]]#2, %[[forOut1]]#3 +// +// %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %5, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ +// %53:2 = scf.for %arg10 = %c0 to %c128 step %c1 iter_args(%arg3 = %arg1, %arg4 = %arg2) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ +// %11 = tt.addptr %arg3, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> +// %10 = arith.addf %9, %arg4 : tensor<1024xf32, #blocked> +// scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> +// } +// scf.yield %53#0, %53#1: tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> +// } +// %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> +// tt.return %11 : tensor<1024xf32, #blocked> +// } +//} +// +//// ----- +// +//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { +// // CHECK-LABEL: tt.func @ifOp +// tt.func @ifOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{ +// %c1024_i32 = arith.constant 1024 : i32 +// %c0 = arith.constant 0: index +// %c128 = arith.constant 128: index +// %c1 = arith.constant 1 : index +// %0 = tt.get_program_id x : i32 +// %1 = arith.muli %0, %c1024_i32 : i32 +// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> +// // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 +// // CHECK: %[[variableOffset:.*]] = arith.addi +// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> +// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> +// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> +// // CHECK: %[[baseOffsetVariable:.*]] = tt.splat {{.*}} : i64 -> tensor<1024xi64, #blocked> +// // CHECK: %[[ifOut:.*]]:3 = scf.if {{.*}} -> (tensor<1024x!tt.ptr, #blocked>, !tt.ptr, tensor<1024xi64, #blocked>) +// %6 = scf.if %cond -> (tensor<1024x!tt.ptr, #blocked>){ +// // CHECK: %[[scalarOffsetUpdate:.*]] = tt.addptr %arg0, %[[scalarOffset]] +// // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] +// // CHECK: %[[if_offset:.*]] = arith.addi %[[ext_offset0]], %[[baseOffsetVariable]] +// // CHECK: scf.yield %{{.*}}, %[[scalarOffsetUpdate]], %[[if_offset]] +// %true = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// scf.yield %true : tensor<1024x!tt.ptr, #blocked> +// } else { +// // CHECK: %[[new_scalar_ptr:.*]] = tt.addptr %arg0, {{.*}} +// // CHECK: scf.yield %{{.*}}, %[[new_scalar_ptr]], %[[baseOffsetVariable]] +// %false = tt.addptr %5, %3 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// scf.yield %false : tensor<1024x!tt.ptr, #blocked> +// } +// // CHECK: %[[trunc_offset:.*]] = arith.trunci %[[ifOut]]#2 +// // CHECK: %[[base_ptr:.*]] = tt.splat %[[ifOut]]#1 +// // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[trunc_offset]] +// // CHECK: tt.load %[[newPtr]] +// %11 = tt.load %6 : tensor<1024x!tt.ptr, #blocked> +// tt.return %11 : tensor<1024xf32, #blocked> +// } +//} +// +//// ----- +// +//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { +// // CHECK-LABEL: tt.func @whileOp +// tt.func @whileOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{ +// %c1024_i32 = arith.constant 1024 : i32 +// %c0 = arith.constant 0: index +// %c128 = arith.constant 128: index +// %c1 = arith.constant 1 : index +// %0 = tt.get_program_id x : i32 +// %1 = arith.muli %0, %c1024_i32 : i32 +// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> +// // CHECK: %[[base_offset:.*]] = tt.splat %{{.*}} : i64 +// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> +// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> +// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> +// // CHECK: %[[whileOut:.*]]:3 = scf.while ({{.*}}, %[[loopPtr:.*]] = %arg0, %[[loopOffset:.*]] = %[[base_offset]]) +// %6 = scf.while (%arg1 = %5, %arg2 = %cond) : (tensor<1024x!tt.ptr, #blocked>, i1) -> (tensor<1024x!tt.ptr, #blocked>) { +// // CHECK: scf.condition({{.*}}) %{{.*}}, %[[loopPtr]], %[[loopOffset]] +// scf.condition(%arg2) %arg1 : tensor<1024x!tt.ptr, #blocked> +// } do { +// // CHECK: ^bb{{.*}}(%{{.*}}, %[[blockPtr:.*]]: !tt.ptr, %[[blockOffset:.*]]: tensor<1024xi64, #blocked>): +// ^bb0(%arg1: tensor<1024x!tt.ptr, #blocked>): +// // CHECK: scf.yield {{.*}}, %[[blockPtr]], %[[blockOffset]] +// scf.yield %arg1, %cond : tensor<1024x!tt.ptr, #blocked>, i1 +// } +// // CHECK: %[[trunc_offset:.*]] = arith.trunci %[[whileOut]]#2 +// // CHECK: %[[base_ptr:.*]] = tt.splat %[[whileOut]]#1 +// // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[trunc_offset]] +// // CHECK: tt.load %[[newPtr]] +// %11 = tt.load %6 : tensor<1024x!tt.ptr, #blocked> +// tt.return %11 : tensor<1024xf32, #blocked> +// } +//} +// +//// ----- +// +//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { +// // CHECK-LABEL: tt.func @condBranch +// tt.func @condBranch(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ +// %c1024_i32 = arith.constant 1024 : i32 +// %c0 = arith.constant 0: index +// %c128 = arith.constant 128: index +// %c1 = arith.constant 1 : index +// %0 = tt.get_program_id x : i32 +// %1 = arith.muli %0, %c1024_i32 : i32 +// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> +// // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 +// // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> +// // CHECK: %[[base_offset:.*]] = tt.splat %{{.*}} : i64 +// // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] +// // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] +// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> +// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> +// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> +// // CHECK: %[[offset1:.*]] = arith.addi %[[ext_offset0]], %[[base_offset]] +// %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// // CHECK: cf.cond_br {{.*}}, ^bb1(%{{.*}}, %arg0, %[[base_offset]] : {{.*}}), ^bb2(%{{.*}}, %[[scalarPtr]], %[[offset1]] : {{.*}}) +// cf.cond_br %i1, ^bb1(%5 : tensor<1024x!tt.ptr, #blocked>), ^bb2(%6 : tensor<1024x!tt.ptr, #blocked>) +// // CHECK: ^bb1({{.*}}, %[[block1ScalarPtr:.*]]: !tt.ptr, %[[block1Offset:.*]]: tensor<1024xi64, #blocked>) +// ^bb1(%arg1 : tensor<1024x!tt.ptr, #blocked>): +// // CHECK: %[[trunc_offset_1:.*]] = arith.trunci %[[block1Offset]] +// // CHECK: %[[basePtr1:.*]] = tt.splat %[[block1ScalarPtr]] +// // CHECK: %[[newPtr1:.*]] = tt.addptr %[[basePtr1]], %[[trunc_offset_1]] +// // CHECK: tt.load %[[newPtr1]] +// %out1 = tt.load %arg1 : tensor<1024x!tt.ptr, #blocked> +// tt.return %out1 : tensor<1024xf32, #blocked> +// // CHECK: ^bb2({{.*}}, %[[block2ScalarPtr:.*]]: !tt.ptr, %[[block2Offset:.*]]: tensor<1024xi64, #blocked>) +// ^bb2(%arg2 : tensor<1024x!tt.ptr, #blocked>): // 2 preds: ^bb0, ^bb1 +// // CHECK: %[[trunc_offset_2:.*]] = arith.trunci %[[block2Offset]] +// // CHECK: %[[basePtr2:.*]] = tt.splat %[[block2ScalarPtr]] +// // CHECK: %[[newPtr2:.*]] = tt.addptr %[[basePtr2]], %[[trunc_offset_2]] +// // CHECK: tt.load %[[newPtr2]] +// %out2 = tt.load %arg2 : tensor<1024x!tt.ptr, #blocked> +// tt.return %out2 : tensor<1024xf32, #blocked> +// } +//} +// +//// ----- +// +//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { +// // CHECK-LABEL: tt.func @branch +// tt.func @branch(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ +// %c1024_i32 = arith.constant 1024 : i32 +// %c0 = arith.constant 0: index +// %c128 = arith.constant 128: index +// %c1 = arith.constant 1 : index +// %0 = tt.get_program_id x : i32 +// %1 = arith.muli %0, %c1024_i32 : i32 +// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> +// // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 +// // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> +// // CHECK: %[[base_offset:.*]] = tt.splat %{{.*}} : i64 +// // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] +// // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] +// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> +// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> +// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> +// // CHECK: %[[offset1:.*]] = arith.addi %[[ext_offset0]], %[[base_offset]] +// %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// // CHECK: cf.br ^bb1(%{{.*}}, %[[scalarPtr]], %[[offset1]] : {{.*}}) +// // CHECK: ^bb1({{.*}}, %[[block1ScalarPtr:.*]]: {{.*}}, %[[block1Offset:.*]]: {{.*}}) +// cf.br ^bb1(%6 : tensor<1024x!tt.ptr, #blocked>) +// ^bb1(%arg1 : tensor<1024x!tt.ptr, #blocked>): +// // CHECK: %[[trunc_offset_1:.*]] = arith.trunci %[[block1Offset]] +// // CHECK: %[[basePtr1:.*]] = tt.splat %[[block1ScalarPtr]] +// // CHECK: %[[newPtr1:.*]] = tt.addptr %[[basePtr1]], %[[trunc_offset_1]] +// // CHECK: tt.load %[[newPtr1]] +// %out1 = tt.load %arg1 : tensor<1024x!tt.ptr, #blocked> +// tt.return %out1 : tensor<1024xf32, #blocked> +// } +//} +// +//// ----- +// +//// The following is a simple case of a tile offset like: (A*B + C + D) where B,C are Uniform and A,D are not. So +//// we expect that the Uniform offset (which can be added to the scalar pointer) will be simply C and the NonUniform +//// offset will be A*B+D +//#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +//module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// // CHECK-LABEL: tt.func @tile_offset +// tt.func @tile_offset(%arg1: !tt.ptr, %arg5: i32 , %arg7: i32 ) { +// %c128_i32 = arith.constant 128 : i32 +// %c256_i32 = arith.constant 256 : i32 +// %1 = tt.get_program_id x : i32 +// %20 = arith.muli %1, %c256_i32 : i32 +// %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> +// %24 = tt.splat %20 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> +// %26 = arith.addi %24, %22 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> +// %36 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> +// %37 = tt.expand_dims %36 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> +// %38 = tt.splat %arg7 : i32 -> tensor<16x1xi32, #blocked> +// %39 = arith.muli %37, %38 : tensor<16x1xi32, #blocked> +// %41 = tt.expand_dims %26 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> +// %42 = tt.broadcast %39 : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> +// %43 = tt.broadcast %41 : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> +// %44 = arith.addi %42, %43 : tensor<16x256xi32, #blocked> +// %45 = tt.splat %arg1 : !tt.ptr -> tensor<16x256x!tt.ptr, #blocked> +// %46 = tt.addptr %45, %44 : tensor<16x256x!tt.ptr, #blocked>, tensor<16x256xi32, #blocked> +// // CHECK: %[[uniformOffset1:.*]] = arith.muli %c0_i32_0, %arg2 : i32 +// // CHECK: {{.*}} = tt.expand_dims %{{.*}} {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> +// // CHECK: %[[tensorOffset6:.*]] = tt.expand_dims %{{.*}} {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> +// // CHECK: {{.*}} = tt.broadcast %{{.*}} : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> +// // CHECK: %[[tensorOffset3:.*]] = tt.broadcast %{{.*}} : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> +// // CHECK: %[[tensorOffset4:.*]] = tt.broadcast %{{.*}} : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> +// // CHECK: %[[tensorOffset5:.*]] = tt.broadcast %[[tensorOffset6]] : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> +// // CHECK: %[[uniformOffset:.*]] = arith.addi %[[uniformOffset1]], %{{.*}}: i32 +// // CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset3]], %[[tensorOffset5]] : tensor<16x256xi32, #blocked> +// // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[uniformOffset]] : !tt.ptr, i32 +// // CHECK: %[[tensorOffset2ext:.*]] = arith.extsi %[[tensorOffset2]] : tensor<16x256xi32, #blocked> to tensor<16x256xi64, #blocked> +// // CHECK: %[[tensorOffset1:.*]] = arith.addi %[[tensorOffset2ext]], %{{.*}} : tensor<16x256xi64, #blocked> +// // CHECK: %[[tensorOffset:.*]] = arith.trunci %[[tensorOffset1:.*]] : tensor<16x256xi64, #blocked> to tensor<16x256xi32, #blocked> +// // CHECK: %[[ptr:.*]] = tt.splat %[[scalarPtr]] : !tt.ptr -> tensor<16x256x!tt.ptr, #blocked> +// // CHECK: tt.addptr %[[ptr]], %[[tensorOffset]] : tensor<16x256x!tt.ptr, #block +// %61 = tt.load %46 : tensor<16x256x!tt.ptr, #blocked> +// tt.return +// } +//} +// +//// ----- +// +//// The following is a more complex case where also a multiplication is involved. It's useful to walk through the case. +//// We have that the offset to the pointer is the following: +//// %12 = %10 + 11 +//// This can be transformed in: +//// = %7 + %9 +//// = %5*%6 + %8 +//// = %4*%arg1 + %8 +//// = (%3+%2)*%arg1 + %8 +//// = (%1 + %2) * %arg1 + %8 +//// = (U + N)*U + N +//// Where U means uniform (e.g., a splat) and N means NonUniform (e.g., a make_range) +//// The scalar offset we want is (%1*%arg1), while the variable offset should be (%2*%arg1 + %8) +//#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +//module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +// // CHECK-LABEL: tt.func public @matmul_kernel +// tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}) { +// %c128_i32 = arith.constant 128 : i32 +// %0 = tt.get_program_id x : i32 +// %1 = arith.muli %0, %c128_i32 : i32 +// %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> +// %3 = tt.splat %1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> +// %4 = arith.addi %3, %2 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> +// %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> +// %6 = tt.splat %arg1 : i32 -> tensor<128x1xi32, #blocked> +// %7 = arith.muli %5, %6 : tensor<128x1xi32, #blocked> +// %8 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> +// %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> +// %10 = tt.broadcast %7 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> +// %11 = tt.broadcast %9 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> +// %12 = arith.addi %10, %11 : tensor<128x16xi32, #blocked> +// %13 = tt.splat %arg0 : !tt.ptr -> tensor<128x16x!tt.ptr, #blocked> +// %14 = tt.addptr %13, %12 : tensor<128x16x!tt.ptr, #blocked>, tensor<128x16xi32, #blocked> +// %15 = tt.load %14 : tensor<128x16x!tt.ptr, #blocked> +// // CHECK: %[[pid:.*]] = tt.get_program_id x : i32 +// // CHECK: %[[uniformOffset3:.*]] = arith.muli %[[pid]], %{{.*}} : i32 +// // CHECK: %[[uniformOffset2:.*]] = arith.addi %[[uniformOffset3]], %{{.*}} : i32 +// // CHECK: %[[uniformOffset1:.*]] = arith.muli %[[uniformOffset2]], %arg1 : i32 +// // CHECK: %[[makerange:.*]] = tt.make_range +// // CHECK: %{{.*}} = tt.expand_dims %[[makerange]] {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> +// // CHECK: %[[tensorOffset6:.*]] = tt.expand_dims %[[makerange]] {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> +// // CHECK: %{{.*}} = tt.broadcast %{{.*}} : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> +// // CHECK: %[[tensorOffset3:.*]] = tt.broadcast %{{.*}} : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> +// // CHECK: %{{.*}} = tt.broadcast %{{.*}} : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> +// // CHECK: %[[tensorOffset4:.*]] = tt.broadcast %[[tensorOffset6]] : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> +// // CHECK: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : tensor<128x16xi32, #blocked> +// // CHECK: %[[uniformOffset:.*]] = arith.addi %[[uniformOffset1]], %{{.*}} : i32 +// // CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset3]], %[[tensorOffset4]] : tensor<128x16xi32, #blocked> +// // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[uniformOffset]] : !tt.ptr, i32 +// // CHECK: %[[tensorOffset1Ext:.*]] = arith.extsi %[[tensorOffset2]] : tensor<128x16xi32, #blocked> to tensor<128x16xi64, #blocked> +// // CHECK: %[[tensorOffset:.*]] = arith.addi %[[tensorOffset1Ext]], %{{.*}} : tensor<128x16xi64, #blocked> +// // CHECK: %[[tensorOffsetTrunc:.*]] = arith.trunci %[[tensorOffset]] : tensor<128x16xi64, #blocked> to tensor<128x16xi32, #blocked> +// // CHECK: %[[ptr:.*]] = tt.splat %[[scalarPtr]] : !tt.ptr -> tensor<128x16x!tt.ptr, #blocked> +// // CHECK: tt.addptr %[[ptr]], %[[tensorOffsetTrunc]] : tensor<128x16x!tt.ptr, #blocked>, tensor<128x16xi32, #blocked> +// tt.return +// } +//} +// +// +//// ----- +//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { +// // CHECK-LABEL: tt.func @select +// tt.func @select(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ +// %c1024_i32 = arith.constant 1024 : i32 +// %c0 = arith.constant 0: index +// %c128 = arith.constant 128: index +// %c1 = arith.constant 1 : index +// %0 = tt.get_program_id x : i32 +// %1 = arith.muli %0, %c1024_i32 : i32 +// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> +// // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 +// // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> +// // CHECK: %[[baseOffset:.*]] = tt.splat %{{.*}} : i64 +// // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] +// // CHECK: %[[extVariableOffset:.*]] = arith.extsi %[[variableOffset]] +// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> +// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> +// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> +// %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// // CHECK: %[[offset2:.*]] = arith.addi %[[extVariableOffset]], %[[baseOffset]] +// // CHECK: %[[scalarPtr1:.*]] = arith.select %arg1, %arg0, %[[scalarPtr]] +// // CHECK: %[[offset0:.*]] = arith.select %arg1, {{.*}}, %[[offset2]] +// // CHECK: %[[offset1:.*]] = arith.trunci %[[offset0]] +// // CHECK: %[[ptr:.*]] = tt.splat %[[scalarPtr1]] +// // CHECK: tt.addptr %[[ptr]], %[[offset1]] +// %7 = arith.select %i1, %5 , %6 : tensor<1024x!tt.ptr, #blocked> +// %out = tt.load %7: tensor<1024x!tt.ptr, #blocked> +// tt.return %out : tensor<1024xf32, #blocked> +// } +//} +// +//// ----- +//#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +//module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { +// // CHECK-LABEL: tt.func @where_kernel +// tt.func @where_kernel(%arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}){ +// %c0_i8 = arith.constant 0 : i8 +// %c1024_i32 = arith.constant 1024 : i32 +// %0 = tt.get_program_id x : i32 +// %1 = arith.muli %0, %c1024_i32 : i32 +// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> +// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> +// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> +// %9 = arith.cmpi ne, %c0_i8, %c0_i8 : i8 +// %10 = arith.select %9, %arg1, %arg2 : !tt.ptr +// // CHECK: %[[selectPtr:.*]] = arith.select {{.*}} : !tt.ptr +// %11 = tt.splat %10: !tt.ptr -> tensor<1024x!tt.ptr, #blocked> +// %13 = tt.addptr %11, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// // CHECK: %[[selectPtr0:.*]] = tt.addptr %[[selectPtr]] +// // CHECK: %[[tensorPtr:.*]] = tt.splat %[[selectPtr0]] +// // CHECK: tt.addptr %[[tensorPtr]] +// %14 = tt.load %13 : tensor<1024x!tt.ptr, #blocked> +// tt.return +// } +//} +// +//// ----- +// +//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { +// // CHECK-LABEL: tt.func @forOpWithHints +// tt.func @forOpWithHints(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ +// %c0 = arith.constant 0: index +// %c1 = arith.constant 1 : index +// %c128 = arith.constant 128: index +// %0 = tt.get_program_id x : i32 +// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> +// %3 = tt.splat %0 : i32 -> tensor<1024xi32, #blocked> +// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> +// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> +// %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %6, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ +// %9 = tt.load %arg1: tensor<1024x!tt.ptr, #blocked> +// // CHECK: tt.addptr {{.*}}, {{.*}} {tt.divisibility = dense<16> : tensor<1xi32>} +// %11 = tt.addptr %arg1, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// %12 = tt.addptr %11, %3 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// %10 = arith.addf %9, %arg2 : tensor<1024xf32, #blocked> +// scf.yield %12, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> +// } {"tt.divisibility_arg1"=dense<[16]> : tensor<1xi32>} +// // CHECK: tt.divisibility_arg1 +// // CHECK-SAME: tt.divisibility_arg4 +// %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> +// %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> +// tt.return %11 : tensor<1024xf32, #blocked> +// } +//} +// +//// ----- +// +//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +//module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { +// // CHECK-LABEL: scalar_pointers +// tt.func public @scalar_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// %0 = tt.get_program_id x : i32 +// %c1_i32 = arith.constant 1 : i32 +// %c0_i64 = arith.constant 0 : i64 +// %c10_i64 = arith.constant 10 : i64 +// %c100_i32 = arith.constant 100 : i32 +// %5 = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 +// // CHECK: arith.constant 0 : i64 +// // CHECK: arith.constant 0 : i64 +// // CHECK: %[[offset0:.*]] = arith.constant 0 : i64 +// // CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 +// // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[ptr1:.*]] = %[[ptr0]], %[[offset1:.*]] = %[[offset0]]) +// %10:1 = scf.for %arg3 = %c1_i32 to %c100_i32 step %c1_i32 iter_args(%arg4 = %5) -> (!tt.ptr) : i32 { +// // CHECK: tt.store %[[ptr1]] +// tt.store %arg4, %c0_i64 : !tt.ptr +// // CHECK: tt.addptr %[[ptr1]] +// %11 = tt.addptr %arg4, %c1_i32 : !tt.ptr, i32 +// scf.yield %11 : !tt.ptr +// } +// tt.return +// } +//} +// +//// ----- +// +//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +//module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { +// // CHECK-LABEL: @scalar_if +// tt.func @scalar_if(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)->f32{ +// %0 = tt.get_program_id x : i32 +// %c1_i32 = arith.constant 1 : i32 +// %c0_i64 = arith.constant 0 : i64 +// %c10_i64 = arith.constant 10 : i64 +// %c100_i32 = arith.constant 100 : i32 +// %5 = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 +// // CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %{{.*}} +// // CHECK: scf.if {{.*}} -> ({{.*}}, !tt.ptr, i64) +// %6 = scf.if %cond -> (!tt.ptr){ +// %true = tt.addptr %5, %c1_i32 : !tt.ptr, i32 +// // CHECK: %[[ptr1:.*]] = tt.addptr %[[ptr0]] +// // CHECK: scf.yield {{.*}}, %[[ptr1]] +// scf.yield %true : !tt.ptr +// } else { +// %false = tt.addptr %5, %c100_i32 : !tt.ptr, i32 +// // CHECK: %[[ptr2:.*]] = tt.addptr %[[ptr0]] +// // CHECK: scf.yield {{.*}}, %[[ptr2]] +// scf.yield %false : !tt.ptr +// } +// %11 = tt.load %6 : !tt.ptr +// tt.return %11 : f32 +// } +//} +// +//// ----- +// +//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { +// // CHECK-LABEL: tt.func @scalar_while +// tt.func @scalar_while(%arg0: !tt.ptr, %init : f32, %cond : i1)->f32{ +// %c1024_i32 = arith.constant 1024 : i32 +// %c0 = arith.constant 0: index +// %c128 = arith.constant 128: index +// %c1 = arith.constant 1 : index +// %0 = tt.get_program_id x : i32 +// %1 = arith.muli %0, %c1024_i32 : i32 +// // CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %{{.*}} +// // CHECK: scf.while ({{.*}}, {{.*}} = %arg2, %[[ptr1:.*]] = %[[ptr0]], {{.*}}) +// %2 = tt.addptr %arg0, %0 : !tt.ptr, i32 +// %6 = scf.while (%arg1 = %2, %arg2 = %cond) : (!tt.ptr, i1) -> (!tt.ptr) { +// // CHECK: scf.condition({{.*}}) {{.*}}, %[[ptr1]] +// scf.condition(%arg2) %arg1 : !tt.ptr +// } do { +// // CHECK: ^bb0({{.*}}: !tt.ptr, %[[ptr2:.*]]: !tt.ptr, {{.*}}) +// // CHECK: scf.yield %{{.*}}, {{.*}} %[[ptr2]], {{.*}}, {{.*}} +// ^bb0(%arg1: !tt.ptr): +// scf.yield %arg1, %cond : !tt.ptr, i1 +// } +// %11 = tt.load %6 : !tt.ptr +// tt.return %11 : f32 +// } +//} +// +//// ----- +// +//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { +// // CHECK-LABEL: tt.func @scalar_cond_branch +// tt.func @scalar_cond_branch(%arg0 : !tt.ptr, %i1 : i1) -> f32{ +// %c1024_i32 = arith.constant 1024 : i32 +// %c0 = arith.constant 0: index +// %c128 = arith.constant 128: index +// %c1 = arith.constant 1 : index +// %0 = tt.get_program_id x : i32 +// %1 = arith.muli %0, %c1024_i32 : i32 +// %6 = tt.addptr %arg0, %0 : !tt.ptr, i32 +// // CHECK: %[[ptr0:.*]] = tt.addptr %arg0 +// // CHECK: cf.cond_br %arg1, ^bb1(%{{.*}}, %[[ptr0]], {{.*}}), ^bb2(%{{.*}}, %arg0, {{.*}}) +// cf.cond_br %i1, ^bb1(%6 : !tt.ptr), ^bb2(%arg0 : !tt.ptr) +// // CHECK: ^bb1({{.*}}, %[[ptr1:.*]]: !tt.ptr, {{.*}}): +// ^bb1(%arg1 : !tt.ptr): +// // CHECK: tt.load %[[ptr1]] +// %out1 = tt.load %arg1 : !tt.ptr +// tt.return %out1 : f32 +// // CHECK: ^bb2({{.*}}, %[[ptr2:.*]]: !tt.ptr, {{.*}}): +// ^bb2(%arg2 : !tt.ptr): // 2 preds: ^bb0, ^bb1 +// // CHECK: tt.load %[[ptr2]] +// %out2 = tt.load %arg2 : !tt.ptr +// tt.return %out2 : f32 +// } +//} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index f798edca1676..8b83231e85f6 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -1,5 +1,7 @@ +#include "TritonAMDGPUTransforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" @@ -12,7 +14,8 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" -#include "mlir/Support/LogicalResult.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/OneToNTypeConversion.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" @@ -26,9 +29,6 @@ #include "llvm/Support/LogicalResult.h" #include -#include "TritonAMDGPUTransforms/Passes.h" -#include "mlir/Pass/Pass.h" - #define GEN_PASS_CLASSES #include "TritonAMDGPUTransforms/Passes.h.inc" @@ -92,7 +92,6 @@ class PointerCanonicalizer { // Propagate fat pointers in all the functions of the module LogicalResult run(); -private: // A fat pointer is represented as `basePtr + offset` internally. struct FatPtr { // Scalar base pointer. Needs to be `tt.splat`ed before used @@ -102,7 +101,7 @@ class PointerCanonicalizer { // Flag to express if we can narrow the uses of the offset down to 32 bits bool canNarrow = false; // Collection of attributes that need to be applied to the pointer - SmallVector attributes; + llvm::SmallDenseMap attributes{}; // Utility copy functions FatPtr copy(Value newBasePtr, Value newOffset) { @@ -115,9 +114,15 @@ class PointerCanonicalizer { return FatPtr{newBase, offset, canNarrow}; } // Attribute functions - void setAttr(NamedAttribute attr) { attributes.push_back(attr); } + void setAttr(StringAttr name, Attribute value) { + attributes.insert({name, value}); + } + void setAttr(NamedAttribute attr) { + attributes.insert({attr.getName(), attr.getValue()}); + } void setAttrs(ArrayRef attrs) { - llvm::append_range(attributes, attrs); + for (auto attr : attrs) + attributes.insert({attr.getName(), attr.getValue()}); } }; @@ -130,7 +135,6 @@ class PointerCanonicalizer { // Create a tensor pointer from a fat pointer `fatPtr`. The tensor pointer is // obtained by splatting the `fatPtr.basePtr` using the `fatPtr.offset` shape // and adding the offset to it. - Value createTensorPointer(FatPtr fatPtr, Location loc); // Push the attributes of the given operation `op` to the fat pointer // corresponding to `val` @@ -181,12 +185,12 @@ class PointerCanonicalizer { // // The function returns the two components of the given offset as a // std::pair{U, NU} - std::pair decomposeOffsetFromExpr(Location loc, Value expr, - int64_t bitness); - std::pair decomposeOffsetFromAdd(Location loc, Value expr, - int64_t bitness); - std::pair decomposeOffsetFromMul(Location loc, Value expr, - int64_t bitness); + // std::pair decomposeOffsetFromExpr(Location loc, Value expr, + // int64_t bitness); + // std::pair decomposeOffsetFromAdd(Location loc, Value expr, + // int64_t bitness); + // std::pair decomposeOffsetFromMul(Location loc, Value expr, + // int64_t bitness); // Return either the operation or its rewritten op template @@ -224,10 +228,8 @@ class PointerCanonicalizer { namespace { // Extend a 32bit `offset` into 64bit using a arith.extsi operation -static Value extend32bitOffsetTo64Bits(IRRewriter &rewriter, Location loc, - Value offset) { - Type elementType = getElementTypeOrSelf(offset); - +static Value createExtend32bitOffsetTo64Bits(RewriterBase &rewriter, + Location loc, Value offset) { if (auto tensorType = dyn_cast(offset.getType())) { auto shape = tensorType.getShape(); auto newTensorType = RankedTensorType::get(shape, rewriter.getI64Type(), @@ -238,8 +240,8 @@ static Value extend32bitOffsetTo64Bits(IRRewriter &rewriter, Location loc, } // Narrow a 64bit `offset` into 32bit using a arith.trunci operation -static Value narrow64bitOffsetTo32bits(IRRewriter &rewriter, Location loc, - Value offset) { +static Value createNarrow64bitOffsetTo32bits(RewriterBase &rewriter, + Location loc, Value offset) { Type elementType = getElementTypeOrSelf(offset); if (elementType.isInteger(32)) return offset; @@ -255,7 +257,7 @@ static Value narrow64bitOffsetTo32bits(IRRewriter &rewriter, Location loc, // Helper function to determine if the given `op` is a constant tensor and in // that case return the scalar value. -Value getScalarConstant(IRRewriter &rewriter, Location loc, Value expr) { +Value getScalarConstant(RewriterBase &rewriter, Location loc, Value expr) { Operation *op = expr.getDefiningOp(); // Check for splatness @@ -291,7 +293,7 @@ bool canNarrowOffset(Value baseOffset, Value addOffset) { } // Create a zero tensor with a given `type` -Value createTensorZero(IRRewriter &rw, Location loc, RankedTensorType type) { +Value createTensorZero(RewriterBase &rw, Location loc, RankedTensorType type) { mlir::Attribute zeroAttr = rw.getZeroAttr(type.getElementType()); auto zeroDenseAttr = DenseElementsAttr::get(type, zeroAttr); return rw.create(loc, zeroDenseAttr); @@ -339,16 +341,19 @@ void PointerCanonicalizer::collectFatPointerAttributes(Operation *op, pointers[val].setAttr(attr); } +std::pair decomposeOffsetFromExpr(RewriterBase &rewriter, + Location loc, Value expr, + int64_t bitness); // Offset extraction logic for an addition op: // decompose(A+B) = {U(A)+U(B), NU(A)+NU(B)} -std::pair -PointerCanonicalizer::decomposeOffsetFromAdd(Location loc, Value expr, - int64_t bitness) { +std::pair decomposeOffsetFromAdd(RewriterBase &rewriter, + Location loc, Value expr, + int64_t bitness) { auto addOp = expr.getDefiningOp(); auto [uniformOffsetL, nonUniformOffsetL] = - decomposeOffsetFromExpr(loc, addOp.getLhs(), bitness); + decomposeOffsetFromExpr(rewriter, loc, addOp.getLhs(), bitness); auto [uniformOffsetR, nonUniformOffsetR] = - decomposeOffsetFromExpr(loc, addOp.getRhs(), bitness); + decomposeOffsetFromExpr(rewriter, loc, addOp.getRhs(), bitness); Value uniformAdd = rewriter.create(loc, uniformOffsetL, uniformOffsetR); Value nonUniformAdd = @@ -358,14 +363,14 @@ PointerCanonicalizer::decomposeOffsetFromAdd(Location loc, Value expr, // Offset extraction logic for a multiplication op: // decompose(A*B) = {U(A)*U(B), NU(A)*NU(B)+NU(B)*U(A)+U(A)*NU(B)} -std::pair -PointerCanonicalizer::decomposeOffsetFromMul(Location loc, Value expr, - int64_t bitness) { +std::pair decomposeOffsetFromMul(RewriterBase &rewriter, + Location loc, Value expr, + int64_t bitness) { auto mulOp = expr.getDefiningOp(); auto [uniformOffsetL, nonUniformOffsetL] = - decomposeOffsetFromExpr(loc, mulOp.getLhs(), bitness); + decomposeOffsetFromExpr(rewriter, loc, mulOp.getLhs(), bitness); auto [uniformOffsetR, nonUniformOffsetR] = - decomposeOffsetFromExpr(loc, mulOp.getRhs(), bitness); + decomposeOffsetFromExpr(rewriter, loc, mulOp.getRhs(), bitness); Value uniformMul = rewriter.create(loc, uniformOffsetL, uniformOffsetR); @@ -386,12 +391,12 @@ PointerCanonicalizer::decomposeOffsetFromMul(Location loc, Value expr, return {uniformMul, nonUniformMul}; } -std::pair -PointerCanonicalizer::decomposeOffsetFromExpr(Location loc, Value expr, - int64_t bitness) { +std::pair decomposeOffsetFromExpr(RewriterBase &rewriter, + Location loc, Value expr, + int64_t bitness) { - RewriterBase::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfterValue(expr); + // RewriterBase::InsertionGuard guard(rewriter); + // rewriter.setInsertionPointAfterValue(expr); // Base case 1: it is a splat. Return the scalar constant as the uniform part if (Value scalarConst = getScalarConstant(rewriter, loc, expr)) { @@ -412,24 +417,24 @@ PointerCanonicalizer::decomposeOffsetFromExpr(Location loc, Value expr, llvm::TypeSwitch>( expr.getDefiningOp()) .Case([&](auto broadcastOp) { - auto [uniform, nonUniform] = - decomposeOffsetFromExpr(loc, broadcastOp.getSrc(), bitness); + auto [uniform, nonUniform] = decomposeOffsetFromExpr( + rewriter, loc, broadcastOp.getSrc(), bitness); auto broadcastNonUniform = rewriter.create( loc, broadcastOp.getType(), nonUniform); return std::make_pair(uniform, broadcastNonUniform); }) .Case([&](auto expandOp) { - auto [uniform, nonUniform] = - decomposeOffsetFromExpr(loc, expandOp.getSrc(), bitness); + auto [uniform, nonUniform] = decomposeOffsetFromExpr( + rewriter, loc, expandOp.getSrc(), bitness); auto expandNonUniform = rewriter.create( loc, nonUniform, expandOp.getAxis()); return std::make_pair(uniform, expandNonUniform); }) .Case([&](Operation *op) { - return decomposeOffsetFromAdd(loc, expr, bitness); + return decomposeOffsetFromAdd(rewriter, loc, expr, bitness); }) .Case([&](Operation *op) { - return decomposeOffsetFromMul(loc, expr, bitness); + return decomposeOffsetFromMul(rewriter, loc, expr, bitness); }) .Default([&](Operation *op) { // Base case 3: it is not a supported operation. We assume no @@ -442,16 +447,18 @@ PointerCanonicalizer::decomposeOffsetFromExpr(Location loc, Value expr, return offsets; } -Value PointerCanonicalizer::createTensorPointer(FatPtr fatPtr, Location loc) { - Value basePtr = fatPtr.basePtr; - Value offset = fatPtr.offset; +Value createTensorPointer( + RewriterBase &rewriter, Value basePtr, Value offset, Location loc, + bool canNarrow, + const llvm::SmallDenseMap &attributes) { auto tensorType = dyn_cast(offset.getType()); // Scalar case: we only need to `tt.addptr %basePtr, %offset` if (!tensorType) { auto addPtrOp = rewriter.create(loc, basePtr.getType(), basePtr, offset); - addPtrOp->setAttrs(fatPtr.attributes); + for (auto attribute : attributes) + addPtrOp->setAttr(attribute.getFirst(), attribute.getSecond()); return addPtrOp.getResult(); } @@ -463,8 +470,8 @@ Value PointerCanonicalizer::createTensorPointer(FatPtr fatPtr, Location loc) { ArrayRef offsetShape = tensorType.getShape(); auto tensorPtrType = RankedTensorType::get(offsetShape, basePtr.getType(), tensorType.getEncoding()); - if (fatPtr.canNarrow) - offset = narrow64bitOffsetTo32bits(rewriter, loc, offset); + if (canNarrow) + offset = createNarrow64bitOffsetTo32bits(rewriter, loc, offset); Value tensorPtr = rewriter.create(loc, tensorPtrType, basePtr); @@ -472,7 +479,8 @@ Value PointerCanonicalizer::createTensorPointer(FatPtr fatPtr, Location loc) { auto addPtrOp = rewriter.create(loc, tensorPtrType, tensorPtr, offset); - addPtrOp->setAttrs(fatPtr.attributes); + for (auto attribute : attributes) + addPtrOp->setAttr(attribute.getFirst(), attribute.getSecond()); return addPtrOp.getResult(); } @@ -487,7 +495,8 @@ LogicalResult PointerCanonicalizer::materializeFatPointer(Operation *op, // Create the tensor pointer (i.e., splat the base && add the offset) Value newPtr = basePtr; if (isa(ptr.getType())) - newPtr = createTensorPointer(fatPtr, loc); + newPtr = createTensorPointer(rewriter, fatPtr.basePtr, fatPtr.offset, loc, + fatPtr.canNarrow, fatPtr.attributes); // Save the fat pointer in the table pointers[newPtr] = fatPtr; @@ -561,7 +570,8 @@ LogicalResult PointerCanonicalizer::rewriteAddPtrOp(triton::AddPtrOp addPtrOp, pointers[nextPtr] = fatPtr.copyWithOffset(newPtr); // If we are updating the tensor pointer with a uniform value, we can // propagate the attributes of the tensor pointer to the fat pointer. - pointers[nextPtr].setAttrs(fatPtr.attributes); + for (auto attribute : fatPtr.attributes) + pointers[nextPtr].setAttr(attribute.getFirst(), attribute.getSecond()); opToDelete.insert(addPtrOp); return success(); } @@ -569,7 +579,7 @@ LogicalResult PointerCanonicalizer::rewriteAddPtrOp(triton::AddPtrOp addPtrOp, int64_t bitness = cast(offset.getType()).getElementTypeBitWidth(); auto [uniformOffset, nonUniformOffset] = - decomposeOffsetFromExpr(curLoc, offset, bitness); + decomposeOffsetFromExpr(rewriter, curLoc, offset, bitness); // Scalar pointer update: bump the scalar pointer newPtr = rewriter.create(curLoc, newPtr.getType(), newPtr, @@ -588,10 +598,10 @@ LogicalResult PointerCanonicalizer::rewriteAddPtrOp(triton::AddPtrOp addPtrOp, // Upcast or downcast the offset accordingly if (addPtrOffsetType.isInteger(32) && fatPtrOffsetType.isInteger(64)) nonUniformOffset = - extend32bitOffsetTo64Bits(rewriter, curLoc, nonUniformOffset); + createExtend32bitOffsetTo64Bits(rewriter, curLoc, nonUniformOffset); else if (addPtrOffsetType.isInteger(64) && fatPtrOffsetType.isInteger(32)) nonUniformOffset = - narrow64bitOffsetTo32bits(rewriter, curLoc, nonUniformOffset); + createNarrow64bitOffsetTo32bits(rewriter, curLoc, nonUniformOffset); newOffset = rewriter.create(curLoc, nonUniformOffset, fatPtrOffset); @@ -603,7 +613,8 @@ LogicalResult PointerCanonicalizer::rewriteAddPtrOp(triton::AddPtrOp addPtrOp, // If we are updating the tensor pointer with a uniform value, we can // propagate the attributes of the tensor pointer to the fat pointer. if (propagateAtrs) - pointers[nextPtr].setAttrs(fatPtr.attributes); + for (auto attribute : fatPtr.attributes) + pointers[nextPtr].setAttr(attribute.getFirst(), attribute.getSecond()); return success(); } @@ -611,71 +622,30 @@ LogicalResult PointerCanonicalizer::rewriteForOp(scf::ForOp forOp, Location curLoc, OpOperand *curOperand, Value &nextPtr) { + size_t operandNum = curOperand->getOperandNumber(); FatPtr fatPtr = pointers[curOperand->get()]; Value offset = fatPtr.offset; Value basePtr = fatPtr.basePtr; - size_t operandNum = curOperand->getOperandNumber(); - auto numLeadingInitsIncludingCurOp = - operandNum + 1 - forOp.getNumControlOperands(); - - SmallVector newInitArgs( - forOp.getInitArgs().take_front(numLeadingInitsIncludingCurOp)); - newInitArgs.append({basePtr, offset}); - auto trailingInits = - forOp.getInitArgs().drop_front(numLeadingInitsIncludingCurOp); - newInitArgs.append(trailingInits.begin(), trailingInits.end()); - - scf::ForOp newForOp; - { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(forOp); - newForOp = rewriter.create( - forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newInitArgs); - newForOp->setAttrs(forOp->getAttrs()); - newForOp.getBody()->erase(); - newForOp.getRegion().getBlocks().splice( - newForOp.getRegion().getBlocks().begin(), - forOp.getRegion().getBlocks()); - - newForOp.getBody()->insertArgument(numLeadingInitsIncludingCurOp + 1, - offset.getType(), offset.getLoc()); - newForOp.getBody()->insertArgument(numLeadingInitsIncludingCurOp + 1, - basePtr.getType(), basePtr.getLoc()); - - for (auto [src, tgt] : llvm::zip( - forOp.getResults().take_front(numLeadingInitsIncludingCurOp), - newForOp.getResults().take_front(numLeadingInitsIncludingCurOp))) - rewriter.replaceAllUsesWith(src, tgt); - for (auto [src, tgt] : - llvm::zip(forOp.getResults().drop_front(numLeadingInitsIncludingCurOp), - newForOp.getResults().drop_front( - numLeadingInitsIncludingCurOp + 2))) - rewriter.replaceAllUsesWith(src, tgt); - - newForOp.getBody()->getTerminator()->insertOperands( - numLeadingInitsIncludingCurOp, - newForOp.getRegionIterArg(numLeadingInitsIncludingCurOp)); - newForOp.getBody()->getTerminator()->insertOperands( - numLeadingInitsIncludingCurOp + 1, - newForOp.getRegionIterArg(numLeadingInitsIncludingCurOp + 1)); - } - + // Replace the forOp with two additional argument (i.e., the curOperand's + // scalar pointer and the offset) + Value tensorPtr = + createTensorPointer(rewriter, fatPtr.basePtr, fatPtr.offset, curLoc, + fatPtr.canNarrow, fatPtr.attributes); + auto newForOp = + replaceForOpWithNewSignature(rewriter, forOp, {basePtr, offset}); rewriteOpMap[forOp] = newForOp; - { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(newForOp); - Value tensorPtr = createTensorPointer(fatPtr, curLoc); - newForOp->setOperand(operandNum, tensorPtr); - } - + newForOp->setOperand(operandNum, tensorPtr); OpOperand *forOperand = &newForOp->getOpOperand(operandNum); // This is making sure we propagate the visit from the forOp result + nextPtr = newForOp.getTiedLoopResult(forOperand); + // This is making sure we visit the uses within the forOp region Value arg = newForOp.getTiedLoopRegionIterArg(forOperand); - pointers[arg] = fatPtr.copy(basePtr, offset); + size_t numIterArgs = newForOp.getNumRegionIterArgs(); + pointers[arg] = fatPtr.copy(newForOp.getRegionIterArg(numIterArgs - 2), + newForOp.getRegionIterArg(numIterArgs - 1)); // Collect attributes before continuing the visit collectFatPointerAttributes(newForOp, arg); @@ -685,9 +655,9 @@ LogicalResult PointerCanonicalizer::rewriteForOp(scf::ForOp forOp, // This is setting the fat pointer for the users of the loop // and then propagate the result - nextPtr = newForOp.getTiedLoopResult(forOperand); - pointers[nextPtr] = fatPtr.copy(basePtr, offset); - + size_t numResults = newForOp->getNumResults(); + pointers[nextPtr] = fatPtr.copy(newForOp->getResult(numResults - 2), + newForOp.getResult(numResults - 1)); opToDelete.insert(forOp); return success(); } @@ -702,15 +672,18 @@ LogicalResult PointerCanonicalizer::rewriteYieldOp(scf::YieldOp yieldOp, // IfOp size_t operandNum = curOperand->getOperandNumber(); FatPtr fatPtr = pointers[curOperand->get()]; + yieldOp.getResultsMutable().append(fatPtr.basePtr); + yieldOp.getResultsMutable().append(fatPtr.offset); + if (auto forOp = dyn_cast(yieldOp->getParentOp())) { yieldOp->setOperand(operandNum, forOp.getRegionIterArg(operandNum)); } else if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { - yieldOp.getResultsMutable().append(fatPtr.basePtr); - yieldOp.getResultsMutable().append(fatPtr.offset); // Case 1: the yieldOp is contained within an IfOp. One of the // two branches is responsible to rewrite the operation. The other // branch only update the yieldOp with the right parameters - Value tensorPtr = createTensorPointer(fatPtr, curLoc); + Value tensorPtr = + createTensorPointer(rewriter, fatPtr.basePtr, fatPtr.offset, curLoc, + fatPtr.canNarrow, fatPtr.attributes); yieldOp->setOperand(operandNum, tensorPtr); if (yieldOp->getBlock() == &ifOp.getThenRegion().front()) { @@ -725,8 +698,6 @@ LogicalResult PointerCanonicalizer::rewriteYieldOp(scf::YieldOp yieldOp, } else if (auto whileOp = resolveOp(yieldOp->getParentOp(), rewriteOpMap)) { - yieldOp.getResultsMutable().append(fatPtr.basePtr); - yieldOp.getResultsMutable().append(fatPtr.offset); // Case 2: the yieldOp is contained within the AfterRegion of a // WhileOp. In this case, we know that the before region should have // already been replaced (when we met the WhileOp), hence we can @@ -758,7 +729,9 @@ LogicalResult PointerCanonicalizer::rewriteWhileOp(scf::WhileOp whileOp, Value basePtr = fatPtr.basePtr; // Rewrite the while op with a new set of operands (but with the same // set of return types) - Value tensorPtr = createTensorPointer(fatPtr, curLoc); + Value tensorPtr = + createTensorPointer(rewriter, fatPtr.basePtr, fatPtr.offset, curLoc, + fatPtr.canNarrow, fatPtr.attributes); auto newWhileOp = replaceWhileOpWithNewSignature(rewriter, whileOp, {basePtr, offset}, {}); newWhileOp->setOperand(operandNum, tensorPtr); @@ -873,7 +846,9 @@ LogicalResult PointerCanonicalizer::rewriteCondBranchOp( // Create a new condBranch. We cannot simply extend the operands, // because this would invalidate other operands pointing at the same // cond branch - Value tensorPtr = createTensorPointer(fatPtr, curLoc); + Value tensorPtr = + createTensorPointer(rewriter, fatPtr.basePtr, fatPtr.offset, curLoc, + fatPtr.canNarrow, fatPtr.attributes); auto newCondBranch = rewriter.create( curLoc, condBrOp.getCondition(), trueDest, trueOperands, falseDest, falseOperands); @@ -930,7 +905,9 @@ LogicalResult PointerCanonicalizer::rewriteBranchOp(cf::BranchOp branchOp, Value offset = fatPtr.offset; Value basePtr = fatPtr.basePtr; branchOp.getDestOperandsMutable().append({basePtr, fatPtr.offset}); - Value tensorPtr = createTensorPointer(fatPtr, curLoc); + Value tensorPtr = + createTensorPointer(rewriter, fatPtr.basePtr, fatPtr.offset, curLoc, + fatPtr.canNarrow, fatPtr.attributes); branchOp->setOperand(operandNum, tensorPtr); Block *dest = branchOp.getDest(); @@ -1068,13 +1045,282 @@ class TritonAMDGPUCanonicalizePointersPass public: TritonAMDGPUCanonicalizePointersPass() = default; - void runOnOperation() override { - ModuleOp m = getOperation(); - if (failed(PointerCanonicalizer(m).run())) - signalPassFailure(); + void runOnOperation() override; + void runOnOperationmine(); +}; + +class ConvertAddPtrOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AddPtrOp addPtrOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(addPtrOp); + + ArrayRef remappedOperands = adaptor.getOperands(); + assert(remappedOperands.size() == 2 && remappedOperands[0].size() == 2 && + "expected adaptor to have 2,1 remapped values"); + Value fatPtrBase = remappedOperands[0][0]; + Value fatPtrOffset = remappedOperands[0][1]; + Value origOffset = remappedOperands[1][0]; + auto curLoc = addPtrOp.getLoc(); + + // If it is a scalar pointer update, simply bump the base pointer + if (!isa(addPtrOp.getPtr().getType())) { + auto newAddPtrOp = rewriter.create( + curLoc, TypeRange{fatPtrBase.getType()}, + ValueRange{fatPtrBase, origOffset}, + llvm::ArrayRef{ + rewriter.getNamedAttr("legal", rewriter.getUnitAttr())}); + rewriter.modifyOpInPlace(addPtrOp, [&] { + addPtrOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); + }); + rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, fatPtrOffset}}); + return success(); + } + + // Early exit for the case of a constant tensor + if (Value scalarConst = getScalarConstant(rewriter, curLoc, origOffset)) { + auto newAddPtrOp = rewriter.create( + curLoc, TypeRange{fatPtrBase.getType()}, + ValueRange{fatPtrBase, scalarConst}, + llvm::ArrayRef{ + rewriter.getNamedAttr("legal", rewriter.getUnitAttr())}); + rewriter.modifyOpInPlace(addPtrOp, [&] { + addPtrOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); + }); + rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, fatPtrOffset}}); + // If we are updating the tensor pointer with a uniform value, we can + // propagate the attributes of the tensor pointer to the fat pointer. + // TODO(max): re-enable + // for (auto attribute : fatPtr.attributes) + // pointers[nextPtr].setAttr(attribute.getFirst(), attribute.getSecond()); + // opToDelete.insert(addPtrOp); + return success(); + } + + int64_t bitness = + cast(origOffset.getType()).getElementTypeBitWidth(); + auto [uniformOffset, nonUniformOffset] = + decomposeOffsetFromExpr(rewriter, curLoc, origOffset, bitness); + + // Vector offset update (if any): bump the tensor offset + // TODO(max): stash somewhere + bool canNarrow = false; + bool propagateAtrs = false; + + Value newOffset = fatPtrOffset; + if (!isZeroConst(nonUniformOffset)) { + Type addPtrOffsetType = getElementTypeOrSelf(nonUniformOffset); + Type fatPtrOffsetType = getElementTypeOrSelf(fatPtrOffset); + canNarrow = canNarrow && canNarrowOffset(fatPtrOffset, nonUniformOffset); + + // Upcast or downcast the offset accordingly + if (addPtrOffsetType.isInteger(32) && fatPtrOffsetType.isInteger(64)) + nonUniformOffset = + createExtend32bitOffsetTo64Bits(rewriter, curLoc, nonUniformOffset); + else if (addPtrOffsetType.isInteger(64) && fatPtrOffsetType.isInteger(32)) + nonUniformOffset = + createNarrow64bitOffsetTo32bits(rewriter, curLoc, nonUniformOffset); + + newOffset = rewriter.create(curLoc, nonUniformOffset, + fatPtrOffset); + propagateAtrs = false; + } + + // Scalar pointer update: bump the scalar pointer + auto newAddPtrOp = rewriter.create( + curLoc, TypeRange{fatPtrBase.getType()}, + ValueRange{fatPtrBase, uniformOffset}, + llvm::ArrayRef{rewriter.getNamedAttr("legal", rewriter.getUnitAttr())}); + rewriter.modifyOpInPlace(addPtrOp, [&] { + addPtrOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); + }); + rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, newOffset}}); + // + // // If we are updating the tensor pointer with a uniform value, we can + // // propagate the attributes of the tensor pointer to the fat pointer. + // TODO(max): re-enable + // if (propagateAtrs) + // for (auto attribute : fatPtr.attributes) + // pointers[nextPtr].setAttr(attribute.getFirst(), + // attribute.getSecond()); + + return success(); + } +}; + +class ConvertSplatOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::SplatOp splatOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ArrayRef remappedOperands = adaptor.getOperands(); + // see + // https://github.com/llvm/llvm-project/blob/58389b220a9354ed6c34bdb9310a35165579c5e3/mlir/lib/Transforms/Utils/DialectConversion.cpp#L1177 + assert(remappedOperands.size() == 1 && remappedOperands[0].size() == 2 && + "expected adaptor to have 2 remapped values"); + Value fatPtrBase = remappedOperands[0][0]; + Value fatPtrOffset = remappedOperands[0][1]; + assert(llvm::isa(fatPtrBase.getType()) && + "expected fatPtrBase to be a tt.ptr"); + assert(llvm::isa(fatPtrOffset.getType()) && + "expected fatPtrOffset to be an integer type"); + + auto outType = splatOp.getResult().getType(); + auto ptrShape = outType.getShape(); + auto newOffsetType = RankedTensorType::get(ptrShape, fatPtrOffset.getType(), + outType.getEncoding()); + Value offset = rewriter.create( + splatOp.getLoc(), newOffsetType, fatPtrOffset); + rewriter.modifyOpInPlace(splatOp, [&] { + splatOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); + }); + rewriter.replaceOpWithMultiple(splatOp, {{splatOp.getSrc(), offset}}); + + return success(); + } +}; + +class ConvertLoadOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::LoadOp loadOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto fatPtr = *adaptor.getOperands().begin(); + Value fatPtrBase = fatPtr.front(); + Value fatPtrOffset = fatPtr.back(); + Location curLoc = loadOp.getLoc(); + + llvm::SmallDenseMap attributes{}; + auto newPtr = + createTensorPointer(rewriter, fatPtrBase, fatPtrOffset, curLoc, + // TODO(max): + /*canNarrow*/ true, attributes); + SmallVector operands = + loadOp.getOperands().take_back(loadOp.getNumOperands() - 1); + operands.insert(operands.begin(), newPtr); + auto newLoadPtrOp = rewriter.replaceOpWithNewOp( + loadOp, operands, loadOp->getAttrs()); + rewriter.modifyOpInPlace(newLoadPtrOp, [&] { + newLoadPtrOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); + }); + return success(); + } +}; + +class ConvertFuncOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + int64_t bitness = 64; + rewriter.setInsertionPointToStart(&funcOp.getBody().front()); + rewriter.modifyOpInPlace(funcOp, [&] { + for (auto [idx, arg] : llvm::enumerate(funcOp.getArguments())) { + // The pointer argument needs to be a scalar + if (!isa(arg.getType())) + continue; + if (auto pointerRangeAttr = + funcOp.getArgAttrOfType(idx, "tt.pointer_range")) + bitness = pointerRangeAttr.getInt(); + Value zeroOffset = + rewriter.create(funcOp.getLoc(), 0, bitness); + auto dummyCast = rewriter.create( + arg.getLoc(), TypeRange{arg.getType()}, ValueRange{arg}); + rewriter.replaceUsesOfBlockArgument(arg, dummyCast.getResult(0)); + rewriter.replaceOpWithMultiple(dummyCast, {{arg, zeroOffset}}); + } + funcOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); + }); + + return success(); + } +}; + +class ConvertUnrealizedConversionCastOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp castOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(castOp->hasOneUse() && "expected at least 1 use of unrealized_cast"); + ArrayRef remappedOperands = adaptor.getOperands(); + assert(remappedOperands.size() == 1 && remappedOperands[0].size() == 2 && + "expected adaptor to have 2 remapped values"); + Value fatPtrBase = remappedOperands[0][0]; + Value fatPtrOffset = remappedOperands[0][1]; + assert(llvm::isa(fatPtrBase.getType()) && + "expected fatPtrBase to be a tt.ptr"); + assert(llvm::isa(fatPtrOffset.getType()) && + "expected fatPtrOffset to be an integer type"); + OpFoldResult maybeScalar = getAsOpFoldResult(fatPtrOffset); + if (auto attr = llvm::dyn_cast(maybeScalar)) { + auto integerAttr = llvm::cast(attr); + if (integerAttr.getValue() == 0) { + rewriter.replaceAllUsesWith(castOp.getResult(0), fatPtrBase); + rewriter.eraseOp(castOp); + return success(); + } + } + return failure(); } }; +void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { + ModuleOp module = getOperation(); + auto *context = &getContext(); + ConversionTarget target(*context); + RewritePatternSet patterns(context); + target.addLegalDialect(); + target.addDynamicallyLegalDialect([](Operation *op) { + if (llvm::isa(op) && !op->hasAttr("rewritten")) + return false; + + if (op->hasAttr("rewritten") || op->hasAttr("legal")) + return true; + for (auto operand : op->getOperands()) { + if (llvm::isa(operand)) + return false; + if (operand.getDefiningOp()->hasAttr("rewritten")) + return false; + } + + return true; + }); + + patterns.add( + patterns.getContext()); + ConversionConfig config; + config.buildMaterializations = false; + if (failed( + applyPartialConversion(module, target, std::move(patterns), config))) + return signalPassFailure(); + + patterns.clear(); + target.addIllegalOp(); + patterns.add(patterns.getContext()); + if (failed( + applyPartialConversion(module, target, std::move(patterns), config))) + return signalPassFailure(); +} + +void TritonAMDGPUCanonicalizePointersPass::runOnOperationmine() { + ModuleOp m = getOperation(); + if (failed(PointerCanonicalizer(m).run())) + signalPassFailure(); +} + std::unique_ptr mlir::createTritonAMDGPUCanonicalizePointersPass() { return std::make_unique(); } From ce211fd0a6285ef09fd5202c03befce3ce990ce3 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Fri, 13 Dec 2024 20:03:07 -0500 Subject: [PATCH 03/17] scf.for handled --- .../amd/amd-canonicalize-pointers.mlir | 638 +++++++++--------- .../CanonicalizePointers.cpp | 351 ++++++++-- 2 files changed, 614 insertions(+), 375 deletions(-) diff --git a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir index 0c3b3b9f977c..b8aa788e77cb 100644 --- a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir +++ b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir @@ -43,292 +43,292 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} // ----- -//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { -// // CHECK-LABEL: tt.func @conversion2 -// tt.func @conversion2(%arg0: !tt.ptr)-> tensor<1024xf32, #blocked>{ -// %c1024_i32 = arith.constant 1024 : i32 -// %0 = tt.get_program_id x : i32 -// %1 = arith.muli %0, %c1024_i32 : i32 -// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> -// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> -// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> -// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> -// // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 -// // CHECK: %[[baseOffset064bit:.*]] = tt.splat {{.*}} : i64 -// // CHECK: %[[newScalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] -// // CHECK: %[[offset064bit:.*]] = arith.extsi {{.*}} -// // CHECK: %[[offset164bit:.*]] = arith.addi %[[offset064bit]], %[[baseOffset064bit]] -// // CHECK: %[[offset132bit:.*]] = arith.trunci %[[offset164bit]] : tensor<1024xi64, #blocked> to tensor<1024xi32, #blocked> -// // CHECK: %[[basePtr:.*]] = tt.splat %[[newScalarPtr]] -// // CHECK: %[[newPtr:.*]] = tt.addptr %[[basePtr]], %[[offset132bit]] -// // CHECK: tt.load %[[newPtr]] -// %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// %7 = tt.load %6 : tensor<1024x!tt.ptr, #blocked> -// tt.return %7 : tensor<1024xf32, #blocked> -// } -//} -// -//// ----- -// -//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { -// // CHECK-LABEL: tt.func @conversion3 -// tt.func @conversion3(%arg0: !tt.ptr)-> tensor<1024xf32, #blocked>{ -// %c1024_i32 = arith.constant 1024 : i32 -// %0 = tt.get_program_id x : i32 -// %1 = arith.muli %0, %c1024_i32 : i32 -// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> -// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> -// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> -// -// //CHECK: %0 = tt.get_program_id x : i32 -// //CHECK: %[[pid:.*]] = arith.muli %0, {{.*}} : i32 -// //CHECK: %[[makerange:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> -// //CHECK: %[[uniformOffset1:.*]] = arith.addi %[[pid]], {{.*}} : i32 -// //CHECK: %[[tensorOffset1:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> -// //CHECK: %[[uniformOffset0:.*]] = arith.addi %[[pid:.*]], %{{.*}} : i32 -// //CHECK: %[[tensorOffset3:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> -// //CHECK: %[[zero:.*]] = tt.splat %{{.*}} : i64 -> tensor<1024xi64, #blocked> -// //CHECK: %[[uniformPtr0:.*]] = tt.addptr %arg0, %[[uniformOffset0:.*]] : !tt.ptr, i32 -// //CHECK: %[[tensorOffset3ext:.*]] = arith.extsi %[[tensorOffset3]] : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> -// //CHECK: %[[tensorOffset0:.*]]= arith.addi %[[tensorOffset3ext]], %[[zero]] : tensor<1024xi64, #blocked> -// //CHECK: %[[uniformPtr1:.*]] = tt.addptr %[[uniformPtr0]], %[[uniformOffset1]] : !tt.ptr, i32 -// //CHECK: %[[tensorOffset1ext:.*]] = arith.extsi %[[tensorOffset1]] : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> -// //CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset1ext]], %[[tensorOffset0]]: tensor<1024xi64, #blocked> -// //CHECK: %[[scalarPtr:.*]] = tt.splat %[[uniformPtr1]] : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> -// //CHECK: %[[newPtr:.*]] = tt.addptr %[[scalarPtr]], %[[tensorOffset2]] : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi64, #blocked> -// //CHECK: tt.load %[[newPtr]] -// -// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> -// %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// %8 = tt.load %7 : tensor<1024x!tt.ptr, #blocked> -// tt.return %8 : tensor<1024xf32, #blocked> -// } -//} -// -//// ----- -// -//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { -// // -// // This is the same as conversion3, but now the `arith.extsi` operations -// // disappeared and all the offsets are 32 bits. -// // -// // CHECK-LABEL: tt.func @conversion4 -// tt.func @conversion4(%arg0: !tt.ptr{tt.pointer_range = 32 : i32})-> tensor<1024xf32, #blocked>{ -// %c1024_i32 = arith.constant 1024 : i32 -// %0 = tt.get_program_id x : i32 -// %1 = arith.muli %0, %c1024_i32 : i32 -// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> -// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> -// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> -// -// //CHECK: %0 = tt.get_program_id x : i32 -// //CHECK: %[[pid:.*]] = arith.muli %0, {{.*}} : i32 -// //CHECK: %[[makerange:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> -// //CHECK: %[[uniformOffset1:.*]] = arith.addi %[[pid]], {{.*}} : i32 -// //CHECK: %[[tensorOffset1:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> -// //CHECK: %[[uniformOffset0:.*]] = arith.addi %[[pid:.*]], %{{.*}} : i32 -// //CHECK: %[[tensorOffset3:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> -// //CHECK: %[[zero:.*]] = tt.splat %{{.*}} : i32 -> tensor<1024xi32, #blocked> -// //CHECK: %[[uniformPtr0:.*]] = tt.addptr %arg0, %[[uniformOffset0:.*]] : !tt.ptr, i32 -// //CHECK: %[[tensorOffset0:.*]]= arith.addi %[[tensorOffset3]], %[[zero]] : tensor<1024xi32, #blocked> -// //CHECK: %[[uniformPtr1:.*]] = tt.addptr %[[uniformPtr0]], %[[uniformOffset1]] : !tt.ptr, i32 -// //CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset1]], %[[tensorOffset0]]: tensor<1024xi32, #blocked> -// //CHECK: %[[scalarPtr:.*]] = tt.splat %[[uniformPtr1]] : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> -// //CHECK: %[[newPtr:.*]] = tt.addptr %[[scalarPtr]], %[[tensorOffset2]] : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// //CHECK: tt.load %[[newPtr]] -// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> -// %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// %8 = tt.load %7 : tensor<1024x!tt.ptr, #blocked> -// tt.return %8 : tensor<1024xf32, #blocked> -// } -//} -// -//// ----- -// -//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { -// // CHECK-LABEL: tt.func @forOp -// tt.func @forOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ -// %c1024_i32 = arith.constant 1024 : i32 -// %c0 = arith.constant 0: index -// %c128 = arith.constant 128: index -// %c1 = arith.constant 1 : index -// %0 = tt.get_program_id x : i32 -// %1 = arith.muli %0, %c1024_i32 : i32 -// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> -// // CHECK: %[[scalarOffsetLoop:.*]] = arith.addi {{.*}}, {{.*}} : i32 -// // CHECK: %[[variableOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor -// // CHECK: %[[scalarOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 -// // CHECK: %[[scalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 -// // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor -// // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %arg0, %[[scalarOffset]] -// // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] -// // CHECK: %[[offset1:.*]] = arith.addi %[[ext_offset0]], %{{.*}} : tensor<1024xi64, #blocked> -// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> -// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> -// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> -// %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// // CHECK: %[[loop:.*]]:4 = scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[loopScalarPtr:.*]] = %{{.*}}, %[[loopOffset:.*]] = %[[offset1]]) -> {{.*}} { -// %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %6, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ -// // CHECK: %[[scalarPtrUpdateLoop:.*]] = tt.addptr %[[loopScalarPtr]], %[[scalarOffsetLoop]] -// // CHECK: %[[ext_offset0i:.*]] = arith.extsi %[[variableOffset1]] -// // CHECK: %[[offset_i:.*]] = arith.addi %[[ext_offset0i]], %[[loopOffset]] -// // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdateLoop]] -// // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[offset_i]] -// // CHECK: tt.load %[[newPtr]] -// // CHECK: scf.yield {{.*}}, {{.*}}, %[[scalarPtrUpdateLoop]], %[[offset_i]] -// %11 = tt.addptr %arg1, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> -// %10 = arith.addf %9, %arg2 : tensor<1024xf32, #blocked> -// scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> -// } -// // CHECK: tt.addptr %[[loop]]#2, %[[scalarOffset1]] : !tt.ptr, i32 -// %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> -// tt.return %11 : tensor<1024xf32, #blocked> -// } -//} -// -//// ----- -// -//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { -// // CHECK-LABEL: tt.func @forOp2 -// tt.func @forOp2(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ -// %c1024_i32 = arith.constant 1024 : i32 -// %c0 = arith.constant 0: index -// %c128 = arith.constant 128: index -// %c1 = arith.constant 1 : index -// %0 = tt.get_program_id x : i32 -// %1 = arith.muli %0, %c1024_i32 : i32 -// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> -// // CHECK: %[[scalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 -// // CHECK: %[[variableOffset0:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> -// // CHECK: %[[finalScalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 -// // CHECK: %[[variableOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> -// // CHECK: %[[base_offset:.*]] = tt.splat {{.*}} : i64 -// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> -// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> -// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> -// // CHECK: %[[forOut:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr:.*]] = %arg0, %[[loopOffset:.*]] = %[[base_offset]]) -// %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %5, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ -// // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %[[scalarPtr]], %[[scalarOffset]] -// // CHECK: %[[ext_offset0i:.*]] = arith.extsi %[[variableOffset0]] -// // CHECK: %[[ext_offset_i:.*]] = arith.addi %[[ext_offset0i]], %[[loopOffset]] -// // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdate]] -// // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[ext_offset_i]] -// // CHECK: tt.load %[[newPtr]] -// %11 = tt.addptr %arg1, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> -// %10 = arith.addf %9, %arg2 : tensor<1024xf32, #blocked> -// scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> -// } -// // CHECK: %[[scalarPtrFinalUpdate:.*]] = tt.addptr %[[forOut]]#2, %[[finalScalarOffset]] -// // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset1]] -// // CHECK: %[[tailOffset:.*]] = arith.addi %[[ext_offset0]], %[[forOut]]#3 -// // CHECK: %[[tail_base_ptr:.*]] = tt.splat %[[scalarPtrFinalUpdate]] -// // CHECK: %[[tailPtr:.*]] = tt.addptr %[[tail_base_ptr]], %[[tailOffset]] -// // CHECK: tt.load %[[tailPtr]] -// %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> -// tt.return %11 : tensor<1024xf32, #blocked> -// } -//} -// -//// ----- -// -//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { -// // CHECK-LABEL: tt.func @forNested -// tt.func @forNested(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ -// %c1024_i32 = arith.constant 1024 : i32 -// %c0 = arith.constant 0: index -// %c128 = arith.constant 128: index -// %c1 = arith.constant 1 : index -// %0 = tt.get_program_id x : i32 -// %1 = arith.muli %0, %c1024_i32 : i32 -// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> -// // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 -// // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> -// // CHECK: %[[base_offset:.*]] = tt.splat {{.*}} : i64 -// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> -// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> -// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> -// -// // CHECK: %[[forOut0:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr0:.*]] = %arg0, %[[loopOffset0:.*]] = %[[base_offset]]){{.*}}{ -// // CHECK: %[[forOut1:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr1:.*]] = %[[scalarPtr0]], %[[loopOffset1:.*]] = %[[loopOffset0]]){{.*}}{ -// // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %[[scalarPtr1]], %{{.*}} -// // CHECK: %[[ext_loop_offset1:.*]] = arith.extsi %[[variableOffset]] -// // CHECK: %[[offset_i:.*]] = arith.addi %[[ext_loop_offset1]], %[[loopOffset1]] -// // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdate]] -// // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[offset_i]] -// // CHECK: tt.load %[[newPtr]] -// // CHECK: scf.yield %{{.*}}, {{.*}}, %[[scalarPtrUpdate]], %[[offset_i]] -// // CHECK: scf.yield %{{.*}}, {{.*}}, %[[forOut1]]#2, %[[forOut1]]#3 -// -// %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %5, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ -// %53:2 = scf.for %arg10 = %c0 to %c128 step %c1 iter_args(%arg3 = %arg1, %arg4 = %arg2) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ -// %11 = tt.addptr %arg3, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> -// %10 = arith.addf %9, %arg4 : tensor<1024xf32, #blocked> -// scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> -// } -// scf.yield %53#0, %53#1: tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> -// } -// %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> -// tt.return %11 : tensor<1024xf32, #blocked> -// } -//} -// -//// ----- -// -//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { -// // CHECK-LABEL: tt.func @ifOp -// tt.func @ifOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{ -// %c1024_i32 = arith.constant 1024 : i32 -// %c0 = arith.constant 0: index -// %c128 = arith.constant 128: index -// %c1 = arith.constant 1 : index -// %0 = tt.get_program_id x : i32 -// %1 = arith.muli %0, %c1024_i32 : i32 -// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> -// // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 -// // CHECK: %[[variableOffset:.*]] = arith.addi -// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> -// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> -// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> -// // CHECK: %[[baseOffsetVariable:.*]] = tt.splat {{.*}} : i64 -> tensor<1024xi64, #blocked> -// // CHECK: %[[ifOut:.*]]:3 = scf.if {{.*}} -> (tensor<1024x!tt.ptr, #blocked>, !tt.ptr, tensor<1024xi64, #blocked>) -// %6 = scf.if %cond -> (tensor<1024x!tt.ptr, #blocked>){ -// // CHECK: %[[scalarOffsetUpdate:.*]] = tt.addptr %arg0, %[[scalarOffset]] -// // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] -// // CHECK: %[[if_offset:.*]] = arith.addi %[[ext_offset0]], %[[baseOffsetVariable]] -// // CHECK: scf.yield %{{.*}}, %[[scalarOffsetUpdate]], %[[if_offset]] -// %true = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// scf.yield %true : tensor<1024x!tt.ptr, #blocked> -// } else { -// // CHECK: %[[new_scalar_ptr:.*]] = tt.addptr %arg0, {{.*}} -// // CHECK: scf.yield %{{.*}}, %[[new_scalar_ptr]], %[[baseOffsetVariable]] -// %false = tt.addptr %5, %3 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// scf.yield %false : tensor<1024x!tt.ptr, #blocked> -// } -// // CHECK: %[[trunc_offset:.*]] = arith.trunci %[[ifOut]]#2 -// // CHECK: %[[base_ptr:.*]] = tt.splat %[[ifOut]]#1 -// // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[trunc_offset]] -// // CHECK: tt.load %[[newPtr]] -// %11 = tt.load %6 : tensor<1024x!tt.ptr, #blocked> -// tt.return %11 : tensor<1024xf32, #blocked> -// } -//} -// -//// ----- +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @conversion2 + tt.func @conversion2(%arg0: !tt.ptr)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 + // CHECK: %[[baseOffset064bit:.*]] = tt.splat {{.*}} : i64 + // CHECK: %[[newScalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] + // CHECK: %[[offset064bit:.*]] = arith.extsi {{.*}} + // CHECK: %[[offset164bit:.*]] = arith.addi %[[offset064bit]], %[[baseOffset064bit]] + // CHECK: %[[offset132bit:.*]] = arith.trunci %[[offset164bit]] : tensor<1024xi64, #blocked> to tensor<1024xi32, #blocked> + // CHECK: %[[basePtr:.*]] = tt.splat %[[newScalarPtr]] + // CHECK: %[[newPtr:.*]] = tt.addptr %[[basePtr]], %[[offset132bit]] + // CHECK: tt.load %[[newPtr]] + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %7 = tt.load %6 : tensor<1024x!tt.ptr, #blocked> + tt.return %7 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @conversion3 + tt.func @conversion3(%arg0: !tt.ptr)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + + //CHECK: %0 = tt.get_program_id x : i32 + //CHECK: %[[pid:.*]] = arith.muli %0, {{.*}} : i32 + //CHECK: %[[makerange:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + //CHECK: %[[uniformOffset1:.*]] = arith.addi %[[pid]], {{.*}} : i32 + //CHECK: %[[tensorOffset1:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> + //CHECK: %[[uniformOffset0:.*]] = arith.addi %[[pid:.*]], %{{.*}} : i32 + //CHECK: %[[tensorOffset3:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> + //CHECK: %[[zero:.*]] = tt.splat %{{.*}} : i64 -> tensor<1024xi64, #blocked> + //CHECK: %[[uniformPtr0:.*]] = tt.addptr %arg0, %[[uniformOffset0:.*]] : !tt.ptr, i32 + //CHECK: %[[tensorOffset3ext:.*]] = arith.extsi %[[tensorOffset3]] : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> + //CHECK: %[[tensorOffset0:.*]]= arith.addi %[[tensorOffset3ext]], %[[zero]] : tensor<1024xi64, #blocked> + //CHECK: %[[uniformPtr1:.*]] = tt.addptr %[[uniformPtr0]], %[[uniformOffset1]] : !tt.ptr, i32 + //CHECK: %[[tensorOffset1ext:.*]] = arith.extsi %[[tensorOffset1]] : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> + //CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset1ext]], %[[tensorOffset0]]: tensor<1024xi64, #blocked> + //CHECK: %[[scalarPtr:.*]] = tt.splat %[[uniformPtr1]] : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + //CHECK: %[[newPtr:.*]] = tt.addptr %[[scalarPtr]], %[[tensorOffset2]] : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi64, #blocked> + //CHECK: tt.load %[[newPtr]] + + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %8 = tt.load %7 : tensor<1024x!tt.ptr, #blocked> + tt.return %8 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // + // This is the same as conversion3, but now the `arith.extsi` operations + // disappeared and all the offsets are 32 bits. + // + // CHECK-LABEL: tt.func @conversion4 + tt.func @conversion4(%arg0: !tt.ptr{tt.pointer_range = 32 : i32})-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + + //CHECK: %0 = tt.get_program_id x : i32 + //CHECK: %[[pid:.*]] = arith.muli %0, {{.*}} : i32 + //CHECK: %[[makerange:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + //CHECK: %[[uniformOffset1:.*]] = arith.addi %[[pid]], {{.*}} : i32 + //CHECK: %[[tensorOffset1:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> + //CHECK: %[[uniformOffset0:.*]] = arith.addi %[[pid:.*]], %{{.*}} : i32 + //CHECK: %[[tensorOffset3:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> + //CHECK: %[[zero:.*]] = tt.splat %{{.*}} : i32 -> tensor<1024xi32, #blocked> + //CHECK: %[[uniformPtr0:.*]] = tt.addptr %arg0, %[[uniformOffset0:.*]] : !tt.ptr, i32 + //CHECK: %[[tensorOffset0:.*]]= arith.addi %[[tensorOffset3]], %[[zero]] : tensor<1024xi32, #blocked> + //CHECK: %[[uniformPtr1:.*]] = tt.addptr %[[uniformPtr0]], %[[uniformOffset1]] : !tt.ptr, i32 + //CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset1]], %[[tensorOffset0]]: tensor<1024xi32, #blocked> + //CHECK: %[[scalarPtr:.*]] = tt.splat %[[uniformPtr1]] : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + //CHECK: %[[newPtr:.*]] = tt.addptr %[[scalarPtr]], %[[tensorOffset2]] : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + //CHECK: tt.load %[[newPtr]] + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %8 = tt.load %7 : tensor<1024x!tt.ptr, #blocked> + tt.return %8 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @forOp + tt.func @forOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[scalarOffsetLoop:.*]] = arith.addi {{.*}}, {{.*}} : i32 + // CHECK: %[[variableOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor + // CHECK: %[[scalarOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 + // CHECK: %[[scalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 + // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor + // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %arg0, %[[scalarOffset]] + // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] + // CHECK: %[[offset1:.*]] = arith.addi %[[ext_offset0]], %{{.*}} : tensor<1024xi64, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: %[[loop:.*]]:4 = scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[loopScalarPtr:.*]] = %{{.*}}, %[[loopOffset:.*]] = %[[offset1]]) -> {{.*}} { + %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %6, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ + // CHECK: %[[scalarPtrUpdateLoop:.*]] = tt.addptr %[[loopScalarPtr]], %[[scalarOffsetLoop]] + // CHECK: %[[ext_offset0i:.*]] = arith.extsi %[[variableOffset1]] + // CHECK: %[[offset_i:.*]] = arith.addi %[[ext_offset0i]], %[[loopOffset]] + // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdateLoop]] + // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[offset_i]] + // CHECK: tt.load %[[newPtr]] + // CHECK: scf.yield {{.*}}, {{.*}}, %[[scalarPtrUpdateLoop]], %[[offset_i]] + %11 = tt.addptr %arg1, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> + %10 = arith.addf %9, %arg2 : tensor<1024xf32, #blocked> + scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> + } + // CHECK: tt.addptr %[[loop]]#2, %[[scalarOffset1]] : !tt.ptr, i32 + %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> + tt.return %11 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @forOp2 + tt.func @forOp2(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[scalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 + // CHECK: %[[variableOffset0:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> + // CHECK: %[[finalScalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 + // CHECK: %[[variableOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> + // CHECK: %[[base_offset:.*]] = tt.splat {{.*}} : i64 + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + // CHECK: %[[forOut:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr:.*]] = %arg0, %[[loopOffset:.*]] = %[[base_offset]]) + %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %5, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ + // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %[[scalarPtr]], %[[scalarOffset]] + // CHECK: %[[ext_offset0i:.*]] = arith.extsi %[[variableOffset0]] + // CHECK: %[[ext_offset_i:.*]] = arith.addi %[[ext_offset0i]], %[[loopOffset]] + // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdate]] + // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[ext_offset_i]] + // CHECK: tt.load %[[newPtr]] + %11 = tt.addptr %arg1, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> + %10 = arith.addf %9, %arg2 : tensor<1024xf32, #blocked> + scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> + } + // CHECK: %[[scalarPtrFinalUpdate:.*]] = tt.addptr %[[forOut]]#2, %[[finalScalarOffset]] + // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset1]] + // CHECK: %[[tailOffset:.*]] = arith.addi %[[ext_offset0]], %[[forOut]]#3 + // CHECK: %[[tail_base_ptr:.*]] = tt.splat %[[scalarPtrFinalUpdate]] + // CHECK: %[[tailPtr:.*]] = tt.addptr %[[tail_base_ptr]], %[[tailOffset]] + // CHECK: tt.load %[[tailPtr]] + %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> + tt.return %11 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @forNested + tt.func @forNested(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 + // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> + // CHECK: %[[base_offset:.*]] = tt.splat {{.*}} : i64 + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + + // CHECK: %[[forOut0:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr0:.*]] = %arg0, %[[loopOffset0:.*]] = %[[base_offset]]){{.*}}{ + // CHECK: %[[forOut1:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr1:.*]] = %[[scalarPtr0]], %[[loopOffset1:.*]] = %[[loopOffset0]]){{.*}}{ + // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %[[scalarPtr1]], %{{.*}} + // CHECK: %[[ext_loop_offset1:.*]] = arith.extsi %[[variableOffset]] + // CHECK: %[[offset_i:.*]] = arith.addi %[[ext_loop_offset1]], %[[loopOffset1]] + // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdate]] + // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[offset_i]] + // CHECK: tt.load %[[newPtr]] + // CHECK: scf.yield %{{.*}}, {{.*}}, %[[scalarPtrUpdate]], %[[offset_i]] + // CHECK: scf.yield %{{.*}}, {{.*}}, %[[forOut1]]#2, %[[forOut1]]#3 + + %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %5, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ + %53:2 = scf.for %arg10 = %c0 to %c128 step %c1 iter_args(%arg3 = %arg1, %arg4 = %arg2) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ + %11 = tt.addptr %arg3, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> + %10 = arith.addf %9, %arg4 : tensor<1024xf32, #blocked> + scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> + } + scf.yield %53#0, %53#1: tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> + } + %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> + tt.return %11 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @ifOp + tt.func @ifOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 + // CHECK: %[[variableOffset:.*]] = arith.addi + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + // CHECK: %[[baseOffsetVariable:.*]] = tt.splat {{.*}} : i64 -> tensor<1024xi64, #blocked> + // CHECK: %[[ifOut:.*]]:3 = scf.if {{.*}} -> (tensor<1024x!tt.ptr, #blocked>, !tt.ptr, tensor<1024xi64, #blocked>) + %6 = scf.if %cond -> (tensor<1024x!tt.ptr, #blocked>){ + // CHECK: %[[scalarOffsetUpdate:.*]] = tt.addptr %arg0, %[[scalarOffset]] + // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] + // CHECK: %[[if_offset:.*]] = arith.addi %[[ext_offset0]], %[[baseOffsetVariable]] + // CHECK: scf.yield %{{.*}}, %[[scalarOffsetUpdate]], %[[if_offset]] + %true = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + scf.yield %true : tensor<1024x!tt.ptr, #blocked> + } else { + // CHECK: %[[new_scalar_ptr:.*]] = tt.addptr %arg0, {{.*}} + // CHECK: scf.yield %{{.*}}, %[[new_scalar_ptr]], %[[baseOffsetVariable]] + %false = tt.addptr %5, %3 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + scf.yield %false : tensor<1024x!tt.ptr, #blocked> + } + // CHECK: %[[trunc_offset:.*]] = arith.trunci %[[ifOut]]#2 + // CHECK: %[[base_ptr:.*]] = tt.splat %[[ifOut]]#1 + // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[trunc_offset]] + // CHECK: tt.load %[[newPtr]] + %11 = tt.load %6 : tensor<1024x!tt.ptr, #blocked> + tt.return %11 : tensor<1024xf32, #blocked> + } +} + +// ----- // //#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> //module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { @@ -364,8 +364,8 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} // } //} // -//// ----- -// +// ----- + //#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> //module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // // CHECK-LABEL: tt.func @condBranch @@ -609,37 +609,37 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} //} // //// ----- -// -//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { -// // CHECK-LABEL: tt.func @forOpWithHints -// tt.func @forOpWithHints(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ -// %c0 = arith.constant 0: index -// %c1 = arith.constant 1 : index -// %c128 = arith.constant 128: index -// %0 = tt.get_program_id x : i32 -// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> -// %3 = tt.splat %0 : i32 -> tensor<1024xi32, #blocked> -// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> -// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> -// %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %6, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ -// %9 = tt.load %arg1: tensor<1024x!tt.ptr, #blocked> -// // CHECK: tt.addptr {{.*}}, {{.*}} {tt.divisibility = dense<16> : tensor<1xi32>} -// %11 = tt.addptr %arg1, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// %12 = tt.addptr %11, %3 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// %10 = arith.addf %9, %arg2 : tensor<1024xf32, #blocked> -// scf.yield %12, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> -// } {"tt.divisibility_arg1"=dense<[16]> : tensor<1xi32>} -// // CHECK: tt.divisibility_arg1 -// // CHECK-SAME: tt.divisibility_arg4 -// %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> -// tt.return %11 : tensor<1024xf32, #blocked> -// } -//} -// -//// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @forOpWithHints + tt.func @forOpWithHints(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ + %c0 = arith.constant 0: index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128: index + %0 = tt.get_program_id x : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %0 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %6, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ + %9 = tt.load %arg1: tensor<1024x!tt.ptr, #blocked> + // CHECK: tt.addptr {{.*}}, {{.*}} {tt.divisibility = dense<16> : tensor<1xi32>} + %11 = tt.addptr %arg1, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.addptr %11, %3 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %10 = arith.addf %9, %arg2 : tensor<1024xf32, #blocked> + scf.yield %12, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> + } {"tt.divisibility_arg1"=dense<[16]> : tensor<1xi32>} + // CHECK: tt.divisibility_arg1 + // CHECK-SAME: tt.divisibility_arg4 + %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> + tt.return %11 : tensor<1024xf32, #blocked> + } +} + +// ----- // //#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> //module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index 8b83231e85f6..2ad8ca7de944 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -473,8 +473,9 @@ Value createTensorPointer( if (canNarrow) offset = createNarrow64bitOffsetTo32bits(rewriter, loc, offset); - Value tensorPtr = - rewriter.create(loc, tensorPtrType, basePtr); + Value tensorPtr = rewriter.create( + loc, TypeRange{tensorPtrType}, ValueRange{basePtr}, + SmallVector{rewriter.getNamedAttr("legal", rewriter.getUnitAttr())}); auto addPtrOp = rewriter.create(loc, tensorPtrType, tensorPtr, offset); @@ -1049,9 +1050,91 @@ class TritonAMDGPUCanonicalizePointersPass void runOnOperationmine(); }; -class ConvertAddPtrOp : public OpConversionPattern { +struct FatPointers { + struct FatPtr { + bool canNarrow = false; + llvm::SmallDenseMap attributes; + }; + using KeyT = std::pair; + using ValueT = FatPtr; + using DenseMapT = DenseMap; + DenseMapT pointers; + ValueT &operator[](const KeyT &k) { return pointers[k]; } + ValueT &operator[](KeyT &&k) { return pointers[k]; } + template + using const_arg_type_t = typename llvm::const_pointer_or_const_ref::type; + const ValueT &at(const_arg_type_t k) const { return pointers.at(k); } +}; + +std::optional getFatPtrCastOp(Value base, + Value offset) { + std::optional maybeCastOp; + for (Operation *user : base.getUsers()) { + if (auto castOp = llvm::dyn_cast(user)) { + if (castOp.getNumOperands() == 2 && castOp.getOperand(0) == base && + castOp.getOperand(1) == offset) { + maybeCastOp = castOp; + } + } + } +#ifndef NDEBUG + for (Operation *user : offset.getUsers()) { + if (auto castOp = llvm::dyn_cast(user)) { + if (castOp.getNumOperands() == 2 && castOp.getOperand(0) == base && + castOp.getOperand(1) == offset) { + assert( + castOp == *maybeCastOp && + "expected castop through base and castop through offset to match"); + } + } + } +#endif + return maybeCastOp; +} + +std::optional getFatPtrCastOp(OpOperand &operand) { + Value operandVal = operand.get(); + for (Operation *user : operandVal.getUsers()) { + if (auto castOp = llvm::dyn_cast(user)) { + if (castOp.getNumOperands() == 2 && + (castOp.getOperand(0) == operandVal || + castOp.getOperand(1) == operandVal) && + castOp.getNumResults() == 1 && + std::distance(castOp->getUsers().begin(), castOp->getUsers().end()) == + 1 && + *castOp->getUsers().begin() == operand.getOwner()) { + return castOp; + } + } + } + return {}; +} + +/// Flatten the given value ranges into a single vector of values. +static SmallVector flattenValues(ArrayRef values) { + SmallVector result; + for (const ValueRange &vals : values) + llvm::append_range(result, vals); + return result; +} + +/// Assert that the given value range contains a single value and return it. +static Value getSingleValue(ValueRange values) { + assert(values.size() == 1 && "expected single value"); + return values.front(); +} + +template +struct PointerCanonPattern : OpConversionPattern { + PointerCanonPattern(MLIRContext *context, FatPointers &fatPtrs, + PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), fatPtrs(fatPtrs) {} + FatPointers &fatPtrs; +}; + +class ConvertAddPtrOp : public PointerCanonPattern { public: - using OpConversionPattern::OpConversionPattern; + using PointerCanonPattern::PointerCanonPattern; LogicalResult matchAndRewrite(triton::AddPtrOp addPtrOp, OneToNOpAdaptor adaptor, @@ -1061,11 +1144,11 @@ class ConvertAddPtrOp : public OpConversionPattern { ArrayRef remappedOperands = adaptor.getOperands(); assert(remappedOperands.size() == 2 && remappedOperands[0].size() == 2 && - "expected adaptor to have 2,1 remapped values"); + "expected adaptor to have 2 remapped values"); Value fatPtrBase = remappedOperands[0][0]; Value fatPtrOffset = remappedOperands[0][1]; Value origOffset = remappedOperands[1][0]; - auto curLoc = addPtrOp.getLoc(); + Location curLoc = addPtrOp.getLoc(); // If it is a scalar pointer update, simply bump the base pointer if (!isa(addPtrOp.getPtr().getType())) { @@ -1094,10 +1177,8 @@ class ConvertAddPtrOp : public OpConversionPattern { rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, fatPtrOffset}}); // If we are updating the tensor pointer with a uniform value, we can // propagate the attributes of the tensor pointer to the fat pointer. - // TODO(max): re-enable - // for (auto attribute : fatPtr.attributes) - // pointers[nextPtr].setAttr(attribute.getFirst(), attribute.getSecond()); - // opToDelete.insert(addPtrOp); + fatPtrs[{newAddPtrOp.getResult(), fatPtrOffset}].attributes = + fatPtrs[{fatPtrBase, fatPtrOffset}].attributes; return success(); } @@ -1107,10 +1188,8 @@ class ConvertAddPtrOp : public OpConversionPattern { decomposeOffsetFromExpr(rewriter, curLoc, origOffset, bitness); // Vector offset update (if any): bump the tensor offset - // TODO(max): stash somewhere bool canNarrow = false; - bool propagateAtrs = false; - + bool propagateAtrs = true; Value newOffset = fatPtrOffset; if (!isZeroConst(nonUniformOffset)) { Type addPtrOffsetType = getElementTypeOrSelf(nonUniformOffset); @@ -1139,22 +1218,19 @@ class ConvertAddPtrOp : public OpConversionPattern { addPtrOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); }); rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, newOffset}}); - // - // // If we are updating the tensor pointer with a uniform value, we can - // // propagate the attributes of the tensor pointer to the fat pointer. - // TODO(max): re-enable - // if (propagateAtrs) - // for (auto attribute : fatPtr.attributes) - // pointers[nextPtr].setAttr(attribute.getFirst(), - // attribute.getSecond()); + auto nextFatPtr = std::pair{newAddPtrOp.getResult(), newOffset}; + fatPtrs[nextFatPtr].canNarrow = canNarrow; + if (propagateAtrs) + fatPtrs[nextFatPtr].attributes = + fatPtrs.at({fatPtrBase, fatPtrOffset}).attributes; return success(); } }; -class ConvertSplatOp : public OpConversionPattern { +class ConvertSplatOp : public PointerCanonPattern { public: - using OpConversionPattern::OpConversionPattern; + using PointerCanonPattern::PointerCanonPattern; LogicalResult matchAndRewrite(triton::SplatOp splatOp, OneToNOpAdaptor adaptor, @@ -1171,8 +1247,8 @@ class ConvertSplatOp : public OpConversionPattern { assert(llvm::isa(fatPtrOffset.getType()) && "expected fatPtrOffset to be an integer type"); - auto outType = splatOp.getResult().getType(); - auto ptrShape = outType.getShape(); + RankedTensorType outType = splatOp.getResult().getType(); + llvm::ArrayRef ptrShape = outType.getShape(); auto newOffsetType = RankedTensorType::get(ptrShape, fatPtrOffset.getType(), outType.getEncoding()); Value offset = rewriter.create( @@ -1186,38 +1262,37 @@ class ConvertSplatOp : public OpConversionPattern { } }; -class ConvertLoadOp : public OpConversionPattern { +class ConvertLoadOp : public PointerCanonPattern { public: - using OpConversionPattern::OpConversionPattern; + using PointerCanonPattern::PointerCanonPattern; LogicalResult matchAndRewrite(triton::LoadOp loadOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto fatPtr = *adaptor.getOperands().begin(); + ValueRange fatPtr = *adaptor.getOperands().begin(); Value fatPtrBase = fatPtr.front(); Value fatPtrOffset = fatPtr.back(); Location curLoc = loadOp.getLoc(); - llvm::SmallDenseMap attributes{}; - auto newPtr = - createTensorPointer(rewriter, fatPtrBase, fatPtrOffset, curLoc, - // TODO(max): - /*canNarrow*/ true, attributes); + llvm::SmallDenseMap attributes{ + {rewriter.getStringAttr("legal"), rewriter.getUnitAttr()}}; + Value newPtr = createTensorPointer( + rewriter, fatPtrBase, fatPtrOffset, curLoc, + fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow, attributes); SmallVector operands = loadOp.getOperands().take_back(loadOp.getNumOperands() - 1); operands.insert(operands.begin(), newPtr); - auto newLoadPtrOp = rewriter.replaceOpWithNewOp( - loadOp, operands, loadOp->getAttrs()); - rewriter.modifyOpInPlace(newLoadPtrOp, [&] { - newLoadPtrOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); - }); + SmallVector attrs = llvm::to_vector(loadOp->getAttrs()); + attrs.append({rewriter.getNamedAttr("legal", rewriter.getUnitAttr())}); + auto newLoadPtrOp = + rewriter.replaceOpWithNewOp(loadOp, operands, attrs); return success(); } }; -class ConvertFuncOp : public OpConversionPattern { +class ConvertFuncOp : public PointerCanonPattern { public: - using OpConversionPattern::OpConversionPattern; + using PointerCanonPattern::PointerCanonPattern; LogicalResult matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, @@ -1246,15 +1321,56 @@ class ConvertFuncOp : public OpConversionPattern { } }; +class ConvertSCFYieldOp : public PointerCanonPattern { +public: + using PointerCanonPattern::PointerCanonPattern; + + LogicalResult + matchAndRewrite(scf::YieldOp yieldOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector newYieldedValues = flattenValues(adaptor.getOperands()); + + // Value tensorPtr = + // createTensorPointer(rewriter, fatPtr.basePtr, fatPtr.offset, curLoc, + // fatPtr.canNarrow, fatPtr.attributes); + + rewriter.modifyOpInPlace(yieldOp, [&]() { + yieldOp.getResultsMutable().clear(); + yieldOp.getResultsMutable().append(newYieldedValues); + }); + + // TODO(max): this is bad + if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { + if (ifOp.thenBlock() == yieldOp->getBlock()) + rewriter.modifyOpInPlace(ifOp, [&] { + ifOp->setDiscardableAttr("then_rewritten", rewriter.getUnitAttr()); + }); + else + rewriter.modifyOpInPlace(ifOp, [&] { + ifOp->setDiscardableAttr("else_rewritten", rewriter.getUnitAttr()); + }); + } + + rewriter.modifyOpInPlace(yieldOp, [&] { + yieldOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); + }); + + return success(); + } +}; + class ConvertUnrealizedConversionCastOp - : public OpConversionPattern { + : public PointerCanonPattern { public: - using OpConversionPattern::OpConversionPattern; + using PointerCanonPattern::PointerCanonPattern; LogicalResult matchAndRewrite(UnrealizedConversionCastOp castOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - assert(castOp->hasOneUse() && "expected at least 1 use of unrealized_cast"); + assert(std::distance(castOp->getUses().begin(), castOp->getUses().end()) > + 0 && + "expected at least 1 use of unrealized_cast"); + // dunno why but i get -Wdangling here... ArrayRef remappedOperands = adaptor.getOperands(); assert(remappedOperands.size() == 1 && remappedOperands[0].size() == 2 && "expected adaptor to have 2 remapped values"); @@ -1277,30 +1393,152 @@ class ConvertUnrealizedConversionCastOp } }; +class ConvertSCFForOp : public PointerCanonPattern { + using PointerCanonPattern::PointerCanonPattern; + +public: + LogicalResult + matchAndRewrite(scf::ForOp forOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector> fatPtrInits; + SmallVector valRangeLens; + ArrayRef remappedInits = adaptor.getInitArgs(); + for (ValueRange remappedInit : remappedInits) { + if (remappedInit.size() == 2) { + Value fatPtrBase = remappedInit[0]; + Value fatPtrOffset = remappedInit[1]; + fatPtrInits.emplace_back(fatPtrBase, fatPtrOffset); + } + valRangeLens.push_back(remappedInit.size()); + } + + TypeConverter hackTypeConverter; + unsigned inputNo = 0; + hackTypeConverter.addConversion( + [&inputNo, &remappedInits = std::as_const(remappedInits)]( + Type inputType, SmallVectorImpl &types) { + // handle the 0th iv + if (inputNo == 0) { + types.append({inputType}); + } else { + SmallVector remappedInitTypes = + llvm::to_vector(remappedInits[inputNo - 1].getTypes()); + types.append(remappedInitTypes); + } + inputNo++; + return success(); + }); + if (failed( + rewriter.convertRegionTypes(&forOp.getRegion(), hackTypeConverter))) + return failure(); + SmallVector initArgs = flattenValues(adaptor.getInitArgs()); + auto newForOp = rewriter.create( + forOp.getLoc(), getSingleValue(adaptor.getLowerBound()), + getSingleValue(adaptor.getUpperBound()), + getSingleValue(adaptor.getStep()), initArgs); + // replaceWithAdditionalYields + + newForOp->setAttrs(forOp->getAttrs()); + rewriter.eraseBlock(newForOp.getBody(0)); + rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(), + newForOp.getRegion().end()); + + SmallVector packedRets; + for (unsigned i = 0, offset = 0; i < valRangeLens.size(); i++) { + size_t len = valRangeLens[i]; + assert(offset < newForOp->getNumResults() && + "expected offset to be within bounds of results"); + ValueRange mappedValue = newForOp->getResults().slice(offset, len); + packedRets.push_back(mappedValue); + offset += len; + } + + rewriter.modifyOpInPlace(forOp, [&] { + forOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); + }); + rewriter.modifyOpInPlace(newForOp, [&] { + newForOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); + }); + rewriter.replaceOpWithMultiple(forOp, packedRets); + + return success(); + } +}; + +class ConvertSCFIfOp : public PointerCanonPattern { +public: + using PointerCanonPattern::PointerCanonPattern; + // One of the two branches is responsible to rewrite the operation. The other + // branch only update the yieldOp with the right parameters + LogicalResult + matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(ifOp.getNumResults() == 1 && + ifOp.thenYield().getOperandTypes().size() == 2 && + "only 1 -> 2 supported for scf::IfOp rewrite"); + bool withElseRegion = ifOp.getNumRegions() > 1; + if (withElseRegion) { + assert(ifOp.thenYield().getOperandTypes() == + ifOp.elseYield().getOperandTypes() && + "ifOp types must match in both arms"); + } + + auto newIfOp = rewriter.create( + ifOp.getLoc(), ifOp.thenYield().getOperandTypes(), ifOp.getCondition(), + withElseRegion); + rewriter.inlineBlockBefore(ifOp.thenBlock(), newIfOp.thenBlock(), + newIfOp.thenBlock()->begin()); + if (withElseRegion) + rewriter.inlineBlockBefore(ifOp.elseBlock(), newIfOp.elseBlock(), + newIfOp.elseBlock()->begin()); + + rewriter.modifyOpInPlace(ifOp, [&] { + ifOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); + }); + rewriter.modifyOpInPlace(newIfOp, [&] { + newIfOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); + }); + rewriter.replaceOpWithMultiple(ifOp, {newIfOp.getResults()}); + + return success(); + } +}; + void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { ModuleOp module = getOperation(); - auto *context = &getContext(); + mlir::MLIRContext *context = &getContext(); ConversionTarget target(*context); RewritePatternSet patterns(context); target.addLegalDialect(); - target.addDynamicallyLegalDialect([](Operation *op) { - if (llvm::isa(op) && !op->hasAttr("rewritten")) - return false; - + auto isLegal = [](Operation *op) { if (op->hasAttr("rewritten") || op->hasAttr("legal")) return true; - for (auto operand : op->getOperands()) { - if (llvm::isa(operand)) - return false; - if (operand.getDefiningOp()->hasAttr("rewritten")) + for (OpOperand &operand : op->getOpOperands()) { + if (auto arg = llvm::dyn_cast(operand.get())) + return !llvm::isa(getElementTypeOrSelf(arg)); + if (operand.get().getDefiningOp()->hasAttr("rewritten")) return false; } - return true; + }; + target.addDynamicallyLegalDialect( + [&isLegal](Operation *op) { + if (llvm::isa(op) && !op->hasAttr("rewritten")) + return false; + return isLegal(op); + }); + target.addDynamicallyLegalDialect([&isLegal](Operation *op) { + if (auto ifOp = llvm::dyn_cast(op)) + return !(ifOp->hasAttr("then_rewritten") and + ifOp->hasAttr("else_rewritten")); + return isLegal(op); }); - patterns.add( - patterns.getContext()); + FatPointers fatPrs; + + patterns.add( + patterns.getContext(), fatPrs); ConversionConfig config; config.buildMaterializations = false; if (failed( @@ -1309,7 +1547,8 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { patterns.clear(); target.addIllegalOp(); - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext(), + fatPrs); if (failed( applyPartialConversion(module, target, std::move(patterns), config))) return signalPassFailure(); From e4b27bd0c4d052892751e658ad738daf3263c116 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 16 Dec 2024 14:58:59 -0500 Subject: [PATCH 04/17] handle scf.while (TODO: handle cond arg) --- .../amd/amd-canonicalize-pointers.mlir | 70 ++++++------ .../CanonicalizePointers.cpp | 105 +++++++++++++++--- 2 files changed, 123 insertions(+), 52 deletions(-) diff --git a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir index b8aa788e77cb..a7fd50cbf051 100644 --- a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir +++ b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir @@ -329,41 +329,41 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} } // ----- -// -//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { -// // CHECK-LABEL: tt.func @whileOp -// tt.func @whileOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{ -// %c1024_i32 = arith.constant 1024 : i32 -// %c0 = arith.constant 0: index -// %c128 = arith.constant 128: index -// %c1 = arith.constant 1 : index -// %0 = tt.get_program_id x : i32 -// %1 = arith.muli %0, %c1024_i32 : i32 -// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> -// // CHECK: %[[base_offset:.*]] = tt.splat %{{.*}} : i64 -// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> -// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> -// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> -// // CHECK: %[[whileOut:.*]]:3 = scf.while ({{.*}}, %[[loopPtr:.*]] = %arg0, %[[loopOffset:.*]] = %[[base_offset]]) -// %6 = scf.while (%arg1 = %5, %arg2 = %cond) : (tensor<1024x!tt.ptr, #blocked>, i1) -> (tensor<1024x!tt.ptr, #blocked>) { -// // CHECK: scf.condition({{.*}}) %{{.*}}, %[[loopPtr]], %[[loopOffset]] -// scf.condition(%arg2) %arg1 : tensor<1024x!tt.ptr, #blocked> -// } do { -// // CHECK: ^bb{{.*}}(%{{.*}}, %[[blockPtr:.*]]: !tt.ptr, %[[blockOffset:.*]]: tensor<1024xi64, #blocked>): -// ^bb0(%arg1: tensor<1024x!tt.ptr, #blocked>): -// // CHECK: scf.yield {{.*}}, %[[blockPtr]], %[[blockOffset]] -// scf.yield %arg1, %cond : tensor<1024x!tt.ptr, #blocked>, i1 -// } -// // CHECK: %[[trunc_offset:.*]] = arith.trunci %[[whileOut]]#2 -// // CHECK: %[[base_ptr:.*]] = tt.splat %[[whileOut]]#1 -// // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[trunc_offset]] -// // CHECK: tt.load %[[newPtr]] -// %11 = tt.load %6 : tensor<1024x!tt.ptr, #blocked> -// tt.return %11 : tensor<1024xf32, #blocked> -// } -//} -// + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @whileOp + tt.func @whileOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[base_offset:.*]] = tt.splat %{{.*}} : i64 + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + // CHECK: %[[whileOut:.*]]:3 = scf.while ({{.*}}, %[[loopPtr:.*]] = %arg0, %[[loopOffset:.*]] = %[[base_offset]]) + %6 = scf.while (%arg1 = %5) : (tensor<1024x!tt.ptr, #blocked>) -> (tensor<1024x!tt.ptr, #blocked>) { + // CHECK: scf.condition({{.*}}) %{{.*}}, %[[loopPtr]], %[[loopOffset]] + scf.condition(%cond) %arg1 : tensor<1024x!tt.ptr, #blocked> + } do { + // CHECK: ^bb{{.*}}(%{{.*}}, %[[blockPtr:.*]]: !tt.ptr, %[[blockOffset:.*]]: tensor<1024xi64, #blocked>): + ^bb0(%arg1: tensor<1024x!tt.ptr, #blocked>): + // CHECK: scf.yield {{.*}}, %[[blockPtr]], %[[blockOffset]] + scf.yield %arg1 : tensor<1024x!tt.ptr, #blocked> + } + // CHECK: %[[trunc_offset:.*]] = arith.trunci %[[whileOut]]#2 + // CHECK: %[[base_ptr:.*]] = tt.splat %[[whileOut]]#1 + // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[trunc_offset]] + // CHECK: tt.load %[[newPtr]] + %11 = tt.load %6 : tensor<1024x!tt.ptr, #blocked> + tt.return %11 : tensor<1024xf32, #blocked> + } +} + // ----- //#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index 2ad8ca7de944..d325e438c902 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -1329,11 +1329,6 @@ class ConvertSCFYieldOp : public PointerCanonPattern { matchAndRewrite(scf::YieldOp yieldOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector newYieldedValues = flattenValues(adaptor.getOperands()); - - // Value tensorPtr = - // createTensorPointer(rewriter, fatPtr.basePtr, fatPtr.offset, curLoc, - // fatPtr.canNarrow, fatPtr.attributes); - rewriter.modifyOpInPlace(yieldOp, [&]() { yieldOp.getResultsMutable().clear(); yieldOp.getResultsMutable().append(newYieldedValues); @@ -1400,18 +1395,10 @@ class ConvertSCFForOp : public PointerCanonPattern { LogicalResult matchAndRewrite(scf::ForOp forOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector> fatPtrInits; SmallVector valRangeLens; ArrayRef remappedInits = adaptor.getInitArgs(); - for (ValueRange remappedInit : remappedInits) { - if (remappedInit.size() == 2) { - Value fatPtrBase = remappedInit[0]; - Value fatPtrOffset = remappedInit[1]; - fatPtrInits.emplace_back(fatPtrBase, fatPtrOffset); - } + for (ValueRange remappedInit : remappedInits) valRangeLens.push_back(remappedInit.size()); - } - TypeConverter hackTypeConverter; unsigned inputNo = 0; hackTypeConverter.addConversion( @@ -1436,7 +1423,6 @@ class ConvertSCFForOp : public PointerCanonPattern { forOp.getLoc(), getSingleValue(adaptor.getLowerBound()), getSingleValue(adaptor.getUpperBound()), getSingleValue(adaptor.getStep()), initArgs); - // replaceWithAdditionalYields newForOp->setAttrs(forOp->getAttrs()); rewriter.eraseBlock(newForOp.getBody(0)); @@ -1504,6 +1490,88 @@ class ConvertSCFIfOp : public PointerCanonPattern { } }; +class ConvertSCFWhileOp : public PointerCanonPattern { +public: + using PointerCanonPattern::PointerCanonPattern; + LogicalResult + matchAndRewrite(scf::WhileOp whileOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector valRangeLens; + ArrayRef remappedInits = adaptor.getInits(); + for (ValueRange remappedInit : remappedInits) + valRangeLens.push_back(remappedInit.size()); + TypeConverter hackTypeConverter; + unsigned inputNo = 0; + hackTypeConverter.addConversion( + [&inputNo, &remappedInits = std::as_const(remappedInits)]( + Type inputType, SmallVectorImpl &types) { + SmallVector remappedInitTypes = + llvm::to_vector(remappedInits[inputNo].getTypes()); + types.append(remappedInitTypes); + inputNo++; + return success(); + }); + if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(), + hackTypeConverter))) + return failure(); + if (failed(rewriter.convertRegionTypes(&whileOp.getAfter(), + hackTypeConverter))) + return failure(); + + SmallVector initArgs = flattenValues(remappedInits); + SmallVector resultTypes = + llvm::map_to_vector(initArgs, [](Value v) { return v.getType(); }); + auto newWhileOp = + rewriter.create(whileOp.getLoc(), resultTypes, initArgs); + + newWhileOp->setAttrs(whileOp->getAttrs()); + rewriter.inlineRegionBefore(whileOp.getBefore(), newWhileOp.getBefore(), + newWhileOp.getBefore().end()); + rewriter.inlineRegionBefore(whileOp.getAfter(), newWhileOp.getAfter(), + newWhileOp.getAfter().end()); + + SmallVector packedRets; + for (unsigned i = 0, offset = 0; i < valRangeLens.size(); i++) { + size_t len = valRangeLens[i]; + assert(offset < newWhileOp->getNumResults() && + "expected offset to be within bounds of results"); + ValueRange mappedValue = newWhileOp->getResults().slice(offset, len); + packedRets.push_back(mappedValue); + offset += len; + } + + rewriter.modifyOpInPlace(whileOp, [&] { + whileOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); + }); + rewriter.modifyOpInPlace(newWhileOp, [&] { + newWhileOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); + }); + rewriter.replaceOpWithMultiple(whileOp, packedRets); + + return success(); + } +}; + +class ConvertSCFConditionOp : public PointerCanonPattern { +public: + using PointerCanonPattern::PointerCanonPattern; + LogicalResult + matchAndRewrite(scf::ConditionOp condOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector newArgs = flattenValues(adaptor.getArgs()); + rewriter.modifyOpInPlace(condOp, [&]() { + condOp.getArgsMutable().clear(); + condOp.getArgsMutable().append(newArgs); + }); + + rewriter.modifyOpInPlace(condOp, [&] { + condOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); + }); + + return success(); + } +}; + void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { ModuleOp module = getOperation(); mlir::MLIRContext *context = &getContext(); @@ -1531,14 +1599,17 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { if (auto ifOp = llvm::dyn_cast(op)) return !(ifOp->hasAttr("then_rewritten") and ifOp->hasAttr("else_rewritten")); + if (llvm::isa(op) && !op->hasAttr("legal")) + return false; return isLegal(op); }); FatPointers fatPrs; patterns.add( - patterns.getContext(), fatPrs); + ConvertSCFForOp, ConvertSCFYieldOp, ConvertSCFIfOp, + ConvertSCFConditionOp, ConvertSCFWhileOp>(patterns.getContext(), + fatPrs); ConversionConfig config; config.buildMaterializations = false; if (failed( From 3e1dd919e65b9c4450a705b0786db38d10587ade Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 16 Dec 2024 17:41:48 -0500 Subject: [PATCH 05/17] handle cf.cond_br and cf.br --- .../amd/amd-canonicalize-pointers.mlir | 683 +++++------------- .../CanonicalizePointers.cpp | 116 ++- 2 files changed, 306 insertions(+), 493 deletions(-) diff --git a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir index a7fd50cbf051..f2cbd4e93128 100644 --- a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir +++ b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir @@ -73,87 +73,8 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} #blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: tt.func @conversion3 - tt.func @conversion3(%arg0: !tt.ptr)-> tensor<1024xf32, #blocked>{ - %c1024_i32 = arith.constant 1024 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> - %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> - - //CHECK: %0 = tt.get_program_id x : i32 - //CHECK: %[[pid:.*]] = arith.muli %0, {{.*}} : i32 - //CHECK: %[[makerange:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - //CHECK: %[[uniformOffset1:.*]] = arith.addi %[[pid]], {{.*}} : i32 - //CHECK: %[[tensorOffset1:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> - //CHECK: %[[uniformOffset0:.*]] = arith.addi %[[pid:.*]], %{{.*}} : i32 - //CHECK: %[[tensorOffset3:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> - //CHECK: %[[zero:.*]] = tt.splat %{{.*}} : i64 -> tensor<1024xi64, #blocked> - //CHECK: %[[uniformPtr0:.*]] = tt.addptr %arg0, %[[uniformOffset0:.*]] : !tt.ptr, i32 - //CHECK: %[[tensorOffset3ext:.*]] = arith.extsi %[[tensorOffset3]] : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> - //CHECK: %[[tensorOffset0:.*]]= arith.addi %[[tensorOffset3ext]], %[[zero]] : tensor<1024xi64, #blocked> - //CHECK: %[[uniformPtr1:.*]] = tt.addptr %[[uniformPtr0]], %[[uniformOffset1]] : !tt.ptr, i32 - //CHECK: %[[tensorOffset1ext:.*]] = arith.extsi %[[tensorOffset1]] : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> - //CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset1ext]], %[[tensorOffset0]]: tensor<1024xi64, #blocked> - //CHECK: %[[scalarPtr:.*]] = tt.splat %[[uniformPtr1]] : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - //CHECK: %[[newPtr:.*]] = tt.addptr %[[scalarPtr]], %[[tensorOffset2]] : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi64, #blocked> - //CHECK: tt.load %[[newPtr]] - - %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %8 = tt.load %7 : tensor<1024x!tt.ptr, #blocked> - tt.return %8 : tensor<1024xf32, #blocked> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // - // This is the same as conversion3, but now the `arith.extsi` operations - // disappeared and all the offsets are 32 bits. - // - // CHECK-LABEL: tt.func @conversion4 - tt.func @conversion4(%arg0: !tt.ptr{tt.pointer_range = 32 : i32})-> tensor<1024xf32, #blocked>{ - %c1024_i32 = arith.constant 1024 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> - %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> - - //CHECK: %0 = tt.get_program_id x : i32 - //CHECK: %[[pid:.*]] = arith.muli %0, {{.*}} : i32 - //CHECK: %[[makerange:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - //CHECK: %[[uniformOffset1:.*]] = arith.addi %[[pid]], {{.*}} : i32 - //CHECK: %[[tensorOffset1:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> - //CHECK: %[[uniformOffset0:.*]] = arith.addi %[[pid:.*]], %{{.*}} : i32 - //CHECK: %[[tensorOffset3:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> - //CHECK: %[[zero:.*]] = tt.splat %{{.*}} : i32 -> tensor<1024xi32, #blocked> - //CHECK: %[[uniformPtr0:.*]] = tt.addptr %arg0, %[[uniformOffset0:.*]] : !tt.ptr, i32 - //CHECK: %[[tensorOffset0:.*]]= arith.addi %[[tensorOffset3]], %[[zero]] : tensor<1024xi32, #blocked> - //CHECK: %[[uniformPtr1:.*]] = tt.addptr %[[uniformPtr0]], %[[uniformOffset1]] : !tt.ptr, i32 - //CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset1]], %[[tensorOffset0]]: tensor<1024xi32, #blocked> - //CHECK: %[[scalarPtr:.*]] = tt.splat %[[uniformPtr1]] : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - //CHECK: %[[newPtr:.*]] = tt.addptr %[[scalarPtr]], %[[tensorOffset2]] : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - //CHECK: tt.load %[[newPtr]] - %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %8 = tt.load %7 : tensor<1024x!tt.ptr, #blocked> - tt.return %8 : tensor<1024xf32, #blocked> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: tt.func @forOp - tt.func @forOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ + // CHECK-LABEL: tt.func @condBranch + tt.func @condBranch(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 %c0 = arith.constant 0: index %c128 = arith.constant 128: index @@ -161,82 +82,34 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} %0 = tt.get_program_id x : i32 %1 = arith.muli %0, %c1024_i32 : i32 %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - // CHECK: %[[scalarOffsetLoop:.*]] = arith.addi {{.*}}, {{.*}} : i32 - // CHECK: %[[variableOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor - // CHECK: %[[scalarOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 - // CHECK: %[[scalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 - // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor - // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %arg0, %[[scalarOffset]] + // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 + // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> + // CHECK: %[[base_offset:.*]] = tt.splat %{{.*}} : i64 + // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] - // CHECK: %[[offset1:.*]] = arith.addi %[[ext_offset0]], %{{.*}} : tensor<1024xi64, #blocked> %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + // CHECK: %[[offset1:.*]] = arith.addi %[[ext_offset0]], %[[base_offset]] %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - // CHECK: %[[loop:.*]]:4 = scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[loopScalarPtr:.*]] = %{{.*}}, %[[loopOffset:.*]] = %[[offset1]]) -> {{.*}} { - %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %6, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ - // CHECK: %[[scalarPtrUpdateLoop:.*]] = tt.addptr %[[loopScalarPtr]], %[[scalarOffsetLoop]] - // CHECK: %[[ext_offset0i:.*]] = arith.extsi %[[variableOffset1]] - // CHECK: %[[offset_i:.*]] = arith.addi %[[ext_offset0i]], %[[loopOffset]] - // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdateLoop]] - // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[offset_i]] - // CHECK: tt.load %[[newPtr]] - // CHECK: scf.yield {{.*}}, {{.*}}, %[[scalarPtrUpdateLoop]], %[[offset_i]] - %11 = tt.addptr %arg1, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> - %10 = arith.addf %9, %arg2 : tensor<1024xf32, #blocked> - scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> - } - // CHECK: tt.addptr %[[loop]]#2, %[[scalarOffset1]] : !tt.ptr, i32 - %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> - tt.return %11 : tensor<1024xf32, #blocked> - } -} - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: tt.func @forOp2 - tt.func @forOp2(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ - %c1024_i32 = arith.constant 1024 : i32 - %c0 = arith.constant 0: index - %c128 = arith.constant 128: index - %c1 = arith.constant 1 : index - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - // CHECK: %[[scalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 - // CHECK: %[[variableOffset0:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> - // CHECK: %[[finalScalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 - // CHECK: %[[variableOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> - // CHECK: %[[base_offset:.*]] = tt.splat {{.*}} : i64 - %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> - %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - // CHECK: %[[forOut:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr:.*]] = %arg0, %[[loopOffset:.*]] = %[[base_offset]]) - %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %5, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ - // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %[[scalarPtr]], %[[scalarOffset]] - // CHECK: %[[ext_offset0i:.*]] = arith.extsi %[[variableOffset0]] - // CHECK: %[[ext_offset_i:.*]] = arith.addi %[[ext_offset0i]], %[[loopOffset]] - // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdate]] - // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[ext_offset_i]] - // CHECK: tt.load %[[newPtr]] - %11 = tt.addptr %arg1, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> - %10 = arith.addf %9, %arg2 : tensor<1024xf32, #blocked> - scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> - } - // CHECK: %[[scalarPtrFinalUpdate:.*]] = tt.addptr %[[forOut]]#2, %[[finalScalarOffset]] - // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset1]] - // CHECK: %[[tailOffset:.*]] = arith.addi %[[ext_offset0]], %[[forOut]]#3 - // CHECK: %[[tail_base_ptr:.*]] = tt.splat %[[scalarPtrFinalUpdate]] - // CHECK: %[[tailPtr:.*]] = tt.addptr %[[tail_base_ptr]], %[[tailOffset]] - // CHECK: tt.load %[[tailPtr]] - %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> - tt.return %11 : tensor<1024xf32, #blocked> + // CHECK: cf.cond_br {{.*}}, ^bb1(%{{.*}}, %arg0, %[[base_offset]] : {{.*}}), ^bb2(%{{.*}}, %[[scalarPtr]], %[[offset1]] : {{.*}}) + cf.cond_br %i1, ^bb1(%5 : tensor<1024x!tt.ptr, #blocked>), ^bb2(%6 : tensor<1024x!tt.ptr, #blocked>) + // CHECK: ^bb1({{.*}}, %[[block1ScalarPtr:.*]]: !tt.ptr, %[[block1Offset:.*]]: tensor<1024xi64, #blocked>) + ^bb1(%arg1 : tensor<1024x!tt.ptr, #blocked>): + // CHECK: %[[trunc_offset_1:.*]] = arith.trunci %[[block1Offset]] + // CHECK: %[[basePtr1:.*]] = tt.splat %[[block1ScalarPtr]] + // CHECK: %[[newPtr1:.*]] = tt.addptr %[[basePtr1]], %[[trunc_offset_1]] + // CHECK: tt.load %[[newPtr1]] + %out1 = tt.load %arg1 : tensor<1024x!tt.ptr, #blocked> + tt.return %out1 : tensor<1024xf32, #blocked> + // CHECK: ^bb2({{.*}}, %[[block2ScalarPtr:.*]]: !tt.ptr, %[[block2Offset:.*]]: tensor<1024xi64, #blocked>) + ^bb2(%arg2 : tensor<1024x!tt.ptr, #blocked>): // 2 preds: ^bb0, ^bb1 + // CHECK: %[[trunc_offset_2:.*]] = arith.trunci %[[block2Offset]] + // CHECK: %[[basePtr2:.*]] = tt.splat %[[block2ScalarPtr]] + // CHECK: %[[newPtr2:.*]] = tt.addptr %[[basePtr2]], %[[trunc_offset_2]] + // CHECK: tt.load %[[newPtr2]] + %out2 = tt.load %arg2 : tensor<1024x!tt.ptr, #blocked> + tt.return %out2 : tensor<1024xf32, #blocked> } } @@ -244,8 +117,8 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} #blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: tt.func @forNested - tt.func @forNested(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ + // CHECK-LABEL: tt.func @branch + tt.func @branch(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 %c0 = arith.constant 0: index %c128 = arith.constant 128: index @@ -255,301 +128,133 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> - // CHECK: %[[base_offset:.*]] = tt.splat {{.*}} : i64 + // CHECK: %[[base_offset:.*]] = tt.splat %{{.*}} : i64 + // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] + // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - - // CHECK: %[[forOut0:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr0:.*]] = %arg0, %[[loopOffset0:.*]] = %[[base_offset]]){{.*}}{ - // CHECK: %[[forOut1:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr1:.*]] = %[[scalarPtr0]], %[[loopOffset1:.*]] = %[[loopOffset0]]){{.*}}{ - // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %[[scalarPtr1]], %{{.*}} - // CHECK: %[[ext_loop_offset1:.*]] = arith.extsi %[[variableOffset]] - // CHECK: %[[offset_i:.*]] = arith.addi %[[ext_loop_offset1]], %[[loopOffset1]] - // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdate]] - // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[offset_i]] - // CHECK: tt.load %[[newPtr]] - // CHECK: scf.yield %{{.*}}, {{.*}}, %[[scalarPtrUpdate]], %[[offset_i]] - // CHECK: scf.yield %{{.*}}, {{.*}}, %[[forOut1]]#2, %[[forOut1]]#3 - - %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %5, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ - %53:2 = scf.for %arg10 = %c0 to %c128 step %c1 iter_args(%arg3 = %arg1, %arg4 = %arg2) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ - %11 = tt.addptr %arg3, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> - %10 = arith.addf %9, %arg4 : tensor<1024xf32, #blocked> - scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> - } - scf.yield %53#0, %53#1: tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> - } - %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> - tt.return %11 : tensor<1024xf32, #blocked> + // CHECK: %[[offset1:.*]] = arith.addi %[[ext_offset0]], %[[base_offset]] + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: cf.br ^bb1(%{{.*}}, %[[scalarPtr]], %[[offset1]] : {{.*}}) + // CHECK: ^bb1({{.*}}, %[[block1ScalarPtr:.*]]: {{.*}}, %[[block1Offset:.*]]: {{.*}}) + cf.br ^bb1(%6 : tensor<1024x!tt.ptr, #blocked>) + ^bb1(%arg1 : tensor<1024x!tt.ptr, #blocked>): + // CHECK: %[[trunc_offset_1:.*]] = arith.trunci %[[block1Offset]] + // CHECK: %[[basePtr1:.*]] = tt.splat %[[block1ScalarPtr]] + // CHECK: %[[newPtr1:.*]] = tt.addptr %[[basePtr1]], %[[trunc_offset_1]] + // CHECK: tt.load %[[newPtr1]] + %out1 = tt.load %arg1 : tensor<1024x!tt.ptr, #blocked> + tt.return %out1 : tensor<1024xf32, #blocked> } } // ----- -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: tt.func @ifOp - tt.func @ifOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{ - %c1024_i32 = arith.constant 1024 : i32 - %c0 = arith.constant 0: index - %c128 = arith.constant 128: index - %c1 = arith.constant 1 : index - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 - // CHECK: %[[variableOffset:.*]] = arith.addi - %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> - %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - // CHECK: %[[baseOffsetVariable:.*]] = tt.splat {{.*}} : i64 -> tensor<1024xi64, #blocked> - // CHECK: %[[ifOut:.*]]:3 = scf.if {{.*}} -> (tensor<1024x!tt.ptr, #blocked>, !tt.ptr, tensor<1024xi64, #blocked>) - %6 = scf.if %cond -> (tensor<1024x!tt.ptr, #blocked>){ - // CHECK: %[[scalarOffsetUpdate:.*]] = tt.addptr %arg0, %[[scalarOffset]] - // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] - // CHECK: %[[if_offset:.*]] = arith.addi %[[ext_offset0]], %[[baseOffsetVariable]] - // CHECK: scf.yield %{{.*}}, %[[scalarOffsetUpdate]], %[[if_offset]] - %true = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - scf.yield %true : tensor<1024x!tt.ptr, #blocked> - } else { - // CHECK: %[[new_scalar_ptr:.*]] = tt.addptr %arg0, {{.*}} - // CHECK: scf.yield %{{.*}}, %[[new_scalar_ptr]], %[[baseOffsetVariable]] - %false = tt.addptr %5, %3 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> - scf.yield %false : tensor<1024x!tt.ptr, #blocked> - } - // CHECK: %[[trunc_offset:.*]] = arith.trunci %[[ifOut]]#2 - // CHECK: %[[base_ptr:.*]] = tt.splat %[[ifOut]]#1 - // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[trunc_offset]] - // CHECK: tt.load %[[newPtr]] - %11 = tt.load %6 : tensor<1024x!tt.ptr, #blocked> - tt.return %11 : tensor<1024xf32, #blocked> +// The following is a simple case of a tile offset like: (A*B + C + D) where B,C are Uniform and A,D are not. So +// we expect that the Uniform offset (which can be added to the scalar pointer) will be simply C and the NonUniform +// offset will be A*B+D +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func @tile_offset + tt.func @tile_offset(%arg1: !tt.ptr, %arg5: i32 , %arg7: i32 ) { + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %1 = tt.get_program_id x : i32 + %20 = arith.muli %1, %c256_i32 : i32 + %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %24 = tt.splat %20 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %26 = arith.addi %24, %22 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %36 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %37 = tt.expand_dims %36 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> + %38 = tt.splat %arg7 : i32 -> tensor<16x1xi32, #blocked> + %39 = arith.muli %37, %38 : tensor<16x1xi32, #blocked> + %41 = tt.expand_dims %26 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + %42 = tt.broadcast %39 : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> + %43 = tt.broadcast %41 : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> + %44 = arith.addi %42, %43 : tensor<16x256xi32, #blocked> + %45 = tt.splat %arg1 : !tt.ptr -> tensor<16x256x!tt.ptr, #blocked> + %46 = tt.addptr %45, %44 : tensor<16x256x!tt.ptr, #blocked>, tensor<16x256xi32, #blocked> + // CHECK: %[[uniformOffset1:.*]] = arith.muli %c0_i32_0, %arg2 : i32 + // CHECK: {{.*}} = tt.expand_dims %{{.*}} {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + // CHECK: %[[tensorOffset6:.*]] = tt.expand_dims %{{.*}} {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> + // CHECK: {{.*}} = tt.broadcast %{{.*}} : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> + // CHECK: %[[tensorOffset3:.*]] = tt.broadcast %{{.*}} : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> + // CHECK: %[[tensorOffset4:.*]] = tt.broadcast %{{.*}} : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> + // CHECK: %[[tensorOffset5:.*]] = tt.broadcast %[[tensorOffset6]] : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> + // CHECK: %[[uniformOffset:.*]] = arith.addi %[[uniformOffset1]], %{{.*}}: i32 + // CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset3]], %[[tensorOffset5]] : tensor<16x256xi32, #blocked> + // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[uniformOffset]] : !tt.ptr, i32 + // CHECK: %[[tensorOffset2ext:.*]] = arith.extsi %[[tensorOffset2]] : tensor<16x256xi32, #blocked> to tensor<16x256xi64, #blocked> + // CHECK: %[[tensorOffset1:.*]] = arith.addi %[[tensorOffset2ext]], %{{.*}} : tensor<16x256xi64, #blocked> + // CHECK: %[[tensorOffset:.*]] = arith.trunci %[[tensorOffset1:.*]] : tensor<16x256xi64, #blocked> to tensor<16x256xi32, #blocked> + // CHECK: %[[ptr:.*]] = tt.splat %[[scalarPtr]] : !tt.ptr -> tensor<16x256x!tt.ptr, #blocked> + // CHECK: tt.addptr %[[ptr]], %[[tensorOffset]] : tensor<16x256x!tt.ptr, #block + %61 = tt.load %46 : tensor<16x256x!tt.ptr, #blocked> + tt.return } } // ----- -#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: tt.func @whileOp - tt.func @whileOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{ - %c1024_i32 = arith.constant 1024 : i32 - %c0 = arith.constant 0: index - %c128 = arith.constant 128: index - %c1 = arith.constant 1 : index +// The following is a more complex case where also a multiplication is involved. It's useful to walk through the case. +// We have that the offset to the pointer is the following: +// %12 = %10 + 11 +// This can be transformed in: +// = %7 + %9 +// = %5*%6 + %8 +// = %4*%arg1 + %8 +// = (%3+%2)*%arg1 + %8 +// = (%1 + %2) * %arg1 + %8 +// = (U + N)*U + N +// Where U means uniform (e.g., a splat) and N means NonUniform (e.g., a make_range) +// The scalar offset we want is (%1*%arg1), while the variable offset should be (%2*%arg1 + %8) +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func public @matmul_kernel + tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}) { + %c128_i32 = arith.constant 128 : i32 %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> - // CHECK: %[[base_offset:.*]] = tt.splat %{{.*}} : i64 - %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> - %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> - // CHECK: %[[whileOut:.*]]:3 = scf.while ({{.*}}, %[[loopPtr:.*]] = %arg0, %[[loopOffset:.*]] = %[[base_offset]]) - %6 = scf.while (%arg1 = %5) : (tensor<1024x!tt.ptr, #blocked>) -> (tensor<1024x!tt.ptr, #blocked>) { - // CHECK: scf.condition({{.*}}) %{{.*}}, %[[loopPtr]], %[[loopOffset]] - scf.condition(%cond) %arg1 : tensor<1024x!tt.ptr, #blocked> - } do { - // CHECK: ^bb{{.*}}(%{{.*}}, %[[blockPtr:.*]]: !tt.ptr, %[[blockOffset:.*]]: tensor<1024xi64, #blocked>): - ^bb0(%arg1: tensor<1024x!tt.ptr, #blocked>): - // CHECK: scf.yield {{.*}}, %[[blockPtr]], %[[blockOffset]] - scf.yield %arg1 : tensor<1024x!tt.ptr, #blocked> - } - // CHECK: %[[trunc_offset:.*]] = arith.trunci %[[whileOut]]#2 - // CHECK: %[[base_ptr:.*]] = tt.splat %[[whileOut]]#1 - // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[trunc_offset]] - // CHECK: tt.load %[[newPtr]] - %11 = tt.load %6 : tensor<1024x!tt.ptr, #blocked> - tt.return %11 : tensor<1024xf32, #blocked> + %1 = arith.muli %0, %c128_i32 : i32 + %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %3 = tt.splat %1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %4 = arith.addi %3, %2 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %6 = tt.splat %arg1 : i32 -> tensor<128x1xi32, #blocked> + %7 = arith.muli %5, %6 : tensor<128x1xi32, #blocked> + %8 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + %10 = tt.broadcast %7 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> + %11 = tt.broadcast %9 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> + %12 = arith.addi %10, %11 : tensor<128x16xi32, #blocked> + %13 = tt.splat %arg0 : !tt.ptr -> tensor<128x16x!tt.ptr, #blocked> + %14 = tt.addptr %13, %12 : tensor<128x16x!tt.ptr, #blocked>, tensor<128x16xi32, #blocked> + %15 = tt.load %14 : tensor<128x16x!tt.ptr, #blocked> + // CHECK: %[[pid:.*]] = tt.get_program_id x : i32 + // CHECK: %[[uniformOffset3:.*]] = arith.muli %[[pid]], %{{.*}} : i32 + // CHECK: %[[uniformOffset2:.*]] = arith.addi %[[uniformOffset3]], %{{.*}} : i32 + // CHECK: %[[uniformOffset1:.*]] = arith.muli %[[uniformOffset2]], %arg1 : i32 + // CHECK: %[[makerange:.*]] = tt.make_range + // CHECK: %{{.*}} = tt.expand_dims %[[makerange]] {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + // CHECK: %[[tensorOffset6:.*]] = tt.expand_dims %[[makerange]] {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + // CHECK: %{{.*}} = tt.broadcast %{{.*}} : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> + // CHECK: %[[tensorOffset3:.*]] = tt.broadcast %{{.*}} : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> + // CHECK: %{{.*}} = tt.broadcast %{{.*}} : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> + // CHECK: %[[tensorOffset4:.*]] = tt.broadcast %[[tensorOffset6]] : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> + // CHECK: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : tensor<128x16xi32, #blocked> + // CHECK: %[[uniformOffset:.*]] = arith.addi %[[uniformOffset1]], %{{.*}} : i32 + // CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset3]], %[[tensorOffset4]] : tensor<128x16xi32, #blocked> + // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[uniformOffset]] : !tt.ptr, i32 + // CHECK: %[[tensorOffset1Ext:.*]] = arith.extsi %[[tensorOffset2]] : tensor<128x16xi32, #blocked> to tensor<128x16xi64, #blocked> + // CHECK: %[[tensorOffset:.*]] = arith.addi %[[tensorOffset1Ext]], %{{.*}} : tensor<128x16xi64, #blocked> + // CHECK: %[[tensorOffsetTrunc:.*]] = arith.trunci %[[tensorOffset]] : tensor<128x16xi64, #blocked> to tensor<128x16xi32, #blocked> + // CHECK: %[[ptr:.*]] = tt.splat %[[scalarPtr]] : !tt.ptr -> tensor<128x16x!tt.ptr, #blocked> + // CHECK: tt.addptr %[[ptr]], %[[tensorOffsetTrunc]] : tensor<128x16x!tt.ptr, #blocked>, tensor<128x16xi32, #blocked> + tt.return } } -// ----- -//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { -// // CHECK-LABEL: tt.func @condBranch -// tt.func @condBranch(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ -// %c1024_i32 = arith.constant 1024 : i32 -// %c0 = arith.constant 0: index -// %c128 = arith.constant 128: index -// %c1 = arith.constant 1 : index -// %0 = tt.get_program_id x : i32 -// %1 = arith.muli %0, %c1024_i32 : i32 -// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> -// // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 -// // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> -// // CHECK: %[[base_offset:.*]] = tt.splat %{{.*}} : i64 -// // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] -// // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] -// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> -// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> -// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> -// // CHECK: %[[offset1:.*]] = arith.addi %[[ext_offset0]], %[[base_offset]] -// %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// // CHECK: cf.cond_br {{.*}}, ^bb1(%{{.*}}, %arg0, %[[base_offset]] : {{.*}}), ^bb2(%{{.*}}, %[[scalarPtr]], %[[offset1]] : {{.*}}) -// cf.cond_br %i1, ^bb1(%5 : tensor<1024x!tt.ptr, #blocked>), ^bb2(%6 : tensor<1024x!tt.ptr, #blocked>) -// // CHECK: ^bb1({{.*}}, %[[block1ScalarPtr:.*]]: !tt.ptr, %[[block1Offset:.*]]: tensor<1024xi64, #blocked>) -// ^bb1(%arg1 : tensor<1024x!tt.ptr, #blocked>): -// // CHECK: %[[trunc_offset_1:.*]] = arith.trunci %[[block1Offset]] -// // CHECK: %[[basePtr1:.*]] = tt.splat %[[block1ScalarPtr]] -// // CHECK: %[[newPtr1:.*]] = tt.addptr %[[basePtr1]], %[[trunc_offset_1]] -// // CHECK: tt.load %[[newPtr1]] -// %out1 = tt.load %arg1 : tensor<1024x!tt.ptr, #blocked> -// tt.return %out1 : tensor<1024xf32, #blocked> -// // CHECK: ^bb2({{.*}}, %[[block2ScalarPtr:.*]]: !tt.ptr, %[[block2Offset:.*]]: tensor<1024xi64, #blocked>) -// ^bb2(%arg2 : tensor<1024x!tt.ptr, #blocked>): // 2 preds: ^bb0, ^bb1 -// // CHECK: %[[trunc_offset_2:.*]] = arith.trunci %[[block2Offset]] -// // CHECK: %[[basePtr2:.*]] = tt.splat %[[block2ScalarPtr]] -// // CHECK: %[[newPtr2:.*]] = tt.addptr %[[basePtr2]], %[[trunc_offset_2]] -// // CHECK: tt.load %[[newPtr2]] -// %out2 = tt.load %arg2 : tensor<1024x!tt.ptr, #blocked> -// tt.return %out2 : tensor<1024xf32, #blocked> -// } -//} -// -//// ----- -// -//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { -// // CHECK-LABEL: tt.func @branch -// tt.func @branch(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ -// %c1024_i32 = arith.constant 1024 : i32 -// %c0 = arith.constant 0: index -// %c128 = arith.constant 128: index -// %c1 = arith.constant 1 : index -// %0 = tt.get_program_id x : i32 -// %1 = arith.muli %0, %c1024_i32 : i32 -// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> -// // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 -// // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> -// // CHECK: %[[base_offset:.*]] = tt.splat %{{.*}} : i64 -// // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] -// // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] -// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> -// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> -// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> -// // CHECK: %[[offset1:.*]] = arith.addi %[[ext_offset0]], %[[base_offset]] -// %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// // CHECK: cf.br ^bb1(%{{.*}}, %[[scalarPtr]], %[[offset1]] : {{.*}}) -// // CHECK: ^bb1({{.*}}, %[[block1ScalarPtr:.*]]: {{.*}}, %[[block1Offset:.*]]: {{.*}}) -// cf.br ^bb1(%6 : tensor<1024x!tt.ptr, #blocked>) -// ^bb1(%arg1 : tensor<1024x!tt.ptr, #blocked>): -// // CHECK: %[[trunc_offset_1:.*]] = arith.trunci %[[block1Offset]] -// // CHECK: %[[basePtr1:.*]] = tt.splat %[[block1ScalarPtr]] -// // CHECK: %[[newPtr1:.*]] = tt.addptr %[[basePtr1]], %[[trunc_offset_1]] -// // CHECK: tt.load %[[newPtr1]] -// %out1 = tt.load %arg1 : tensor<1024x!tt.ptr, #blocked> -// tt.return %out1 : tensor<1024xf32, #blocked> -// } -//} -// -//// ----- -// -//// The following is a simple case of a tile offset like: (A*B + C + D) where B,C are Uniform and A,D are not. So -//// we expect that the Uniform offset (which can be added to the scalar pointer) will be simply C and the NonUniform -//// offset will be A*B+D -//#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -//module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { -// // CHECK-LABEL: tt.func @tile_offset -// tt.func @tile_offset(%arg1: !tt.ptr, %arg5: i32 , %arg7: i32 ) { -// %c128_i32 = arith.constant 128 : i32 -// %c256_i32 = arith.constant 256 : i32 -// %1 = tt.get_program_id x : i32 -// %20 = arith.muli %1, %c256_i32 : i32 -// %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -// %24 = tt.splat %20 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -// %26 = arith.addi %24, %22 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -// %36 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -// %37 = tt.expand_dims %36 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> -// %38 = tt.splat %arg7 : i32 -> tensor<16x1xi32, #blocked> -// %39 = arith.muli %37, %38 : tensor<16x1xi32, #blocked> -// %41 = tt.expand_dims %26 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> -// %42 = tt.broadcast %39 : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> -// %43 = tt.broadcast %41 : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> -// %44 = arith.addi %42, %43 : tensor<16x256xi32, #blocked> -// %45 = tt.splat %arg1 : !tt.ptr -> tensor<16x256x!tt.ptr, #blocked> -// %46 = tt.addptr %45, %44 : tensor<16x256x!tt.ptr, #blocked>, tensor<16x256xi32, #blocked> -// // CHECK: %[[uniformOffset1:.*]] = arith.muli %c0_i32_0, %arg2 : i32 -// // CHECK: {{.*}} = tt.expand_dims %{{.*}} {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> -// // CHECK: %[[tensorOffset6:.*]] = tt.expand_dims %{{.*}} {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> -// // CHECK: {{.*}} = tt.broadcast %{{.*}} : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> -// // CHECK: %[[tensorOffset3:.*]] = tt.broadcast %{{.*}} : tensor<16x1xi32, #blocked> -> tensor<16x256xi32, #blocked> -// // CHECK: %[[tensorOffset4:.*]] = tt.broadcast %{{.*}} : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> -// // CHECK: %[[tensorOffset5:.*]] = tt.broadcast %[[tensorOffset6]] : tensor<1x256xi32, #blocked> -> tensor<16x256xi32, #blocked> -// // CHECK: %[[uniformOffset:.*]] = arith.addi %[[uniformOffset1]], %{{.*}}: i32 -// // CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset3]], %[[tensorOffset5]] : tensor<16x256xi32, #blocked> -// // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[uniformOffset]] : !tt.ptr, i32 -// // CHECK: %[[tensorOffset2ext:.*]] = arith.extsi %[[tensorOffset2]] : tensor<16x256xi32, #blocked> to tensor<16x256xi64, #blocked> -// // CHECK: %[[tensorOffset1:.*]] = arith.addi %[[tensorOffset2ext]], %{{.*}} : tensor<16x256xi64, #blocked> -// // CHECK: %[[tensorOffset:.*]] = arith.trunci %[[tensorOffset1:.*]] : tensor<16x256xi64, #blocked> to tensor<16x256xi32, #blocked> -// // CHECK: %[[ptr:.*]] = tt.splat %[[scalarPtr]] : !tt.ptr -> tensor<16x256x!tt.ptr, #blocked> -// // CHECK: tt.addptr %[[ptr]], %[[tensorOffset]] : tensor<16x256x!tt.ptr, #block -// %61 = tt.load %46 : tensor<16x256x!tt.ptr, #blocked> -// tt.return -// } -//} -// -//// ----- -// -//// The following is a more complex case where also a multiplication is involved. It's useful to walk through the case. -//// We have that the offset to the pointer is the following: -//// %12 = %10 + 11 -//// This can be transformed in: -//// = %7 + %9 -//// = %5*%6 + %8 -//// = %4*%arg1 + %8 -//// = (%3+%2)*%arg1 + %8 -//// = (%1 + %2) * %arg1 + %8 -//// = (U + N)*U + N -//// Where U means uniform (e.g., a splat) and N means NonUniform (e.g., a make_range) -//// The scalar offset we want is (%1*%arg1), while the variable offset should be (%2*%arg1 + %8) -//#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -//module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { -// // CHECK-LABEL: tt.func public @matmul_kernel -// tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}) { -// %c128_i32 = arith.constant 128 : i32 -// %0 = tt.get_program_id x : i32 -// %1 = arith.muli %0, %c128_i32 : i32 -// %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -// %3 = tt.splat %1 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -// %4 = arith.addi %3, %2 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -// %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> -// %6 = tt.splat %arg1 : i32 -> tensor<128x1xi32, #blocked> -// %7 = arith.muli %5, %6 : tensor<128x1xi32, #blocked> -// %8 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -// %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> -// %10 = tt.broadcast %7 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> -// %11 = tt.broadcast %9 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> -// %12 = arith.addi %10, %11 : tensor<128x16xi32, #blocked> -// %13 = tt.splat %arg0 : !tt.ptr -> tensor<128x16x!tt.ptr, #blocked> -// %14 = tt.addptr %13, %12 : tensor<128x16x!tt.ptr, #blocked>, tensor<128x16xi32, #blocked> -// %15 = tt.load %14 : tensor<128x16x!tt.ptr, #blocked> -// // CHECK: %[[pid:.*]] = tt.get_program_id x : i32 -// // CHECK: %[[uniformOffset3:.*]] = arith.muli %[[pid]], %{{.*}} : i32 -// // CHECK: %[[uniformOffset2:.*]] = arith.addi %[[uniformOffset3]], %{{.*}} : i32 -// // CHECK: %[[uniformOffset1:.*]] = arith.muli %[[uniformOffset2]], %arg1 : i32 -// // CHECK: %[[makerange:.*]] = tt.make_range -// // CHECK: %{{.*}} = tt.expand_dims %[[makerange]] {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> -// // CHECK: %[[tensorOffset6:.*]] = tt.expand_dims %[[makerange]] {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> -// // CHECK: %{{.*}} = tt.broadcast %{{.*}} : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> -// // CHECK: %[[tensorOffset3:.*]] = tt.broadcast %{{.*}} : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> -// // CHECK: %{{.*}} = tt.broadcast %{{.*}} : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> -// // CHECK: %[[tensorOffset4:.*]] = tt.broadcast %[[tensorOffset6]] : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> -// // CHECK: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : tensor<128x16xi32, #blocked> -// // CHECK: %[[uniformOffset:.*]] = arith.addi %[[uniformOffset1]], %{{.*}} : i32 -// // CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset3]], %[[tensorOffset4]] : tensor<128x16xi32, #blocked> -// // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[uniformOffset]] : !tt.ptr, i32 -// // CHECK: %[[tensorOffset1Ext:.*]] = arith.extsi %[[tensorOffset2]] : tensor<128x16xi32, #blocked> to tensor<128x16xi64, #blocked> -// // CHECK: %[[tensorOffset:.*]] = arith.addi %[[tensorOffset1Ext]], %{{.*}} : tensor<128x16xi64, #blocked> -// // CHECK: %[[tensorOffsetTrunc:.*]] = arith.trunci %[[tensorOffset]] : tensor<128x16xi64, #blocked> to tensor<128x16xi32, #blocked> -// // CHECK: %[[ptr:.*]] = tt.splat %[[scalarPtr]] : !tt.ptr -> tensor<128x16x!tt.ptr, #blocked> -// // CHECK: tt.addptr %[[ptr]], %[[tensorOffsetTrunc]] : tensor<128x16x!tt.ptr, #blocked>, tensor<128x16xi32, #blocked> -// tt.return -// } -//} -// -// //// ----- //#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> //module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { @@ -667,37 +372,37 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} // } //} // -//// ----- -// -//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -//module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { -// // CHECK-LABEL: @scalar_if -// tt.func @scalar_if(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)->f32{ -// %0 = tt.get_program_id x : i32 -// %c1_i32 = arith.constant 1 : i32 -// %c0_i64 = arith.constant 0 : i64 -// %c10_i64 = arith.constant 10 : i64 -// %c100_i32 = arith.constant 100 : i32 -// %5 = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 -// // CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %{{.*}} -// // CHECK: scf.if {{.*}} -> ({{.*}}, !tt.ptr, i64) -// %6 = scf.if %cond -> (!tt.ptr){ -// %true = tt.addptr %5, %c1_i32 : !tt.ptr, i32 -// // CHECK: %[[ptr1:.*]] = tt.addptr %[[ptr0]] -// // CHECK: scf.yield {{.*}}, %[[ptr1]] -// scf.yield %true : !tt.ptr -// } else { -// %false = tt.addptr %5, %c100_i32 : !tt.ptr, i32 -// // CHECK: %[[ptr2:.*]] = tt.addptr %[[ptr0]] -// // CHECK: scf.yield {{.*}}, %[[ptr2]] -// scf.yield %false : !tt.ptr -// } -// %11 = tt.load %6 : !tt.ptr -// tt.return %11 : f32 -// } -//} -// -//// ----- +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: @scalar_if + tt.func @scalar_if(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)->f32{ + %0 = tt.get_program_id x : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i64 = arith.constant 0 : i64 + %c10_i64 = arith.constant 10 : i64 + %c100_i32 = arith.constant 100 : i32 + %5 = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 + // CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %{{.*}} + // CHECK: scf.if {{.*}} -> ({{.*}}, !tt.ptr, i64) + %6 = scf.if %cond -> (!tt.ptr){ + %true = tt.addptr %5, %c1_i32 : !tt.ptr, i32 + // CHECK: %[[ptr1:.*]] = tt.addptr %[[ptr0]] + // CHECK: scf.yield {{.*}}, %[[ptr1]] + scf.yield %true : !tt.ptr + } else { + %false = tt.addptr %5, %c100_i32 : !tt.ptr, i32 + // CHECK: %[[ptr2:.*]] = tt.addptr %[[ptr0]] + // CHECK: scf.yield {{.*}}, %[[ptr2]] + scf.yield %false : !tt.ptr + } + %11 = tt.load %6 : !tt.ptr + tt.return %11 : f32 + } +} + +// ----- // //#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> //module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { @@ -727,30 +432,30 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} //} // //// ----- -// -//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { -// // CHECK-LABEL: tt.func @scalar_cond_branch -// tt.func @scalar_cond_branch(%arg0 : !tt.ptr, %i1 : i1) -> f32{ -// %c1024_i32 = arith.constant 1024 : i32 -// %c0 = arith.constant 0: index -// %c128 = arith.constant 128: index -// %c1 = arith.constant 1 : index -// %0 = tt.get_program_id x : i32 -// %1 = arith.muli %0, %c1024_i32 : i32 -// %6 = tt.addptr %arg0, %0 : !tt.ptr, i32 -// // CHECK: %[[ptr0:.*]] = tt.addptr %arg0 -// // CHECK: cf.cond_br %arg1, ^bb1(%{{.*}}, %[[ptr0]], {{.*}}), ^bb2(%{{.*}}, %arg0, {{.*}}) -// cf.cond_br %i1, ^bb1(%6 : !tt.ptr), ^bb2(%arg0 : !tt.ptr) -// // CHECK: ^bb1({{.*}}, %[[ptr1:.*]]: !tt.ptr, {{.*}}): -// ^bb1(%arg1 : !tt.ptr): -// // CHECK: tt.load %[[ptr1]] -// %out1 = tt.load %arg1 : !tt.ptr -// tt.return %out1 : f32 -// // CHECK: ^bb2({{.*}}, %[[ptr2:.*]]: !tt.ptr, {{.*}}): -// ^bb2(%arg2 : !tt.ptr): // 2 preds: ^bb0, ^bb1 -// // CHECK: tt.load %[[ptr2]] -// %out2 = tt.load %arg2 : !tt.ptr -// tt.return %out2 : f32 -// } -//} + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @scalar_cond_branch + tt.func @scalar_cond_branch(%arg0 : !tt.ptr, %i1 : i1) -> f32{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %6 = tt.addptr %arg0, %0 : !tt.ptr, i32 + // CHECK: %[[ptr0:.*]] = tt.addptr %arg0 + // CHECK: cf.cond_br %arg1, ^bb1(%{{.*}}, %[[ptr0]], {{.*}}), ^bb2(%{{.*}}, %arg0, {{.*}}) + cf.cond_br %i1, ^bb1(%6 : !tt.ptr), ^bb2(%arg0 : !tt.ptr) + // CHECK: ^bb1({{.*}}, %[[ptr1:.*]]: !tt.ptr, {{.*}}): + ^bb1(%arg1 : !tt.ptr): + // CHECK: tt.load %[[ptr1]] + %out1 = tt.load %arg1 : !tt.ptr + tt.return %out1 : f32 + // CHECK: ^bb2({{.*}}, %[[ptr2:.*]]: !tt.ptr, {{.*}}): + ^bb2(%arg2 : !tt.ptr): // 2 preds: ^bb0, ^bb1 + // CHECK: tt.load %[[ptr2]] + %out2 = tt.load %arg2 : !tt.ptr + tt.return %out2 : f32 + } +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index d325e438c902..d0cec873fa19 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -1514,6 +1514,7 @@ class ConvertSCFWhileOp : public PointerCanonPattern { if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(), hackTypeConverter))) return failure(); + inputNo = 0; if (failed(rewriter.convertRegionTypes(&whileOp.getAfter(), hackTypeConverter))) return failure(); @@ -1572,6 +1573,108 @@ class ConvertSCFConditionOp : public PointerCanonPattern { } }; +class ConvertCFCondBranch : public PointerCanonPattern { +public: + using PointerCanonPattern::PointerCanonPattern; + LogicalResult + matchAndRewrite(cf::CondBranchOp branchOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector trueOperands = + flattenValues(adaptor.getTrueDestOperands()); + SmallVector falseOperands = + flattenValues(adaptor.getFalseDestOperands()); + + rewriter.modifyOpInPlace(branchOp, [&] { + branchOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); + }); + auto newBrancOp = rewriter.replaceOpWithNewOp( + branchOp, branchOp.getCondition(), branchOp.getTrueDest(), trueOperands, + branchOp.getFalseDest(), falseOperands); + + unsigned inputNo = 0; + TypeConverter hackTypeConverterTrueDest; + hackTypeConverterTrueDest.addConversion( + [&inputNo, remappedOperands = adaptor.getTrueDestOperands()]( + Type inputType, SmallVectorImpl &types) { + SmallVector remappedInitTypes = + llvm::to_vector(remappedOperands[inputNo].getTypes()); + types.append(remappedInitTypes); + inputNo++; + return success(); + }); + + std::optional conversion = + hackTypeConverterTrueDest.convertBlockSignature(branchOp.getTrueDest()); + if (!conversion) + return failure(); + rewriter.applySignatureConversion(branchOp.getTrueDest(), *conversion, + &hackTypeConverterTrueDest); + + inputNo = 0; + TypeConverter hackTypeConverterFalseDest; + hackTypeConverterFalseDest.addConversion( + [&inputNo, remappedOperands = adaptor.getFalseDestOperands()]( + Type inputType, SmallVectorImpl &types) { + SmallVector remappedInitTypes = + llvm::to_vector(remappedOperands[inputNo].getTypes()); + types.append(remappedInitTypes); + inputNo++; + return success(); + }); + + conversion = hackTypeConverterFalseDest.convertBlockSignature( + branchOp.getFalseDest()); + if (!conversion) + return failure(); + rewriter.applySignatureConversion(branchOp.getFalseDest(), *conversion, + &hackTypeConverterFalseDest); + rewriter.modifyOpInPlace(newBrancOp, [&] { + newBrancOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); + }); + return success(); + } +}; + +class ConvertCFBranch : public PointerCanonPattern { +public: + using PointerCanonPattern::PointerCanonPattern; + LogicalResult + matchAndRewrite(cf::BranchOp branchOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector trueOperands = flattenValues(adaptor.getDestOperands()); + + rewriter.modifyOpInPlace(branchOp, [&] { + branchOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); + }); + auto newBrancOp = rewriter.replaceOpWithNewOp( + branchOp, branchOp.getDest(), trueOperands); + + unsigned inputNo = 0; + TypeConverter hackTypeConverterTrueDest; + hackTypeConverterTrueDest.addConversion( + [&inputNo, remappedOperands = adaptor.getDestOperands()]( + Type inputType, SmallVectorImpl &types) { + SmallVector remappedInitTypes = + llvm::to_vector(remappedOperands[inputNo].getTypes()); + types.append(remappedInitTypes); + inputNo++; + return success(); + }); + + std::optional conversion = + hackTypeConverterTrueDest.convertBlockSignature(branchOp.getDest()); + if (!conversion) + return failure(); + rewriter.applySignatureConversion(branchOp.getDest(), *conversion, + &hackTypeConverterTrueDest); + + rewriter.modifyOpInPlace(newBrancOp, [&] { + newBrancOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); + }); + return success(); + } +}; + void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { ModuleOp module = getOperation(); mlir::MLIRContext *context = &getContext(); @@ -1582,8 +1685,11 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { if (op->hasAttr("rewritten") || op->hasAttr("legal")) return true; for (OpOperand &operand : op->getOpOperands()) { - if (auto arg = llvm::dyn_cast(operand.get())) - return !llvm::isa(getElementTypeOrSelf(arg)); + if (auto arg = llvm::dyn_cast(operand.get())) { + if (!llvm::isa(getElementTypeOrSelf(arg))) + continue; + return false; + } if (operand.get().getDefiningOp()->hasAttr("rewritten")) return false; } @@ -1603,13 +1709,15 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { return false; return isLegal(op); }); + target.addDynamicallyLegalDialect( + [&isLegal](Operation *op) { return isLegal(op); }); FatPointers fatPrs; patterns.add(patterns.getContext(), - fatPrs); + ConvertSCFConditionOp, ConvertSCFWhileOp, ConvertCFCondBranch, + ConvertCFBranch>(patterns.getContext(), fatPrs); ConversionConfig config; config.buildMaterializations = false; if (failed( From 5091eefa9f74e1daf5aa76e5bd96406d7fc3536a Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 16 Dec 2024 19:29:58 -0500 Subject: [PATCH 06/17] handle arith.select --- .../amd/amd-canonicalize-pointers.mlir | 119 +++++++++--------- .../CanonicalizePointers.cpp | 66 +++++++++- 2 files changed, 124 insertions(+), 61 deletions(-) diff --git a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir index f2cbd4e93128..93f37fd3c6cc 100644 --- a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir +++ b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir @@ -254,66 +254,67 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr } } +// ----- -//// ----- -//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { -// // CHECK-LABEL: tt.func @select -// tt.func @select(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ -// %c1024_i32 = arith.constant 1024 : i32 -// %c0 = arith.constant 0: index -// %c128 = arith.constant 128: index -// %c1 = arith.constant 1 : index -// %0 = tt.get_program_id x : i32 -// %1 = arith.muli %0, %c1024_i32 : i32 -// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> -// // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 -// // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> -// // CHECK: %[[baseOffset:.*]] = tt.splat %{{.*}} : i64 -// // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] -// // CHECK: %[[extVariableOffset:.*]] = arith.extsi %[[variableOffset]] -// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> -// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> -// %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> -// %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// // CHECK: %[[offset2:.*]] = arith.addi %[[extVariableOffset]], %[[baseOffset]] -// // CHECK: %[[scalarPtr1:.*]] = arith.select %arg1, %arg0, %[[scalarPtr]] -// // CHECK: %[[offset0:.*]] = arith.select %arg1, {{.*}}, %[[offset2]] -// // CHECK: %[[offset1:.*]] = arith.trunci %[[offset0]] -// // CHECK: %[[ptr:.*]] = tt.splat %[[scalarPtr1]] -// // CHECK: tt.addptr %[[ptr]], %[[offset1]] -// %7 = arith.select %i1, %5 , %6 : tensor<1024x!tt.ptr, #blocked> -// %out = tt.load %7: tensor<1024x!tt.ptr, #blocked> -// tt.return %out : tensor<1024xf32, #blocked> -// } -//} -// -//// ----- -//#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -//module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { -// // CHECK-LABEL: tt.func @where_kernel -// tt.func @where_kernel(%arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}){ -// %c0_i8 = arith.constant 0 : i8 -// %c1024_i32 = arith.constant 1024 : i32 -// %0 = tt.get_program_id x : i32 -// %1 = arith.muli %0, %c1024_i32 : i32 -// %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> -// %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> -// %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> -// %9 = arith.cmpi ne, %c0_i8, %c0_i8 : i8 -// %10 = arith.select %9, %arg1, %arg2 : !tt.ptr -// // CHECK: %[[selectPtr:.*]] = arith.select {{.*}} : !tt.ptr -// %11 = tt.splat %10: !tt.ptr -> tensor<1024x!tt.ptr, #blocked> -// %13 = tt.addptr %11, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> -// // CHECK: %[[selectPtr0:.*]] = tt.addptr %[[selectPtr]] -// // CHECK: %[[tensorPtr:.*]] = tt.splat %[[selectPtr0]] -// // CHECK: tt.addptr %[[tensorPtr]] -// %14 = tt.load %13 : tensor<1024x!tt.ptr, #blocked> -// tt.return -// } -//} -// -//// ----- +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @select + tt.func @select(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 + // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> + // CHECK: %[[baseOffset:.*]] = tt.splat %{{.*}} : i64 + // CHECK: %[[scalarPtr:.*]] = tt.addptr %arg0, %[[scalarOffset]] + // CHECK: %[[extVariableOffset:.*]] = arith.extsi %[[variableOffset]] + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: %[[offset2:.*]] = arith.addi %[[extVariableOffset]], %[[baseOffset]] + // CHECK: %[[scalarPtr1:.*]] = arith.select %arg1, %arg0, %[[scalarPtr]] + // CHECK: %[[offset0:.*]] = arith.select %arg1, {{.*}}, %[[offset2]] + // CHECK: %[[offset1:.*]] = arith.trunci %[[offset0]] + // CHECK: %[[ptr:.*]] = tt.splat %[[scalarPtr1]] + // CHECK: tt.addptr %[[ptr]], %[[offset1]] + %7 = arith.select %i1, %5 , %6 : tensor<1024x!tt.ptr, #blocked> + %out = tt.load %7: tensor<1024x!tt.ptr, #blocked> + tt.return %out : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tt.func @where_kernel + tt.func @where_kernel(%arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}){ + %c0_i8 = arith.constant 0 : i8 + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %9 = arith.cmpi ne, %c0_i8, %c0_i8 : i8 + %10 = arith.select %9, %arg1, %arg2 : !tt.ptr + // CHECK: %[[selectPtr:.*]] = arith.select {{.*}} : !tt.ptr + %11 = tt.splat %10: !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %13 = tt.addptr %11, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: %[[selectPtr0:.*]] = tt.addptr %[[selectPtr]] + // CHECK: %[[tensorPtr:.*]] = tt.splat %[[selectPtr0]] + // CHECK: tt.addptr %[[tensorPtr]] + %14 = tt.load %13 : tensor<1024x!tt.ptr, #blocked> + tt.return + } +} + +// ----- #blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index d0cec873fa19..4b6947de5f1d 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -1675,12 +1675,67 @@ class ConvertCFBranch : public PointerCanonPattern { } }; +class ConvertArithSelectOp : public PointerCanonPattern { +public: + using PointerCanonPattern::PointerCanonPattern; + LogicalResult + matchAndRewrite(arith::SelectOp selectOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ArrayRef remappedOperands = adaptor.getOperands(); + assert(remappedOperands.size() == 3 && remappedOperands[1].size() == 2 && + remappedOperands[2].size() == 2 && + "expected adaptor to have 3 remapped values, 1 for cond, 2 for each " + "arm"); + // If both have been traversed, then we can rewrite select of pointers as a + // select of base and offset + ValueRange fatPtrT = remappedOperands[1]; + ValueRange fatPtrF = remappedOperands[2]; + // Simple case of a scalar select: update the base pointer + if (!isa(selectOp.getType())) { + auto newSelectOp = rewriter.create( + selectOp.getLoc(), selectOp.getType(), + // TODO(max): why fatPtrTrue here? + selectOp.getCondition(), fatPtrT[0], selectOp.getFalseValue()); + rewriter.modifyOpInPlace(selectOp, [&] { + selectOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); + }); + rewriter.modifyOpInPlace(newSelectOp, [&] { + newSelectOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); + }); + rewriter.replaceOpWithMultiple(selectOp, {{newSelectOp, fatPtrT[1]}}); + return success(); + } + + // Rewrite `select` for base and offset + auto newBase = rewriter.create( + selectOp.getLoc(), selectOp.getCondition(), fatPtrT[0], fatPtrF[0]); + auto newOffset = rewriter.create( + selectOp.getLoc(), selectOp.getCondition(), fatPtrT[1], fatPtrF[1]); + + assert((fatPtrs[{fatPtrT[0], fatPtrT[1]}].canNarrow == + fatPtrs[{fatPtrF[0], fatPtrF[1]}].canNarrow)); + + rewriter.modifyOpInPlace(selectOp, [&] { + selectOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); + }); + rewriter.modifyOpInPlace(newBase, [&] { + newBase->setDiscardableAttr("legal", rewriter.getUnitAttr()); + }); + rewriter.modifyOpInPlace(newOffset, [&] { + newOffset->setDiscardableAttr("legal", rewriter.getUnitAttr()); + }); + + rewriter.replaceOpWithMultiple(selectOp, {{newBase, newOffset}}); + + return success(); + } +}; + void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { ModuleOp module = getOperation(); mlir::MLIRContext *context = &getContext(); ConversionTarget target(*context); RewritePatternSet patterns(context); - target.addLegalDialect(); auto isLegal = [](Operation *op) { if (op->hasAttr("rewritten") || op->hasAttr("legal")) return true; @@ -1711,13 +1766,20 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { }); target.addDynamicallyLegalDialect( [&isLegal](Operation *op) { return isLegal(op); }); + target.addDynamicallyLegalDialect( + [&isLegal](Operation *op) { + if (llvm::isa(op)) + return isLegal(op); + return true; + }); FatPointers fatPrs; patterns.add(patterns.getContext(), fatPrs); + ConvertCFBranch, ConvertArithSelectOp>(patterns.getContext(), + fatPrs); ConversionConfig config; config.buildMaterializations = false; if (failed( From be4fe7d4f2dd49a3422564910c2825e5dd395081 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 16 Dec 2024 19:49:55 -0500 Subject: [PATCH 07/17] handle tt.store --- .../amd/amd-canonicalize-pointers.mlir | 56 ++++++------- .../CanonicalizePointers.cpp | 79 +++++++++++++++++-- 2 files changed, 102 insertions(+), 33 deletions(-) diff --git a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir index 93f37fd3c6cc..a90aa72c72e1 100644 --- a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir +++ b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir @@ -346,33 +346,33 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} } // ----- -// -//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -//module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { -// // CHECK-LABEL: scalar_pointers -// tt.func public @scalar_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { -// %0 = tt.get_program_id x : i32 -// %c1_i32 = arith.constant 1 : i32 -// %c0_i64 = arith.constant 0 : i64 -// %c10_i64 = arith.constant 10 : i64 -// %c100_i32 = arith.constant 100 : i32 -// %5 = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 -// // CHECK: arith.constant 0 : i64 -// // CHECK: arith.constant 0 : i64 -// // CHECK: %[[offset0:.*]] = arith.constant 0 : i64 -// // CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 -// // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[ptr1:.*]] = %[[ptr0]], %[[offset1:.*]] = %[[offset0]]) -// %10:1 = scf.for %arg3 = %c1_i32 to %c100_i32 step %c1_i32 iter_args(%arg4 = %5) -> (!tt.ptr) : i32 { -// // CHECK: tt.store %[[ptr1]] -// tt.store %arg4, %c0_i64 : !tt.ptr -// // CHECK: tt.addptr %[[ptr1]] -// %11 = tt.addptr %arg4, %c1_i32 : !tt.ptr, i32 -// scf.yield %11 : !tt.ptr -// } -// tt.return -// } -//} -// + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: scalar_pointers + tt.func public @scalar_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = tt.get_program_id x : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i64 = arith.constant 0 : i64 + %c10_i64 = arith.constant 10 : i64 + %c100_i32 = arith.constant 100 : i32 + %5 = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 + // CHECK: arith.constant 0 : i64 + // CHECK: arith.constant 0 : i64 + // CHECK: %[[offset0:.*]] = arith.constant 0 : i64 + // CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %c1_i32 : !tt.ptr, i32 + // CHECK: scf.for {{.*}} iter_args({{.*}}, %[[ptr1:.*]] = %[[ptr0]], %[[offset1:.*]] = %[[offset0]]) + %10:1 = scf.for %arg3 = %c1_i32 to %c100_i32 step %c1_i32 iter_args(%arg4 = %5) -> (!tt.ptr) : i32 { + // CHECK: tt.store %[[ptr1]] + tt.store %arg4, %c0_i64 : !tt.ptr + // CHECK: tt.addptr %[[ptr1]] + %11 = tt.addptr %arg4, %c1_i32 : !tt.ptr, i32 + scf.yield %11 : !tt.ptr + } + tt.return + } +} + // ----- #blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> @@ -432,7 +432,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // } //} // -//// ----- +// ----- #blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index 4b6947de5f1d..0307a60d527c 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -1262,6 +1262,46 @@ class ConvertSplatOp : public PointerCanonPattern { } }; +class ConvertBroadcastOp : public PointerCanonPattern { +public: + using PointerCanonPattern::PointerCanonPattern; + + LogicalResult + matchAndRewrite(triton::BroadcastOp broadcastOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ArrayRef remappedOperands = adaptor.getOperands(); + // see + // https://github.com/llvm/llvm-project/blob/58389b220a9354ed6c34bdb9310a35165579c5e3/mlir/lib/Transforms/Utils/DialectConversion.cpp#L1177 + assert(remappedOperands.size() == 1 && remappedOperands[0].size() == 2 && + "expected adaptor to have 2 remapped values"); + Value fatPtrBase = remappedOperands[0][0]; + Value fatPtrOffset = remappedOperands[0][1]; + assert(llvm::isa(fatPtrBase.getType()) && + "expected fatPtrBase to be a tt.ptr"); + assert(llvm::isa(fatPtrOffset.getType()) && + "expected fatPtrOffset to be an integer type"); + + auto outType = + dyn_cast(broadcastOp.getResult().getType()); + auto ptrShape = outType.getShape(); + auto offsetType = dyn_cast(fatPtrOffset.getType()); + if (!offsetType) + return failure(); + + auto newOffsetType = RankedTensorType::get( + ptrShape, offsetType.getElementType(), outType.getEncoding()); + Value offset = rewriter.create( + broadcastOp.getLoc(), newOffsetType, fatPtrOffset); + rewriter.modifyOpInPlace(broadcastOp, [&] { + broadcastOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); + }); + rewriter.replaceOpWithMultiple(broadcastOp, + {{broadcastOp.getSrc(), offset}}); + + return success(); + } +}; + class ConvertLoadOp : public PointerCanonPattern { public: using PointerCanonPattern::PointerCanonPattern; @@ -1290,6 +1330,34 @@ class ConvertLoadOp : public PointerCanonPattern { } }; +class ConvertStoreOp : public PointerCanonPattern { +public: + using PointerCanonPattern::PointerCanonPattern; + + LogicalResult + matchAndRewrite(triton::StoreOp storeOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange fatPtr = *adaptor.getOperands().begin(); + Value fatPtrBase = fatPtr.front(); + Value fatPtrOffset = fatPtr.back(); + Location curLoc = storeOp.getLoc(); + + llvm::SmallDenseMap attributes{ + {rewriter.getStringAttr("legal"), rewriter.getUnitAttr()}}; + Value newPtr = createTensorPointer( + rewriter, fatPtrBase, fatPtrOffset, curLoc, + fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow, attributes); + SmallVector operands = + storeOp.getOperands().take_back(storeOp.getNumOperands() - 1); + operands.insert(operands.begin(), newPtr); + SmallVector attrs = llvm::to_vector(storeOp->getAttrs()); + attrs.append({rewriter.getNamedAttr("legal", rewriter.getUnitAttr())}); + auto newStoreOp = rewriter.replaceOpWithNewOp( + storeOp, TypeRange{}, ValueRange{operands}, attrs); + return success(); + } +}; + class ConvertFuncOp : public PointerCanonPattern { public: using PointerCanonPattern::PointerCanonPattern; @@ -1775,11 +1843,12 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { FatPointers fatPrs; - patterns.add(patterns.getContext(), - fatPrs); + patterns + .add( + patterns.getContext(), fatPrs); ConversionConfig config; config.buildMaterializations = false; if (failed( From d45cf77477a19659d89f709f281f23c91c98670f Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 16 Dec 2024 19:56:15 -0500 Subject: [PATCH 08/17] uncomment all test cases --- .../amd/amd-canonicalize-pointers.mlir | 351 ++++++++++++++++-- third_party/amd/backend/compiler.py | 2 +- .../CanonicalizePointers.cpp | 17 +- 3 files changed, 338 insertions(+), 32 deletions(-) diff --git a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir index a90aa72c72e1..8185758a43dc 100644 --- a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir +++ b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir @@ -71,6 +71,301 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} // ----- +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @conversion3 + tt.func @conversion3(%arg0: !tt.ptr)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + + //CHECK: %0 = tt.get_program_id x : i32 + //CHECK: %[[pid:.*]] = arith.muli %0, {{.*}} : i32 + //CHECK: %[[makerange:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + //CHECK: %[[uniformOffset1:.*]] = arith.addi %[[pid]], {{.*}} : i32 + //CHECK: %[[tensorOffset1:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> + //CHECK: %[[uniformOffset0:.*]] = arith.addi %[[pid:.*]], %{{.*}} : i32 + //CHECK: %[[tensorOffset3:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> + //CHECK: %[[zero:.*]] = tt.splat %{{.*}} : i64 -> tensor<1024xi64, #blocked> + //CHECK: %[[uniformPtr0:.*]] = tt.addptr %arg0, %[[uniformOffset0:.*]] : !tt.ptr, i32 + //CHECK: %[[tensorOffset3ext:.*]] = arith.extsi %[[tensorOffset3]] : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> + //CHECK: %[[tensorOffset0:.*]]= arith.addi %[[tensorOffset3ext]], %[[zero]] : tensor<1024xi64, #blocked> + //CHECK: %[[uniformPtr1:.*]] = tt.addptr %[[uniformPtr0]], %[[uniformOffset1]] : !tt.ptr, i32 + //CHECK: %[[tensorOffset1ext:.*]] = arith.extsi %[[tensorOffset1]] : tensor<1024xi32, #blocked> to tensor<1024xi64, #blocked> + //CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset1ext]], %[[tensorOffset0]]: tensor<1024xi64, #blocked> + //CHECK: %[[scalarPtr:.*]] = tt.splat %[[uniformPtr1]] : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + //CHECK: %[[newPtr:.*]] = tt.addptr %[[scalarPtr]], %[[tensorOffset2]] : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi64, #blocked> + //CHECK: tt.load %[[newPtr]] + + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %8 = tt.load %7 : tensor<1024x!tt.ptr, #blocked> + tt.return %8 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // + // This is the same as conversion3, but now the `arith.extsi` operations + // disappeared and all the offsets are 32 bits. + // + // CHECK-LABEL: tt.func @conversion4 + tt.func @conversion4(%arg0: !tt.ptr{tt.pointer_range = 32 : i32})-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + + //CHECK: %0 = tt.get_program_id x : i32 + //CHECK: %[[pid:.*]] = arith.muli %0, {{.*}} : i32 + //CHECK: %[[makerange:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + //CHECK: %[[uniformOffset1:.*]] = arith.addi %[[pid]], {{.*}} : i32 + //CHECK: %[[tensorOffset1:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> + //CHECK: %[[uniformOffset0:.*]] = arith.addi %[[pid:.*]], %{{.*}} : i32 + //CHECK: %[[tensorOffset3:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked> + //CHECK: %[[zero:.*]] = tt.splat %{{.*}} : i32 -> tensor<1024xi32, #blocked> + //CHECK: %[[uniformPtr0:.*]] = tt.addptr %arg0, %[[uniformOffset0:.*]] : !tt.ptr, i32 + //CHECK: %[[tensorOffset0:.*]]= arith.addi %[[tensorOffset3]], %[[zero]] : tensor<1024xi32, #blocked> + //CHECK: %[[uniformPtr1:.*]] = tt.addptr %[[uniformPtr0]], %[[uniformOffset1]] : !tt.ptr, i32 + //CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset1]], %[[tensorOffset0]]: tensor<1024xi32, #blocked> + //CHECK: %[[scalarPtr:.*]] = tt.splat %[[uniformPtr1]] : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + //CHECK: %[[newPtr:.*]] = tt.addptr %[[scalarPtr]], %[[tensorOffset2]] : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + //CHECK: tt.load %[[newPtr]] + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %8 = tt.load %7 : tensor<1024x!tt.ptr, #blocked> + tt.return %8 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @forOp + tt.func @forOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[scalarOffsetLoop:.*]] = arith.addi {{.*}}, {{.*}} : i32 + // CHECK: %[[variableOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor + // CHECK: %[[scalarOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 + // CHECK: %[[scalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 + // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor + // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %arg0, %[[scalarOffset]] + // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] + // CHECK: %[[offset1:.*]] = arith.addi %[[ext_offset0]], %{{.*}} : tensor<1024xi64, #blocked> + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: %[[loop:.*]]:4 = scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[loopScalarPtr:.*]] = %{{.*}}, %[[loopOffset:.*]] = %[[offset1]]) -> {{.*}} { + %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %6, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ + // CHECK: %[[scalarPtrUpdateLoop:.*]] = tt.addptr %[[loopScalarPtr]], %[[scalarOffsetLoop]] + // CHECK: %[[ext_offset0i:.*]] = arith.extsi %[[variableOffset1]] + // CHECK: %[[offset_i:.*]] = arith.addi %[[ext_offset0i]], %[[loopOffset]] + // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdateLoop]] + // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[offset_i]] + // CHECK: tt.load %[[newPtr]] + // CHECK: scf.yield {{.*}}, {{.*}}, %[[scalarPtrUpdateLoop]], %[[offset_i]] + %11 = tt.addptr %arg1, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> + %10 = arith.addf %9, %arg2 : tensor<1024xf32, #blocked> + scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> + } + // CHECK: tt.addptr %[[loop]]#2, %[[scalarOffset1]] : !tt.ptr, i32 + %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> + tt.return %11 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @forOp2 + tt.func @forOp2(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[scalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 + // CHECK: %[[variableOffset0:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> + // CHECK: %[[finalScalarOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 + // CHECK: %[[variableOffset1:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> + // CHECK: %[[base_offset:.*]] = tt.splat {{.*}} : i64 + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + // CHECK: %[[forOut:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr:.*]] = %arg0, %[[loopOffset:.*]] = %[[base_offset]]) + %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %5, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ + // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %[[scalarPtr]], %[[scalarOffset]] + // CHECK: %[[ext_offset0i:.*]] = arith.extsi %[[variableOffset0]] + // CHECK: %[[ext_offset_i:.*]] = arith.addi %[[ext_offset0i]], %[[loopOffset]] + // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdate]] + // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[ext_offset_i]] + // CHECK: tt.load %[[newPtr]] + %11 = tt.addptr %arg1, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> + %10 = arith.addf %9, %arg2 : tensor<1024xf32, #blocked> + scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> + } + // CHECK: %[[scalarPtrFinalUpdate:.*]] = tt.addptr %[[forOut]]#2, %[[finalScalarOffset]] + // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset1]] + // CHECK: %[[tailOffset:.*]] = arith.addi %[[ext_offset0]], %[[forOut]]#3 + // CHECK: %[[tail_base_ptr:.*]] = tt.splat %[[scalarPtrFinalUpdate]] + // CHECK: %[[tailPtr:.*]] = tt.addptr %[[tail_base_ptr]], %[[tailOffset]] + // CHECK: tt.load %[[tailPtr]] + %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> + tt.return %11 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @forNested + tt.func @forNested(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 + // CHECK: %[[variableOffset:.*]] = arith.addi %{{.*}}, %{{.*}} : tensor<1024xi32, #blocked> + // CHECK: %[[base_offset:.*]] = tt.splat {{.*}} : i64 + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + + // CHECK: %[[forOut0:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr0:.*]] = %arg0, %[[loopOffset0:.*]] = %[[base_offset]]){{.*}}{ + // CHECK: %[[forOut1:.*]]:4 = scf.for {{.*}} iter_args(%{{.*}}, {{.*}}, %[[scalarPtr1:.*]] = %[[scalarPtr0]], %[[loopOffset1:.*]] = %[[loopOffset0]]){{.*}}{ + // CHECK: %[[scalarPtrUpdate:.*]] = tt.addptr %[[scalarPtr1]], %{{.*}} + // CHECK: %[[ext_loop_offset1:.*]] = arith.extsi %[[variableOffset]] + // CHECK: %[[offset_i:.*]] = arith.addi %[[ext_loop_offset1]], %[[loopOffset1]] + // CHECK: %[[base_ptr:.*]] = tt.splat %[[scalarPtrUpdate]] + // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[offset_i]] + // CHECK: tt.load %[[newPtr]] + // CHECK: scf.yield %{{.*}}, {{.*}}, %[[scalarPtrUpdate]], %[[offset_i]] + // CHECK: scf.yield %{{.*}}, {{.*}}, %[[forOut1]]#2, %[[forOut1]]#3 + + %52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %5, %arg2 = %init) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ + %53:2 = scf.for %arg10 = %c0 to %c128 step %c1 iter_args(%arg3 = %arg1, %arg4 = %arg2) -> (tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked>){ + %11 = tt.addptr %arg3, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %11 : tensor<1024x!tt.ptr, #blocked> + %10 = arith.addf %9, %arg4 : tensor<1024xf32, #blocked> + scf.yield %11, %10 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> + } + scf.yield %53#0, %53#1: tensor<1024x!tt.ptr, #blocked>, tensor<1024xf32, #blocked> + } + %8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %11 = tt.load %8 : tensor<1024x!tt.ptr, #blocked> + tt.return %11 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @ifOp + tt.func @ifOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[scalarOffset:.*]] = arith.addi {{.*}}, {{.*}} : i32 + // CHECK: %[[variableOffset:.*]] = arith.addi + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + // CHECK: %[[baseOffsetVariable:.*]] = tt.splat {{.*}} : i64 -> tensor<1024xi64, #blocked> + // CHECK: %[[ifOut:.*]]:3 = scf.if {{.*}} -> (tensor<1024x!tt.ptr, #blocked>, !tt.ptr, tensor<1024xi64, #blocked>) + %6 = scf.if %cond -> (tensor<1024x!tt.ptr, #blocked>){ + // CHECK: %[[scalarOffsetUpdate:.*]] = tt.addptr %arg0, %[[scalarOffset]] + // CHECK: %[[ext_offset0:.*]] = arith.extsi %[[variableOffset]] + // CHECK: %[[if_offset:.*]] = arith.addi %[[ext_offset0]], %[[baseOffsetVariable]] + // CHECK: scf.yield %{{.*}}, %[[scalarOffsetUpdate]], %[[if_offset]] + %true = tt.addptr %5, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + scf.yield %true : tensor<1024x!tt.ptr, #blocked> + } else { + // CHECK: %[[new_scalar_ptr:.*]] = tt.addptr %arg0, {{.*}} + // CHECK: scf.yield %{{.*}}, %[[new_scalar_ptr]], %[[baseOffsetVariable]] + %false = tt.addptr %5, %3 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + scf.yield %false : tensor<1024x!tt.ptr, #blocked> + } + // CHECK: %[[trunc_offset:.*]] = arith.trunci %[[ifOut]]#2 + // CHECK: %[[base_ptr:.*]] = tt.splat %[[ifOut]]#1 + // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[trunc_offset]] + // CHECK: tt.load %[[newPtr]] + %11 = tt.load %6 : tensor<1024x!tt.ptr, #blocked> + tt.return %11 : tensor<1024xf32, #blocked> + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @whileOp + tt.func @whileOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{ + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + // CHECK: %[[base_offset:.*]] = tt.splat %{{.*}} : i64 + %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> + // CHECK: %[[whileOut:.*]]:3 = scf.while ({{.*}}, %[[loopPtr:.*]] = %arg0, %[[loopOffset:.*]] = %[[base_offset]]) + %6 = scf.while (%arg1 = %5) : (tensor<1024x!tt.ptr, #blocked>) -> (tensor<1024x!tt.ptr, #blocked>) { + // CHECK: scf.condition({{.*}}) %{{.*}}, %[[loopPtr]], %[[loopOffset]] + scf.condition(%cond) %arg1 : tensor<1024x!tt.ptr, #blocked> + } do { + // CHECK: ^bb{{.*}}(%{{.*}}, %[[blockPtr:.*]]: !tt.ptr, %[[blockOffset:.*]]: tensor<1024xi64, #blocked>): + ^bb0(%arg1: tensor<1024x!tt.ptr, #blocked>): + // CHECK: scf.yield {{.*}}, %[[blockPtr]], %[[blockOffset]] + scf.yield %arg1 : tensor<1024x!tt.ptr, #blocked> + } + // CHECK: %[[trunc_offset:.*]] = arith.trunci %[[whileOut]]#2 + // CHECK: %[[base_ptr:.*]] = tt.splat %[[whileOut]]#1 + // CHECK: %[[newPtr:.*]] = tt.addptr %[[base_ptr]], %[[trunc_offset]] + // CHECK: tt.load %[[newPtr]] + %11 = tt.load %6 : tensor<1024x!tt.ptr, #blocked> + tt.return %11 : tensor<1024xf32, #blocked> + } +} + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @condBranch @@ -404,34 +699,34 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ } // ----- -// -//#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> -//module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { -// // CHECK-LABEL: tt.func @scalar_while -// tt.func @scalar_while(%arg0: !tt.ptr, %init : f32, %cond : i1)->f32{ -// %c1024_i32 = arith.constant 1024 : i32 -// %c0 = arith.constant 0: index -// %c128 = arith.constant 128: index -// %c1 = arith.constant 1 : index -// %0 = tt.get_program_id x : i32 -// %1 = arith.muli %0, %c1024_i32 : i32 -// // CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %{{.*}} -// // CHECK: scf.while ({{.*}}, {{.*}} = %arg2, %[[ptr1:.*]] = %[[ptr0]], {{.*}}) -// %2 = tt.addptr %arg0, %0 : !tt.ptr, i32 -// %6 = scf.while (%arg1 = %2, %arg2 = %cond) : (!tt.ptr, i1) -> (!tt.ptr) { -// // CHECK: scf.condition({{.*}}) {{.*}}, %[[ptr1]] -// scf.condition(%arg2) %arg1 : !tt.ptr -// } do { -// // CHECK: ^bb0({{.*}}: !tt.ptr, %[[ptr2:.*]]: !tt.ptr, {{.*}}) -// // CHECK: scf.yield %{{.*}}, {{.*}} %[[ptr2]], {{.*}}, {{.*}} -// ^bb0(%arg1: !tt.ptr): -// scf.yield %arg1, %cond : !tt.ptr, i1 -// } -// %11 = tt.load %6 : !tt.ptr -// tt.return %11 : f32 -// } -//} -// + +#blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @scalar_while + tt.func @scalar_while(%arg0: !tt.ptr, %init : f32, %cond : i1) -> f32 { + %c1024_i32 = arith.constant 1024 : i32 + %c0 = arith.constant 0: index + %c128 = arith.constant 128: index + %c1 = arith.constant 1 : index + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + // CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %{{.*}} + // CHECK: scf.while ({{.*}}, {{.*}} = %arg2, %[[ptr1:.*]] = %[[ptr0]], {{.*}}) + %2 = tt.addptr %arg0, %0 : !tt.ptr, i32 + %6 = scf.while (%arg1 = %2) : (!tt.ptr) -> (!tt.ptr) { + // CHECK: scf.condition({{.*}}) {{.*}}, %[[ptr1]] + scf.condition(%cond) %arg1 : !tt.ptr + } do { + // CHECK: ^bb0({{.*}}: !tt.ptr, %[[ptr2:.*]]: !tt.ptr, {{.*}}) + // CHECK: scf.yield %{{.*}}, {{.*}} %[[ptr2]], {{.*}}, {{.*}} + ^bb0(%arg1: !tt.ptr): + scf.yield %arg1 : !tt.ptr + } + %11 = tt.load %6 : !tt.ptr + tt.return %11 : f32 + } +} + // ----- #blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index fff8421af5a3..81b07f2e7d86 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -250,8 +250,8 @@ def make_ttgir(mod, metadata, options): if amd.has_matrix_core_feature(options.arch): amd.passes.ttgpuir.add_reorder_instructions(pm) - amd.passes.ttgpuir.add_canonicalize_pointers(pm) if use_buffer_ops: + amd.passes.ttgpuir.add_canonicalize_pointers(pm) passes.common.add_canonicalizer(pm) amd.passes.ttgpuir.add_convert_to_buffer_ops(pm) passes.common.add_canonicalizer(pm) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index 0307a60d527c..baf0d8102a34 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -1179,6 +1179,8 @@ class ConvertAddPtrOp : public PointerCanonPattern { // propagate the attributes of the tensor pointer to the fat pointer. fatPtrs[{newAddPtrOp.getResult(), fatPtrOffset}].attributes = fatPtrs[{fatPtrBase, fatPtrOffset}].attributes; + fatPtrs[{newAddPtrOp.getResult(), fatPtrOffset}].canNarrow = + fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow; return success(); } @@ -1188,7 +1190,7 @@ class ConvertAddPtrOp : public PointerCanonPattern { decomposeOffsetFromExpr(rewriter, curLoc, origOffset, bitness); // Vector offset update (if any): bump the tensor offset - bool canNarrow = false; + bool canNarrow = fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow; bool propagateAtrs = true; Value newOffset = fatPtrOffset; if (!isZeroConst(nonUniformOffset)) { @@ -1257,7 +1259,8 @@ class ConvertSplatOp : public PointerCanonPattern { splatOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); }); rewriter.replaceOpWithMultiple(splatOp, {{splatOp.getSrc(), offset}}); - + fatPtrs[{splatOp.getSrc(), offset}].canNarrow = + fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow; return success(); } }; @@ -1297,7 +1300,8 @@ class ConvertBroadcastOp : public PointerCanonPattern { }); rewriter.replaceOpWithMultiple(broadcastOp, {{broadcastOp.getSrc(), offset}}); - + fatPtrs[{broadcastOp.getSrc(), offset}].canNarrow = + fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow; return success(); } }; @@ -1380,6 +1384,8 @@ class ConvertFuncOp : public PointerCanonPattern { auto dummyCast = rewriter.create( arg.getLoc(), TypeRange{arg.getType()}, ValueRange{arg}); rewriter.replaceUsesOfBlockArgument(arg, dummyCast.getResult(0)); + // TODO(max): why is this true? + fatPtrs[{arg, zeroOffset}].canNarrow = true; rewriter.replaceOpWithMultiple(dummyCast, {{arg, zeroOffset}}); } funcOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); @@ -1862,6 +1868,11 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { if (failed( applyPartialConversion(module, target, std::move(patterns), config))) return signalPassFailure(); + + module.walk([](Operation *op) { + op->removeDiscardableAttr("rewritten"); + op->removeDiscardableAttr("legal"); + }); } void TritonAMDGPUCanonicalizePointersPass::runOnOperationmine() { From dc314dd01fc87c781379d5bb7d37fc2921d6149e Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 17 Dec 2024 17:22:05 -0500 Subject: [PATCH 09/17] cleanup --- .../CanonicalizePointers.cpp | 1749 +++++------------ 1 file changed, 527 insertions(+), 1222 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index baf0d8102a34..f3902f770f55 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -8,9 +8,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" -#include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" @@ -84,147 +82,6 @@ using namespace mlir; // if we ever meet 64 bit operations (because we know that the offset can be // contained in 32 bits) // -class PointerCanonicalizer { -public: - explicit PointerCanonicalizer(ModuleOp moduleOp) - : rewriter(moduleOp.getContext()), mod(moduleOp) {} - - // Propagate fat pointers in all the functions of the module - LogicalResult run(); - - // A fat pointer is represented as `basePtr + offset` internally. - struct FatPtr { - // Scalar base pointer. Needs to be `tt.splat`ed before used - Value basePtr; - // Tensor offset - Value offset; - // Flag to express if we can narrow the uses of the offset down to 32 bits - bool canNarrow = false; - // Collection of attributes that need to be applied to the pointer - llvm::SmallDenseMap attributes{}; - - // Utility copy functions - FatPtr copy(Value newBasePtr, Value newOffset) { - return FatPtr{newBasePtr, newOffset, canNarrow}; - } - FatPtr copyWithBase(Value newOffset) { - return FatPtr{basePtr, newOffset, canNarrow}; - } - FatPtr copyWithOffset(Value newBase) { - return FatPtr{newBase, offset, canNarrow}; - } - // Attribute functions - void setAttr(StringAttr name, Attribute value) { - attributes.insert({name, value}); - } - void setAttr(NamedAttribute attr) { - attributes.insert({attr.getName(), attr.getValue()}); - } - void setAttrs(ArrayRef attrs) { - for (auto attr : attrs) - attributes.insert({attr.getName(), attr.getValue()}); - } - }; - - // Rewrite any operation that needs a pointer - LogicalResult materializeFatPointer(Operation *op, Location loc, Value ptr); - - // Start from an argument of a function and propagate its fat pointers - LogicalResult rewritePointer(Value argPtr); - - // Create a tensor pointer from a fat pointer `fatPtr`. The tensor pointer is - // obtained by splatting the `fatPtr.basePtr` using the `fatPtr.offset` shape - // and adding the offset to it. - - // Push the attributes of the given operation `op` to the fat pointer - // corresponding to `val` - void collectFatPointerAttributes(Operation *op, Value val); - - // Rewrite a given function, canonicalizing the different pointer arguments of - // the region - LogicalResult rewriteFunction(triton::FuncOp funcOp); - - // Rewriters for different operation a pointer can walk into - LogicalResult rewriteSplatOp(triton::SplatOp splatOp, Location curLoc, - Value &nextPtr); - LogicalResult rewriteBroadcastOp(triton::BroadcastOp broadcastOp, - Location curLoc, Value &nextPtr); - LogicalResult rewriteAddPtrOp(triton::AddPtrOp addPtrOp, Location curLoc, - Value &nextPtr); - LogicalResult rewriteForOp(scf::ForOp forOp, Location curLoc, - OpOperand *operand, Value &nextPtr); - LogicalResult rewriteYieldOp(scf::YieldOp yieldOp, Location curLoc, - OpOperand *operand, Value &nextPtr); - LogicalResult rewriteWhileOp(scf::WhileOp whileOp, Location curLoc, - OpOperand *operand, Value &nextPtr); - LogicalResult rewriteConditionOp(scf::ConditionOp conditionOp, - Location curLoc, OpOperand *operand, - Value &nextPtr); - LogicalResult rewriteCondBranchOp(cf::CondBranchOp condBrOp, Location curLoc, - OpOperand *operand, Value &nextPtr); - LogicalResult rewriteSelectOp(arith::SelectOp selectOp, Location curLoc, - OpOperand *operand, Value &nextPtr); - LogicalResult rewriteBranchOp(cf::BranchOp branchOp, Location curLoc, - OpOperand *operand, Value &nextPtr); - - // Perform simplified scalar extraction. An offset can be composed by Unifrom - // (U) and non-uniform(N) components. A uniform component is basically a - // tensor constant (or a splat). A NonUniform value is a `make_range` or - // whatever we multiply with a `make_range` operation. We consider the generic - // expressions: - // offset = (N+U)*(N+U) - // - // Where the `uniformOffset=U*U` and the `nonUniformOffset=(N*U+U*N+N*N). - // - // We do not consider any expression not involving * and +. - // - // The function accepts the `rewriter`, the `location` and start recursing at - // the given `expr`. - // - // We also pass the bitness of the offset. - // - // The function returns the two components of the given offset as a - // std::pair{U, NU} - // std::pair decomposeOffsetFromExpr(Location loc, Value expr, - // int64_t bitness); - // std::pair decomposeOffsetFromAdd(Location loc, Value expr, - // int64_t bitness); - // std::pair decomposeOffsetFromMul(Location loc, Value expr, - // int64_t bitness); - - // Return either the operation or its rewritten op - template - OpTy resolveOp(Operation *op, - const DenseMap &rewriteOpMap) { - OpTy resolvedOp = dyn_cast(op); - if (rewriteOpMap.contains(op)) - resolvedOp = dyn_cast(rewriteOpMap.at(op)); - return resolvedOp; - } - - mlir::IRRewriter rewriter; - ModuleOp mod; - - // Symbol table: association between pointers and fatPointers - llvm::MapVector pointers; - - void clearFunctionState() { - rewriteOpMap.clear(); - queue.clear(); - opToDelete.clear(); - } - - // This structure is used to point to the right operation during the traversal - // of a function - DenseMap rewriteOpMap; - - // Queue of operations to visit in the current function - SmallVector queue; - - // List of IR to delete in the current function - SetVector opToDelete; -}; - namespace { // Extend a 32bit `offset` into 64bit using a arith.extsi operation @@ -257,7 +114,8 @@ static Value createNarrow64bitOffsetTo32bits(RewriterBase &rewriter, // Helper function to determine if the given `op` is a constant tensor and in // that case return the scalar value. -Value getScalarConstant(RewriterBase &rewriter, Location loc, Value expr) { +std::optional maybeGetOrCreateScalarConstant(RewriterBase &rewriter, + Location loc, Value expr) { Operation *op = expr.getDefiningOp(); // Check for splatness @@ -280,12 +138,13 @@ Value getScalarConstant(RewriterBase &rewriter, Location loc, Value expr) { return blockArg; } - return Value(); + return {}; } // Narrowing logic // For now we allow to narrow down to 32 bits only in the following case: // - `baseOffset` is 32-bits and `addOffset`(64-bits) is zero +// TODO(max): is this correct? bool canNarrowOffset(Value baseOffset, Value addOffset) { Type addOffsetType = getElementTypeOrSelf(addOffset); auto baseSplatOp = baseOffset.getDefiningOp(); @@ -301,59 +160,19 @@ Value createTensorZero(RewriterBase &rw, Location loc, RankedTensorType type) { } // namespace -void PointerCanonicalizer::collectFatPointerAttributes(Operation *op, - Value val) { - auto addBlockArgumentAttr = [&](BlockArgument arg) { - // If the value is a block parameter, the operation can specify - // an attribute for the given parameter by using `tt.property_argi` - // where `argi` refers to the arg number of the given parameter. - // So we need to iterate through the property, find the right one - // and push the property onto the pointers attributes. - llvm::SmallString<8> scratchStr; - for (NamedAttribute namedAttr : op->getAttrs()) { - scratchStr.clear(); - llvm::raw_svector_ostream sstream(scratchStr); - sstream << "_arg" << arg.getArgNumber(); - StringRef attrName = namedAttr.getName().getValue(); - if (attrName.ends_with(scratchStr)) { - StringRef newAttrName = attrName.drop_back(scratchStr.size()); - namedAttr.setName(rewriter.getStringAttr(newAttrName)); - pointers[val].setAttr(namedAttr); - // Propagate the argument to the offset if it is also a block argument - if (auto offsetArg = dyn_cast(pointers[val].offset)) { - scratchStr.clear(); - sstream << newAttrName << "_arg" << offsetArg.getArgNumber(); - op->setAttr(scratchStr, namedAttr.getValue()); - } - } - } - }; - - // If it is the i-th block argument, then look if the operation defined some - // _argi attribute and add it to the fat pointer attributes - if (auto arg = dyn_cast(val)) { - addBlockArgumentAttr(arg); - return; - } - - // Otherwise add the attributes of the operation to the fat pointer - for (NamedAttribute attr : op->getAttrs()) - pointers[val].setAttr(attr); -} - -std::pair decomposeOffsetFromExpr(RewriterBase &rewriter, - Location loc, Value expr, - int64_t bitness); +std::pair createDecomposeOffsetFromExpr(RewriterBase &rewriter, + Location loc, Value expr, + int64_t bitness); // Offset extraction logic for an addition op: // decompose(A+B) = {U(A)+U(B), NU(A)+NU(B)} -std::pair decomposeOffsetFromAdd(RewriterBase &rewriter, - Location loc, Value expr, - int64_t bitness) { +std::pair createDecomposeOffsetFromAdd(RewriterBase &rewriter, + Location loc, Value expr, + int64_t bitness) { auto addOp = expr.getDefiningOp(); auto [uniformOffsetL, nonUniformOffsetL] = - decomposeOffsetFromExpr(rewriter, loc, addOp.getLhs(), bitness); + createDecomposeOffsetFromExpr(rewriter, loc, addOp.getLhs(), bitness); auto [uniformOffsetR, nonUniformOffsetR] = - decomposeOffsetFromExpr(rewriter, loc, addOp.getRhs(), bitness); + createDecomposeOffsetFromExpr(rewriter, loc, addOp.getRhs(), bitness); Value uniformAdd = rewriter.create(loc, uniformOffsetL, uniformOffsetR); Value nonUniformAdd = @@ -363,14 +182,14 @@ std::pair decomposeOffsetFromAdd(RewriterBase &rewriter, // Offset extraction logic for a multiplication op: // decompose(A*B) = {U(A)*U(B), NU(A)*NU(B)+NU(B)*U(A)+U(A)*NU(B)} -std::pair decomposeOffsetFromMul(RewriterBase &rewriter, - Location loc, Value expr, - int64_t bitness) { +std::pair createDecomposeOffsetFromMul(RewriterBase &rewriter, + Location loc, Value expr, + int64_t bitness) { auto mulOp = expr.getDefiningOp(); auto [uniformOffsetL, nonUniformOffsetL] = - decomposeOffsetFromExpr(rewriter, loc, mulOp.getLhs(), bitness); + createDecomposeOffsetFromExpr(rewriter, loc, mulOp.getLhs(), bitness); auto [uniformOffsetR, nonUniformOffsetR] = - decomposeOffsetFromExpr(rewriter, loc, mulOp.getRhs(), bitness); + createDecomposeOffsetFromExpr(rewriter, loc, mulOp.getRhs(), bitness); Value uniformMul = rewriter.create(loc, uniformOffsetL, uniformOffsetR); @@ -391,50 +210,47 @@ std::pair decomposeOffsetFromMul(RewriterBase &rewriter, return {uniformMul, nonUniformMul}; } -std::pair decomposeOffsetFromExpr(RewriterBase &rewriter, - Location loc, Value expr, - int64_t bitness) { - - // RewriterBase::InsertionGuard guard(rewriter); - // rewriter.setInsertionPointAfterValue(expr); +std::pair createDecomposeOffsetFromExpr(RewriterBase &rewriter, + Location loc, Value expr, + int64_t bitness) { // Base case 1: it is a splat. Return the scalar constant as the uniform part - if (Value scalarConst = getScalarConstant(rewriter, loc, expr)) { + if (auto scalarConst = maybeGetOrCreateScalarConstant(rewriter, loc, expr)) { auto tensorZero = createTensorZero(rewriter, loc, cast(expr.getType())); - return {scalarConst, tensorZero}; + return {*scalarConst, tensorZero}; } // Base case 2: block argument. Since it is not a scalar constant, it must be // a tensor. Note that this means we won't be able to decompose across loop // boundaries (TODO: giuseros). - if (auto blockArg = dyn_cast(expr)) { + if (llvm::isa(expr)) { Value scalarZero = rewriter.create(loc, 0, bitness); - return std::make_pair(scalarZero, expr); + return {scalarZero, expr}; } auto offsets = llvm::TypeSwitch>( expr.getDefiningOp()) .Case([&](auto broadcastOp) { - auto [uniform, nonUniform] = decomposeOffsetFromExpr( + auto [uniform, nonUniform] = createDecomposeOffsetFromExpr( rewriter, loc, broadcastOp.getSrc(), bitness); auto broadcastNonUniform = rewriter.create( loc, broadcastOp.getType(), nonUniform); return std::make_pair(uniform, broadcastNonUniform); }) .Case([&](auto expandOp) { - auto [uniform, nonUniform] = decomposeOffsetFromExpr( + auto [uniform, nonUniform] = createDecomposeOffsetFromExpr( rewriter, loc, expandOp.getSrc(), bitness); auto expandNonUniform = rewriter.create( loc, nonUniform, expandOp.getAxis()); return std::make_pair(uniform, expandNonUniform); }) .Case([&](Operation *op) { - return decomposeOffsetFromAdd(rewriter, loc, expr, bitness); + return createDecomposeOffsetFromAdd(rewriter, loc, expr, bitness); }) .Case([&](Operation *op) { - return decomposeOffsetFromMul(rewriter, loc, expr, bitness); + return createDecomposeOffsetFromMul(rewriter, loc, expr, bitness); }) .Default([&](Operation *op) { // Base case 3: it is not a supported operation. We assume no @@ -447,6 +263,13 @@ std::pair decomposeOffsetFromExpr(RewriterBase &rewriter, return offsets; } +static const char kLegalAttr[] = "__amd-pointer-canonicalize-legal__"; +static const char kRewrittenAttr[] = "__amd-pointer-canonicalize-rewritten__"; +static const char kSCFThenRewrittenAttr[] = + "__amd-pointer-canonicalize-scf-then-rewritten__"; +static const char kSCFElseRewrittenAttr[] = + "__amd-pointer-canonicalize-scf-else-rewritten__"; + Value createTensorPointer( RewriterBase &rewriter, Value basePtr, Value offset, Location loc, bool canNarrow, @@ -475,7 +298,7 @@ Value createTensorPointer( Value tensorPtr = rewriter.create( loc, TypeRange{tensorPtrType}, ValueRange{basePtr}, - SmallVector{rewriter.getNamedAttr("legal", rewriter.getUnitAttr())}); + SmallVector{rewriter.getNamedAttr(kLegalAttr, rewriter.getUnitAttr())}); auto addPtrOp = rewriter.create(loc, tensorPtrType, tensorPtr, offset); @@ -485,561 +308,6 @@ Value createTensorPointer( return addPtrOp.getResult(); } -// Rewrite a memory operation -LogicalResult PointerCanonicalizer::materializeFatPointer(Operation *op, - Location loc, - Value ptr) { - auto fatPtr = pointers[ptr]; - Value basePtr = fatPtr.basePtr; - Value offset = fatPtr.offset; - - // Create the tensor pointer (i.e., splat the base && add the offset) - Value newPtr = basePtr; - if (isa(ptr.getType())) - newPtr = createTensorPointer(rewriter, fatPtr.basePtr, fatPtr.offset, loc, - fatPtr.canNarrow, fatPtr.attributes); - - // Save the fat pointer in the table - pointers[newPtr] = fatPtr; - - // Map and replace the load - IRMapping mapper; - mapper.map(ptr, newPtr); - Operation *newOp = rewriter.clone(*op, mapper); - rewriter.replaceAllOpUsesWith(op, newOp); - opToDelete.insert(op); - return success(); -} - -LogicalResult PointerCanonicalizer::rewriteSplatOp(triton::SplatOp splatOp, - Location curLoc, - Value &nextPtr) { - nextPtr = splatOp.getResult(); - auto fatPtr = pointers[splatOp.getSrc()]; - auto outType = splatOp.getResult().getType(); - auto ptrShape = outType.getShape(); - auto newOffsetType = RankedTensorType::get(ptrShape, fatPtr.offset.getType(), - outType.getEncoding()); - Value offset = - rewriter.create(curLoc, newOffsetType, fatPtr.offset); - // The shape of the fat pointer is contained within the offset. We don't - // need to keep the `splat` operation here. - opToDelete.insert(splatOp); - pointers[nextPtr] = fatPtr.copy(splatOp.getSrc(), offset); - return success(); -} - -LogicalResult -PointerCanonicalizer::rewriteBroadcastOp(triton::BroadcastOp broadcastOp, - Location curLoc, Value &nextPtr) { - nextPtr = broadcastOp.getResult(); - auto fatPtr = pointers[broadcastOp.getSrc()]; - auto outType = dyn_cast(broadcastOp.getResult().getType()); - auto ptrShape = outType.getShape(); - auto offsetType = dyn_cast(fatPtr.offset.getType()); - if (!offsetType) - return failure(); - - opToDelete.insert(broadcastOp); - - auto newOffsetType = RankedTensorType::get( - ptrShape, offsetType.getElementType(), outType.getEncoding()); - Value offset = rewriter.create(curLoc, newOffsetType, - fatPtr.offset); - pointers[nextPtr] = fatPtr.copyWithBase(offset); - return success(); -} - -LogicalResult PointerCanonicalizer::rewriteAddPtrOp(triton::AddPtrOp addPtrOp, - Location curLoc, - Value &nextPtr) { - nextPtr = addPtrOp.getResult(); - auto fatPtr = pointers[addPtrOp.getPtr()]; - Value newPtr = fatPtr.basePtr; - // If it is a scalar pointer update, simply bump the base pointer - if (!isa(addPtrOp.getPtr().getType())) { - addPtrOp->setOperand(0, newPtr); - pointers[nextPtr] = fatPtr.copyWithOffset(nextPtr); - return success(); - } - Value offset = addPtrOp.getOffset(); - - // Early exit for the case of a constant tensor - if (Value scalarConst = getScalarConstant(rewriter, curLoc, offset)) { - newPtr = rewriter.create(curLoc, newPtr.getType(), newPtr, - scalarConst); - pointers[nextPtr] = fatPtr.copyWithOffset(newPtr); - // If we are updating the tensor pointer with a uniform value, we can - // propagate the attributes of the tensor pointer to the fat pointer. - for (auto attribute : fatPtr.attributes) - pointers[nextPtr].setAttr(attribute.getFirst(), attribute.getSecond()); - opToDelete.insert(addPtrOp); - return success(); - } - - int64_t bitness = - cast(offset.getType()).getElementTypeBitWidth(); - auto [uniformOffset, nonUniformOffset] = - decomposeOffsetFromExpr(rewriter, curLoc, offset, bitness); - - // Scalar pointer update: bump the scalar pointer - newPtr = rewriter.create(curLoc, newPtr.getType(), newPtr, - uniformOffset); - - // Vector offset update (if any): bump the tensor offset - Value fatPtrOffset = fatPtr.offset; - bool canNarrow = fatPtr.canNarrow; - Value newOffset = fatPtrOffset; - bool propagateAtrs = true; - if (!isZeroConst(nonUniformOffset)) { - Type addPtrOffsetType = getElementTypeOrSelf(nonUniformOffset); - Type fatPtrOffsetType = getElementTypeOrSelf(fatPtrOffset); - canNarrow = canNarrow && canNarrowOffset(fatPtrOffset, nonUniformOffset); - - // Upcast or downcast the offset accordingly - if (addPtrOffsetType.isInteger(32) && fatPtrOffsetType.isInteger(64)) - nonUniformOffset = - createExtend32bitOffsetTo64Bits(rewriter, curLoc, nonUniformOffset); - else if (addPtrOffsetType.isInteger(64) && fatPtrOffsetType.isInteger(32)) - nonUniformOffset = - createNarrow64bitOffsetTo32bits(rewriter, curLoc, nonUniformOffset); - - newOffset = - rewriter.create(curLoc, nonUniformOffset, fatPtrOffset); - propagateAtrs = false; - } - opToDelete.insert(addPtrOp); - pointers[nextPtr] = FatPtr{newPtr, newOffset, canNarrow}; - - // If we are updating the tensor pointer with a uniform value, we can - // propagate the attributes of the tensor pointer to the fat pointer. - if (propagateAtrs) - for (auto attribute : fatPtr.attributes) - pointers[nextPtr].setAttr(attribute.getFirst(), attribute.getSecond()); - return success(); -} - -LogicalResult PointerCanonicalizer::rewriteForOp(scf::ForOp forOp, - Location curLoc, - OpOperand *curOperand, - Value &nextPtr) { - size_t operandNum = curOperand->getOperandNumber(); - FatPtr fatPtr = pointers[curOperand->get()]; - Value offset = fatPtr.offset; - Value basePtr = fatPtr.basePtr; - - // Replace the forOp with two additional argument (i.e., the curOperand's - // scalar pointer and the offset) - Value tensorPtr = - createTensorPointer(rewriter, fatPtr.basePtr, fatPtr.offset, curLoc, - fatPtr.canNarrow, fatPtr.attributes); - auto newForOp = - replaceForOpWithNewSignature(rewriter, forOp, {basePtr, offset}); - rewriteOpMap[forOp] = newForOp; - - newForOp->setOperand(operandNum, tensorPtr); - OpOperand *forOperand = &newForOp->getOpOperand(operandNum); - // This is making sure we propagate the visit from the forOp result - nextPtr = newForOp.getTiedLoopResult(forOperand); - - // This is making sure we visit the uses within the forOp region - Value arg = newForOp.getTiedLoopRegionIterArg(forOperand); - size_t numIterArgs = newForOp.getNumRegionIterArgs(); - pointers[arg] = fatPtr.copy(newForOp.getRegionIterArg(numIterArgs - 2), - newForOp.getRegionIterArg(numIterArgs - 1)); - - // Collect attributes before continuing the visit - collectFatPointerAttributes(newForOp, arg); - - for (OpOperand &use : arg.getUses()) - queue.push_back(&use); - - // This is setting the fat pointer for the users of the loop - // and then propagate the result - size_t numResults = newForOp->getNumResults(); - pointers[nextPtr] = fatPtr.copy(newForOp->getResult(numResults - 2), - newForOp.getResult(numResults - 1)); - opToDelete.insert(forOp); - return success(); -} - -LogicalResult PointerCanonicalizer::rewriteYieldOp(scf::YieldOp yieldOp, - Location curLoc, - OpOperand *curOperand, - Value &nextPtr) { - - // Rewriting the yield op is a bit more complicated, because a - // yield op can be inside of a ForOp, WhileOp(in the AfterRegion) or - // IfOp - size_t operandNum = curOperand->getOperandNumber(); - FatPtr fatPtr = pointers[curOperand->get()]; - yieldOp.getResultsMutable().append(fatPtr.basePtr); - yieldOp.getResultsMutable().append(fatPtr.offset); - - if (auto forOp = dyn_cast(yieldOp->getParentOp())) { - yieldOp->setOperand(operandNum, forOp.getRegionIterArg(operandNum)); - } else if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { - // Case 1: the yieldOp is contained within an IfOp. One of the - // two branches is responsible to rewrite the operation. The other - // branch only update the yieldOp with the right parameters - Value tensorPtr = - createTensorPointer(rewriter, fatPtr.basePtr, fatPtr.offset, curLoc, - fatPtr.canNarrow, fatPtr.attributes); - yieldOp->setOperand(operandNum, tensorPtr); - - if (yieldOp->getBlock() == &ifOp.getThenRegion().front()) { - auto newIfOp = replaceIfOpWithNewSignature( - rewriter, ifOp, {fatPtr.basePtr.getType(), fatPtr.offset.getType()}); - nextPtr = newIfOp.getResult(operandNum); - size_t numResults = newIfOp->getNumResults(); - pointers[nextPtr] = fatPtr.copy(newIfOp->getResult(numResults - 2), - newIfOp.getResult(numResults - 1)); - opToDelete.insert(ifOp); - } - - } else if (auto whileOp = resolveOp(yieldOp->getParentOp(), - rewriteOpMap)) { - // Case 2: the yieldOp is contained within the AfterRegion of a - // WhileOp. In this case, we know that the before region should have - // already been replaced (when we met the WhileOp), hence we can - // simply replace the WhileOp with a new AfterRegion (and hance a new - // set of return types) - auto newWhileOp = replaceWhileOpWithNewSignature( - rewriter, whileOp, {}, - {fatPtr.basePtr.getType(), fatPtr.offset.getType()}); - nextPtr = newWhileOp.getResult(operandNum); - size_t numResults = newWhileOp->getNumResults(); - pointers[nextPtr] = fatPtr.copy(newWhileOp->getResult(numResults - 2), - newWhileOp->getResult(numResults - 1)); - rewriteOpMap[whileOp] = newWhileOp; - opToDelete.insert(whileOp.getOperation()); - yieldOp.setOperand(operandNum, newWhileOp.getAfterArguments()[operandNum]); - } - return success(); -} - -LogicalResult PointerCanonicalizer::rewriteWhileOp(scf::WhileOp whileOp, - Location curLoc, - OpOperand *curOperand, - Value &nextPtr) { - // WhileOp rewrite happens in two phases: first rewrite the operand list - // and then rewrite the types when we meet the yieldOp - size_t operandNum = curOperand->getOperandNumber(); - FatPtr fatPtr = pointers[curOperand->get()]; - Value offset = fatPtr.offset; - Value basePtr = fatPtr.basePtr; - // Rewrite the while op with a new set of operands (but with the same - // set of return types) - Value tensorPtr = - createTensorPointer(rewriter, fatPtr.basePtr, fatPtr.offset, curLoc, - fatPtr.canNarrow, fatPtr.attributes); - auto newWhileOp = - replaceWhileOpWithNewSignature(rewriter, whileOp, {basePtr, offset}, {}); - newWhileOp->setOperand(operandNum, tensorPtr); - Value arg = newWhileOp.getBeforeBody()->getArgument(operandNum); - // Propagate inside the BeforeRegion - size_t numArguments = newWhileOp.getBeforeBody()->getNumArguments(); - pointers[arg] = - fatPtr.copy(newWhileOp.getBeforeBody()->getArgument(numArguments - 2), - newWhileOp.getBeforeBody()->getArgument(numArguments - 1)); - nextPtr = arg; - rewriteOpMap[whileOp] = newWhileOp; - opToDelete.insert(whileOp); - return success(); -} - -// ConditionOp can only be contained within the BeforeRegion of a -// WhileOp. We already rewrote the WhileOp with the right operands, so -// we need only to add the offset the current operand to be the base -// pointer and continue the walk inside the AfterRegion -LogicalResult -PointerCanonicalizer::rewriteConditionOp(scf::ConditionOp conditionOp, - Location curLoc, OpOperand *curOperand, - Value &nextPtr) { - - size_t operandNum = curOperand->getOperandNumber(); - FatPtr fatPtr = pointers[curOperand->get()]; - Value offset = fatPtr.offset; - Value basePtr = fatPtr.basePtr; - auto whileOp = cast(conditionOp->getParentOp()); - - // Update the condition op - auto afterBlock = whileOp.getAfterBody(); - conditionOp.getArgsMutable().append({basePtr, offset}); - - // Propagate through the after region - afterBlock->addArgument(basePtr.getType(), curLoc); - afterBlock->addArgument(offset.getType(), curLoc); - nextPtr = afterBlock->getArgument(operandNum - 1); - size_t numArguments = afterBlock->getNumArguments(); - conditionOp.setOperand(operandNum, - whileOp.getRegionIterArgs()[operandNum - 1]); - pointers[nextPtr] = fatPtr.copy(afterBlock->getArgument(numArguments - 2), - afterBlock->getArgument(numArguments - 1)); - return success(); -} - -LogicalResult PointerCanonicalizer::rewriteCondBranchOp( - cf::CondBranchOp condBrOp, Location curLoc, OpOperand *curOperand, - Value &nextPtr) { - // CondBranchOp is a bit tricky to handle. Because we might be inserting - // the basePtr+offset as a TrueDestOperand(s), which is not the end of - // `condBrOp.getOperands()` - auto falseOperands = llvm::to_vector(condBrOp.getFalseDestOperands()); - auto trueOperands = llvm::to_vector(condBrOp.getTrueOperands()); - auto it = llvm::find(falseOperands, curOperand->get()); - bool isFalseOperand = (it != falseOperands.end()); - size_t operandNum = curOperand->getOperandNumber(); - - if (rewriteOpMap.contains(condBrOp)) { - // If we need to use a different condBrOp, we might also need to - // update `operandNum` - auto condBranchReplacement = - dyn_cast(rewriteOpMap[condBrOp]); - if (isFalseOperand) { - // basePtr+offset need to be added if we are on the FalseOperands - // side, but the true operands have been rewritten - bool needOffset = (condBranchReplacement.getTrueDestOperands().size() != - condBrOp.getTrueDestOperands().size()); - int maybeOffset = (needOffset ? 2 : 0); - operandNum += maybeOffset; - curOperand = &condBranchReplacement->getOpOperand(operandNum); - } - // Now we need to recompute the currentOperation and its {true,false} - // operands - falseOperands = - llvm::to_vector(condBranchReplacement.getFalseDestOperands()); - trueOperands = llvm::to_vector(condBranchReplacement.getTrueDestOperands()); - condBrOp = condBranchReplacement; - } - - // Now we can proceed almost normally - FatPtr fatPtr = pointers[curOperand->get()]; - Value offset = fatPtr.offset; - Value basePtr = fatPtr.basePtr; - - Block *falseDest = condBrOp.getFalseDest(); - Block *trueDest = condBrOp.getTrueDest(); - // Walk the destination block only if you don't have visited it yet - if (isFalseOperand) { - falseOperands.push_back(basePtr); - falseOperands.push_back(offset); - Value falseDestArg = - falseDest->getArgument(operandNum - condBrOp.getNumTrueOperands() - 1); - if (!pointers.contains(falseDestArg)) { - nextPtr = falseDestArg; - Value basePtrArg = falseDest->addArgument(basePtr.getType(), curLoc); - Value offsetArg = falseDest->addArgument(offset.getType(), curLoc); - pointers[nextPtr] = fatPtr.copy(basePtrArg, offsetArg); - } - } else { - trueOperands.push_back(basePtr); - trueOperands.push_back(offset); - Value trueDestArg = trueDest->getArgument(operandNum - 1); - if (!pointers.contains(trueDestArg)) { - nextPtr = trueDestArg; - Value basePtrArg = trueDest->addArgument(basePtr.getType(), curLoc); - Value offsetArg = trueDest->addArgument(offset.getType(), curLoc); - pointers[nextPtr] = fatPtr.copy(basePtrArg, offsetArg); - } - } - - // Create a new condBranch. We cannot simply extend the operands, - // because this would invalidate other operands pointing at the same - // cond branch - Value tensorPtr = - createTensorPointer(rewriter, fatPtr.basePtr, fatPtr.offset, curLoc, - fatPtr.canNarrow, fatPtr.attributes); - auto newCondBranch = rewriter.create( - curLoc, condBrOp.getCondition(), trueDest, trueOperands, falseDest, - falseOperands); - - newCondBranch.setOperand(operandNum, tensorPtr); - rewriteOpMap[condBrOp] = newCondBranch; - opToDelete.insert(condBrOp); - return success(); -} - -LogicalResult PointerCanonicalizer::rewriteSelectOp(arith::SelectOp selectOp, - Location curLoc, - OpOperand *curOperand, - Value &nextPtr) { - Value trueVal = selectOp.getTrueValue(); - Value falseVal = selectOp.getFalseValue(); - Value cond = selectOp.getCondition(); - // If we didn't traverse both operands, simply materialize the pointer - if (!pointers.contains(trueVal) || !pointers.contains(falseVal)) - return materializeFatPointer(selectOp, curLoc, curOperand->get()); - - // If both have been traversed, then we can rewrite select of pointers as a - // select of base and offset - FatPtr fatPtrT = pointers[trueVal]; - FatPtr fatPtrF = pointers[falseVal]; - nextPtr = selectOp.getResult(); - - // Simple case of a scalar select: update the base pointer - if (!isa(selectOp.getType())) { - FatPtr fatPtr = pointers[trueVal]; - pointers[nextPtr] = fatPtr.copyWithOffset(nextPtr); - nextPtr = selectOp.getResult(); - return success(); - } - - // Rewrite `select` for base and offset - Value newBase = rewriter.create( - curLoc, cond, fatPtrT.basePtr, fatPtrF.basePtr); - Value newOffset = rewriter.create( - curLoc, cond, fatPtrT.offset, fatPtrF.offset); - assert(fatPtrT.canNarrow == fatPtrF.canNarrow); - - pointers[nextPtr] = fatPtrT.copy(newBase, newOffset); - opToDelete.insert(selectOp); - return success(); -} - -LogicalResult PointerCanonicalizer::rewriteBranchOp(cf::BranchOp branchOp, - Location curLoc, - OpOperand *curOperand, - Value &nextPtr) { - size_t operandNum = curOperand->getOperandNumber(); - FatPtr fatPtr = pointers[curOperand->get()]; - Value offset = fatPtr.offset; - Value basePtr = fatPtr.basePtr; - branchOp.getDestOperandsMutable().append({basePtr, fatPtr.offset}); - Value tensorPtr = - createTensorPointer(rewriter, fatPtr.basePtr, fatPtr.offset, curLoc, - fatPtr.canNarrow, fatPtr.attributes); - branchOp->setOperand(operandNum, tensorPtr); - Block *dest = branchOp.getDest(); - - // Walk the destination block only if you don't have visited it yet - if (!pointers.contains(dest->getArgument(operandNum))) { - Value basePtrArg = dest->addArgument(basePtr.getType(), curLoc); - Value offsetArg = dest->addArgument(offset.getType(), curLoc); - nextPtr = dest->getArgument(operandNum); - pointers[nextPtr] = {basePtrArg, offsetArg, fatPtr.canNarrow}; - } - return success(); -} - -// Start from an argument of a function and propagate its -// fat pointers -LogicalResult PointerCanonicalizer::rewritePointer(Value argPtr) { - // Start the visit - for (OpOperand &use : argPtr.getUses()) - queue.push_back(&use); - - while (!queue.empty()) { - OpOperand *curOperand = queue.pop_back_val(); - Operation *curOp = curOperand->getOwner(); - Location curLoc = curOp->getLoc(); - - rewriter.setInsertionPoint(curOp); - LogicalResult res = success(); - Value nextPtr; - // We need to propagate the fat pointer throughout the IR - llvm::TypeSwitch(curOp) - .Case([&](auto splatOp) { - res = rewriteSplatOp(splatOp, curLoc, nextPtr); - }) - .Case([&](auto broadcastOp) { - res = rewriteBroadcastOp(broadcastOp, curLoc, nextPtr); - }) - .Case([&](auto addPtrOp) { - res = rewriteAddPtrOp(addPtrOp, curLoc, nextPtr); - }) - .Case([&](auto forOp) { - res = rewriteForOp(resolveOp(forOp, rewriteOpMap), curLoc, - curOperand, nextPtr); - }) - .Case([&](auto yieldOp) { - res = rewriteYieldOp(yieldOp, curLoc, curOperand, nextPtr); - }) - .Case([&](auto whileOp) { - res = rewriteWhileOp(resolveOp(whileOp, rewriteOpMap), - curLoc, curOperand, nextPtr); - }) - .Case([&](auto conditionOp) { - res = rewriteConditionOp(conditionOp, curLoc, curOperand, nextPtr); - }) - .Case([&](auto condBrOp) { - res = rewriteCondBranchOp(condBrOp, curLoc, curOperand, nextPtr); - }) - .Case([&](auto selectOp) { - res = rewriteSelectOp(selectOp, curLoc, curOperand, nextPtr); - }) - .Case([&](auto branchOp) { - res = rewriteBranchOp(branchOp, curLoc, curOperand, nextPtr); - }) - .Case([&](Operation *op) { - res = materializeFatPointer(curOp, curLoc, op->getOperand(0)); - }) - .Default([&](Operation *op) { - // If we meet an unsupported operation, materialize the fat pointer - // and continue. - LDBG("Unknown op during pointer canonicalization: " << *curOp); - res = materializeFatPointer(op, curLoc, curOperand->get()); - }); - - // Collect the attributes and Keep propagating the fat pointer down the IR - if (nextPtr) { - collectFatPointerAttributes(curOp, nextPtr); - for (OpOperand &use : nextPtr.getUses()) - if (!opToDelete.contains(use.getOwner())) - queue.push_back(&use); - } - } - return success(); -} - -LogicalResult PointerCanonicalizer::rewriteFunction(triton::FuncOp funcOp) { - Region ®ion = funcOp.getRegion(); - for (auto [idx, arg] : llvm::enumerate(region.getArguments())) { - // The pointer argument needs to be a scalar - if (!isa(arg.getType())) - continue; - int64_t bitness = 64; - if (IntegerAttr pointerRangeAttr = - funcOp.getArgAttrOfType(idx, "tt.pointer_range")) - bitness = pointerRangeAttr.getInt(); - - rewriter.setInsertionPointToStart(®ion.front()); - Value zeroOffset = - rewriter.create(region.getLoc(), 0, bitness); - - // Start the rewrite - clearFunctionState(); - pointers[arg] = FatPtr{arg, zeroOffset, true}; - if (failed(rewritePointer(arg))) - return failure(); - - // Clean-up: don't assume the operation to delete are in the correct order, - // but force dropping the reference of the ops before we delete them - for (Operation *op : opToDelete) { - op->dropAllReferences(); - op->dropAllDefinedValueUses(); - rewriter.eraseOp(op); - } - } - return success(); -} - -LogicalResult PointerCanonicalizer::run() { - llvm::SmallVector funcOps; - - // For now we don't cross function boundaries, but we should do that whenever - // is possible - mod.walk([&](triton::FuncOp funcOp) { funcOps.push_back(funcOp); }); - - for (triton::FuncOp funcOp : funcOps) { - if (failed(rewriteFunction(funcOp))) - return failure(); - } - return success(); -} -// This pass is calling the pointer canonicalization utility -// on the given MLIR module class TritonAMDGPUCanonicalizePointersPass : public TritonAMDGPUCanonicalizePointersBase< TritonAMDGPUCanonicalizePointersPass> { @@ -1047,23 +315,34 @@ class TritonAMDGPUCanonicalizePointersPass TritonAMDGPUCanonicalizePointersPass() = default; void runOnOperation() override; - void runOnOperationmine(); }; struct FatPointers { - struct FatPtr { + struct FatPtrAttrs { + FatPtrAttrs(const FatPtrAttrs &other) = default; + FatPtrAttrs &operator=(const FatPtrAttrs &other) = default; + // for map default insert + FatPtrAttrs() = default; bool canNarrow = false; llvm::SmallDenseMap attributes; + friend bool operator==(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) { + return lhs.canNarrow == rhs.canNarrow && lhs.attributes == rhs.attributes; + } + friend bool operator!=(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) { + return !(lhs == rhs); + } }; using KeyT = std::pair; - using ValueT = FatPtr; + using ValueT = FatPtrAttrs; using DenseMapT = DenseMap; - DenseMapT pointers; ValueT &operator[](const KeyT &k) { return pointers[k]; } ValueT &operator[](KeyT &&k) { return pointers[k]; } template using const_arg_type_t = typename llvm::const_pointer_or_const_ref::type; const ValueT &at(const_arg_type_t k) const { return pointers.at(k); } + +private: + DenseMapT pointers; }; std::optional getFatPtrCastOp(Value base, @@ -1125,69 +404,169 @@ static Value getSingleValue(ValueRange values) { } template -struct PointerCanonPattern : OpConversionPattern { - PointerCanonPattern(MLIRContext *context, FatPointers &fatPtrs, - PatternBenefit benefit = 1) +struct PointerCanonicalizationPattern : OpConversionPattern { + PointerCanonicalizationPattern(MLIRContext *context, FatPointers &fatPtrs, + PatternBenefit benefit = 1) : OpConversionPattern(context, benefit), fatPtrs(fatPtrs) {} FatPointers &fatPtrs; }; -class ConvertAddPtrOp : public PointerCanonPattern { +static void setLegalAttr(ConversionPatternRewriter &rewriter, + Operation *newOp) { + rewriter.modifyOpInPlace(newOp, [&] { + newOp->setDiscardableAttr(kLegalAttr, rewriter.getUnitAttr()); + }); +} + +static void setRewrittenAttr(ConversionPatternRewriter &rewriter, + Operation *origOp) { + rewriter.modifyOpInPlace(origOp, [&] { + origOp->setDiscardableAttr(kRewrittenAttr, rewriter.getUnitAttr()); + }); +} + +static void setRewrittenLegalAttrs(ConversionPatternRewriter &rewriter, + Operation *origOp, Operation *newOp) { + setRewrittenAttr(rewriter, origOp); + setLegalAttr(rewriter, newOp); +} + +/// splat integer offset, keep base +class ConvertSplatOp : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + + LogicalResult + matchAndRewrite(triton::SplatOp splatOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange remappedOperands = adaptor.getSrc(); + if (remappedOperands.size() != 2) + return rewriter.notifyMatchFailure( + splatOp, "expected SplatOp src to have already been remapped"); + Value fatPtrBase = remappedOperands[0]; + Value fatPtrOffset = remappedOperands[1]; + if (!llvm::isa(fatPtrBase.getType())) + return rewriter.notifyMatchFailure(splatOp, + "non tt.ptr base unimplemented"); + if (!llvm::isa(fatPtrOffset.getType())) + return rewriter.notifyMatchFailure(splatOp, + "non-integer offset unimplemented"); + + RankedTensorType outType = splatOp.getResult().getType(); + auto newOffsetType = RankedTensorType::get( + outType.getShape(), fatPtrOffset.getType(), outType.getEncoding()); + triton::SplatOp offset = rewriter.create( + splatOp.getLoc(), newOffsetType, fatPtrOffset); + setRewrittenLegalAttrs(rewriter, splatOp, offset); + rewriter.replaceOpWithMultiple(splatOp, {{fatPtrBase, offset}}); + fatPtrs[{fatPtrBase, offset}] = fatPtrs[{fatPtrBase, fatPtrOffset}]; + return success(); + } +}; + +/// Broadcast offset, keep base. +class ConvertBroadcastOp + : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + + LogicalResult + matchAndRewrite(triton::BroadcastOp broadcastOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange remappedOperands = adaptor.getSrc(); + if (remappedOperands.size() != 2) + return rewriter.notifyMatchFailure( + broadcastOp, + "expected BroadcastOp src to have already been remapped"); + + Value fatPtrBase = remappedOperands[0]; + Value fatPtrOffset = remappedOperands[1]; + if (!llvm::isa(fatPtrBase.getType())) + return rewriter.notifyMatchFailure(broadcastOp, + "non tt.ptr base unimplemented"); + auto offsetType = dyn_cast(fatPtrOffset.getType()); + return rewriter.notifyMatchFailure(broadcastOp, + "non-tensor offset unimplemented"); + + auto outType = + dyn_cast(broadcastOp.getResult().getType()); + auto newOffsetType = RankedTensorType::get( + outType.getShape(), offsetType.getElementType(), outType.getEncoding()); + triton::BroadcastOp newOffset = rewriter.create( + broadcastOp.getLoc(), newOffsetType, fatPtrOffset); + setRewrittenLegalAttrs(rewriter, broadcastOp, newOffset); + rewriter.replaceOpWithMultiple(broadcastOp, {{fatPtrBase, newOffset}}); + fatPtrs[{fatPtrBase, newOffset}] = fatPtrs[{fatPtrBase, fatPtrOffset}]; + return success(); + } +}; + +/// Three cases: +/// 1. If it is a scalar pointer update -> bump only the base pointer; +/// 2. Constant tensor offset -> bump only the offset +/// 3. Non-constant tensor offset -> decompose parent(offset) into uniform and +/// non-uniform comop +class ConvertAddPtrOp + : public PointerCanonicalizationPattern { public: - using PointerCanonPattern::PointerCanonPattern; + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; LogicalResult matchAndRewrite(triton::AddPtrOp addPtrOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + ValueRange remappedPtr = adaptor.getPtr(); + if (remappedPtr.size() != 2) + return rewriter.notifyMatchFailure( + addPtrOp, "expected AddPtrOp Ptr to have already been remapped"); + ValueRange nonRemappedOffset = adaptor.getOffset(); + if (nonRemappedOffset.size() != 1) + return rewriter.notifyMatchFailure( + addPtrOp, "expected AddPtrOp Offset to have not have been remapped"); + Value fatPtrBase = remappedPtr[0]; + Value fatPtrOffset = remappedPtr[1]; + Value origOffset = nonRemappedOffset[0]; + Location curLoc = addPtrOp.getLoc(); + RewriterBase::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(addPtrOp); - ArrayRef remappedOperands = adaptor.getOperands(); - assert(remappedOperands.size() == 2 && remappedOperands[0].size() == 2 && - "expected adaptor to have 2 remapped values"); - Value fatPtrBase = remappedOperands[0][0]; - Value fatPtrOffset = remappedOperands[0][1]; - Value origOffset = remappedOperands[1][0]; - Location curLoc = addPtrOp.getLoc(); - // If it is a scalar pointer update, simply bump the base pointer - if (!isa(addPtrOp.getPtr().getType())) { + if (llvm::isa(addPtrOp.getPtr().getType())) { + assert(llvm::isa(origOffset.getType()) && + "expected offset to be integer type"); auto newAddPtrOp = rewriter.create( - curLoc, TypeRange{fatPtrBase.getType()}, - ValueRange{fatPtrBase, origOffset}, - llvm::ArrayRef{ - rewriter.getNamedAttr("legal", rewriter.getUnitAttr())}); - rewriter.modifyOpInPlace(addPtrOp, [&] { - addPtrOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); - }); + curLoc, fatPtrBase.getType(), fatPtrBase, origOffset); + setRewrittenLegalAttrs(rewriter, addPtrOp, newAddPtrOp); rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, fatPtrOffset}}); + fatPtrs[{newAddPtrOp, fatPtrOffset}].canNarrow = + fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow; return success(); } + assert(llvm::isa(addPtrOp.getPtr().getType()) && + "expected Ptr to be RankedTensorType type"); + // Early exit for the case of a constant tensor - if (Value scalarConst = getScalarConstant(rewriter, curLoc, origOffset)) { - auto newAddPtrOp = rewriter.create( - curLoc, TypeRange{fatPtrBase.getType()}, - ValueRange{fatPtrBase, scalarConst}, - llvm::ArrayRef{ - rewriter.getNamedAttr("legal", rewriter.getUnitAttr())}); - rewriter.modifyOpInPlace(addPtrOp, [&] { - addPtrOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); - }); + if (auto scalarConst = + maybeGetOrCreateScalarConstant(rewriter, curLoc, origOffset)) { + triton::AddPtrOp newAddPtrOp = rewriter.create( + curLoc, fatPtrBase.getType(), fatPtrBase, *scalarConst); + setRewrittenLegalAttrs(rewriter, addPtrOp, newAddPtrOp); rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, fatPtrOffset}}); - // If we are updating the tensor pointer with a uniform value, we can + // If we are updating the tensor pointer with a constant value, we can // propagate the attributes of the tensor pointer to the fat pointer. - fatPtrs[{newAddPtrOp.getResult(), fatPtrOffset}].attributes = - fatPtrs[{fatPtrBase, fatPtrOffset}].attributes; - fatPtrs[{newAddPtrOp.getResult(), fatPtrOffset}].canNarrow = + fatPtrs[{newAddPtrOp, fatPtrOffset}].canNarrow = fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow; return success(); } - int64_t bitness = - cast(origOffset.getType()).getElementTypeBitWidth(); + int64_t bitness = llvm::cast(origOffset.getType()) + .getElementTypeBitWidth(); auto [uniformOffset, nonUniformOffset] = - decomposeOffsetFromExpr(rewriter, curLoc, origOffset, bitness); + createDecomposeOffsetFromExpr(rewriter, curLoc, origOffset, bitness); + + auto newAddPtrOp = rewriter.create( + curLoc, fatPtrBase.getType(), fatPtrBase, uniformOffset); // Vector offset update (if any): bump the tensor offset bool canNarrow = fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow; @@ -1196,8 +575,8 @@ class ConvertAddPtrOp : public PointerCanonPattern { if (!isZeroConst(nonUniformOffset)) { Type addPtrOffsetType = getElementTypeOrSelf(nonUniformOffset); Type fatPtrOffsetType = getElementTypeOrSelf(fatPtrOffset); + // TODO(max): why is this inside this condition? canNarrow = canNarrow && canNarrowOffset(fatPtrOffset, nonUniformOffset); - // Upcast or downcast the offset accordingly if (addPtrOffsetType.isInteger(32) && fatPtrOffsetType.isInteger(64)) nonUniformOffset = @@ -1211,14 +590,7 @@ class ConvertAddPtrOp : public PointerCanonPattern { propagateAtrs = false; } - // Scalar pointer update: bump the scalar pointer - auto newAddPtrOp = rewriter.create( - curLoc, TypeRange{fatPtrBase.getType()}, - ValueRange{fatPtrBase, uniformOffset}, - llvm::ArrayRef{rewriter.getNamedAttr("legal", rewriter.getUnitAttr())}); - rewriter.modifyOpInPlace(addPtrOp, [&] { - addPtrOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); - }); + setRewrittenLegalAttrs(rewriter, addPtrOp, newAddPtrOp); rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, newOffset}}); auto nextFatPtr = std::pair{newAddPtrOp.getResult(), newOffset}; fatPtrs[nextFatPtr].canNarrow = canNarrow; @@ -1230,350 +602,122 @@ class ConvertAddPtrOp : public PointerCanonPattern { } }; -class ConvertSplatOp : public PointerCanonPattern { -public: - using PointerCanonPattern::PointerCanonPattern; +/// Rewrite init args and result type and bb args. +class ConvertSCFForOp : public PointerCanonicalizationPattern { + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; +public: LogicalResult - matchAndRewrite(triton::SplatOp splatOp, OneToNOpAdaptor adaptor, + matchAndRewrite(scf::ForOp forOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - ArrayRef remappedOperands = adaptor.getOperands(); - // see - // https://github.com/llvm/llvm-project/blob/58389b220a9354ed6c34bdb9310a35165579c5e3/mlir/lib/Transforms/Utils/DialectConversion.cpp#L1177 - assert(remappedOperands.size() == 1 && remappedOperands[0].size() == 2 && - "expected adaptor to have 2 remapped values"); - Value fatPtrBase = remappedOperands[0][0]; - Value fatPtrOffset = remappedOperands[0][1]; - assert(llvm::isa(fatPtrBase.getType()) && - "expected fatPtrBase to be a tt.ptr"); - assert(llvm::isa(fatPtrOffset.getType()) && - "expected fatPtrOffset to be an integer type"); + SmallVector valRangeLens; + ArrayRef remappedInits = adaptor.getInitArgs(); + for (ValueRange remappedInit : remappedInits) + valRangeLens.push_back(remappedInit.size()); - RankedTensorType outType = splatOp.getResult().getType(); - llvm::ArrayRef ptrShape = outType.getShape(); - auto newOffsetType = RankedTensorType::get(ptrShape, fatPtrOffset.getType(), - outType.getEncoding()); - Value offset = rewriter.create( - splatOp.getLoc(), newOffsetType, fatPtrOffset); - rewriter.modifyOpInPlace(splatOp, [&] { - splatOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); - }); - rewriter.replaceOpWithMultiple(splatOp, {{splatOp.getSrc(), offset}}); - fatPtrs[{splatOp.getSrc(), offset}].canNarrow = - fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow; - return success(); - } -}; + // rewrite the body bb args + unsigned inputNo = 0; + TypeConverter hackTypeConverter; + hackTypeConverter.addConversion( + [&inputNo, remappedInits = adaptor.getInitArgs()]( + Type inputType, SmallVectorImpl &types) { + // handle the 0th iv + if (inputNo == 0) { + types.append({inputType}); + } else { + SmallVector remappedInitTypes = + llvm::to_vector(remappedInits[inputNo - 1].getTypes()); + types.append(remappedInitTypes); + } + inputNo++; + return success(); + }); + std::optional conversion = + hackTypeConverter.convertBlockSignature(forOp.getBody()); + if (!conversion) + return failure(); + rewriter.applySignatureConversion(forOp.getBody(), *conversion, + &hackTypeConverter); -class ConvertBroadcastOp : public PointerCanonPattern { -public: - using PointerCanonPattern::PointerCanonPattern; + SmallVector initArgs = flattenValues(adaptor.getInitArgs()); + auto newForOp = rewriter.create( + forOp.getLoc(), getSingleValue(adaptor.getLowerBound()), + getSingleValue(adaptor.getUpperBound()), + getSingleValue(adaptor.getStep()), initArgs); - LogicalResult - matchAndRewrite(triton::BroadcastOp broadcastOp, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - ArrayRef remappedOperands = adaptor.getOperands(); - // see - // https://github.com/llvm/llvm-project/blob/58389b220a9354ed6c34bdb9310a35165579c5e3/mlir/lib/Transforms/Utils/DialectConversion.cpp#L1177 - assert(remappedOperands.size() == 1 && remappedOperands[0].size() == 2 && - "expected adaptor to have 2 remapped values"); - Value fatPtrBase = remappedOperands[0][0]; - Value fatPtrOffset = remappedOperands[0][1]; - assert(llvm::isa(fatPtrBase.getType()) && - "expected fatPtrBase to be a tt.ptr"); - assert(llvm::isa(fatPtrOffset.getType()) && - "expected fatPtrOffset to be an integer type"); + newForOp->setAttrs(forOp->getAttrs()); + rewriter.eraseBlock(newForOp.getBody()); + rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(), + newForOp.getRegion().end()); - auto outType = - dyn_cast(broadcastOp.getResult().getType()); - auto ptrShape = outType.getShape(); - auto offsetType = dyn_cast(fatPtrOffset.getType()); - if (!offsetType) - return failure(); + SmallVector packedRets; + for (unsigned i = 0, offset = 0; i < valRangeLens.size(); i++) { + size_t len = valRangeLens[i]; + assert(offset < newForOp->getNumResults() && + "expected offset to be within bounds of results"); + ValueRange mappedValue = newForOp->getResults().slice(offset, len); + packedRets.push_back(mappedValue); + offset += len; + } + + setRewrittenLegalAttrs(rewriter, forOp, newForOp); + rewriter.replaceOpWithMultiple(forOp, packedRets); - auto newOffsetType = RankedTensorType::get( - ptrShape, offsetType.getElementType(), outType.getEncoding()); - Value offset = rewriter.create( - broadcastOp.getLoc(), newOffsetType, fatPtrOffset); - rewriter.modifyOpInPlace(broadcastOp, [&] { - broadcastOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); - }); - rewriter.replaceOpWithMultiple(broadcastOp, - {{broadcastOp.getSrc(), offset}}); - fatPtrs[{broadcastOp.getSrc(), offset}].canNarrow = - fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow; return success(); } }; -class ConvertLoadOp : public PointerCanonPattern { +/// Rewrite with new remapped operands but also if the scf.yield is inside of +/// scf.if (possibly) annotate the scf.if. +class ConvertSCFYieldOp : public PointerCanonicalizationPattern { public: - using PointerCanonPattern::PointerCanonPattern; + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; LogicalResult - matchAndRewrite(triton::LoadOp loadOp, OneToNOpAdaptor adaptor, + matchAndRewrite(scf::YieldOp yieldOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - ValueRange fatPtr = *adaptor.getOperands().begin(); - Value fatPtrBase = fatPtr.front(); - Value fatPtrOffset = fatPtr.back(); - Location curLoc = loadOp.getLoc(); + SmallVector newYieldedValues = flattenValues(adaptor.getOperands()); + // have to mutate here because otherwise scf.if, scf.for, and scf.while will + // get confused about which yield is the "correct" yield (since there will + // be two of them before the rewriter DCEs) + rewriter.modifyOpInPlace(yieldOp, [&]() { + yieldOp.getResultsMutable().clear(); + yieldOp.getResultsMutable().append(newYieldedValues); + }); - llvm::SmallDenseMap attributes{ - {rewriter.getStringAttr("legal"), rewriter.getUnitAttr()}}; - Value newPtr = createTensorPointer( - rewriter, fatPtrBase, fatPtrOffset, curLoc, - fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow, attributes); - SmallVector operands = - loadOp.getOperands().take_back(loadOp.getNumOperands() - 1); - operands.insert(operands.begin(), newPtr); - SmallVector attrs = llvm::to_vector(loadOp->getAttrs()); - attrs.append({rewriter.getNamedAttr("legal", rewriter.getUnitAttr())}); - auto newLoadPtrOp = - rewriter.replaceOpWithNewOp(loadOp, operands, attrs); + // rewriting a parent op from a child op isn't a great idea but there's no + // other to indicate to the parent IfOp that the result type can now be + // rewritten and not before. + if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { + if (ifOp.thenBlock() == yieldOp->getBlock()) + rewriter.modifyOpInPlace(ifOp, [&] { + ifOp->setDiscardableAttr(kSCFThenRewrittenAttr, + rewriter.getUnitAttr()); + }); + else + rewriter.modifyOpInPlace(ifOp, [&] { + ifOp->setDiscardableAttr(kSCFElseRewrittenAttr, + rewriter.getUnitAttr()); + }); + } + + setLegalAttr(rewriter, yieldOp); return success(); } }; -class ConvertStoreOp : public PointerCanonPattern { +/// Rewrite init_args, result type, before region bb args, after region bb args. +class ConvertSCFWhileOp : public PointerCanonicalizationPattern { public: - using PointerCanonPattern::PointerCanonPattern; - + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; LogicalResult - matchAndRewrite(triton::StoreOp storeOp, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - ValueRange fatPtr = *adaptor.getOperands().begin(); - Value fatPtrBase = fatPtr.front(); - Value fatPtrOffset = fatPtr.back(); - Location curLoc = storeOp.getLoc(); - - llvm::SmallDenseMap attributes{ - {rewriter.getStringAttr("legal"), rewriter.getUnitAttr()}}; - Value newPtr = createTensorPointer( - rewriter, fatPtrBase, fatPtrOffset, curLoc, - fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow, attributes); - SmallVector operands = - storeOp.getOperands().take_back(storeOp.getNumOperands() - 1); - operands.insert(operands.begin(), newPtr); - SmallVector attrs = llvm::to_vector(storeOp->getAttrs()); - attrs.append({rewriter.getNamedAttr("legal", rewriter.getUnitAttr())}); - auto newStoreOp = rewriter.replaceOpWithNewOp( - storeOp, TypeRange{}, ValueRange{operands}, attrs); - return success(); - } -}; - -class ConvertFuncOp : public PointerCanonPattern { -public: - using PointerCanonPattern::PointerCanonPattern; - - LogicalResult - matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - int64_t bitness = 64; - rewriter.setInsertionPointToStart(&funcOp.getBody().front()); - rewriter.modifyOpInPlace(funcOp, [&] { - for (auto [idx, arg] : llvm::enumerate(funcOp.getArguments())) { - // The pointer argument needs to be a scalar - if (!isa(arg.getType())) - continue; - if (auto pointerRangeAttr = - funcOp.getArgAttrOfType(idx, "tt.pointer_range")) - bitness = pointerRangeAttr.getInt(); - Value zeroOffset = - rewriter.create(funcOp.getLoc(), 0, bitness); - auto dummyCast = rewriter.create( - arg.getLoc(), TypeRange{arg.getType()}, ValueRange{arg}); - rewriter.replaceUsesOfBlockArgument(arg, dummyCast.getResult(0)); - // TODO(max): why is this true? - fatPtrs[{arg, zeroOffset}].canNarrow = true; - rewriter.replaceOpWithMultiple(dummyCast, {{arg, zeroOffset}}); - } - funcOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); - }); - - return success(); - } -}; - -class ConvertSCFYieldOp : public PointerCanonPattern { -public: - using PointerCanonPattern::PointerCanonPattern; - - LogicalResult - matchAndRewrite(scf::YieldOp yieldOp, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - SmallVector newYieldedValues = flattenValues(adaptor.getOperands()); - rewriter.modifyOpInPlace(yieldOp, [&]() { - yieldOp.getResultsMutable().clear(); - yieldOp.getResultsMutable().append(newYieldedValues); - }); - - // TODO(max): this is bad - if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { - if (ifOp.thenBlock() == yieldOp->getBlock()) - rewriter.modifyOpInPlace(ifOp, [&] { - ifOp->setDiscardableAttr("then_rewritten", rewriter.getUnitAttr()); - }); - else - rewriter.modifyOpInPlace(ifOp, [&] { - ifOp->setDiscardableAttr("else_rewritten", rewriter.getUnitAttr()); - }); - } - - rewriter.modifyOpInPlace(yieldOp, [&] { - yieldOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); - }); - - return success(); - } -}; - -class ConvertUnrealizedConversionCastOp - : public PointerCanonPattern { -public: - using PointerCanonPattern::PointerCanonPattern; - - LogicalResult - matchAndRewrite(UnrealizedConversionCastOp castOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - assert(std::distance(castOp->getUses().begin(), castOp->getUses().end()) > - 0 && - "expected at least 1 use of unrealized_cast"); - // dunno why but i get -Wdangling here... - ArrayRef remappedOperands = adaptor.getOperands(); - assert(remappedOperands.size() == 1 && remappedOperands[0].size() == 2 && - "expected adaptor to have 2 remapped values"); - Value fatPtrBase = remappedOperands[0][0]; - Value fatPtrOffset = remappedOperands[0][1]; - assert(llvm::isa(fatPtrBase.getType()) && - "expected fatPtrBase to be a tt.ptr"); - assert(llvm::isa(fatPtrOffset.getType()) && - "expected fatPtrOffset to be an integer type"); - OpFoldResult maybeScalar = getAsOpFoldResult(fatPtrOffset); - if (auto attr = llvm::dyn_cast(maybeScalar)) { - auto integerAttr = llvm::cast(attr); - if (integerAttr.getValue() == 0) { - rewriter.replaceAllUsesWith(castOp.getResult(0), fatPtrBase); - rewriter.eraseOp(castOp); - return success(); - } - } - return failure(); - } -}; - -class ConvertSCFForOp : public PointerCanonPattern { - using PointerCanonPattern::PointerCanonPattern; - -public: - LogicalResult - matchAndRewrite(scf::ForOp forOp, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - SmallVector valRangeLens; - ArrayRef remappedInits = adaptor.getInitArgs(); - for (ValueRange remappedInit : remappedInits) - valRangeLens.push_back(remappedInit.size()); - TypeConverter hackTypeConverter; - unsigned inputNo = 0; - hackTypeConverter.addConversion( - [&inputNo, &remappedInits = std::as_const(remappedInits)]( - Type inputType, SmallVectorImpl &types) { - // handle the 0th iv - if (inputNo == 0) { - types.append({inputType}); - } else { - SmallVector remappedInitTypes = - llvm::to_vector(remappedInits[inputNo - 1].getTypes()); - types.append(remappedInitTypes); - } - inputNo++; - return success(); - }); - if (failed( - rewriter.convertRegionTypes(&forOp.getRegion(), hackTypeConverter))) - return failure(); - SmallVector initArgs = flattenValues(adaptor.getInitArgs()); - auto newForOp = rewriter.create( - forOp.getLoc(), getSingleValue(adaptor.getLowerBound()), - getSingleValue(adaptor.getUpperBound()), - getSingleValue(adaptor.getStep()), initArgs); - - newForOp->setAttrs(forOp->getAttrs()); - rewriter.eraseBlock(newForOp.getBody(0)); - rewriter.inlineRegionBefore(forOp.getRegion(), newForOp.getRegion(), - newForOp.getRegion().end()); - - SmallVector packedRets; - for (unsigned i = 0, offset = 0; i < valRangeLens.size(); i++) { - size_t len = valRangeLens[i]; - assert(offset < newForOp->getNumResults() && - "expected offset to be within bounds of results"); - ValueRange mappedValue = newForOp->getResults().slice(offset, len); - packedRets.push_back(mappedValue); - offset += len; - } - - rewriter.modifyOpInPlace(forOp, [&] { - forOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); - }); - rewriter.modifyOpInPlace(newForOp, [&] { - newForOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); - }); - rewriter.replaceOpWithMultiple(forOp, packedRets); - - return success(); - } -}; - -class ConvertSCFIfOp : public PointerCanonPattern { -public: - using PointerCanonPattern::PointerCanonPattern; - // One of the two branches is responsible to rewrite the operation. The other - // branch only update the yieldOp with the right parameters - LogicalResult - matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - assert(ifOp.getNumResults() == 1 && - ifOp.thenYield().getOperandTypes().size() == 2 && - "only 1 -> 2 supported for scf::IfOp rewrite"); - bool withElseRegion = ifOp.getNumRegions() > 1; - if (withElseRegion) { - assert(ifOp.thenYield().getOperandTypes() == - ifOp.elseYield().getOperandTypes() && - "ifOp types must match in both arms"); - } - - auto newIfOp = rewriter.create( - ifOp.getLoc(), ifOp.thenYield().getOperandTypes(), ifOp.getCondition(), - withElseRegion); - rewriter.inlineBlockBefore(ifOp.thenBlock(), newIfOp.thenBlock(), - newIfOp.thenBlock()->begin()); - if (withElseRegion) - rewriter.inlineBlockBefore(ifOp.elseBlock(), newIfOp.elseBlock(), - newIfOp.elseBlock()->begin()); - - rewriter.modifyOpInPlace(ifOp, [&] { - ifOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); - }); - rewriter.modifyOpInPlace(newIfOp, [&] { - newIfOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); - }); - rewriter.replaceOpWithMultiple(ifOp, {newIfOp.getResults()}); - - return success(); - } -}; - -class ConvertSCFWhileOp : public PointerCanonPattern { -public: - using PointerCanonPattern::PointerCanonPattern; - LogicalResult - matchAndRewrite(scf::WhileOp whileOp, OneToNOpAdaptor adaptor, + matchAndRewrite(scf::WhileOp whileOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector valRangeLens; ArrayRef remappedInits = adaptor.getInits(); for (ValueRange remappedInit : remappedInits) valRangeLens.push_back(remappedInit.size()); + // rewrite the "before" region (bb args) TypeConverter hackTypeConverter; unsigned inputNo = 0; hackTypeConverter.addConversion( @@ -1588,6 +732,7 @@ class ConvertSCFWhileOp : public PointerCanonPattern { if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(), hackTypeConverter))) return failure(); + // rewrite the "after" region (bb args) inputNo = 0; if (failed(rewriter.convertRegionTypes(&whileOp.getAfter(), hackTypeConverter))) @@ -1615,41 +760,39 @@ class ConvertSCFWhileOp : public PointerCanonPattern { offset += len; } - rewriter.modifyOpInPlace(whileOp, [&] { - whileOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); - }); - rewriter.modifyOpInPlace(newWhileOp, [&] { - newWhileOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); - }); + setRewrittenLegalAttrs(rewriter, whileOp, newWhileOp); rewriter.replaceOpWithMultiple(whileOp, packedRets); return success(); } }; -class ConvertSCFConditionOp : public PointerCanonPattern { +/// Rewrite with new operands. +class ConvertSCFConditionOp + : public PointerCanonicalizationPattern { public: - using PointerCanonPattern::PointerCanonPattern; + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; LogicalResult matchAndRewrite(scf::ConditionOp condOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector newArgs = flattenValues(adaptor.getArgs()); + // have to mutate here because otherwise scf.while will + // get confused about which condition is the "correct" condition (since + // there will be two of them before the rewriter DCEs) rewriter.modifyOpInPlace(condOp, [&]() { condOp.getArgsMutable().clear(); condOp.getArgsMutable().append(newArgs); }); - - rewriter.modifyOpInPlace(condOp, [&] { - condOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); - }); - + setLegalAttr(rewriter, condOp); return success(); } }; -class ConvertCFCondBranch : public PointerCanonPattern { +/// Rewrite operands for both true dest and false dest. +class ConvertCFCondBranch + : public PointerCanonicalizationPattern { public: - using PointerCanonPattern::PointerCanonPattern; + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; LogicalResult matchAndRewrite(cf::CondBranchOp branchOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -1658,13 +801,12 @@ class ConvertCFCondBranch : public PointerCanonPattern { SmallVector falseOperands = flattenValues(adaptor.getFalseDestOperands()); - rewriter.modifyOpInPlace(branchOp, [&] { - branchOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); - }); + setRewrittenAttr(rewriter, branchOp); auto newBrancOp = rewriter.replaceOpWithNewOp( branchOp, branchOp.getCondition(), branchOp.getTrueDest(), trueOperands, branchOp.getFalseDest(), falseOperands); + // convert the type signature of the true dest bb unsigned inputNo = 0; TypeConverter hackTypeConverterTrueDest; hackTypeConverterTrueDest.addConversion( @@ -1676,7 +818,6 @@ class ConvertCFCondBranch : public PointerCanonPattern { inputNo++; return success(); }); - std::optional conversion = hackTypeConverterTrueDest.convertBlockSignature(branchOp.getTrueDest()); if (!conversion) @@ -1684,6 +825,7 @@ class ConvertCFCondBranch : public PointerCanonPattern { rewriter.applySignatureConversion(branchOp.getTrueDest(), *conversion, &hackTypeConverterTrueDest); + // convert the type signature of the false dest bb inputNo = 0; TypeConverter hackTypeConverterFalseDest; hackTypeConverterFalseDest.addConversion( @@ -1702,24 +844,115 @@ class ConvertCFCondBranch : public PointerCanonPattern { return failure(); rewriter.applySignatureConversion(branchOp.getFalseDest(), *conversion, &hackTypeConverterFalseDest); - rewriter.modifyOpInPlace(newBrancOp, [&] { - newBrancOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); - }); + setLegalAttr(rewriter, newBrancOp); + return success(); + } +}; + +/// Rewrite both operands. Note, this should only be reached after both +/// operands have already been rewritten because DialectConversion walks +/// PreOrder in order ForwardDominance order: see +/// https://github.com/llvm/llvm-project/blob/58389b220a9354ed6c34bdb9310a35165579c5e3/mlir/lib/Transforms/Utils/DialectConversion.cpp#L2702 +class ConvertArithSelectOp + : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + LogicalResult + matchAndRewrite(arith::SelectOp selectOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ArrayRef remappedOperands = adaptor.getOperands(); + if (remappedOperands[1].size() != 2 || remappedOperands[2].size() != 2) + return rewriter.notifyMatchFailure( + selectOp, "expected adaptor to have had both true and false operands " + "already remapped"); + // If both have been traversed, then we can rewrite select of pointers as a + // select of base and offset + ValueRange fatPtrTrue = remappedOperands[1]; + ValueRange fatPtrFalse = remappedOperands[2]; + // Simple case of a scalar select: update the base pointer + if (!isa(selectOp.getType())) { + auto newSelectOp = rewriter.create( + selectOp.getLoc(), selectOp.getType(), + // TODO(max): why fatPtrTrue here? + selectOp.getCondition(), fatPtrTrue[0], selectOp.getFalseValue()); + setRewrittenLegalAttrs(rewriter, selectOp, newSelectOp); + rewriter.replaceOpWithMultiple(selectOp, {{newSelectOp, fatPtrTrue[1]}}); + return success(); + } + + // Rewrite to select(fatBaseT, fatBaseF) and select(fatOffsetT, fatOffsetF) + auto newBase = rewriter.create( + selectOp.getLoc(), selectOp.getCondition(), fatPtrTrue[0], + fatPtrFalse[0]); + auto newOffset = rewriter.create( + selectOp.getLoc(), selectOp.getCondition(), fatPtrTrue[1], + fatPtrFalse[1]); + + assert((fatPtrs[{fatPtrTrue[0], fatPtrTrue[1]}].canNarrow == + fatPtrs[{fatPtrFalse[0], fatPtrFalse[1]}].canNarrow) && + "expected can narrow to be the same for both fatPtrT and fatPtrF"); + + setRewrittenLegalAttrs(rewriter, selectOp, newBase); + setRewrittenLegalAttrs(rewriter, selectOp, newOffset); + rewriter.replaceOpWithMultiple(selectOp, {{newBase, newOffset}}); + fatPtrs[{newBase, newOffset}].canNarrow = + fatPtrs[{fatPtrTrue[0], fatPtrTrue[1]}].canNarrow; + + return success(); + } +}; + +/// Rewrite result type only after both arms have been visited. +/// We contrive this to happen, even though DialectConversion does a PreOrder +/// walk, by checking for two attributes in the ConversionTarget +/// ("then_rewritten", and "else_rewritten"). +class ConvertSCFIfOp : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + LogicalResult + matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (ifOp.getNumResults() != 1 || + ifOp.thenYield().getOperandTypes().size() != 2) + return rewriter.notifyMatchFailure( + ifOp, "only 1 -> 2 supported for scf::IfOp rewrite"); + + bool withElseRegion = ifOp.getNumRegions() > 1; + +#ifndef NDEBUG + if (withElseRegion) { + assert(ifOp.thenYield().getOperandTypes() == + ifOp.elseYield().getOperandTypes() && + "ifOp types must match in both arms"); + } +#endif + + auto newIfOp = rewriter.create( + ifOp.getLoc(), ifOp.thenYield().getOperandTypes(), ifOp.getCondition(), + withElseRegion); + rewriter.inlineBlockBefore(ifOp.thenBlock(), newIfOp.thenBlock(), + newIfOp.thenBlock()->begin()); + if (withElseRegion) + rewriter.inlineBlockBefore(ifOp.elseBlock(), newIfOp.elseBlock(), + newIfOp.elseBlock()->begin()); + + setRewrittenLegalAttrs(rewriter, ifOp, newIfOp); + rewriter.replaceOpWithMultiple(ifOp, {newIfOp.getResults()}); + return success(); } }; -class ConvertCFBranch : public PointerCanonPattern { +/// Rewrite the non-cond operands and the signature of the dest bb. +class ConvertCFBranch : public PointerCanonicalizationPattern { public: - using PointerCanonPattern::PointerCanonPattern; + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; LogicalResult matchAndRewrite(cf::BranchOp branchOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector trueOperands = flattenValues(adaptor.getDestOperands()); - rewriter.modifyOpInPlace(branchOp, [&] { - branchOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); - }); + setRewrittenAttr(rewriter, branchOp); auto newBrancOp = rewriter.replaceOpWithNewOp( branchOp, branchOp.getDest(), trueOperands); @@ -1727,7 +960,7 @@ class ConvertCFBranch : public PointerCanonPattern { TypeConverter hackTypeConverterTrueDest; hackTypeConverterTrueDest.addConversion( [&inputNo, remappedOperands = adaptor.getDestOperands()]( - Type inputType, SmallVectorImpl &types) { + Type _inputType, SmallVectorImpl &types) { SmallVector remappedInitTypes = llvm::to_vector(remappedOperands[inputNo].getTypes()); types.append(remappedInitTypes); @@ -1741,66 +974,144 @@ class ConvertCFBranch : public PointerCanonPattern { return failure(); rewriter.applySignatureConversion(branchOp.getDest(), *conversion, &hackTypeConverterTrueDest); + setLegalAttr(rewriter, newBrancOp); + return success(); + } +}; - rewriter.modifyOpInPlace(newBrancOp, [&] { - newBrancOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); - }); +class ConvertLoadOp : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + + LogicalResult + matchAndRewrite(triton::LoadOp loadOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange fatPtr = *adaptor.getOperands().begin(); + Value fatPtrBase = fatPtr.front(); + Value fatPtrOffset = fatPtr.back(); + Location curLoc = loadOp.getLoc(); + + llvm::SmallDenseMap attributes{ + {rewriter.getStringAttr(kLegalAttr), rewriter.getUnitAttr()}}; + Value newPtr = createTensorPointer( + rewriter, fatPtrBase, fatPtrOffset, curLoc, + fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow, attributes); + SmallVector operands = + loadOp.getOperands().take_back(loadOp.getNumOperands() - 1); + operands.insert(operands.begin(), newPtr); + SmallVector attrs = llvm::to_vector(loadOp->getAttrs()); + attrs.append({rewriter.getNamedAttr(kLegalAttr, rewriter.getUnitAttr())}); + auto newLoadPtrOp = + rewriter.replaceOpWithNewOp(loadOp, operands, attrs); + setLegalAttr(rewriter, newLoadPtrOp); return success(); } }; -class ConvertArithSelectOp : public PointerCanonPattern { +class ConvertStoreOp : public PointerCanonicalizationPattern { public: - using PointerCanonPattern::PointerCanonPattern; + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + LogicalResult - matchAndRewrite(arith::SelectOp selectOp, OneToNOpAdaptor adaptor, + matchAndRewrite(triton::StoreOp storeOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - ArrayRef remappedOperands = adaptor.getOperands(); - assert(remappedOperands.size() == 3 && remappedOperands[1].size() == 2 && - remappedOperands[2].size() == 2 && - "expected adaptor to have 3 remapped values, 1 for cond, 2 for each " - "arm"); - // If both have been traversed, then we can rewrite select of pointers as a - // select of base and offset - ValueRange fatPtrT = remappedOperands[1]; - ValueRange fatPtrF = remappedOperands[2]; - // Simple case of a scalar select: update the base pointer - if (!isa(selectOp.getType())) { - auto newSelectOp = rewriter.create( - selectOp.getLoc(), selectOp.getType(), - // TODO(max): why fatPtrTrue here? - selectOp.getCondition(), fatPtrT[0], selectOp.getFalseValue()); - rewriter.modifyOpInPlace(selectOp, [&] { - selectOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); - }); - rewriter.modifyOpInPlace(newSelectOp, [&] { - newSelectOp->setDiscardableAttr("legal", rewriter.getUnitAttr()); - }); - rewriter.replaceOpWithMultiple(selectOp, {{newSelectOp, fatPtrT[1]}}); - return success(); - } + ValueRange fatPtr = *adaptor.getOperands().begin(); + Value fatPtrBase = fatPtr.front(); + Value fatPtrOffset = fatPtr.back(); + Location curLoc = storeOp.getLoc(); - // Rewrite `select` for base and offset - auto newBase = rewriter.create( - selectOp.getLoc(), selectOp.getCondition(), fatPtrT[0], fatPtrF[0]); - auto newOffset = rewriter.create( - selectOp.getLoc(), selectOp.getCondition(), fatPtrT[1], fatPtrF[1]); + llvm::SmallDenseMap attributes{ + {rewriter.getStringAttr(kLegalAttr), rewriter.getUnitAttr()}}; + Value newPtr = createTensorPointer( + rewriter, fatPtrBase, fatPtrOffset, curLoc, + fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow, attributes); + SmallVector operands = + storeOp.getOperands().take_back(storeOp.getNumOperands() - 1); + operands.insert(operands.begin(), newPtr); + SmallVector attrs = llvm::to_vector(storeOp->getAttrs()); + attrs.append({rewriter.getNamedAttr(kLegalAttr, rewriter.getUnitAttr())}); + auto newStoreOp = rewriter.replaceOpWithNewOp( + storeOp, TypeRange{}, ValueRange{operands}, attrs); + setLegalAttr(rewriter, newStoreOp); + return success(); + } +}; - assert((fatPtrs[{fatPtrT[0], fatPtrT[1]}].canNarrow == - fatPtrs[{fatPtrF[0], fatPtrF[1]}].canNarrow)); +/// tt.func gets rewritten differently from all of the other ops - the op itself +/// is not rewritten but all tt.ptr args are rewritten (all uses) to be +/// %1 = unrealize_cast(%arg0: tt.ptr, c0: i32) -> tt.ptr. +/// This unrealized_cast remains through out the first pass of the dialect +/// conversion and is then materialized in the second pass +/// (ConvertUnrealizedConversionCastOp). +class ConvertFuncOp : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; - rewriter.modifyOpInPlace(selectOp, [&] { - selectOp->setDiscardableAttr("rewritten", rewriter.getUnitAttr()); - }); - rewriter.modifyOpInPlace(newBase, [&] { - newBase->setDiscardableAttr("legal", rewriter.getUnitAttr()); - }); - rewriter.modifyOpInPlace(newOffset, [&] { - newOffset->setDiscardableAttr("legal", rewriter.getUnitAttr()); + LogicalResult + matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + int64_t bitness = 64; + rewriter.setInsertionPointToStart(&funcOp.getBody().front()); + rewriter.modifyOpInPlace(funcOp, [&] { + for (auto [idx, arg] : llvm::enumerate(funcOp.getArguments())) { + // The pointer argument needs to be a scalar + if (!isa(arg.getType())) + continue; + if (auto pointerRangeAttr = + funcOp.getArgAttrOfType(idx, "tt.pointer_range")) + bitness = pointerRangeAttr.getInt(); + Value zeroOffset = + rewriter.create(funcOp.getLoc(), 0, bitness); + auto dummyCast = rewriter.create( + arg.getLoc(), TypeRange{arg.getType()}, ValueRange{arg}); + rewriter.replaceUsesOfBlockArgument(arg, dummyCast.getResult(0)); + // TODO(max): why is this true? + fatPtrs[{arg, zeroOffset}].canNarrow = true; + rewriter.replaceOpWithMultiple(dummyCast, {{arg, zeroOffset}}); + } }); + setRewrittenAttr(rewriter, funcOp); - rewriter.replaceOpWithMultiple(selectOp, {{newBase, newOffset}}); + return success(); + } +}; +/// Rewrite %1 = unrealize_cast(%arg0: tt.ptr, c0: i32) -> tt.ptr inserted by +/// ConvertFuncOp to be just %arg0: tt.ptr. +class ConvertUnrealizedConversionCastOp + : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + + LogicalResult + matchAndRewrite(UnrealizedConversionCastOp castOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(std::distance(castOp->getUses().begin(), castOp->getUses().end()) > + 0 && + "expected at least 1 use of unrealized_cast"); + // dunno why but i get -Wdangling here... + ArrayRef remappedOperands = adaptor.getOperands(); + if (remappedOperands.size() != 1 || remappedOperands[0].size() != 2) + return rewriter.notifyMatchFailure( + castOp, "expected CastOp to have already been remapped"); + Value fatPtrBase = remappedOperands[0][0]; + Value fatPtrOffset = remappedOperands[0][1]; + if (!llvm::isa(fatPtrBase.getType())) + return rewriter.notifyMatchFailure(castOp, + "non tt.ptr base unimplemented"); + if (!llvm::isa(fatPtrOffset.getType())) + return rewriter.notifyMatchFailure(castOp, + "non-integer offset unimplemented"); + OpFoldResult maybeScalar = getAsOpFoldResult(fatPtrOffset); + auto integerAttr = llvm::dyn_cast(maybeScalar); + if (!integerAttr || !llvm::isa(integerAttr) || + llvm::cast(integerAttr).getValue() != 0) + return rewriter.notifyMatchFailure( + castOp, "CastOp should have been inserted by ConvertFuncOp and " + "should have constant integer offset=0"); + + rewriter.replaceAllUsesWith(castOp.getResult(0), fatPtrBase); + rewriter.eraseOp(castOp); return success(); } }; @@ -1811,7 +1122,7 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { ConversionTarget target(*context); RewritePatternSet patterns(context); auto isLegal = [](Operation *op) { - if (op->hasAttr("rewritten") || op->hasAttr("legal")) + if (op->hasAttr(kRewrittenAttr) || op->hasAttr(kLegalAttr)) return true; for (OpOperand &operand : op->getOpOperands()) { if (auto arg = llvm::dyn_cast(operand.get())) { @@ -1819,22 +1130,22 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { continue; return false; } - if (operand.get().getDefiningOp()->hasAttr("rewritten")) + if (operand.get().getDefiningOp()->hasAttr(kRewrittenAttr)) return false; } return true; }; target.addDynamicallyLegalDialect( [&isLegal](Operation *op) { - if (llvm::isa(op) && !op->hasAttr("rewritten")) + if (llvm::isa(op) && !op->hasAttr(kRewrittenAttr)) return false; return isLegal(op); }); target.addDynamicallyLegalDialect([&isLegal](Operation *op) { if (auto ifOp = llvm::dyn_cast(op)) - return !(ifOp->hasAttr("then_rewritten") and - ifOp->hasAttr("else_rewritten")); - if (llvm::isa(op) && !op->hasAttr("legal")) + return !(ifOp->hasAttr(kSCFThenRewrittenAttr) and + ifOp->hasAttr(kSCFElseRewrittenAttr)); + if (llvm::isa(op) && !op->hasAttr(kLegalAttr)) return false; return isLegal(op); }); @@ -1870,17 +1181,11 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { return signalPassFailure(); module.walk([](Operation *op) { - op->removeDiscardableAttr("rewritten"); - op->removeDiscardableAttr("legal"); + op->removeDiscardableAttr(kRewrittenAttr); + op->removeDiscardableAttr(kLegalAttr); }); } -void TritonAMDGPUCanonicalizePointersPass::runOnOperationmine() { - ModuleOp m = getOperation(); - if (failed(PointerCanonicalizer(m).run())) - signalPassFailure(); -} - std::unique_ptr mlir::createTritonAMDGPUCanonicalizePointersPass() { return std::make_unique(); } From 55c3dd23bfbb45e28900d390664898c63a8851e2 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 17 Dec 2024 21:10:41 -0500 Subject: [PATCH 10/17] propagate canNarrow through scf.if, scf.forop, scf.while, cf.br --- .../CanonicalizePointers.cpp | 343 ++++++++++++------ 1 file changed, 240 insertions(+), 103 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index f3902f770f55..c7965f5185dc 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -263,12 +263,33 @@ std::pair createDecomposeOffsetFromExpr(RewriterBase &rewriter, return offsets; } -static const char kLegalAttr[] = "__amd-pointer-canonicalize-legal__"; -static const char kRewrittenAttr[] = "__amd-pointer-canonicalize-rewritten__"; -static const char kSCFThenRewrittenAttr[] = - "__amd-pointer-canonicalize-scf-then-rewritten__"; -static const char kSCFElseRewrittenAttr[] = - "__amd-pointer-canonicalize-scf-else-rewritten__"; +static const std::string kPtrCanonPrefix = "__amdpointercanonicalize."; +static const std::string kLegalAttr = kPtrCanonPrefix + "legal__"; +static const std::string kRewrittenAttr = kPtrCanonPrefix + "rewritten__"; +static const std::string kSCFThenRewrittenAttr = + kPtrCanonPrefix + "scf-then-rewritten__"; +static const std::string kSCFElseRewrittenAttr = + kPtrCanonPrefix + "scf-else-rewritten__"; +static const std::string kSCFIfOpYieldFatPtrOffsets = + kPtrCanonPrefix + "scf-if-yield-fatptr-offsets__"; + +static void setLegalAttr(RewriterBase &rewriter, Operation *newOp) { + rewriter.modifyOpInPlace(newOp, [&] { + newOp->setDiscardableAttr(kLegalAttr, rewriter.getUnitAttr()); + }); +} + +static void setRewrittenAttr(RewriterBase &rewriter, Operation *origOp) { + rewriter.modifyOpInPlace(origOp, [&] { + origOp->setDiscardableAttr(kRewrittenAttr, rewriter.getUnitAttr()); + }); +} + +static void setRewrittenLegalAttrs(RewriterBase &rewriter, Operation *origOp, + Operation *newOp) { + setRewrittenAttr(rewriter, origOp); + setLegalAttr(rewriter, newOp); +} Value createTensorPointer( RewriterBase &rewriter, Value basePtr, Value offset, Location loc, @@ -296,12 +317,12 @@ Value createTensorPointer( if (canNarrow) offset = createNarrow64bitOffsetTo32bits(rewriter, loc, offset); - Value tensorPtr = rewriter.create( - loc, TypeRange{tensorPtrType}, ValueRange{basePtr}, - SmallVector{rewriter.getNamedAttr(kLegalAttr, rewriter.getUnitAttr())}); - - auto addPtrOp = + triton::SplatOp tensorPtr = + rewriter.create(loc, tensorPtrType, basePtr); + setLegalAttr(rewriter, tensorPtr); + triton::AddPtrOp addPtrOp = rewriter.create(loc, tensorPtrType, tensorPtr, offset); + setLegalAttr(rewriter, addPtrOp); for (auto attribute : attributes) addPtrOp->setAttr(attribute.getFirst(), attribute.getSecond()); @@ -340,6 +361,7 @@ struct FatPointers { template using const_arg_type_t = typename llvm::const_pointer_or_const_ref::type; const ValueT &at(const_arg_type_t k) const { return pointers.at(k); } + const bool contains(const KeyT &k) { return pointers.contains(k); } private: DenseMapT pointers; @@ -411,26 +433,6 @@ struct PointerCanonicalizationPattern : OpConversionPattern { FatPointers &fatPtrs; }; -static void setLegalAttr(ConversionPatternRewriter &rewriter, - Operation *newOp) { - rewriter.modifyOpInPlace(newOp, [&] { - newOp->setDiscardableAttr(kLegalAttr, rewriter.getUnitAttr()); - }); -} - -static void setRewrittenAttr(ConversionPatternRewriter &rewriter, - Operation *origOp) { - rewriter.modifyOpInPlace(origOp, [&] { - origOp->setDiscardableAttr(kRewrittenAttr, rewriter.getUnitAttr()); - }); -} - -static void setRewrittenLegalAttrs(ConversionPatternRewriter &rewriter, - Operation *origOp, Operation *newOp) { - setRewrittenAttr(rewriter, origOp); - setLegalAttr(rewriter, newOp); -} - /// splat integer offset, keep base class ConvertSplatOp : public PointerCanonicalizationPattern { public: @@ -617,8 +619,8 @@ class ConvertSCFForOp : public PointerCanonicalizationPattern { // rewrite the body bb args unsigned inputNo = 0; - TypeConverter hackTypeConverter; - hackTypeConverter.addConversion( + TypeConverter localTypeConverter; + localTypeConverter.addConversion( [&inputNo, remappedInits = adaptor.getInitArgs()]( Type inputType, SmallVectorImpl &types) { // handle the 0th iv @@ -633,11 +635,25 @@ class ConvertSCFForOp : public PointerCanonicalizationPattern { return success(); }); std::optional conversion = - hackTypeConverter.convertBlockSignature(forOp.getBody()); + localTypeConverter.convertBlockSignature(forOp.getBody()); if (!conversion) return failure(); - rewriter.applySignatureConversion(forOp.getBody(), *conversion, - &hackTypeConverter); + auto newBodyBlock = rewriter.applySignatureConversion( + forOp.getBody(), *conversion, &localTypeConverter); + + // propagate canNarrow to bb arg fatPtrs in for body bb + // skip iv at index 0 + int offset = 1; + for (auto operands : remappedInits) { + if (operands.size() == 2) { + assert(fatPtrs.contains({operands[0], operands[1]}) && + "expected fatPtrs to contain remapped fat pointer"); + fatPtrs[{newBodyBlock->getArgument(offset), + newBodyBlock->getArgument(offset + 1)}] + .canNarrow = fatPtrs[{operands[0], operands[1]}].canNarrow; + } + offset += operands.size(); + } SmallVector initArgs = flattenValues(adaptor.getInitArgs()); auto newForOp = rewriter.create( @@ -676,7 +692,8 @@ class ConvertSCFYieldOp : public PointerCanonicalizationPattern { LogicalResult matchAndRewrite(scf::YieldOp yieldOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector newYieldedValues = flattenValues(adaptor.getOperands()); + ArrayRef remappedYields = adaptor.getOperands(); + SmallVector newYieldedValues = flattenValues(remappedYields); // have to mutate here because otherwise scf.if, scf.for, and scf.while will // get confused about which yield is the "correct" yield (since there will // be two of them before the rewriter DCEs) @@ -689,16 +706,33 @@ class ConvertSCFYieldOp : public PointerCanonicalizationPattern { // other to indicate to the parent IfOp that the result type can now be // rewritten and not before. if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { - if (ifOp.thenBlock() == yieldOp->getBlock()) + if (ifOp.thenBlock() == yieldOp->getBlock()) { rewriter.modifyOpInPlace(ifOp, [&] { ifOp->setDiscardableAttr(kSCFThenRewrittenAttr, rewriter.getUnitAttr()); }); - else + } else { rewriter.modifyOpInPlace(ifOp, [&] { ifOp->setDiscardableAttr(kSCFElseRewrittenAttr, rewriter.getUnitAttr()); }); + } + // set indices of fatPtrs so that IfOp can propagate canNarrow to + // result users + int offset = 0; + SmallVector fatPtrOffsets; + for (auto operands : remappedYields) { + if (operands.size() == 2) { + assert(fatPtrs.contains({operands[0], operands[1]}) && + "expected fatPtrs to contain remapped fat pointer"); + fatPtrOffsets.push_back(offset); + } + offset += operands.size(); + } + if (!fatPtrOffsets.empty()) + yieldOp->setDiscardableAttr( + kSCFIfOpYieldFatPtrOffsets, + rewriter.getDenseI64ArrayAttr(fatPtrOffsets)); } setLegalAttr(rewriter, yieldOp); @@ -717,26 +751,52 @@ class ConvertSCFWhileOp : public PointerCanonicalizationPattern { ArrayRef remappedInits = adaptor.getInits(); for (ValueRange remappedInit : remappedInits) valRangeLens.push_back(remappedInit.size()); - // rewrite the "before" region (bb args) - TypeConverter hackTypeConverter; + + // rewrite the "before" block bb args unsigned inputNo = 0; - hackTypeConverter.addConversion( - [&inputNo, &remappedInits = std::as_const(remappedInits)]( - Type inputType, SmallVectorImpl &types) { + TypeConverter localTypeConverter; + localTypeConverter.addConversion( + [&inputNo, remappedInits](Type _inputType, + SmallVectorImpl &types) { SmallVector remappedInitTypes = llvm::to_vector(remappedInits[inputNo].getTypes()); types.append(remappedInitTypes); inputNo++; return success(); }); - if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(), - hackTypeConverter))) + std::optional conversion = + localTypeConverter.convertBlockSignature(whileOp.getBeforeBody()); + if (!conversion) return failure(); - // rewrite the "after" region (bb args) - inputNo = 0; - if (failed(rewriter.convertRegionTypes(&whileOp.getAfter(), - hackTypeConverter))) + auto newBeforeBodyBlock = rewriter.applySignatureConversion( + whileOp.getBeforeBody(), *conversion, &localTypeConverter); + + auto propagateCanNarrowToBlock = [remappedInits, this](Block *block) { + int offset = 0; + for (auto operands : remappedInits) { + if (operands.size() == 2) { + assert(fatPtrs.contains({operands[0], operands[1]}) && + "expected fatPtrs to contain remapped fat pointer"); + fatPtrs[{block->getArgument(offset), block->getArgument(offset + 1)}] + .canNarrow = fatPtrs[{operands[0], operands[1]}].canNarrow; + } + offset += operands.size(); + } + }; + + // propagate canNarrow to bb arg fatPtrs in before bb + propagateCanNarrowToBlock(newBeforeBodyBlock); + + // rewrite the "after" block bb args + conversion = + localTypeConverter.convertBlockSignature(whileOp.getAfterBody()); + if (!conversion) return failure(); + auto newAfterBodyBlock = rewriter.applySignatureConversion( + whileOp.getAfterBody(), *conversion, &localTypeConverter); + + // propagate canNarrow to bb arg fatPtrs in after bb + propagateCanNarrowToBlock(newAfterBodyBlock); SmallVector initArgs = flattenValues(remappedInits); SmallVector resultTypes = @@ -796,54 +856,76 @@ class ConvertCFCondBranch LogicalResult matchAndRewrite(cf::CondBranchOp branchOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector trueOperands = - flattenValues(adaptor.getTrueDestOperands()); - SmallVector falseOperands = - flattenValues(adaptor.getFalseDestOperands()); + ArrayRef remappedTrueOperands = adaptor.getTrueDestOperands(); + ArrayRef remappedFalseOperands = adaptor.getFalseDestOperands(); + SmallVector trueOperands = flattenValues(remappedTrueOperands); + SmallVector falseOperands = flattenValues(remappedFalseOperands); setRewrittenAttr(rewriter, branchOp); auto newBrancOp = rewriter.replaceOpWithNewOp( branchOp, branchOp.getCondition(), branchOp.getTrueDest(), trueOperands, branchOp.getFalseDest(), falseOperands); + // can't put inputNo inside because of limited lifetime (it'll be popped + // from stack memory after lambda returns...) + auto makeTypeConv = [](unsigned &inputNo, + ArrayRef remappedOperands) { + return [&inputNo, remappedOperands](Type inputType, + SmallVectorImpl &types) { + SmallVector remappedInitTypes = + llvm::to_vector(remappedOperands[inputNo].getTypes()); + types.append(remappedInitTypes); + inputNo++; + return success(); + }; + }; + + auto propagateCanNarrowToBlock = [this](Block *block, + ArrayRef + remappedOperands) { + int offset = 0; + for (auto operands : remappedOperands) { + if (operands.size() == 2) { + assert(fatPtrs.contains({operands[0], operands[1]}) && + "expected fatPtrs to contain remapped fat pointer"); + fatPtrs[{block->getArgument(offset), block->getArgument(offset + 1)}] + .canNarrow = fatPtrs[{operands[0], operands[1]}].canNarrow; + } + offset += operands.size(); + } + }; + // convert the type signature of the true dest bb + TypeConverter localTypeConverterTrueDest; unsigned inputNo = 0; - TypeConverter hackTypeConverterTrueDest; - hackTypeConverterTrueDest.addConversion( - [&inputNo, remappedOperands = adaptor.getTrueDestOperands()]( - Type inputType, SmallVectorImpl &types) { - SmallVector remappedInitTypes = - llvm::to_vector(remappedOperands[inputNo].getTypes()); - types.append(remappedInitTypes); - inputNo++; - return success(); - }); + localTypeConverterTrueDest.addConversion( + makeTypeConv(inputNo, remappedTrueOperands)); std::optional conversion = - hackTypeConverterTrueDest.convertBlockSignature(branchOp.getTrueDest()); + localTypeConverterTrueDest.convertBlockSignature( + branchOp.getTrueDest()); if (!conversion) return failure(); - rewriter.applySignatureConversion(branchOp.getTrueDest(), *conversion, - &hackTypeConverterTrueDest); + auto newTrueBlock = rewriter.applySignatureConversion( + branchOp.getTrueDest(), *conversion, &localTypeConverterTrueDest); + + // propagate canNarrow to bb arg fatPtrs in true bb + propagateCanNarrowToBlock(newTrueBlock, remappedTrueOperands); // convert the type signature of the false dest bb inputNo = 0; - TypeConverter hackTypeConverterFalseDest; - hackTypeConverterFalseDest.addConversion( - [&inputNo, remappedOperands = adaptor.getFalseDestOperands()]( - Type inputType, SmallVectorImpl &types) { - SmallVector remappedInitTypes = - llvm::to_vector(remappedOperands[inputNo].getTypes()); - types.append(remappedInitTypes); - inputNo++; - return success(); - }); - - conversion = hackTypeConverterFalseDest.convertBlockSignature( + TypeConverter localTypeConverterFalseDest; + localTypeConverterFalseDest.addConversion( + makeTypeConv(inputNo, remappedFalseOperands)); + conversion = localTypeConverterFalseDest.convertBlockSignature( branchOp.getFalseDest()); if (!conversion) return failure(); - rewriter.applySignatureConversion(branchOp.getFalseDest(), *conversion, - &hackTypeConverterFalseDest); + auto newFalseBlock = rewriter.applySignatureConversion( + branchOp.getFalseDest(), *conversion, &localTypeConverterFalseDest); + + // propagate canNarrow to bb arg fatPtrs in false bb + propagateCanNarrowToBlock(newFalseBlock, remappedFalseOperands); + setLegalAttr(rewriter, newBrancOp); return success(); } @@ -912,10 +994,8 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern { LogicalResult matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (ifOp.getNumResults() != 1 || - ifOp.thenYield().getOperandTypes().size() != 2) - return rewriter.notifyMatchFailure( - ifOp, "only 1 -> 2 supported for scf::IfOp rewrite"); + assert(ifOp.thenYield()->hasAttr(kSCFIfOpYieldFatPtrOffsets) && + "expected then yield to report fat ptr indices"); bool withElseRegion = ifOp.getNumRegions() > 1; @@ -924,6 +1004,32 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern { assert(ifOp.thenYield().getOperandTypes() == ifOp.elseYield().getOperandTypes() && "ifOp types must match in both arms"); + assert(ifOp.elseYield()->hasAttr(kSCFIfOpYieldFatPtrOffsets) && + "expected then yield to report fat ptr indices"); + if (auto thenFatPtrIndxs = ifOp.thenYield()->getDiscardableAttr( + kSCFIfOpYieldFatPtrOffsets)) { + auto elseFatPtrIndx = + ifOp.elseYield()->getDiscardableAttr(kSCFIfOpYieldFatPtrOffsets); + assert(elseFatPtrIndx && + "expected else fat ptr indices as well as then fat ptr indices"); + for (auto [i, j] : llvm::zip( + llvm::cast(thenFatPtrIndxs).asArrayRef(), + llvm::cast(elseFatPtrIndx).asArrayRef())) { + assert(i == j && + "expected thenFatPtrIndxs and elseFatPtrIndxs to agree"); + assert(i < ifOp.thenYield().getNumOperands() && + i + 1 < ifOp.thenYield().getNumOperands() && + "expected idx to be within bounds of IfOp's results"); + Value thenFatPtrBase = ifOp.thenYield().getOperand(i); + Value thenFatPtrOffset = ifOp.thenYield().getOperand(i + 1); + Value elseFatPtrBase = ifOp.elseYield().getOperand(i); + Value elseFatPtrOffset = ifOp.elseYield().getOperand(i + 1); + assert((fatPtrs[{thenFatPtrBase, thenFatPtrOffset}].canNarrow == + fatPtrs[{elseFatPtrBase, elseFatPtrOffset}].canNarrow) && + "expected then fat ptr canNarrow and else fat ptr canNarrow " + "to be equal"); + } + } } #endif @@ -939,6 +1045,16 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern { setRewrittenLegalAttrs(rewriter, ifOp, newIfOp); rewriter.replaceOpWithMultiple(ifOp, {newIfOp.getResults()}); + for (int64_t idx : + llvm::cast(newIfOp.thenYield()->getDiscardableAttr( + kSCFIfOpYieldFatPtrOffsets)) + .asArrayRef()) { + Value thenFatPtrBase = newIfOp.thenYield().getOperand(idx); + Value thenFatPtrOffset = newIfOp.thenYield().getOperand(idx + 1); + fatPtrs[{newIfOp.getResult(idx), newIfOp.getResult(idx + 1)}].canNarrow = + fatPtrs[{thenFatPtrBase, thenFatPtrOffset}].canNarrow; + } + return success(); } }; @@ -950,30 +1066,43 @@ class ConvertCFBranch : public PointerCanonicalizationPattern { LogicalResult matchAndRewrite(cf::BranchOp branchOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector trueOperands = flattenValues(adaptor.getDestOperands()); + ArrayRef remappedDestOperands = adaptor.getDestOperands(); + SmallVector trueOperands = flattenValues(remappedDestOperands); setRewrittenAttr(rewriter, branchOp); auto newBrancOp = rewriter.replaceOpWithNewOp( branchOp, branchOp.getDest(), trueOperands); unsigned inputNo = 0; - TypeConverter hackTypeConverterTrueDest; - hackTypeConverterTrueDest.addConversion( - [&inputNo, remappedOperands = adaptor.getDestOperands()]( - Type _inputType, SmallVectorImpl &types) { + TypeConverter localTypeConverterTrueDest; + localTypeConverterTrueDest.addConversion( + [&inputNo, remappedDestOperands](Type _inputType, + SmallVectorImpl &types) { SmallVector remappedInitTypes = - llvm::to_vector(remappedOperands[inputNo].getTypes()); + llvm::to_vector(remappedDestOperands[inputNo].getTypes()); types.append(remappedInitTypes); inputNo++; return success(); }); - std::optional conversion = - hackTypeConverterTrueDest.convertBlockSignature(branchOp.getDest()); + localTypeConverterTrueDest.convertBlockSignature(branchOp.getDest()); if (!conversion) return failure(); - rewriter.applySignatureConversion(branchOp.getDest(), *conversion, - &hackTypeConverterTrueDest); + auto newDestBlock = rewriter.applySignatureConversion( + branchOp.getDest(), *conversion, &localTypeConverterTrueDest); + + int offset = 0; + for (auto operands : remappedDestOperands) { + if (operands.size() == 2) { + assert(fatPtrs.contains({operands[0], operands[1]}) && + "expected fatPtrs to contain remapped fat pointer"); + fatPtrs[{newDestBlock->getArgument(offset), + newDestBlock->getArgument(offset + 1)}] + .canNarrow = fatPtrs[{operands[0], operands[1]}].canNarrow; + } + offset += operands.size(); + } + setLegalAttr(rewriter, newBrancOp); return success(); } @@ -986,9 +1115,12 @@ class ConvertLoadOp : public PointerCanonicalizationPattern { LogicalResult matchAndRewrite(triton::LoadOp loadOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - ValueRange fatPtr = *adaptor.getOperands().begin(); - Value fatPtrBase = fatPtr.front(); - Value fatPtrOffset = fatPtr.back(); + ValueRange fatPtr = adaptor.getPtr(); + if (fatPtr.size() != 2) + return rewriter.notifyMatchFailure( + loadOp, "expected LoadOp ptr to have already been remapped"); + Value fatPtrBase = fatPtr[0]; + Value fatPtrOffset = fatPtr[1]; Location curLoc = loadOp.getLoc(); llvm::SmallDenseMap attributes{ @@ -1015,9 +1147,12 @@ class ConvertStoreOp : public PointerCanonicalizationPattern { LogicalResult matchAndRewrite(triton::StoreOp storeOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - ValueRange fatPtr = *adaptor.getOperands().begin(); - Value fatPtrBase = fatPtr.front(); - Value fatPtrOffset = fatPtr.back(); + ValueRange fatPtr = adaptor.getPtr(); + if (fatPtr.size() != 2) + return rewriter.notifyMatchFailure( + storeOp, "expected StoreOp ptr to have already been remapped"); + Value fatPtrBase = fatPtr[0]; + Value fatPtrOffset = fatPtr[1]; Location curLoc = storeOp.getLoc(); llvm::SmallDenseMap attributes{ @@ -1181,8 +1316,10 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { return signalPassFailure(); module.walk([](Operation *op) { - op->removeDiscardableAttr(kRewrittenAttr); - op->removeDiscardableAttr(kLegalAttr); + for (auto attr : op->getDiscardableAttrs()) { + if (attr.getName().strref().starts_with(kPtrCanonPrefix)) + op->removeDiscardableAttr(attr.getName()); + } }); } From 634cff2997dc6cdffa1c6594ec843cbf1eb05868 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 17 Dec 2024 22:49:53 -0500 Subject: [PATCH 11/17] last test case --- .../CanonicalizePointers.cpp | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index c7965f5185dc..12100ff888b6 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -1125,9 +1125,11 @@ class ConvertLoadOp : public PointerCanonicalizationPattern { llvm::SmallDenseMap attributes{ {rewriter.getStringAttr(kLegalAttr), rewriter.getUnitAttr()}}; - Value newPtr = createTensorPointer( - rewriter, fatPtrBase, fatPtrOffset, curLoc, - fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow, attributes); + Value newPtr = fatPtrBase; + if (llvm::isa(loadOp.getPtr().getType())) + newPtr = createTensorPointer( + rewriter, fatPtrBase, fatPtrOffset, curLoc, + fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow, attributes); SmallVector operands = loadOp.getOperands().take_back(loadOp.getNumOperands() - 1); operands.insert(operands.begin(), newPtr); @@ -1157,9 +1159,12 @@ class ConvertStoreOp : public PointerCanonicalizationPattern { llvm::SmallDenseMap attributes{ {rewriter.getStringAttr(kLegalAttr), rewriter.getUnitAttr()}}; - Value newPtr = createTensorPointer( - rewriter, fatPtrBase, fatPtrOffset, curLoc, - fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow, attributes); + + Value newPtr = fatPtrBase; + if (llvm::isa(storeOp.getPtr().getType())) + newPtr = createTensorPointer( + rewriter, fatPtrBase, fatPtrOffset, curLoc, + fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow, attributes); SmallVector operands = storeOp.getOperands().take_back(storeOp.getNumOperands() - 1); operands.insert(operands.begin(), newPtr); From e028676177561de19cdf827a5e17eeb7fd14a589 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 18 Dec 2024 14:09:07 -0500 Subject: [PATCH 12/17] use forwardslice instead of attributes to check legality --- .../CanonicalizePointers.cpp | 391 ++++++++++-------- 1 file changed, 215 insertions(+), 176 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index 12100ff888b6..7f0ce0a1b2f7 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -19,6 +19,7 @@ #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" @@ -35,6 +36,7 @@ #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; +namespace tt = triton; // ----------------------------------------------------------------------------- // Pointer canonicalizer utility class @@ -119,7 +121,7 @@ std::optional maybeGetOrCreateScalarConstant(RewriterBase &rewriter, Operation *op = expr.getDefiningOp(); // Check for splatness - if (auto splatOp = dyn_cast_or_null(op)) + if (auto splatOp = dyn_cast_or_null(op)) return splatOp.getSrc(); // Check for constant @@ -147,7 +149,7 @@ std::optional maybeGetOrCreateScalarConstant(RewriterBase &rewriter, // TODO(max): is this correct? bool canNarrowOffset(Value baseOffset, Value addOffset) { Type addOffsetType = getElementTypeOrSelf(addOffset); - auto baseSplatOp = baseOffset.getDefiningOp(); + auto baseSplatOp = baseOffset.getDefiningOp(); return baseSplatOp && addOffsetType.isInteger(32); } @@ -193,9 +195,9 @@ std::pair createDecomposeOffsetFromMul(RewriterBase &rewriter, Value uniformMul = rewriter.create(loc, uniformOffsetL, uniformOffsetR); - Value uniformOffsetLSplat = rewriter.create( + Value uniformOffsetLSplat = rewriter.create( loc, nonUniformOffsetL.getType(), uniformOffsetL); - Value uniformOffsetRSplat = rewriter.create( + Value uniformOffsetRSplat = rewriter.create( loc, nonUniformOffsetR.getType(), uniformOffsetR); Value nonUNonU = @@ -232,17 +234,17 @@ std::pair createDecomposeOffsetFromExpr(RewriterBase &rewriter, auto offsets = llvm::TypeSwitch>( expr.getDefiningOp()) - .Case([&](auto broadcastOp) { + .Case([&](auto broadcastOp) { auto [uniform, nonUniform] = createDecomposeOffsetFromExpr( rewriter, loc, broadcastOp.getSrc(), bitness); - auto broadcastNonUniform = rewriter.create( + auto broadcastNonUniform = rewriter.create( loc, broadcastOp.getType(), nonUniform); return std::make_pair(uniform, broadcastNonUniform); }) - .Case([&](auto expandOp) { + .Case([&](auto expandOp) { auto [uniform, nonUniform] = createDecomposeOffsetFromExpr( rewriter, loc, expandOp.getSrc(), bitness); - auto expandNonUniform = rewriter.create( + auto expandNonUniform = rewriter.create( loc, nonUniform, expandOp.getAxis()); return std::make_pair(uniform, expandNonUniform); }) @@ -264,8 +266,6 @@ std::pair createDecomposeOffsetFromExpr(RewriterBase &rewriter, } static const std::string kPtrCanonPrefix = "__amdpointercanonicalize."; -static const std::string kLegalAttr = kPtrCanonPrefix + "legal__"; -static const std::string kRewrittenAttr = kPtrCanonPrefix + "rewritten__"; static const std::string kSCFThenRewrittenAttr = kPtrCanonPrefix + "scf-then-rewritten__"; static const std::string kSCFElseRewrittenAttr = @@ -273,24 +273,6 @@ static const std::string kSCFElseRewrittenAttr = static const std::string kSCFIfOpYieldFatPtrOffsets = kPtrCanonPrefix + "scf-if-yield-fatptr-offsets__"; -static void setLegalAttr(RewriterBase &rewriter, Operation *newOp) { - rewriter.modifyOpInPlace(newOp, [&] { - newOp->setDiscardableAttr(kLegalAttr, rewriter.getUnitAttr()); - }); -} - -static void setRewrittenAttr(RewriterBase &rewriter, Operation *origOp) { - rewriter.modifyOpInPlace(origOp, [&] { - origOp->setDiscardableAttr(kRewrittenAttr, rewriter.getUnitAttr()); - }); -} - -static void setRewrittenLegalAttrs(RewriterBase &rewriter, Operation *origOp, - Operation *newOp) { - setRewrittenAttr(rewriter, origOp); - setLegalAttr(rewriter, newOp); -} - Value createTensorPointer( RewriterBase &rewriter, Value basePtr, Value offset, Location loc, bool canNarrow, @@ -299,8 +281,8 @@ Value createTensorPointer( // Scalar case: we only need to `tt.addptr %basePtr, %offset` if (!tensorType) { - auto addPtrOp = rewriter.create(loc, basePtr.getType(), - basePtr, offset); + auto addPtrOp = + rewriter.create(loc, basePtr.getType(), basePtr, offset); for (auto attribute : attributes) addPtrOp->setAttr(attribute.getFirst(), attribute.getSecond()); return addPtrOp.getResult(); @@ -317,12 +299,10 @@ Value createTensorPointer( if (canNarrow) offset = createNarrow64bitOffsetTo32bits(rewriter, loc, offset); - triton::SplatOp tensorPtr = - rewriter.create(loc, tensorPtrType, basePtr); - setLegalAttr(rewriter, tensorPtr); - triton::AddPtrOp addPtrOp = - rewriter.create(loc, tensorPtrType, tensorPtr, offset); - setLegalAttr(rewriter, addPtrOp); + tt::SplatOp tensorPtr = + rewriter.create(loc, tensorPtrType, basePtr); + tt::AddPtrOp addPtrOp = + rewriter.create(loc, tensorPtrType, tensorPtr, offset); for (auto attribute : attributes) addPtrOp->setAttr(attribute.getFirst(), attribute.getSecond()); @@ -427,27 +407,49 @@ static Value getSingleValue(ValueRange values) { template struct PointerCanonicalizationPattern : OpConversionPattern { - PointerCanonicalizationPattern(MLIRContext *context, FatPointers &fatPtrs, + PointerCanonicalizationPattern(MLIRContext *context, + llvm::SetVector &opsToRewrite, + FatPointers &fatPtrs, PatternBenefit benefit = 1) - : OpConversionPattern(context, benefit), fatPtrs(fatPtrs) {} + : OpConversionPattern(context, benefit), fatPtrs(fatPtrs), + opToRewrite(opsToRewrite) {} + + virtual LogicalResult matchAndRewrite_( + SourceOp op, + typename OpConversionPattern::OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + llvm_unreachable("must override matchAndRewrite_"); + } + + LogicalResult matchAndRewrite( + SourceOp op, + typename OpConversionPattern::OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(matchAndRewrite_(op, adaptor, rewriter))) + return failure(); + opToRewrite.remove(op); + return success(); + } + FatPointers &fatPtrs; + llvm::SetVector &opToRewrite; }; /// splat integer offset, keep base -class ConvertSplatOp : public PointerCanonicalizationPattern { +class ConvertSplatOp : public PointerCanonicalizationPattern { public: using PointerCanonicalizationPattern::PointerCanonicalizationPattern; LogicalResult - matchAndRewrite(triton::SplatOp splatOp, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + matchAndRewrite_(tt::SplatOp splatOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { ValueRange remappedOperands = adaptor.getSrc(); if (remappedOperands.size() != 2) return rewriter.notifyMatchFailure( splatOp, "expected SplatOp src to have already been remapped"); Value fatPtrBase = remappedOperands[0]; Value fatPtrOffset = remappedOperands[1]; - if (!llvm::isa(fatPtrBase.getType())) + if (!llvm::isa(fatPtrBase.getType())) return rewriter.notifyMatchFailure(splatOp, "non tt.ptr base unimplemented"); if (!llvm::isa(fatPtrOffset.getType())) @@ -457,24 +459,25 @@ class ConvertSplatOp : public PointerCanonicalizationPattern { RankedTensorType outType = splatOp.getResult().getType(); auto newOffsetType = RankedTensorType::get( outType.getShape(), fatPtrOffset.getType(), outType.getEncoding()); - triton::SplatOp offset = rewriter.create( + tt::SplatOp offset = rewriter.create( splatOp.getLoc(), newOffsetType, fatPtrOffset); - setRewrittenLegalAttrs(rewriter, splatOp, offset); rewriter.replaceOpWithMultiple(splatOp, {{fatPtrBase, offset}}); + opToRewrite.remove(splatOp); fatPtrs[{fatPtrBase, offset}] = fatPtrs[{fatPtrBase, fatPtrOffset}]; + return success(); } }; /// Broadcast offset, keep base. class ConvertBroadcastOp - : public PointerCanonicalizationPattern { + : public PointerCanonicalizationPattern { public: using PointerCanonicalizationPattern::PointerCanonicalizationPattern; LogicalResult - matchAndRewrite(triton::BroadcastOp broadcastOp, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + matchAndRewrite_(tt::BroadcastOp broadcastOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { ValueRange remappedOperands = adaptor.getSrc(); if (remappedOperands.size() != 2) return rewriter.notifyMatchFailure( @@ -483,7 +486,7 @@ class ConvertBroadcastOp Value fatPtrBase = remappedOperands[0]; Value fatPtrOffset = remappedOperands[1]; - if (!llvm::isa(fatPtrBase.getType())) + if (!llvm::isa(fatPtrBase.getType())) return rewriter.notifyMatchFailure(broadcastOp, "non tt.ptr base unimplemented"); auto offsetType = dyn_cast(fatPtrOffset.getType()); @@ -494,10 +497,10 @@ class ConvertBroadcastOp dyn_cast(broadcastOp.getResult().getType()); auto newOffsetType = RankedTensorType::get( outType.getShape(), offsetType.getElementType(), outType.getEncoding()); - triton::BroadcastOp newOffset = rewriter.create( + tt::BroadcastOp newOffset = rewriter.create( broadcastOp.getLoc(), newOffsetType, fatPtrOffset); - setRewrittenLegalAttrs(rewriter, broadcastOp, newOffset); rewriter.replaceOpWithMultiple(broadcastOp, {{fatPtrBase, newOffset}}); + opToRewrite.remove(broadcastOp); fatPtrs[{fatPtrBase, newOffset}] = fatPtrs[{fatPtrBase, fatPtrOffset}]; return success(); } @@ -508,14 +511,13 @@ class ConvertBroadcastOp /// 2. Constant tensor offset -> bump only the offset /// 3. Non-constant tensor offset -> decompose parent(offset) into uniform and /// non-uniform comop -class ConvertAddPtrOp - : public PointerCanonicalizationPattern { +class ConvertAddPtrOp : public PointerCanonicalizationPattern { public: using PointerCanonicalizationPattern::PointerCanonicalizationPattern; LogicalResult - matchAndRewrite(triton::AddPtrOp addPtrOp, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + matchAndRewrite_(tt::AddPtrOp addPtrOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { ValueRange remappedPtr = adaptor.getPtr(); if (remappedPtr.size() != 2) return rewriter.notifyMatchFailure( @@ -533,12 +535,11 @@ class ConvertAddPtrOp rewriter.setInsertionPoint(addPtrOp); // If it is a scalar pointer update, simply bump the base pointer - if (llvm::isa(addPtrOp.getPtr().getType())) { + if (llvm::isa(addPtrOp.getPtr().getType())) { assert(llvm::isa(origOffset.getType()) && "expected offset to be integer type"); - auto newAddPtrOp = rewriter.create( + auto newAddPtrOp = rewriter.create( curLoc, fatPtrBase.getType(), fatPtrBase, origOffset); - setRewrittenLegalAttrs(rewriter, addPtrOp, newAddPtrOp); rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, fatPtrOffset}}); fatPtrs[{newAddPtrOp, fatPtrOffset}].canNarrow = fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow; @@ -551,9 +552,8 @@ class ConvertAddPtrOp // Early exit for the case of a constant tensor if (auto scalarConst = maybeGetOrCreateScalarConstant(rewriter, curLoc, origOffset)) { - triton::AddPtrOp newAddPtrOp = rewriter.create( + tt::AddPtrOp newAddPtrOp = rewriter.create( curLoc, fatPtrBase.getType(), fatPtrBase, *scalarConst); - setRewrittenLegalAttrs(rewriter, addPtrOp, newAddPtrOp); rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, fatPtrOffset}}); // If we are updating the tensor pointer with a constant value, we can // propagate the attributes of the tensor pointer to the fat pointer. @@ -567,7 +567,7 @@ class ConvertAddPtrOp auto [uniformOffset, nonUniformOffset] = createDecomposeOffsetFromExpr(rewriter, curLoc, origOffset, bitness); - auto newAddPtrOp = rewriter.create( + auto newAddPtrOp = rewriter.create( curLoc, fatPtrBase.getType(), fatPtrBase, uniformOffset); // Vector offset update (if any): bump the tensor offset @@ -592,7 +592,6 @@ class ConvertAddPtrOp propagateAtrs = false; } - setRewrittenLegalAttrs(rewriter, addPtrOp, newAddPtrOp); rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, newOffset}}); auto nextFatPtr = std::pair{newAddPtrOp.getResult(), newOffset}; fatPtrs[nextFatPtr].canNarrow = canNarrow; @@ -610,8 +609,8 @@ class ConvertSCFForOp : public PointerCanonicalizationPattern { public: LogicalResult - matchAndRewrite(scf::ForOp forOp, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + matchAndRewrite_(scf::ForOp forOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { SmallVector valRangeLens; ArrayRef remappedInits = adaptor.getInitArgs(); for (ValueRange remappedInit : remappedInits) @@ -676,7 +675,6 @@ class ConvertSCFForOp : public PointerCanonicalizationPattern { offset += len; } - setRewrittenLegalAttrs(rewriter, forOp, newForOp); rewriter.replaceOpWithMultiple(forOp, packedRets); return success(); @@ -690,8 +688,8 @@ class ConvertSCFYieldOp : public PointerCanonicalizationPattern { using PointerCanonicalizationPattern::PointerCanonicalizationPattern; LogicalResult - matchAndRewrite(scf::YieldOp yieldOp, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + matchAndRewrite_(scf::YieldOp yieldOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { ArrayRef remappedYields = adaptor.getOperands(); SmallVector newYieldedValues = flattenValues(remappedYields); // have to mutate here because otherwise scf.if, scf.for, and scf.while will @@ -735,7 +733,6 @@ class ConvertSCFYieldOp : public PointerCanonicalizationPattern { rewriter.getDenseI64ArrayAttr(fatPtrOffsets)); } - setLegalAttr(rewriter, yieldOp); return success(); } }; @@ -745,8 +742,8 @@ class ConvertSCFWhileOp : public PointerCanonicalizationPattern { public: using PointerCanonicalizationPattern::PointerCanonicalizationPattern; LogicalResult - matchAndRewrite(scf::WhileOp whileOp, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + matchAndRewrite_(scf::WhileOp whileOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { SmallVector valRangeLens; ArrayRef remappedInits = adaptor.getInits(); for (ValueRange remappedInit : remappedInits) @@ -820,7 +817,6 @@ class ConvertSCFWhileOp : public PointerCanonicalizationPattern { offset += len; } - setRewrittenLegalAttrs(rewriter, whileOp, newWhileOp); rewriter.replaceOpWithMultiple(whileOp, packedRets); return success(); @@ -833,8 +829,8 @@ class ConvertSCFConditionOp public: using PointerCanonicalizationPattern::PointerCanonicalizationPattern; LogicalResult - matchAndRewrite(scf::ConditionOp condOp, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + matchAndRewrite_(scf::ConditionOp condOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { SmallVector newArgs = flattenValues(adaptor.getArgs()); // have to mutate here because otherwise scf.while will // get confused about which condition is the "correct" condition (since @@ -843,7 +839,6 @@ class ConvertSCFConditionOp condOp.getArgsMutable().clear(); condOp.getArgsMutable().append(newArgs); }); - setLegalAttr(rewriter, condOp); return success(); } }; @@ -854,15 +849,14 @@ class ConvertCFCondBranch public: using PointerCanonicalizationPattern::PointerCanonicalizationPattern; LogicalResult - matchAndRewrite(cf::CondBranchOp branchOp, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + matchAndRewrite_(cf::CondBranchOp branchOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { ArrayRef remappedTrueOperands = adaptor.getTrueDestOperands(); ArrayRef remappedFalseOperands = adaptor.getFalseDestOperands(); SmallVector trueOperands = flattenValues(remappedTrueOperands); SmallVector falseOperands = flattenValues(remappedFalseOperands); - setRewrittenAttr(rewriter, branchOp); - auto newBrancOp = rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( branchOp, branchOp.getCondition(), branchOp.getTrueDest(), trueOperands, branchOp.getFalseDest(), falseOperands); @@ -926,12 +920,17 @@ class ConvertCFCondBranch // propagate canNarrow to bb arg fatPtrs in false bb propagateCanNarrowToBlock(newFalseBlock, remappedFalseOperands); - setLegalAttr(rewriter, newBrancOp); return success(); } }; -/// Rewrite both operands. Note, this should only be reached after both +/// Rewrite select(fatPtrTrue, fatPtrFalse) -> +/// ( +/// select(fatPtrTrueBase, fatPtrTrueOffset), +/// select(fatPtrFalseBase, fatPtrFalseOffset) +/// ) +/// +/// Note, this should only be reached after both /// operands have already been rewritten because DialectConversion walks /// PreOrder in order ForwardDominance order: see /// https://github.com/llvm/llvm-project/blob/58389b220a9354ed6c34bdb9310a35165579c5e3/mlir/lib/Transforms/Utils/DialectConversion.cpp#L2702 @@ -940,8 +939,8 @@ class ConvertArithSelectOp public: using PointerCanonicalizationPattern::PointerCanonicalizationPattern; LogicalResult - matchAndRewrite(arith::SelectOp selectOp, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + matchAndRewrite_(arith::SelectOp selectOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { ArrayRef remappedOperands = adaptor.getOperands(); if (remappedOperands[1].size() != 2 || remappedOperands[2].size() != 2) return rewriter.notifyMatchFailure( @@ -954,10 +953,8 @@ class ConvertArithSelectOp // Simple case of a scalar select: update the base pointer if (!isa(selectOp.getType())) { auto newSelectOp = rewriter.create( - selectOp.getLoc(), selectOp.getType(), - // TODO(max): why fatPtrTrue here? - selectOp.getCondition(), fatPtrTrue[0], selectOp.getFalseValue()); - setRewrittenLegalAttrs(rewriter, selectOp, newSelectOp); + selectOp.getLoc(), selectOp.getType(), selectOp.getCondition(), + fatPtrTrue[0], selectOp.getFalseValue()); rewriter.replaceOpWithMultiple(selectOp, {{newSelectOp, fatPtrTrue[1]}}); return success(); } @@ -974,8 +971,6 @@ class ConvertArithSelectOp fatPtrs[{fatPtrFalse[0], fatPtrFalse[1]}].canNarrow) && "expected can narrow to be the same for both fatPtrT and fatPtrF"); - setRewrittenLegalAttrs(rewriter, selectOp, newBase); - setRewrittenLegalAttrs(rewriter, selectOp, newOffset); rewriter.replaceOpWithMultiple(selectOp, {{newBase, newOffset}}); fatPtrs[{newBase, newOffset}].canNarrow = fatPtrs[{fatPtrTrue[0], fatPtrTrue[1]}].canNarrow; @@ -992,8 +987,8 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern { public: using PointerCanonicalizationPattern::PointerCanonicalizationPattern; LogicalResult - matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + matchAndRewrite_(scf::IfOp ifOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { assert(ifOp.thenYield()->hasAttr(kSCFIfOpYieldFatPtrOffsets) && "expected then yield to report fat ptr indices"); @@ -1004,17 +999,23 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern { assert(ifOp.thenYield().getOperandTypes() == ifOp.elseYield().getOperandTypes() && "ifOp types must match in both arms"); - assert(ifOp.elseYield()->hasAttr(kSCFIfOpYieldFatPtrOffsets) && - "expected then yield to report fat ptr indices"); if (auto thenFatPtrIndxs = ifOp.thenYield()->getDiscardableAttr( kSCFIfOpYieldFatPtrOffsets)) { - auto elseFatPtrIndx = + assert(ifOp.elseYield()->hasAttr(kSCFIfOpYieldFatPtrOffsets) && + "expected then yield to report fat ptr indices"); + auto elseFatPtrIndxs = ifOp.elseYield()->getDiscardableAttr(kSCFIfOpYieldFatPtrOffsets); - assert(elseFatPtrIndx && + assert(elseFatPtrIndxs && "expected else fat ptr indices as well as then fat ptr indices"); - for (auto [i, j] : llvm::zip( - llvm::cast(thenFatPtrIndxs).asArrayRef(), - llvm::cast(elseFatPtrIndx).asArrayRef())) { + + DenseI64ArrayAttr thenIdxs = + llvm::dyn_cast(thenFatPtrIndxs); + DenseI64ArrayAttr elseIdxs = + llvm::dyn_cast(elseFatPtrIndxs); + assert(bool(thenIdxs) && bool(elseIdxs) && + "expected else fat ptr index attrs to be DenseI64ArrayAttr"); + for (auto [i, j] : + llvm::zip(thenIdxs.asArrayRef(), elseIdxs.asArrayRef())) { assert(i == j && "expected thenFatPtrIndxs and elseFatPtrIndxs to agree"); assert(i < ifOp.thenYield().getNumOperands() && @@ -1042,7 +1043,6 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern { rewriter.inlineBlockBefore(ifOp.elseBlock(), newIfOp.elseBlock(), newIfOp.elseBlock()->begin()); - setRewrittenLegalAttrs(rewriter, ifOp, newIfOp); rewriter.replaceOpWithMultiple(ifOp, {newIfOp.getResults()}); for (int64_t idx : @@ -1064,14 +1064,13 @@ class ConvertCFBranch : public PointerCanonicalizationPattern { public: using PointerCanonicalizationPattern::PointerCanonicalizationPattern; LogicalResult - matchAndRewrite(cf::BranchOp branchOp, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + matchAndRewrite_(cf::BranchOp branchOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { ArrayRef remappedDestOperands = adaptor.getDestOperands(); SmallVector trueOperands = flattenValues(remappedDestOperands); - setRewrittenAttr(rewriter, branchOp); - auto newBrancOp = rewriter.replaceOpWithNewOp( - branchOp, branchOp.getDest(), trueOperands); + rewriter.replaceOpWithNewOp(branchOp, branchOp.getDest(), + trueOperands); unsigned inputNo = 0; TypeConverter localTypeConverterTrueDest; @@ -1103,18 +1102,17 @@ class ConvertCFBranch : public PointerCanonicalizationPattern { offset += operands.size(); } - setLegalAttr(rewriter, newBrancOp); return success(); } }; -class ConvertLoadOp : public PointerCanonicalizationPattern { +class ConvertLoadOp : public PointerCanonicalizationPattern { public: using PointerCanonicalizationPattern::PointerCanonicalizationPattern; LogicalResult - matchAndRewrite(triton::LoadOp loadOp, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + matchAndRewrite_(tt::LoadOp loadOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { ValueRange fatPtr = adaptor.getPtr(); if (fatPtr.size() != 2) return rewriter.notifyMatchFailure( @@ -1123,8 +1121,7 @@ class ConvertLoadOp : public PointerCanonicalizationPattern { Value fatPtrOffset = fatPtr[1]; Location curLoc = loadOp.getLoc(); - llvm::SmallDenseMap attributes{ - {rewriter.getStringAttr(kLegalAttr), rewriter.getUnitAttr()}}; + llvm::SmallDenseMap attributes{}; Value newPtr = fatPtrBase; if (llvm::isa(loadOp.getPtr().getType())) newPtr = createTensorPointer( @@ -1134,21 +1131,18 @@ class ConvertLoadOp : public PointerCanonicalizationPattern { loadOp.getOperands().take_back(loadOp.getNumOperands() - 1); operands.insert(operands.begin(), newPtr); SmallVector attrs = llvm::to_vector(loadOp->getAttrs()); - attrs.append({rewriter.getNamedAttr(kLegalAttr, rewriter.getUnitAttr())}); - auto newLoadPtrOp = - rewriter.replaceOpWithNewOp(loadOp, operands, attrs); - setLegalAttr(rewriter, newLoadPtrOp); + rewriter.replaceOpWithNewOp(loadOp, operands, attrs); return success(); } }; -class ConvertStoreOp : public PointerCanonicalizationPattern { +class ConvertStoreOp : public PointerCanonicalizationPattern { public: using PointerCanonicalizationPattern::PointerCanonicalizationPattern; LogicalResult - matchAndRewrite(triton::StoreOp storeOp, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + matchAndRewrite_(tt::StoreOp storeOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { ValueRange fatPtr = adaptor.getPtr(); if (fatPtr.size() != 2) return rewriter.notifyMatchFailure( @@ -1157,9 +1151,7 @@ class ConvertStoreOp : public PointerCanonicalizationPattern { Value fatPtrOffset = fatPtr[1]; Location curLoc = storeOp.getLoc(); - llvm::SmallDenseMap attributes{ - {rewriter.getStringAttr(kLegalAttr), rewriter.getUnitAttr()}}; - + llvm::SmallDenseMap attributes{}; Value newPtr = fatPtrBase; if (llvm::isa(storeOp.getPtr().getType())) newPtr = createTensorPointer( @@ -1169,10 +1161,8 @@ class ConvertStoreOp : public PointerCanonicalizationPattern { storeOp.getOperands().take_back(storeOp.getNumOperands() - 1); operands.insert(operands.begin(), newPtr); SmallVector attrs = llvm::to_vector(storeOp->getAttrs()); - attrs.append({rewriter.getNamedAttr(kLegalAttr, rewriter.getUnitAttr())}); - auto newStoreOp = rewriter.replaceOpWithNewOp( - storeOp, TypeRange{}, ValueRange{operands}, attrs); - setLegalAttr(rewriter, newStoreOp); + rewriter.replaceOpWithNewOp(storeOp, TypeRange{}, + ValueRange{operands}, attrs); return success(); } }; @@ -1183,19 +1173,19 @@ class ConvertStoreOp : public PointerCanonicalizationPattern { /// This unrealized_cast remains through out the first pass of the dialect /// conversion and is then materialized in the second pass /// (ConvertUnrealizedConversionCastOp). -class ConvertFuncOp : public PointerCanonicalizationPattern { +class ConvertFuncOp : public PointerCanonicalizationPattern { public: using PointerCanonicalizationPattern::PointerCanonicalizationPattern; LogicalResult - matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + matchAndRewrite_(tt::FuncOp funcOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { int64_t bitness = 64; rewriter.setInsertionPointToStart(&funcOp.getBody().front()); rewriter.modifyOpInPlace(funcOp, [&] { for (auto [idx, arg] : llvm::enumerate(funcOp.getArguments())) { // The pointer argument needs to be a scalar - if (!isa(arg.getType())) + if (!isa(arg.getType())) continue; if (auto pointerRangeAttr = funcOp.getArgAttrOfType(idx, "tt.pointer_range")) @@ -1210,18 +1200,31 @@ class ConvertFuncOp : public PointerCanonicalizationPattern { rewriter.replaceOpWithMultiple(dummyCast, {{arg, zeroOffset}}); } }); - setRewrittenAttr(rewriter, funcOp); return success(); } }; +/// No-op to make conversion framework happy. +class ConvertReturnOp : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + + LogicalResult + matchAndRewrite_(tt::ReturnOp returnOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto returns = flattenValues(adaptor.getSrcs()); + rewriter.replaceOpWithNewOp(returnOp, TypeRange{}, returns); + return success(); + } +}; + /// Rewrite %1 = unrealize_cast(%arg0: tt.ptr, c0: i32) -> tt.ptr inserted by /// ConvertFuncOp to be just %arg0: tt.ptr. class ConvertUnrealizedConversionCastOp - : public PointerCanonicalizationPattern { + : public OpConversionPattern { public: - using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(UnrealizedConversionCastOp castOp, OpAdaptor adaptor, @@ -1229,14 +1232,14 @@ class ConvertUnrealizedConversionCastOp assert(std::distance(castOp->getUses().begin(), castOp->getUses().end()) > 0 && "expected at least 1 use of unrealized_cast"); - // dunno why but i get -Wdangling here... + // Don't know why but i get -Wdangling here... ArrayRef remappedOperands = adaptor.getOperands(); if (remappedOperands.size() != 1 || remappedOperands[0].size() != 2) return rewriter.notifyMatchFailure( castOp, "expected CastOp to have already been remapped"); Value fatPtrBase = remappedOperands[0][0]; Value fatPtrOffset = remappedOperands[0][1]; - if (!llvm::isa(fatPtrBase.getType())) + if (!llvm::isa(fatPtrBase.getType())) return rewriter.notifyMatchFailure(castOp, "non tt.ptr base unimplemented"); if (!llvm::isa(fatPtrOffset.getType())) @@ -1257,70 +1260,106 @@ class ConvertUnrealizedConversionCastOp }; void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { - ModuleOp module = getOperation(); mlir::MLIRContext *context = &getContext(); ConversionTarget target(*context); RewritePatternSet patterns(context); - auto isLegal = [](Operation *op) { - if (op->hasAttr(kRewrittenAttr) || op->hasAttr(kLegalAttr)) + + tt::FuncOp func = getOperation(); + + // forward slice == all transitive uses + ForwardSliceOptions sliceOpts([](Operation *op) { + // scf.if isn't a direct user but could contain users + // NB: we need this here and the loop below because + // forward slice not propagate through the scf.if without the `true` here. + if (llvm::isa(op)) return true; - for (OpOperand &operand : op->getOpOperands()) { - if (auto arg = llvm::dyn_cast(operand.get())) { - if (!llvm::isa(getElementTypeOrSelf(arg))) - continue; - return false; + return llvm::any_of(op->getOperandTypes(), [](Type ty) { + return llvm::isa(getElementTypeOrSelf(ty)); + }); + }); + sliceOpts.inclusive = true; + llvm::SetVector opsToRewrite; + opsToRewrite.insert(func); + for (auto arg : func.getArguments()) + if (llvm::isa(arg.getType())) { + // NB: reusing the same SetVector invalidates the topo order implied by + // getForwardSlice + getForwardSlice(arg, &opsToRewrite, sliceOpts); + } + + // getForwardSlice doesn't check successorRegions + for (auto op : opsToRewrite) { + if (llvm::isa(op)) { + unsigned opOffset = llvm::isa(op) ? 0 : 1; + for (auto successor : op->getSuccessors()) { + for (auto arg : successor->getArguments()) { + // if the bb arg corresponds to an op that will be rewritten + if (opsToRewrite.contains( + op->getOperand(opOffset + arg.getArgNumber()) + .getDefiningOp())) + getForwardSlice(arg, &opsToRewrite, sliceOpts); + } } - if (operand.get().getDefiningOp()->hasAttr(kRewrittenAttr)) - return false; } - return true; - }; - target.addDynamicallyLegalDialect( - [&isLegal](Operation *op) { - if (llvm::isa(op) && !op->hasAttr(kRewrittenAttr)) - return false; - return isLegal(op); - }); - target.addDynamicallyLegalDialect([&isLegal](Operation *op) { - if (auto ifOp = llvm::dyn_cast(op)) + } + + // NB: we need this here and the check in sliceOpts because without this + // loop, getForwardSlice never finds any scf.ifs at all (they have no + // operands) + for (auto op : opsToRewrite) { + if (auto parentIfOp = op->getParentOp()) { + if (llvm::isa(parentIfOp)) { + getForwardSlice(parentIfOp, &opsToRewrite, sliceOpts); + } + } + } + + // llvm::errs() << "ops to rewrite:\n"; + // for (auto ops_to_rewrite : opsToRewrite) + // llvm::errs() << ops_to_rewrite->getName() << "\n"; + // llvm::errs() << "\n"; + + auto isLegal = [&opsToRewrite](Operation *op) { + if (auto ifOp = llvm::dyn_cast(op)) { + // This is the only hack in the entire pass; on first traversal, + // `scf.if` will be walked over, but we do not want to rewrite it yet + // because the `yields` in the then/else regions haven't been rewritten + // yet (and those `yields` tell us the final result types of the + // `scf.if`). Therefore, we check for these attributes and if they're + // absent then the `scf.if` is legal. Once both `yields` have been + // rewritten (the corresponding attributes have been added), we report the + // `scf.if` as illegal, and it will be rewritten (the pattern will fire). return !(ifOp->hasAttr(kSCFThenRewrittenAttr) and ifOp->hasAttr(kSCFElseRewrittenAttr)); - if (llvm::isa(op) && !op->hasAttr(kLegalAttr)) - return false; - return isLegal(op); - }); - target.addDynamicallyLegalDialect( - [&isLegal](Operation *op) { return isLegal(op); }); - target.addDynamicallyLegalDialect( - [&isLegal](Operation *op) { - if (llvm::isa(op)) - return isLegal(op); - return true; - }); + } + return !opsToRewrite.contains(op); + }; + + target.addDynamicallyLegalDialect(isLegal); + target.addDynamicallyLegalDialect(isLegal); + target.addDynamicallyLegalDialect(isLegal); + target.addDynamicallyLegalDialect(isLegal); FatPointers fatPrs; - patterns - .add( - patterns.getContext(), fatPrs); + patterns.add(patterns.getContext(), + opsToRewrite, fatPrs); ConversionConfig config; config.buildMaterializations = false; - if (failed( - applyPartialConversion(module, target, std::move(patterns), config))) + if (failed(applyPartialConversion(func, target, std::move(patterns), config))) return signalPassFailure(); patterns.clear(); target.addIllegalOp(); - patterns.add(patterns.getContext(), - fatPrs); - if (failed( - applyPartialConversion(module, target, std::move(patterns), config))) + patterns.add(patterns.getContext()); + if (failed(applyPartialConversion(func, target, std::move(patterns), config))) return signalPassFailure(); - module.walk([](Operation *op) { + func.walk([](Operation *op) { for (auto attr : op->getDiscardableAttrs()) { if (attr.getName().strref().starts_with(kPtrCanonPrefix)) op->removeDiscardableAttr(attr.getName()); From 92bdfe23a05c15adddd903afbec1f805b6138e10 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 18 Dec 2024 14:27:54 -0500 Subject: [PATCH 13/17] fix conds scf.while --- .../amd/amd-canonicalize-pointers.mlir | 35 ++++--------------- .../CanonicalizePointers.cpp | 10 ++++-- 2 files changed, 15 insertions(+), 30 deletions(-) diff --git a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir index 8185758a43dc..2dd185961e8a 100644 --- a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir +++ b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir @@ -1,25 +1,4 @@ // RUN: triton-opt %s -split-input-file -tritonamdgpu-canonicalize-pointers | FileCheck %s - -module { - tt.func public @add_kernel( - %in_ptr0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} , - %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} , - %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} ) -> tensor<1024xf32> attributes {noinline = false} { - %c1024_i32 = arith.constant 1024 : i32 - %pid = tt.get_program_id x : i32 - %block_start = arith.muli %pid, %c1024_i32 : i32 - %make_range = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - %block_start_splat = tt.splat %block_start : i32 -> tensor<1024xi32> - %offsets = arith.addi %block_start_splat, %make_range : tensor<1024xi32> - %in_ptr0_splat = tt.splat %in_ptr0 : !tt.ptr -> tensor<1024x!tt.ptr> - %addr = tt.addptr %in_ptr0_splat, %offsets : tensor<1024x!tt.ptr>, tensor<1024xi32> - %val = tt.load %addr : tensor<1024x!tt.ptr> - tt.return %val : tensor<1024xf32> - } -} - -// ----- - #blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @conversion1 @@ -346,14 +325,14 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> // CHECK: %[[whileOut:.*]]:3 = scf.while ({{.*}}, %[[loopPtr:.*]] = %arg0, %[[loopOffset:.*]] = %[[base_offset]]) - %6 = scf.while (%arg1 = %5) : (tensor<1024x!tt.ptr, #blocked>) -> (tensor<1024x!tt.ptr, #blocked>) { + %6 = scf.while (%arg1 = %5, %arg2 = %cond) : (tensor<1024x!tt.ptr, #blocked>, i1) -> (tensor<1024x!tt.ptr, #blocked>) { // CHECK: scf.condition({{.*}}) %{{.*}}, %[[loopPtr]], %[[loopOffset]] - scf.condition(%cond) %arg1 : tensor<1024x!tt.ptr, #blocked> + scf.condition(%arg2) %arg1 : tensor<1024x!tt.ptr, #blocked> } do { // CHECK: ^bb{{.*}}(%{{.*}}, %[[blockPtr:.*]]: !tt.ptr, %[[blockOffset:.*]]: tensor<1024xi64, #blocked>): ^bb0(%arg1: tensor<1024x!tt.ptr, #blocked>): // CHECK: scf.yield {{.*}}, %[[blockPtr]], %[[blockOffset]] - scf.yield %arg1 : tensor<1024x!tt.ptr, #blocked> + scf.yield %arg1, %cond : tensor<1024x!tt.ptr, #blocked>, i1 } // CHECK: %[[trunc_offset:.*]] = arith.trunci %[[whileOut]]#2 // CHECK: %[[base_ptr:.*]] = tt.splat %[[whileOut]]#1 @@ -703,7 +682,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @scalar_while - tt.func @scalar_while(%arg0: !tt.ptr, %init : f32, %cond : i1) -> f32 { + tt.func @scalar_while(%arg0: !tt.ptr, %init : f32, %cond : i1)->f32{ %c1024_i32 = arith.constant 1024 : i32 %c0 = arith.constant 0: index %c128 = arith.constant 128: index @@ -713,14 +692,14 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} // CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %{{.*}} // CHECK: scf.while ({{.*}}, {{.*}} = %arg2, %[[ptr1:.*]] = %[[ptr0]], {{.*}}) %2 = tt.addptr %arg0, %0 : !tt.ptr, i32 - %6 = scf.while (%arg1 = %2) : (!tt.ptr) -> (!tt.ptr) { + %6 = scf.while (%arg1 = %2, %arg2 = %cond) : (!tt.ptr, i1) -> (!tt.ptr) { // CHECK: scf.condition({{.*}}) {{.*}}, %[[ptr1]] - scf.condition(%cond) %arg1 : !tt.ptr + scf.condition(%arg2) %arg1 : !tt.ptr } do { // CHECK: ^bb0({{.*}}: !tt.ptr, %[[ptr2:.*]]: !tt.ptr, {{.*}}) // CHECK: scf.yield %{{.*}}, {{.*}} %[[ptr2]], {{.*}}, {{.*}} ^bb0(%arg1: !tt.ptr): - scf.yield %arg1 : !tt.ptr + scf.yield %arg1, %cond : !tt.ptr, i1 } %11 = tt.load %6 : !tt.ptr tt.return %11 : f32 diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index 7f0ce0a1b2f7..089248c76bf7 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -796,8 +796,10 @@ class ConvertSCFWhileOp : public PointerCanonicalizationPattern { propagateCanNarrowToBlock(newAfterBodyBlock); SmallVector initArgs = flattenValues(remappedInits); - SmallVector resultTypes = - llvm::map_to_vector(initArgs, [](Value v) { return v.getType(); }); + SmallVector resultTypes = llvm::map_to_vector( + llvm::make_filter_range( + initArgs, [](Value v) { return !v.getType().isInteger(1); }), + [](Value v) { return v.getType(); }); auto newWhileOp = rewriter.create(whileOp.getLoc(), resultTypes, initArgs); @@ -809,6 +811,10 @@ class ConvertSCFWhileOp : public PointerCanonicalizationPattern { SmallVector packedRets; for (unsigned i = 0, offset = 0; i < valRangeLens.size(); i++) { + // skip %cond + if (remappedInits[i].size() == 1 && + remappedInits[i].getType()[0].isInteger(1)) + continue; size_t len = valRangeLens[i]; assert(offset < newWhileOp->getNumResults() && "expected offset to be within bounds of results"); From e9022700594243d96bb1dbe170780342ab6c8d33 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 18 Dec 2024 15:12:44 -0500 Subject: [PATCH 14/17] propagate canNarrow correctly --- .../CanonicalizePointers.cpp | 127 +++++++----------- 1 file changed, 47 insertions(+), 80 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index 089248c76bf7..e99229664b96 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -336,61 +336,23 @@ struct FatPointers { using KeyT = std::pair; using ValueT = FatPtrAttrs; using DenseMapT = DenseMap; - ValueT &operator[](const KeyT &k) { return pointers[k]; } - ValueT &operator[](KeyT &&k) { return pointers[k]; } + ValueT &operator[](const KeyT &k) { return pointerAttrs[k]; } + ValueT &operator[](KeyT &&k) { return pointerAttrs[k]; } template using const_arg_type_t = typename llvm::const_pointer_or_const_ref::type; - const ValueT &at(const_arg_type_t k) const { return pointers.at(k); } - const bool contains(const KeyT &k) { return pointers.contains(k); } + const ValueT &at(const_arg_type_t k) const { + // this is redundant - DenseMap will assert the same thing - but better to + // have our own message + assert(pointerAttrs.contains(k) && + "expected fatPtrs to contain remapped fat pointer"); + return pointerAttrs.at(k); + } + const bool contains(const KeyT &k) { return pointerAttrs.contains(k); } private: - DenseMapT pointers; + DenseMapT pointerAttrs; }; -std::optional getFatPtrCastOp(Value base, - Value offset) { - std::optional maybeCastOp; - for (Operation *user : base.getUsers()) { - if (auto castOp = llvm::dyn_cast(user)) { - if (castOp.getNumOperands() == 2 && castOp.getOperand(0) == base && - castOp.getOperand(1) == offset) { - maybeCastOp = castOp; - } - } - } -#ifndef NDEBUG - for (Operation *user : offset.getUsers()) { - if (auto castOp = llvm::dyn_cast(user)) { - if (castOp.getNumOperands() == 2 && castOp.getOperand(0) == base && - castOp.getOperand(1) == offset) { - assert( - castOp == *maybeCastOp && - "expected castop through base and castop through offset to match"); - } - } - } -#endif - return maybeCastOp; -} - -std::optional getFatPtrCastOp(OpOperand &operand) { - Value operandVal = operand.get(); - for (Operation *user : operandVal.getUsers()) { - if (auto castOp = llvm::dyn_cast(user)) { - if (castOp.getNumOperands() == 2 && - (castOp.getOperand(0) == operandVal || - castOp.getOperand(1) == operandVal) && - castOp.getNumResults() == 1 && - std::distance(castOp->getUsers().begin(), castOp->getUsers().end()) == - 1 && - *castOp->getUsers().begin() == operand.getOwner()) { - return castOp; - } - } - } - return {}; -} - /// Flatten the given value ranges into a single vector of values. static SmallVector flattenValues(ArrayRef values) { SmallVector result; @@ -462,8 +424,8 @@ class ConvertSplatOp : public PointerCanonicalizationPattern { tt::SplatOp offset = rewriter.create( splatOp.getLoc(), newOffsetType, fatPtrOffset); rewriter.replaceOpWithMultiple(splatOp, {{fatPtrBase, offset}}); - opToRewrite.remove(splatOp); - fatPtrs[{fatPtrBase, offset}] = fatPtrs[{fatPtrBase, fatPtrOffset}]; + fatPtrs[{fatPtrBase, offset}].canNarrow = + fatPtrs.at({fatPtrBase, fatPtrOffset}).canNarrow; return success(); } @@ -500,8 +462,8 @@ class ConvertBroadcastOp tt::BroadcastOp newOffset = rewriter.create( broadcastOp.getLoc(), newOffsetType, fatPtrOffset); rewriter.replaceOpWithMultiple(broadcastOp, {{fatPtrBase, newOffset}}); - opToRewrite.remove(broadcastOp); - fatPtrs[{fatPtrBase, newOffset}] = fatPtrs[{fatPtrBase, fatPtrOffset}]; + fatPtrs[{fatPtrBase, newOffset}].canNarrow = + fatPtrs.at({fatPtrBase, fatPtrOffset}).canNarrow; return success(); } }; @@ -542,7 +504,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern { curLoc, fatPtrBase.getType(), fatPtrBase, origOffset); rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, fatPtrOffset}}); fatPtrs[{newAddPtrOp, fatPtrOffset}].canNarrow = - fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow; + fatPtrs.at({fatPtrBase, fatPtrOffset}).canNarrow; return success(); } @@ -558,7 +520,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern { // If we are updating the tensor pointer with a constant value, we can // propagate the attributes of the tensor pointer to the fat pointer. fatPtrs[{newAddPtrOp, fatPtrOffset}].canNarrow = - fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow; + fatPtrs.at({fatPtrBase, fatPtrOffset}).canNarrow; return success(); } @@ -571,7 +533,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern { curLoc, fatPtrBase.getType(), fatPtrBase, uniformOffset); // Vector offset update (if any): bump the tensor offset - bool canNarrow = fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow; + bool canNarrow = fatPtrs.at({fatPtrBase, fatPtrOffset}).canNarrow; bool propagateAtrs = true; Value newOffset = fatPtrOffset; if (!isZeroConst(nonUniformOffset)) { @@ -645,11 +607,9 @@ class ConvertSCFForOp : public PointerCanonicalizationPattern { int offset = 1; for (auto operands : remappedInits) { if (operands.size() == 2) { - assert(fatPtrs.contains({operands[0], operands[1]}) && - "expected fatPtrs to contain remapped fat pointer"); fatPtrs[{newBodyBlock->getArgument(offset), newBodyBlock->getArgument(offset + 1)}] - .canNarrow = fatPtrs[{operands[0], operands[1]}].canNarrow; + .canNarrow = fatPtrs.at({operands[0], operands[1]}).canNarrow; } offset += operands.size(); } @@ -671,6 +631,13 @@ class ConvertSCFForOp : public PointerCanonicalizationPattern { assert(offset < newForOp->getNumResults() && "expected offset to be within bounds of results"); ValueRange mappedValue = newForOp->getResults().slice(offset, len); + // propagate fatPtrs + if (mappedValue.size() == 2) { + assert(remappedInits[i].size() == 2 && + "expected corresponding inits to be a remapped fat ptr"); + fatPtrs[{mappedValue[0], mappedValue[1]}] = + fatPtrs.at({remappedInits[i][0], remappedInits[i][1]}); + } packedRets.push_back(mappedValue); offset += len; } @@ -720,11 +687,8 @@ class ConvertSCFYieldOp : public PointerCanonicalizationPattern { int offset = 0; SmallVector fatPtrOffsets; for (auto operands : remappedYields) { - if (operands.size() == 2) { - assert(fatPtrs.contains({operands[0], operands[1]}) && - "expected fatPtrs to contain remapped fat pointer"); + if (operands.size() == 2) fatPtrOffsets.push_back(offset); - } offset += operands.size(); } if (!fatPtrOffsets.empty()) @@ -772,10 +736,8 @@ class ConvertSCFWhileOp : public PointerCanonicalizationPattern { int offset = 0; for (auto operands : remappedInits) { if (operands.size() == 2) { - assert(fatPtrs.contains({operands[0], operands[1]}) && - "expected fatPtrs to contain remapped fat pointer"); fatPtrs[{block->getArgument(offset), block->getArgument(offset + 1)}] - .canNarrow = fatPtrs[{operands[0], operands[1]}].canNarrow; + .canNarrow = fatPtrs.at({operands[0], operands[1]}).canNarrow; } offset += operands.size(); } @@ -819,6 +781,13 @@ class ConvertSCFWhileOp : public PointerCanonicalizationPattern { assert(offset < newWhileOp->getNumResults() && "expected offset to be within bounds of results"); ValueRange mappedValue = newWhileOp->getResults().slice(offset, len); + // propagate fatPtrs + if (mappedValue.size() == 2) { + assert(remappedInits[i].size() == 2 && + "expected corresponding inits to be a remapped fat ptr"); + fatPtrs[{mappedValue[0], mappedValue[1]}] = + fatPtrs.at({remappedInits[i][0], remappedInits[i][1]}); + } packedRets.push_back(mappedValue); offset += len; } @@ -886,10 +855,8 @@ class ConvertCFCondBranch int offset = 0; for (auto operands : remappedOperands) { if (operands.size() == 2) { - assert(fatPtrs.contains({operands[0], operands[1]}) && - "expected fatPtrs to contain remapped fat pointer"); fatPtrs[{block->getArgument(offset), block->getArgument(offset + 1)}] - .canNarrow = fatPtrs[{operands[0], operands[1]}].canNarrow; + .canNarrow = fatPtrs.at({operands[0], operands[1]}).canNarrow; } offset += operands.size(); } @@ -962,6 +929,8 @@ class ConvertArithSelectOp selectOp.getLoc(), selectOp.getType(), selectOp.getCondition(), fatPtrTrue[0], selectOp.getFalseValue()); rewriter.replaceOpWithMultiple(selectOp, {{newSelectOp, fatPtrTrue[1]}}); + fatPtrs[{newSelectOp, /*fatPtrOffset*/ fatPtrTrue[1]}].canNarrow = + fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]}).canNarrow; return success(); } @@ -973,13 +942,13 @@ class ConvertArithSelectOp selectOp.getLoc(), selectOp.getCondition(), fatPtrTrue[1], fatPtrFalse[1]); - assert((fatPtrs[{fatPtrTrue[0], fatPtrTrue[1]}].canNarrow == - fatPtrs[{fatPtrFalse[0], fatPtrFalse[1]}].canNarrow) && + assert((fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]}).canNarrow == + fatPtrs.at({fatPtrFalse[0], fatPtrFalse[1]}).canNarrow) && "expected can narrow to be the same for both fatPtrT and fatPtrF"); rewriter.replaceOpWithMultiple(selectOp, {{newBase, newOffset}}); fatPtrs[{newBase, newOffset}].canNarrow = - fatPtrs[{fatPtrTrue[0], fatPtrTrue[1]}].canNarrow; + fatPtrs.at({fatPtrTrue[0], fatPtrTrue[1]}).canNarrow; return success(); } @@ -1031,8 +1000,8 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern { Value thenFatPtrOffset = ifOp.thenYield().getOperand(i + 1); Value elseFatPtrBase = ifOp.elseYield().getOperand(i); Value elseFatPtrOffset = ifOp.elseYield().getOperand(i + 1); - assert((fatPtrs[{thenFatPtrBase, thenFatPtrOffset}].canNarrow == - fatPtrs[{elseFatPtrBase, elseFatPtrOffset}].canNarrow) && + assert((fatPtrs.at({thenFatPtrBase, thenFatPtrOffset}).canNarrow == + fatPtrs.at({elseFatPtrBase, elseFatPtrOffset}).canNarrow) && "expected then fat ptr canNarrow and else fat ptr canNarrow " "to be equal"); } @@ -1058,7 +1027,7 @@ class ConvertSCFIfOp : public PointerCanonicalizationPattern { Value thenFatPtrBase = newIfOp.thenYield().getOperand(idx); Value thenFatPtrOffset = newIfOp.thenYield().getOperand(idx + 1); fatPtrs[{newIfOp.getResult(idx), newIfOp.getResult(idx + 1)}].canNarrow = - fatPtrs[{thenFatPtrBase, thenFatPtrOffset}].canNarrow; + fatPtrs.at({thenFatPtrBase, thenFatPtrOffset}).canNarrow; } return success(); @@ -1099,11 +1068,9 @@ class ConvertCFBranch : public PointerCanonicalizationPattern { int offset = 0; for (auto operands : remappedDestOperands) { if (operands.size() == 2) { - assert(fatPtrs.contains({operands[0], operands[1]}) && - "expected fatPtrs to contain remapped fat pointer"); fatPtrs[{newDestBlock->getArgument(offset), newDestBlock->getArgument(offset + 1)}] - .canNarrow = fatPtrs[{operands[0], operands[1]}].canNarrow; + .canNarrow = fatPtrs.at({operands[0], operands[1]}).canNarrow; } offset += operands.size(); } @@ -1132,7 +1099,7 @@ class ConvertLoadOp : public PointerCanonicalizationPattern { if (llvm::isa(loadOp.getPtr().getType())) newPtr = createTensorPointer( rewriter, fatPtrBase, fatPtrOffset, curLoc, - fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow, attributes); + fatPtrs.at({fatPtrBase, fatPtrOffset}).canNarrow, attributes); SmallVector operands = loadOp.getOperands().take_back(loadOp.getNumOperands() - 1); operands.insert(operands.begin(), newPtr); @@ -1162,7 +1129,7 @@ class ConvertStoreOp : public PointerCanonicalizationPattern { if (llvm::isa(storeOp.getPtr().getType())) newPtr = createTensorPointer( rewriter, fatPtrBase, fatPtrOffset, curLoc, - fatPtrs[{fatPtrBase, fatPtrOffset}].canNarrow, attributes); + fatPtrs.at({fatPtrBase, fatPtrOffset}).canNarrow, attributes); SmallVector operands = storeOp.getOperands().take_back(storeOp.getNumOperands() - 1); operands.insert(operands.begin(), newPtr); From 6a721d25748535354516d1bc50a3feb634ba9e69 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 18 Dec 2024 16:03:49 -0500 Subject: [PATCH 15/17] collect fat pointers --- .../CanonicalizePointers.cpp | 85 ++++++++++++++++--- 1 file changed, 73 insertions(+), 12 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index e99229664b96..13bd3d4e79ff 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -31,6 +31,9 @@ #define GEN_PASS_CLASSES #include "TritonAMDGPUTransforms/Passes.h.inc" +#include "llvm/Support/Format.h" +#include "llvm/Support/FormatVariadic.h" + #define DEBUG_TYPE "tritonamdgpu-canonicalize-pointers" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") @@ -309,15 +312,6 @@ Value createTensorPointer( return addPtrOp.getResult(); } -class TritonAMDGPUCanonicalizePointersPass - : public TritonAMDGPUCanonicalizePointersBase< - TritonAMDGPUCanonicalizePointersPass> { -public: - TritonAMDGPUCanonicalizePointersPass() = default; - - void runOnOperation() override; -}; - struct FatPointers { struct FatPtrAttrs { FatPtrAttrs(const FatPtrAttrs &other) = default; @@ -325,7 +319,7 @@ struct FatPointers { // for map default insert FatPtrAttrs() = default; bool canNarrow = false; - llvm::SmallDenseMap attributes; + llvm::SmallDenseMap attributes; friend bool operator==(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) { return lhs.canNarrow == rhs.canNarrow && lhs.attributes == rhs.attributes; } @@ -336,8 +330,17 @@ struct FatPointers { using KeyT = std::pair; using ValueT = FatPtrAttrs; using DenseMapT = DenseMap; - ValueT &operator[](const KeyT &k) { return pointerAttrs[k]; } - ValueT &operator[](KeyT &&k) { return pointerAttrs[k]; } + void collectFatPointerAttributes(const KeyT &k); + ValueT &operator[](const KeyT &k) { + if (!pointerAttrs.contains(k)) + collectFatPointerAttributes(k); + return pointerAttrs[k]; + } + ValueT &operator[](KeyT &&k) { + if (!pointerAttrs.contains(k)) + collectFatPointerAttributes(k); + return pointerAttrs[k]; + } template using const_arg_type_t = typename llvm::const_pointer_or_const_ref::type; const ValueT &at(const_arg_type_t k) const { @@ -353,6 +356,55 @@ struct FatPointers { DenseMapT pointerAttrs; }; +// TODO(max): this is not a good way to do this... +void FatPointers::collectFatPointerAttributes(const KeyT &k) { + auto [base, offset] = k; + // If it is the i-th block argument, then look if the operation defined some + // _argi attribute and add it to the fat pointer attributes + if (auto arg = dyn_cast(base)) { + // If the value is a block parameter, the operation can specify + // an attribute for the given parameter by using `tt.property_argi` + // where `argi` refers to the arg number of the given parameter. + // So we need to iterate through the property, find the right one + // and push the property onto the pointers attributes. + auto op = arg.getOwner()->getParentOp(); + for (NamedAttribute namedAttr : op->getAttrs()) { + StringAttr attrName = namedAttr.getName(); + std::string argSuffix = + llvm::formatv("_arg{0}", arg.getArgNumber()).str(); + if (!attrName.strref().ends_with(argSuffix)) + continue; + + auto newAttrName = attrName.strref().drop_back(argSuffix.size()); + pointerAttrs[k].attributes[newAttrName] = namedAttr.getValue(); + // Propagate the argument to the offset if it is also a block + // argument + if (auto offsetArg = dyn_cast(offset)) + op->setAttr( + llvm::formatv("{0}_arg{1}", newAttrName, offsetArg.getArgNumber()) + .str(), + namedAttr.getValue()); + } + return; + } + + // TODO(max): this doesn't make sense - ops have all sorts of dialect + // attributes? + // + // Otherwise add the attributes of the operation to the fat + // pointer auto baseAttrs = base.getDefiningOp()->getAttrs(); auto offsetAttrs + // = offset.getDefiningOp()->getAttrs(); assert(baseAttrs.size() == + // offsetAttrs.size() && + // "expected base and offset attr dicts to be same size"); + // for (auto [baseAttr, offsetAttr] : llvm::zip(baseAttrs, offsetAttrs)) { + // assert(baseAttr.getName() == offsetAttr.getName() && + // "expected base attr name == offset attr name"); + // assert(baseAttr.getValue() == offsetAttr.getValue() && + // "expected base attr value == offset attr value"); + // pointerAttrs[k].attributes[baseAttr.getName()] = baseAttr.getValue(); + // } +} + /// Flatten the given value ranges into a single vector of values. static SmallVector flattenValues(ArrayRef values) { SmallVector result; @@ -1232,6 +1284,15 @@ class ConvertUnrealizedConversionCastOp } }; +class TritonAMDGPUCanonicalizePointersPass + : public TritonAMDGPUCanonicalizePointersBase< + TritonAMDGPUCanonicalizePointersPass> { +public: + TritonAMDGPUCanonicalizePointersPass() = default; + + void runOnOperation() override; +}; + void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { mlir::MLIRContext *context = &getContext(); ConversionTarget target(*context); From e6ddf5ee2323ce77093b853cea014355d7a84eba Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 18 Dec 2024 19:44:03 -0500 Subject: [PATCH 16/17] normalize test --- .../amd/amd-canonicalize-pointers.mlir | 62 +++++++++---------- 1 file changed, 29 insertions(+), 33 deletions(-) diff --git a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir index 2dd185961e8a..38fff4d59dde 100644 --- a/test/TritonGPU/amd/amd-canonicalize-pointers.mlir +++ b/test/TritonGPU/amd/amd-canonicalize-pointers.mlir @@ -311,8 +311,9 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} #blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func private @evaluate_condition() -> i1 // CHECK-LABEL: tt.func @whileOp - tt.func @whileOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>, %cond : i1)-> tensor<1024xf32, #blocked>{ + tt.func @whileOp(%arg0: !tt.ptr, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{ %c1024_i32 = arith.constant 1024 : i32 %c0 = arith.constant 0: index %c128 = arith.constant 128: index @@ -325,14 +326,15 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> %5 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> // CHECK: %[[whileOut:.*]]:3 = scf.while ({{.*}}, %[[loopPtr:.*]] = %arg0, %[[loopOffset:.*]] = %[[base_offset]]) - %6 = scf.while (%arg1 = %5, %arg2 = %cond) : (tensor<1024x!tt.ptr, #blocked>, i1) -> (tensor<1024x!tt.ptr, #blocked>) { + %6 = scf.while (%arg1 = %5) : (tensor<1024x!tt.ptr, #blocked>) -> (tensor<1024x!tt.ptr, #blocked>) { // CHECK: scf.condition({{.*}}) %{{.*}}, %[[loopPtr]], %[[loopOffset]] - scf.condition(%arg2) %arg1 : tensor<1024x!tt.ptr, #blocked> + %cond = tt.call @evaluate_condition() : () -> i1 + scf.condition(%cond) %arg1 : tensor<1024x!tt.ptr, #blocked> } do { // CHECK: ^bb{{.*}}(%{{.*}}, %[[blockPtr:.*]]: !tt.ptr, %[[blockOffset:.*]]: tensor<1024xi64, #blocked>): ^bb0(%arg1: tensor<1024x!tt.ptr, #blocked>): // CHECK: scf.yield {{.*}}, %[[blockPtr]], %[[blockOffset]] - scf.yield %arg1, %cond : tensor<1024x!tt.ptr, #blocked>, i1 + scf.yield %arg1 : tensor<1024x!tt.ptr, #blocked> } // CHECK: %[[trunc_offset:.*]] = arith.trunci %[[whileOut]]#2 // CHECK: %[[base_ptr:.*]] = tt.splat %[[whileOut]]#1 @@ -392,7 +394,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} #blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @branch - tt.func @branch(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked>{ + tt.func @branch(%arg0 : !tt.ptr, %i1 : i1) -> tensor<1024xf32, #blocked> { %c1024_i32 = arith.constant 1024 : i32 %c0 = arith.constant 0: index %c128 = arith.constant 128: index @@ -431,7 +433,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tt.func @tile_offset - tt.func @tile_offset(%arg1: !tt.ptr, %arg5: i32 , %arg7: i32 ) { + tt.func @tile_offset(%arg1: !tt.ptr, %arg5: i32 , %arg7: i32) -> tensor<16x256xf16, #blocked> { %c128_i32 = arith.constant 128 : i32 %c256_i32 = arith.constant 256 : i32 %1 = tt.get_program_id x : i32 @@ -465,7 +467,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr // CHECK: %[[ptr:.*]] = tt.splat %[[scalarPtr]] : !tt.ptr -> tensor<16x256x!tt.ptr, #blocked> // CHECK: tt.addptr %[[ptr]], %[[tensorOffset]] : tensor<16x256x!tt.ptr, #block %61 = tt.load %46 : tensor<16x256x!tt.ptr, #blocked> - tt.return + tt.return %61 : tensor<16x256xf16, #blocked> } } @@ -486,7 +488,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tt.func public @matmul_kernel - tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}) { + tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}) -> tensor<128x16xf16, #blocked> { %c128_i32 = arith.constant 128 : i32 %0 = tt.get_program_id x : i32 %1 = arith.muli %0, %c128_i32 : i32 @@ -524,7 +526,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr // CHECK: %[[tensorOffsetTrunc:.*]] = arith.trunci %[[tensorOffset]] : tensor<128x16xi64, #blocked> to tensor<128x16xi32, #blocked> // CHECK: %[[ptr:.*]] = tt.splat %[[scalarPtr]] : !tt.ptr -> tensor<128x16x!tt.ptr, #blocked> // CHECK: tt.addptr %[[ptr]], %[[tensorOffsetTrunc]] : tensor<128x16x!tt.ptr, #blocked>, tensor<128x16xi32, #blocked> - tt.return + tt.return %15 : tensor<128x16xf16, #blocked> } } @@ -567,7 +569,7 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} #blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1100", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: tt.func @where_kernel - tt.func @where_kernel(%arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}){ + tt.func @where_kernel(%arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}) -> tensor<1024xi64, #blocked> { %c0_i8 = arith.constant 0 : i8 %c1024_i32 = arith.constant 1024 : i32 %0 = tt.get_program_id x : i32 @@ -584,7 +586,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: %[[tensorPtr:.*]] = tt.splat %[[selectPtr0]] // CHECK: tt.addptr %[[tensorPtr]] %14 = tt.load %13 : tensor<1024x!tt.ptr, #blocked> - tt.return + tt.return %14 : tensor<1024xi64, #blocked> } } @@ -681,26 +683,28 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func private @evaluate_condition() -> i1 // CHECK-LABEL: tt.func @scalar_while - tt.func @scalar_while(%arg0: !tt.ptr, %init : f32, %cond : i1)->f32{ + tt.func @scalar_while(%arg0: !tt.ptr, %init : f32) -> f32 { %c1024_i32 = arith.constant 1024 : i32 - %c0 = arith.constant 0: index - %c128 = arith.constant 128: index + %c128 = arith.constant 128: i32 %c1 = arith.constant 1 : index %0 = tt.get_program_id x : i32 %1 = arith.muli %0, %c1024_i32 : i32 // CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %{{.*}} // CHECK: scf.while ({{.*}}, {{.*}} = %arg2, %[[ptr1:.*]] = %[[ptr0]], {{.*}}) %2 = tt.addptr %arg0, %0 : !tt.ptr, i32 - %6 = scf.while (%arg1 = %2, %arg2 = %cond) : (!tt.ptr, i1) -> (!tt.ptr) { + %6 = scf.while (%arg1 = %2) : (!tt.ptr) -> (!tt.ptr) { // CHECK: scf.condition({{.*}}) {{.*}}, %[[ptr1]] - scf.condition(%arg2) %arg1 : !tt.ptr - } do { + %cond = tt.call @evaluate_condition() : () -> i1 + scf.condition(%cond) %arg1 : !tt.ptr + } do { // CHECK: ^bb0({{.*}}: !tt.ptr, %[[ptr2:.*]]: !tt.ptr, {{.*}}) // CHECK: scf.yield %{{.*}}, {{.*}} %[[ptr2]], {{.*}}, {{.*}} ^bb0(%arg1: !tt.ptr): - scf.yield %arg1, %cond : !tt.ptr, i1 - } + %newptr = tt.addptr %arg1, %c128 : !tt.ptr, i32 + scf.yield %newptr : !tt.ptr + } %11 = tt.load %6 : !tt.ptr tt.return %11 : f32 } @@ -711,26 +715,18 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} #blocked = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { // CHECK-LABEL: tt.func @scalar_cond_branch - tt.func @scalar_cond_branch(%arg0 : !tt.ptr, %i1 : i1) -> f32{ - %c1024_i32 = arith.constant 1024 : i32 - %c0 = arith.constant 0: index - %c128 = arith.constant 128: index - %c1 = arith.constant 1 : index - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %6 = tt.addptr %arg0, %0 : !tt.ptr, i32 - // CHECK: %[[ptr0:.*]] = tt.addptr %arg0 + tt.func @scalar_cond_branch(%arg0 : !tt.ptr, %arg1 : !tt.ptr, %i1 : i1) -> f32{ // CHECK: cf.cond_br %arg1, ^bb1(%{{.*}}, %[[ptr0]], {{.*}}), ^bb2(%{{.*}}, %arg0, {{.*}}) - cf.cond_br %i1, ^bb1(%6 : !tt.ptr), ^bb2(%arg0 : !tt.ptr) + cf.cond_br %i1, ^bb1(%arg0 : !tt.ptr), ^bb2(%arg1 : !tt.ptr) // CHECK: ^bb1({{.*}}, %[[ptr1:.*]]: !tt.ptr, {{.*}}): - ^bb1(%arg1 : !tt.ptr): + ^bb1(%arg3 : !tt.ptr): // CHECK: tt.load %[[ptr1]] - %out1 = tt.load %arg1 : !tt.ptr + %out1 = tt.load %arg3 : !tt.ptr tt.return %out1 : f32 // CHECK: ^bb2({{.*}}, %[[ptr2:.*]]: !tt.ptr, {{.*}}): - ^bb2(%arg2 : !tt.ptr): // 2 preds: ^bb0, ^bb1 + ^bb2(%arg4 : !tt.ptr): // 2 preds: ^bb0, ^bb1 // CHECK: tt.load %[[ptr2]] - %out2 = tt.load %arg2 : !tt.ptr + %out2 = tt.load %arg4 : !tt.ptr tt.return %out2 : f32 } } From 21b6677efab35492070743b6b750b231b6ed7367 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 18 Dec 2024 20:58:10 -0500 Subject: [PATCH 17/17] fix cf.cond_bar propagation --- .../CanonicalizePointers.cpp | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index 13bd3d4e79ff..6495c714b5bf 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -30,6 +30,7 @@ #define GEN_PASS_CLASSES #include "TritonAMDGPUTransforms/Passes.h.inc" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Format.h" #include "llvm/Support/FormatVariadic.h" @@ -1147,11 +1148,9 @@ class ConvertLoadOp : public PointerCanonicalizationPattern { Location curLoc = loadOp.getLoc(); llvm::SmallDenseMap attributes{}; - Value newPtr = fatPtrBase; - if (llvm::isa(loadOp.getPtr().getType())) - newPtr = createTensorPointer( - rewriter, fatPtrBase, fatPtrOffset, curLoc, - fatPtrs.at({fatPtrBase, fatPtrOffset}).canNarrow, attributes); + Value newPtr = createTensorPointer( + rewriter, fatPtrBase, fatPtrOffset, curLoc, + fatPtrs.at({fatPtrBase, fatPtrOffset}).canNarrow, attributes); SmallVector operands = loadOp.getOperands().take_back(loadOp.getNumOperands() - 1); operands.insert(operands.begin(), newPtr); @@ -1294,11 +1293,12 @@ class TritonAMDGPUCanonicalizePointersPass }; void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { - mlir::MLIRContext *context = &getContext(); - ConversionTarget target(*context); - RewritePatternSet patterns(context); + ConversionTarget target(getContext()); + RewritePatternSet patterns(&getContext()); tt::FuncOp func = getOperation(); + if (func.isPrivate()) + return; // forward slice == all transitive uses ForwardSliceOptions sliceOpts([](Operation *op) { @@ -1327,10 +1327,10 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { unsigned opOffset = llvm::isa(op) ? 0 : 1; for (auto successor : op->getSuccessors()) { for (auto arg : successor->getArguments()) { - // if the bb arg corresponds to an op that will be rewritten - if (opsToRewrite.contains( - op->getOperand(opOffset + arg.getArgNumber()) - .getDefiningOp())) + auto oper = op->getOperand(opOffset + arg.getArgNumber()); + // this is a heuristic - bb args with tt.ptr types need to be + // rewritten + if (llvm::isa(getElementTypeOrSelf(oper.getType()))) getForwardSlice(arg, &opsToRewrite, sliceOpts); } }