Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T5 Encoder #2069

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

T5 Encoder #2069

wants to merge 9 commits into from

Conversation

calvinpelletier
Copy link
Contributor

@calvinpelletier calvinpelletier commented Nov 25, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?

  • T5 tokenizer
  • T5 encoder
  • convert weights from HF's T5 to ours
  • unit tests

Analysis

Comparison to HF's implemention (batch of text -> encoder output):

  • 6.7e-5 MSE output difference
  • ours is ~5% faster

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Minimal test code

import torch

from torchtune.models.t5 import t5_tokenizer, t5_v1p1_xxl_encoder
from torchtune.training.checkpointing._checkpointer import FullModelHFCheckpointer

MAX_SEQ_LEN = 512

# tune download google/t5-v1_1-xxl --output-dir /tmp/t5-hf
tokenizer = t5_tokenizer("/tmp/t5-hf/spiece.model", max_seq_len=MAX_SEQ_LEN)
checkpointer = FullModelHFCheckpointer(
    "/tmp/t5-hf",
    ["pytorch_model.bin"],
    "T5_ENCODER",
    "/tmp/t5-tt",
)

model = t5_v1p1_xxl_encoder(max_seq_len=MAX_SEQ_LEN)
model.load_state_dict(checkpointer.load_checkpoint()["model"])
model = model.to(device="cuda", dtype=torch.bfloat16).eval().requires_grad_(False)


def tokenize(texts):
    result = torch.full(
        (len(texts), tokenizer.max_seq_len),
        tokenizer.pad_id,
        dtype=torch.int,
    )
    for i, text in enumerate(texts):
        tokens = tokenizer.encode(text)
        result[i, : len(tokens)] = torch.tensor(tokens)
    return result


tokens = tokenize(
    [
        "a cow jumping over the moon",
        "a helpful AI assistant",
    ]
)
encoding = model(tokens.to("cuda"))

Copy link

pytorch-bot bot commented Nov 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2069

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 17caaf7 with merge base 32e265d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 25, 2024
Comment on lines +205 to +209
# attention with relative position bias
attn_score = torch.matmul(q, k.transpose(-2, -1))
attn_score += rel_pos_bias
attn_weight = F.softmax(attn_score.float(), dim=-1).to(attn_score.dtype)
attn_out = torch.matmul(attn_weight, v)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part could be simplified by using F.scaled_dot_product_attention by repurposing the mask argument for rel_pos_bias (because scaled_dot_product_attention simply adds the mask to the attention score when the mask is a float tensor). However, when I tried this it was significantly slower for some reason

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether this is a case where we could benefit from using flex attention?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return x.permute([2, 0, 1]).unsqueeze(0)


def _calc_birectional_rel_pos_to_bucket(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should I add more comments in this function explaining each operation? or is it fine to just leave it a bit opaque

from torchtune.models.t5._tokenizer import T5Tokenizer


def t5_v1p1_xxl_encoder(max_seq_len: int = 512) -> T5Encoder:
Copy link
Contributor Author

@calvinpelletier calvinpelletier Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on writing decimal points as p instead of _ in snake case? IMO it's hard to read the _ decimals when _ is also being used as a word separator. Like in t5_v1_1_xxl_encoder, to my eyes it looks like it's version 1 not 1.1. Plus it's ambiguous: if one day we have a "Qwen3 1.5B" and "Qwen3.1 5B", they're both gonna be named qwen3_1_5b.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the p could be better. But I'd want to switch everything to this notation then

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make a separate PR for this

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave this a high level pass so far. This looks really clean and good. My only concern is wish us having to have a custom T5 layer and attention module. If flex attention would let us use our existing modules I'd prefer to go down that route.

torchtune/models/t5/__init__.py Show resolved Hide resolved
self.sa_norm = sa_norm
self.mlp_norm = mlp_norm

def forward(self, x: Tensor, rel_pos_bias: Tensor) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rel_pos_bias is just a mask no? Couldn't we use all our standard modules here? Is the only reason we have these custom layers because this needs flex attention to be fast?

Copy link
Contributor Author

@calvinpelletier calvinpelletier Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, kind of... it's a float tensor that gets added to the attention scores (which is how scaled_dot_product_attention deals with float masks, so yeah we could think of it as a mask).

I didn't use our modules because:

  • our MultiHeadAttention/TransformerSelfAttentionLayer modules expect the masks to be boolean tensors (according to the docstring at least)
  • MultiHeadAttention uses the default attention scaling (1/sqrt(dim)), but T5 doesn't scale it at all
  • we can't use flex attention with float masks. MultiHeadAttention would use F.scaled_dot_product_attention, which is much slower for float masks than the manual implementation I went with

We could switch to our modules with a couple small changes (update the docstrings to clarify that the mask can also be a float tensor and add an argument for disabling attention scaling), but given how little code is required to just implement separate versions for T5, I thought it was cleaner to leave our attention/transformer modules alone (especially since this implementation is faster).

@@ -7,6 +7,7 @@
from typing import List, Optional

from sentencepiece import SentencePieceProcessor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove? Or was this from the linter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its from the linter

self.max_seq_len = max_seq_len
self.truncate = truncate

def encode(self, text: str) -> List[int]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't need decode like the CLIP tokenizer?

Copy link
Contributor Author

@calvinpelletier calvinpelletier Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has decode (I test it in the unit test). It's in the base tokenizer class: https://github.com/pytorch/torchtune/blob/main/torchtune/modules/tokenizers/_sentencepiece.py#L102

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants