-
Notifications
You must be signed in to change notification settings - Fork 465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
T5 Encoder #2069
base: main
Are you sure you want to change the base?
T5 Encoder #2069
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2069
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 17caaf7 with merge base 32e265d (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
# attention with relative position bias | ||
attn_score = torch.matmul(q, k.transpose(-2, -1)) | ||
attn_score += rel_pos_bias | ||
attn_weight = F.softmax(attn_score.float(), dim=-1).to(attn_score.dtype) | ||
attn_out = torch.matmul(attn_weight, v) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part could be simplified by using F.scaled_dot_product_attention
by repurposing the mask
argument for rel_pos_bias
(because scaled_dot_product_attention
simply adds the mask to the attention score when the mask is a float tensor). However, when I tried this it was significantly slower for some reason
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder whether this is a case where we could benefit from using flex attention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flex attention requires boolean masks AFAICT: https://github.com/pytorch/pytorch/blob/main/torch/nn/attention/flex_attention.py#L826-L827
return x.permute([2, 0, 1]).unsqueeze(0) | ||
|
||
|
||
def _calc_birectional_rel_pos_to_bucket( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should I add more comments in this function explaining each operation? or is it fine to just leave it a bit opaque
from torchtune.models.t5._tokenizer import T5Tokenizer | ||
|
||
|
||
def t5_v1p1_xxl_encoder(max_seq_len: int = 512) -> T5Encoder: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thoughts on writing decimal points as p
instead of _
in snake case? IMO it's hard to read the _
decimals when _
is also being used as a word separator. Like in t5_v1_1_xxl_encoder
, to my eyes it looks like it's version 1 not 1.1. Plus it's ambiguous: if one day we have a "Qwen3 1.5B" and "Qwen3.1 5B", they're both gonna be named qwen3_1_5b
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the p could be better. But I'd want to switch everything to this notation then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll make a separate PR for this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I gave this a high level pass so far. This looks really clean and good. My only concern is wish us having to have a custom T5 layer and attention module. If flex attention would let us use our existing modules I'd prefer to go down that route.
self.sa_norm = sa_norm | ||
self.mlp_norm = mlp_norm | ||
|
||
def forward(self, x: Tensor, rel_pos_bias: Tensor) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rel_pos_bias is just a mask no? Couldn't we use all our standard modules here? Is the only reason we have these custom layers because this needs flex attention to be fast?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, kind of... it's a float tensor that gets added to the attention scores (which is how scaled_dot_product_attention
deals with float masks, so yeah we could think of it as a mask).
I didn't use our modules because:
- our
MultiHeadAttention
/TransformerSelfAttentionLayer
modules expect the masks to be boolean tensors (according to the docstring at least) MultiHeadAttention
uses the default attention scaling (1/sqrt(dim)
), but T5 doesn't scale it at all- we can't use flex attention with float masks.
MultiHeadAttention
would useF.scaled_dot_product_attention
, which is much slower for float masks than the manual implementation I went with
We could switch to our modules with a couple small changes (update the docstrings to clarify that the mask can also be a float tensor and add an argument for disabling attention scaling), but given how little code is required to just implement separate versions for T5, I thought it was cleaner to leave our attention/transformer modules alone (especially since this implementation is faster).
@@ -7,6 +7,7 @@ | |||
from typing import List, Optional | |||
|
|||
from sentencepiece import SentencePieceProcessor | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove? Or was this from the linter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its from the linter
self.max_seq_len = max_seq_len | ||
self.truncate = truncate | ||
|
||
def encode(self, text: str) -> List[int]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't need decode like the CLIP tokenizer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has decode (I test it in the unit test). It's in the base tokenizer class: https://github.com/pytorch/torchtune/blob/main/torchtune/modules/tokenizers/_sentencepiece.py#L102
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
Analysis
Comparison to HF's implemention (batch of text -> encoder output):
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example
Minimal test code