Skip to content

How to get a slice of a tensor from tl.load()? #1313

Answered by tristanheywood
zw2326 asked this question in Q&A
Discussion options

You must be logged in to vote

Unfortunately, Triton does not currently support indexing, so there is no good way to access the second row. For your test kernel, you could get around this by using tl.load() again to load the 2nd row independently. In general you can use a combination of tl.store() and tl.load() to perform indexing, however this will be likely result in poor performance.

For some context around why indexing isn't supported: from what I understand, each kernel instance (i.e. each program id) will be automatically parallelized across multiple GPU threads. So you write your kernel to parallelize the given task between a grid of kernel instances, and Triton further parallelizes each kernel instance. My impr…

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by zw2326
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants