How to get a slice of a tensor from tl.load()? #1313
-
Noob question - say I
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Unfortunately, Triton does not currently support indexing, so there is no good way to access the second row. For your test kernel, you could get around this by using For some context around why indexing isn't supported: from what I understand, each kernel instance (i.e. each program id) will be automatically parallelized across multiple GPU threads. So you write your kernel to parallelize the given task between a grid of kernel instances, and Triton further parallelizes each kernel instance. My impression is that this automatic parallelization relies on each block/tensor of data being indivisible, and supporting indexing would make the parallization process much more difficult. |
Beta Was this translation helpful? Give feedback.
-
Hi, based on this reply I believe there is no way of efficiently doing something like this right?
|
Beta Was this translation helpful? Give feedback.
Unfortunately, Triton does not currently support indexing, so there is no good way to access the second row. For your test kernel, you could get around this by using
tl.load()
again to load the 2nd row independently. In general you can use a combination oftl.store()
andtl.load()
to perform indexing, however this will be likely result in poor performance.For some context around why indexing isn't supported: from what I understand, each kernel instance (i.e. each program id) will be automatically parallelized across multiple GPU threads. So you write your kernel to parallelize the given task between a grid of kernel instances, and Triton further parallelizes each kernel instance. My impr…