Replies: 2 comments 13 replies
-
Hello, This RFC is interesting, but I think I disagree with its entire premise. A couple of specific points:
I don't think this is true. While TritonGPU -> LLVM is indeed abrupt when it comes to shared memory management, the rest actually just implements the semantics of
I don't think that this is a reasonable assumption. The Triton memory model opens up a lot of interesting possibilities (e.g., linked lists, tensor of trees, hashmaps, etc.) that are just fundamentally incompatible with I can see how a Triton -> Linalg conversion pass makes sense if you have some hardware whose compiler already relies on |
Beta Was this translation helpful? Give feedback.
-
@sethbrin you may be interested in #1797. There is no plan to merge this into main branch soon (this would probably live in a fork), but this should solve a lot of your problems |
Beta Was this translation helpful? Give feedback.
-
Hi Dears,
This is an RFC for adding convention from Triton to Linalg.
Background
The current lowering of TritonGPU -> LLVM is too direct and abrupt, maybe making it hard to debug/analyze in case of a Codegen bug or performance issue.
MLIR community currently provides a CodeGen progressively lowering path, which can disassemble the entire Codegen path into
Linalg->Vector->GPU->LLVM
dialects, and defines various transform dialect mechanisms for Codegen on Linalg dialect, including tiling, reduction, promotion, pad, interchange, fusion, etc.If we can build a bridge between Triton Dialect and Linalg Dialect, we can make use of the MLIR community infrastructure, reduce the complexity of adding a new backend, and the code generation of each backend can reuse the backend-independent optimization passes.
Proposal
Triton currently supports very flexible pointer operations, but the current Triton repository has only the use of a pointer of element type (May be incorrect?). If we met the pointer of the pointer, the backend may be difficult to do continuity analysis, and performance is difficult to optimize.
We think that the use scenario of pointer of pointer is not a lot, so the following design is only for a pointer of element type. The last also briefly discusses the Linalg program for how to deal with pointer of pointer.
Auxiliary Dialect
For the pointer type, we can think of it as an offset relative to the base address, so that we can express it on the Tensor, for which we introduce the auxiliary dialect, which contains the tensor_view and store operators.
LinalgExt Dialect
If we get the offset corresponding to the pointer, for the
tt.load
operator, semantically it actually takes data from a bunch of pointer addresses and assembles it into a tensor, which semantically istensor.gather
, but the official tensor.gather does not support mask/other, for this reason, we defined a LinalgExt dialect as a supplement to Linalg Dialect(The main reason why we don't introduce this operator on Tensor Dialect is that the infrastructure on Linalg, such as promotion, fusion etc, is currently not available on the Tensor Dialect)Convention from Triton to Linalg
We give a simple example to describe the conversion algorithm.
Dim Mapping
For triton, it expresses the computational logic inside the block, and use
tt.get_program_id/tt.get_num_programs
to get the index infomation.In order to have the same expression on all backends and reuse the outlining kernel pass in GPU backend, we recover it by bringing dim mapping information in scf.forall op.
For example:
Optimization
In the previous example, we can reuse
AxisInfoAnalysis
to easily know that it is continuous throughout the dim, for which we can usetensor.extract
instead oflinalg_ext.gather
and simply compute the offset of the first address by the scalar.after conversion
At this point, we are done converting Triton to the Linalg dialect, and we can use the transform ops defined on Linalg to optimize the operator.
The previous example deals with the one-dimensional continuous case, for matrix multiplication, which actually has two dimensions, we need to enhance
AxisInfoAnalysis
, introducing strides information, to map thed
-th dimension to the length of the shortest sequence of the same stride.We intercepted part of matmul's tt ir code snippet.
Pointer of Pointer Discussion
Since each pointer actually corresponds to a large global tensor, for the second level pointer, we represent it by an attribute that indexes the specific value directly according to the offset.
Take the following code as a example:
The implementation of linalg_ext.gather may shows here.
Beta Was this translation helpful? Give feedback.
All reactions