diff --git a/tests/torchtune/modules/model_fusion/test_fusion_models.py b/tests/torchtune/modules/model_fusion/test_deep_fusion.py similarity index 90% rename from tests/torchtune/modules/model_fusion/test_fusion_models.py rename to tests/torchtune/modules/model_fusion/test_deep_fusion.py index 322616276e..79b2f9ab3d 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_models.py +++ b/tests/torchtune/modules/model_fusion/test_deep_fusion.py @@ -22,7 +22,7 @@ class DummyModel(nn.Module): def __init__(self, dim, vocab_size): super().__init__() self.cache_enabled = False - self.embed = nn.Embedding(vocab_size, dim) + self.tok_embeddings = nn.Embedding(vocab_size, dim) self.q = nn.Linear(dim, dim) self.k = nn.Linear(dim, dim) self.v = nn.Linear(dim, dim) @@ -38,14 +38,22 @@ def caches_are_setup(self): def reset_caches(self): self.cache_enabled = False - def forward(self, tokens, mask, encoder_input, encoder_mask, input_pos): - x = self.embed(tokens) + def forward( + self, + tokens, + *, + mask=None, + encoder_input=None, + encoder_mask=None, + input_pos=None, + ): + x = self.tok_embeddings(tokens) if encoder_input is not None: q = self.q(x) - k = self.k(encoder_input) - v = self.v(encoder_input) + k = self.k(encoder_input) if encoder_input is not None else self.k(x) + v = self.v(encoder_input) if encoder_input is not None else self.v(x) x += nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=encoder_mask + q, k, v, attn_mask=encoder_mask if encoder_mask is not None else mask ) x = self.output(x) return x @@ -85,7 +93,7 @@ def fused_model(self, encoder, decoder) -> DeepFusionModel: return model @pytest.fixture - def inputs(self, dim, vocab_size): + def inputs(self, vocab_size): batch_size = 2 seq_len = 10 tokens = torch.randint(0, vocab_size, (batch_size, seq_len)) @@ -183,5 +191,5 @@ def test_set_trainable_params(self, fused_model, encoder, decoder): "decoder.k.bias", "decoder.v.weight", "decoder.v.bias", - "decoder.embed.weight", + "decoder.tok_embeddings.weight", } diff --git a/tests/torchtune/modules/model_fusion/test_early_fusion.py b/tests/torchtune/modules/model_fusion/test_early_fusion.py new file mode 100644 index 0000000000..d7ff407289 --- /dev/null +++ b/tests/torchtune/modules/model_fusion/test_early_fusion.py @@ -0,0 +1,336 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict + +import pytest + +import torch +from tests.test_utils import assert_expected, fixed_init_model +from torch import nn +from torchtune.modules.model_fusion import EarlyFusionModel, register_fusion_module +from torchtune.training.seed import set_seed + + +@pytest.fixture(autouse=True) +def random(): + set_seed(1) + + +class DummyModel(nn.Module): + def __init__(self, dim, vocab_size): + super().__init__() + self.cache_enabled = False + self.tok_embeddings = nn.Embedding(vocab_size, dim) + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.output = nn.Linear(dim, vocab_size) + register_fusion_module(self.output) + + def setup_caches(self, batch_size, dtype, *args, **kwargs): + self.cache_enabled = True + + def caches_are_setup(self): + return self.cache_enabled + + def reset_caches(self): + self.cache_enabled = False + + def forward( + self, + tokens, + *, + mask=None, + encoder_input=None, + encoder_mask=None, + input_pos=None, + ): + x = self.tok_embeddings(tokens) + if encoder_input is not None: + q = self.q(x) + k = self.k(encoder_input) if encoder_input is not None else self.k(x) + v = self.v(encoder_input) if encoder_input is not None else self.v(x) + x += nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=encoder_mask if encoder_mask is not None else mask + ) + x = self.output(x) + return x + + +class TestEarlyFusionModel: + @pytest.fixture + def vocab_size(self) -> int: + return 100 + + @pytest.fixture + def dim(self) -> int: + return 64 + + @pytest.fixture + def batch_size(self) -> int: + return 2 + + @pytest.fixture + def seq_len(self) -> int: + return 10 + + @pytest.fixture + def decoder(self, dim, vocab_size) -> nn.Module: + decoder = DummyModel(dim, vocab_size) + fixed_init_model(decoder, max_val=0.1) + return decoder + + @pytest.fixture + def fused_model(self, vocab_size, dim, decoder) -> EarlyFusionModel: + red = nn.Embedding(vocab_size, dim) + fixed_init_model(red) + green = nn.Embedding(vocab_size, dim) + fixed_init_model(green) + blue = nn.Embedding(vocab_size, dim) + fixed_init_model(blue) + + model = EarlyFusionModel( + encoders={"red": red, "green": green, "blue": blue}, + decoder=decoder, + # These are IDs that are out of vocab in the decoder + encoder_tokens={ + "red": vocab_size, + "green": vocab_size + 1, + "blue": vocab_size + 2, + }, + decoder_trainable=True, + encoders_trainable={"red": False, "green": True, "blue": False}, + fusion_trainable=False, + ) + return model + + @pytest.fixture + def inputs(self, batch_size, seq_len, vocab_size): + tokens = torch.randint(0, vocab_size, (batch_size, seq_len)) + red_seq_len, green_seq_len, blue_seq_len = 1, 2, 3 + tokens[:, 0] = vocab_size + tokens[:, 3:5] = vocab_size + 1 + tokens[:, 7:] = vocab_size + 2 + encoder_input = { + "red": {"input": torch.randint(0, vocab_size, (batch_size, red_seq_len))}, + "green": { + "input": torch.randint(0, vocab_size, (batch_size, green_seq_len)) + }, + "blue": {"input": torch.randint(0, vocab_size, (batch_size, blue_seq_len))}, + } + encoder_mask = torch.randint(0, 2, (batch_size, seq_len, seq_len)).bool() + input_pos = torch.Tensor([1]).int() + return tokens, encoder_input, encoder_mask, input_pos + + @pytest.fixture + def state_dict(self, dim, vocab_size): + return OrderedDict( + { + "decoder.q.weight": torch.randn((dim, dim)), + "decoder.q.bias": torch.randn((dim,)), + "decoder.k.weight": torch.randn((dim, dim)), + "decoder.k.bias": torch.randn((dim,)), + "decoder.v.weight": torch.randn((dim, dim)), + "decoder.v.bias": torch.randn((dim,)), + "decoder.output.weight": torch.randn((vocab_size, dim)), + "decoder.output.bias": torch.randn((vocab_size,)), + "decoder.tok_embeddings.weight": torch.randn((vocab_size, dim)), + "encoders.red.weight": torch.randn((vocab_size, dim)), + "encoders.green.weight": torch.randn((vocab_size, dim)), + "encoders.blue.weight": torch.randn((vocab_size, dim)), + } + ) + + @torch.no_grad() + def test_forward(self, fused_model, inputs, vocab_size): + """ + Test that the forward pass of the EarlyFusionModel works as expected. + """ + tokens, encoder_input, *_ = inputs + batch_size, seq_len = tokens.shape + + out = fused_model( + tokens, + encoder_input=encoder_input, + ) + + assert out.shape == (batch_size, seq_len, vocab_size) + assert_expected(out.mean(), torch.tensor(0.5647), atol=1e-3, rtol=1e-3) + + @torch.no_grad() + def test_forward_no_decoder(self, fused_model, inputs, dim): + """ + Test that the forward pass of the EarlyFusionModel works as expected. + """ + tokens, encoder_input, *_ = inputs + batch_size, seq_len = tokens.shape + + # No-op for the decoder + class DummyModule(nn.Module): + def forward(self, x, **kwargs): + return x + + fused_model.decoder = DummyModule() + + out = fused_model( + tokens, + encoder_input=encoder_input, + ) + + assert out.shape == (batch_size, seq_len, dim) + # Check that each encoder output is placed correctly in the fused output + red = fused_model.encoders["red"](**encoder_input["red"]) + assert_expected(out[:, :1, :], red, atol=1e-3, rtol=1e-3) + green = fused_model.encoders["green"](**encoder_input["green"]) + assert_expected(out[:, 3:5, :], green, atol=1e-3, rtol=1e-3) + blue = fused_model.encoders["blue"](**encoder_input["blue"]) + assert_expected(out[:, 7:, :], blue, atol=1e-3, rtol=1e-3) + + @torch.no_grad() + def test_forward_no_encoder(self, fused_model, batch_size, seq_len, vocab_size): + """ + Test the forward pass of the EarlyFusionModel with no encoder input. + """ + tokens = torch.randint(0, vocab_size, (batch_size, seq_len)) + + actual = fused_model(tokens) + expected = fused_model.decoder(fused_model.tok_embeddings(tokens)) + + assert_expected(actual, expected, atol=1e-3, rtol=1e-3) + + @torch.no_grad() + def test_forward_no_decoder_uneven_encoder_tokens( + self, fused_model, dim, batch_size, seq_len, vocab_size + ): + """ + If each sample has a different number of encoder tokens in the sequence, test that mask scatter + of embeds still works as expected: + + This is a dog. + My dog is better than yours. + """ + red_seq_len, green_seq_len, blue_seq_len = 1, 2, 3 + # In a real encoder input, it would be padded to max number of media in the batch, so we don't + # make these test inputs uneven. The forward pass should still be able to take the number of embeddings + # it needs and ignore the rest, which would be pad embeddings. + encoder_input = { + "red": {"input": torch.randint(0, vocab_size, (batch_size, red_seq_len))}, + "green": { + "input": torch.randint(0, vocab_size, (batch_size, green_seq_len)) + }, + "blue": {"input": torch.randint(0, vocab_size, (batch_size, blue_seq_len))}, + } + tokens = torch.randint(0, vocab_size, (batch_size, seq_len)) + # For red encoder, only the first sample has a token + tokens[0, 0] = vocab_size + # For green encoder, first sample has 2 tokens, second sample has 1 token + tokens[0, 3:5] = vocab_size + 1 + tokens[1, 4] = vocab_size + 1 + # For blue encoder, first sample has 3 tokens, second sample has 2 tokens + tokens[0, 7:] = vocab_size + 2 + tokens[1, 8:] = vocab_size + 2 + + # No-op for the decoder + class DummyModule(nn.Module): + def forward(self, x, **kwargs): + return x + + fused_model.decoder = DummyModule() + + out = fused_model( + tokens, + encoder_input=encoder_input, + ) + + assert out.shape == (batch_size, seq_len, dim) + # Check that each encoder output is placed correctly in the fused output + red = fused_model.encoders["red"](**encoder_input["red"]) + assert_expected(out[0, 0, :], red[0, 0, :], atol=1e-3, rtol=1e-3) + green = fused_model.encoders["green"](**encoder_input["green"]) + assert_expected(out[0, 3:5, :], green[0, :, :], atol=1e-3, rtol=1e-3) + assert_expected(out[1, 4, :], green[1, 0, :], atol=1e-3, rtol=1e-3) + blue = fused_model.encoders["blue"](**encoder_input["blue"]) + assert_expected(out[0, 7:, :], blue[0, :, :], atol=1e-3, rtol=1e-3) + assert_expected(out[1, 8:, :], blue[1, :2, :], atol=1e-3, rtol=1e-3) + + @torch.no_grad() + def test_decoder_forward(self, fused_model, inputs, vocab_size): + """ + Test that the forward pass of the EarlyFusionModel works during decoding. + """ + tokens, encoder_input, encoder_mask, input_pos = inputs + tokens = tokens[:, input_pos] + encoder_mask = encoder_mask[:, input_pos] + batch_size, seq_len = tokens.shape + out = fused_model( + tokens, + encoder_input=encoder_input, + encoder_mask=encoder_mask, + input_pos=input_pos, + ) + + assert out.shape == (batch_size, seq_len, vocab_size) + assert_expected(out.mean(), torch.tensor(0.2383), atol=1e-3, rtol=1e-3) + + def test_setup_cache(self, fused_model): + """ + Test that the cache methods works as expected. + """ + fused_model.setup_caches(2, torch.float32) + assert fused_model.caches_are_setup() + fused_model.reset_caches() + assert not fused_model.caches_are_setup() + + def test_set_trainable_params(self, fused_model): + """ + Test that the trainable parameters are set correctly. + """ + trainable_params = { + n for n, p in fused_model.named_parameters() if p.requires_grad + } + assert trainable_params == { + "decoder.q.weight", + "decoder.q.bias", + "decoder.k.weight", + "decoder.k.bias", + "decoder.v.weight", + "decoder.v.bias", + "tok_embeddings.weight", + "encoders.green.weight", + } + + def test_mismatched_encoder_tokens(self, decoder): + with pytest.raises(ValueError): + _ = EarlyFusionModel( + encoders={"encoder": nn.Identity(), "encoder2": nn.Identity()}, + decoder=decoder, + encoder_tokens={"encoder": 0, "encoder3": 1}, + encoders_trainable=False, + ) + + def test_mismatched_encoder_trainable(self, decoder): + with pytest.raises(ValueError): + _ = EarlyFusionModel( + encoders={"encoder": nn.Identity(), "encoder2": nn.Identity()}, + decoder=decoder, + encoder_tokens={"encoder": 0, "encoder2": 1}, + encoders_trainable={"encoder": True, "encoder3": False}, + ) + + def test_mismatched_encoder_input(self, fused_model, inputs): + tokens, _, _, _ = inputs + with pytest.raises(ValueError): + _ = fused_model( + tokens, + encoder_input={"encoder": {"input": torch.tensor([1])}}, + ) + + def test_state_dict_hooks(self, fused_model, state_dict): + fused_model.load_state_dict(state_dict) + actual = fused_model.state_dict() + expected = state_dict + assert_expected(actual, expected) diff --git a/tests/torchtune/modules/model_fusion/test_fusion_embed.py b/tests/torchtune/modules/model_fusion/test_fusion_embed.py deleted file mode 100644 index 35ef5c0e87..0000000000 --- a/tests/torchtune/modules/model_fusion/test_fusion_embed.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import pytest - -import torch -from tests.test_utils import assert_expected, fixed_init_model -from torchtune.modules.model_fusion import FusionEmbedding -from torchtune.training.seed import set_seed - - -@pytest.fixture(autouse=True) -def random(): - set_seed(1) - - -class TestFusionEmbedding: - """ - Class for testing our FusionEmbedding. - """ - - @pytest.fixture - def dim(self) -> int: - return 2 - - @pytest.fixture - def vocab_size(self) -> int: - return 10 - - @pytest.fixture - def fusion_vocab_size(self) -> int: - return 5 - - @pytest.fixture - def embed(self, dim, vocab_size, fusion_vocab_size) -> FusionEmbedding: - embeds = FusionEmbedding( - vocab_size=vocab_size, fusion_vocab_size=fusion_vocab_size, embed_dim=dim - ) - fixed_init_model(embeds.embedding, min_val=0, max_val=0.5) - fixed_init_model(embeds.fusion_embedding, min_val=0.51, max_val=1) - return embeds - - @torch.no_grad() - def test_forward(self, embed, vocab_size, fusion_vocab_size, dim): - """ - Test that the forward pass of the FusionEmbedding works as expected. - """ - tokens = torch.randint(0, vocab_size + fusion_vocab_size, (2, 10)) - out = embed(tokens) - - assert out.shape == (2, 10, dim) - assert_expected(out.mean(), torch.tensor(0.3409), atol=1e-3, rtol=1e-3) - - # Only new tokens, embeddings should be > 0.5 - tokens = torch.randint(vocab_size, vocab_size + fusion_vocab_size, (2, 10)) - out = embed(tokens) - - assert out.min() > 0.5 - - # Only old tokens, embeddings should be < 0.5 - tokens = torch.randint(0, vocab_size, (2, 10)) - out = embed(tokens) - - assert out.max() < 0.5 - - def test_fusion_params(self, embed): - """ - Test that the currect fusion params are returned. - """ - fusion_params = set(embed.fusion_params()) - - assert fusion_params == {"fusion_embedding.weight"} - - def test_get_and_load_state_dict(self, embed): - """ - Test that the state dict hooks work in removing the "layer" variable - """ - state_dict = embed.state_dict() - state_keys = set(state_dict.keys()) - - assert state_keys == { - "weight", - "fusion_embedding.weight", - } - - # Check that the state_dict can be loaded back into the model - embed.load_state_dict(state_dict) diff --git a/tests/torchtune/modules/model_fusion/test_fusion_layer.py b/tests/torchtune/modules/model_fusion/test_fusion_layers.py similarity index 67% rename from tests/torchtune/modules/model_fusion/test_fusion_layer.py rename to tests/torchtune/modules/model_fusion/test_fusion_layers.py index a2fc0715eb..da8fdb4b1f 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_layer.py +++ b/tests/torchtune/modules/model_fusion/test_fusion_layers.py @@ -9,7 +9,7 @@ import torch from tests.test_utils import assert_expected, fixed_init_model from torch import nn -from torchtune.modules.model_fusion import FusionLayer +from torchtune.modules.model_fusion import FusionEmbedding, FusionLayer from torchtune.training.seed import set_seed @@ -60,6 +60,79 @@ def forward(self, x): return self.linear(x) +class TestFusionEmbedding: + """ + Class for testing our FusionEmbedding. + """ + + @pytest.fixture + def dim(self) -> int: + return 2 + + @pytest.fixture + def vocab_size(self) -> int: + return 10 + + @pytest.fixture + def fusion_vocab_size(self) -> int: + return 5 + + @pytest.fixture + def embed(self, dim, vocab_size, fusion_vocab_size) -> FusionEmbedding: + embeds = FusionEmbedding( + vocab_size=vocab_size, fusion_vocab_size=fusion_vocab_size, embed_dim=dim + ) + fixed_init_model(embeds.embedding, min_val=0, max_val=0.5) + fixed_init_model(embeds.fusion_embedding, min_val=0.51, max_val=1) + return embeds + + @torch.no_grad() + def test_forward(self, embed, vocab_size, fusion_vocab_size, dim): + """ + Test that the forward pass of the FusionEmbedding works as expected. + """ + tokens = torch.randint(0, vocab_size + fusion_vocab_size, (2, 10)) + out = embed(tokens) + + assert out.shape == (2, 10, dim) + assert_expected(out.mean(), torch.tensor(0.3409), atol=1e-3, rtol=1e-3) + + # Only new tokens, embeddings should be > 0.5 + tokens = torch.randint(vocab_size, vocab_size + fusion_vocab_size, (2, 10)) + out = embed(tokens) + + assert out.min() > 0.5 + + # Only old tokens, embeddings should be < 0.5 + tokens = torch.randint(0, vocab_size, (2, 10)) + out = embed(tokens) + + assert out.max() < 0.5 + + def test_fusion_params(self, embed): + """ + Test that the currect fusion params are returned. + """ + fusion_params = set(embed.fusion_params()) + + assert fusion_params == {"fusion_embedding.weight"} + + def test_get_and_load_state_dict(self, embed): + """ + Test that the state dict hooks work in removing the "layer" variable + """ + state_dict = embed.state_dict() + state_keys = set(state_dict.keys()) + + assert state_keys == { + "weight", + "fusion_embedding.weight", + } + + # Check that the state_dict can be loaded back into the model + embed.load_state_dict(state_dict) + + class TestFusionLayer: """ Class for testing our FusionLayer wrapper. diff --git a/torchtune/modules/model_fusion/__init__.py b/torchtune/modules/model_fusion/__init__.py index 7ad788bd57..21a3c1c063 100644 --- a/torchtune/modules/model_fusion/__init__.py +++ b/torchtune/modules/model_fusion/__init__.py @@ -4,7 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from ._fusion import DeepFusionModel, FusionEmbedding, FusionLayer +from ._deep_fusion import DeepFusionModel +from ._early_fusion import EarlyFusionModel +from ._fusion_layers import FusionEmbedding, FusionLayer from ._fusion_utils import get_fusion_params, register_fusion_module __all__ = [ @@ -13,4 +15,5 @@ "FusionEmbedding", "register_fusion_module", "get_fusion_params", + "EarlyFusionModel", ] diff --git a/torchtune/modules/model_fusion/_deep_fusion.py b/torchtune/modules/model_fusion/_deep_fusion.py new file mode 100644 index 0000000000..6a61c43744 --- /dev/null +++ b/torchtune/modules/model_fusion/_deep_fusion.py @@ -0,0 +1,212 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional, Union + +import torch +from torch import nn +from torchtune.modules import TransformerDecoder +from torchtune.modules.model_fusion._fusion_utils import get_fusion_params +from torchtune.modules.peft._utils import set_trainable_params + + +class DeepFusionModel(nn.Module): + """DeepFusion is a type of fused model architecture where a pretrained encoder is combined + with a pretrained decoder (LLM) in the internal decoder layers. This is a popular architecture for multimodal models, with + a full overview available in `The Evolution of Multimodal Model Architectures `_. + A common deep fusion architecture is to fuse the encoder input into the decoder with interspersed cross-attention + layers. This module makes no assumptions on how the encoder and decoder are fused; it simply + passes in the encoder embeddings to the decoder and lets the decoder handle any fusion. + + This module has the same methods and forward signature as :class:`~torchtune.modules.TransformerDecoder` and can be used + interchangeably where :class:`~torchtune.modules.TransformerDecoder` is. It combines the encoder with the decoder as a + single module for checkpointing and finetuning. It is expected that the encoder and decoder + are already defined with any extra learnable ``fusion_params``: learnable parameters to help + adapt the pre-trained encoder to the pre-trained decoder. + + DeepFusionModel currently only supports a single encoder. + + Example: + >>> # decoder is a TransformerDecoder (e.g. llama3_8b) with fused cross attention layers + >>> embed = FusionEmbedding(...) + >>> layer = FusionLayer( + ... layer=TransformerSelfAttentionLayer(...), + ... fusion_layer=TransformerCrossAttentionLayer(...), + ... ) + >>> decoder = TransformerDecoder(tok_embeddings=embed, layers=layer, num_layers=32, ...) + >>> + >>> # encoder is pre-trained encoder (e.g. clip_vit_224) with an added projection head + >>> projection_head = FeedForward(...) + >>> register_fusion_module(projection_head)) + >>> encoder = nn.Sequential(clip_vit_224(), projection_head) + >>> + >>> # DeepFusionModel combines the encoder and decoder + >>> model = DeepFusionModel(decoder, encoder) + >>> + >>> # Load full fused checkpoints (e.g. a Flamingo checkpoint) + >>> model.load_state_dict(...) + >>> + >>> # Or load pretrained individual models (fusion_params are not loaded) + >>> model.encoder.load_state_dict(..., strict=False) + >>> model.decoder.load_state_dict(..., strict=False) + >>> + >>> # Forward pass + >>> output = model(tokens, mask, encoder_input, encoder_mask, input_pos) + + Args: + decoder (TransformerDecoder): decoder module + encoder (nn.Module): encoder module + decoder_trainable (bool): whether to train or freeze the decoder. Default is False. + encoder_trainable (bool): whether to train or freeze the encoder. Default is False. + fusion_trainable (bool): whether to train the fusion parameters. Default is True. + + """ + + def __init__( + self, + decoder: TransformerDecoder, + encoder: nn.Module, + *, + decoder_trainable: bool = False, + encoder_trainable: bool = False, + fusion_trainable: bool = True, + ): + super().__init__() + self.decoder = decoder + self.encoder = encoder + + trainable_params = set() + if encoder_trainable: + trainable_params |= { + f"encoder.{n}" for n, p in self.encoder.named_parameters() + } + if decoder_trainable: + trainable_params |= { + f"decoder.{n}" for n, p in self.decoder.named_parameters() + } + if fusion_trainable: + trainable_params |= set(get_fusion_params(self)) + else: + trainable_params -= set(get_fusion_params(self)) + set_trainable_params(self, trainable_params) + + def set_num_output_chunks(self, num_output_chunks: int) -> None: + """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. + This should be called before the first forward pass, in the recipe.""" + self.decoder.set_num_output_chunks(num_output_chunks) + + def setup_caches( + self, + batch_size: int, + dtype: torch.dtype, + *, + encoder_max_seq_len: Optional[int] = None, + decoder_max_seq_len: Optional[int] = None, + ): + """ + Sets up key-value attention caches for inference for ``self.decoder``. + For each layer in ``self.decoder.layers``: + - :class:`torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``. + - :class:`torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``. + - :class:`torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + encoder_max_seq_len (Optional[int]): maximum encoder cache sequence length. + decoder_max_seq_len (Optional[int]): maximum decoder cache sequence length. + """ + self.decoder.setup_caches( + batch_size, + dtype, + encoder_max_seq_len=encoder_max_seq_len, + decoder_max_seq_len=decoder_max_seq_len, + ) + + def caches_are_setup(self) -> bool: + """ + Check if the key value caches are setup. This means ``setup_caches`` has been called, and + the relevant attention modules in the model have created their ``KVCache``. + """ + return self.decoder.caches_are_setup() + + def caches_are_enabled(self) -> bool: + """ + Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant + attention modules will be "enabled" and all forward passes will update the caches. This behaviour + can be disabled without altering the state of the KV-caches by "disabling" the KV-caches + using :func:`~torchtune.modules.common_utils.disable_kv_cache`, upon which ``caches_are_enabled`` would return False. + """ + return self.decoder.caches_are_enabled() + + def reset_caches(self): + """ + Resets KV-cache buffers on relevant attention modules to zero, and reset cache positions to zero, + without deleting or reallocating cache tensors. + """ + self.decoder.reset_caches() + + def forward( + self, + tokens: torch.Tensor, + *, + mask: Optional[torch.Tensor] = None, + encoder_input: Optional[Dict] = None, + encoder_mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """ + Args: + tokens (torch.Tensor): input tensor with shape ``[b x s]`` + mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask + with shape ``[b x s x s]``. This is applied after the query-key multiplication and + before the softmax. A value of True in row i and column j means token i attends + to token j. A value of False means token i does not attend to token j. If no + mask is specified, a causal mask is used by default. Default is None. + encoder_input (Optional[Dict]): Optional input for the encoder. + encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between + tokens and encoder embeddings. A True value at position i,j means token i can attend + to embedding j in the decoder. Mask has shape ``[b x s x s_e]``. Default is None. + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape ``[b x s]``. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Note: At the very first step of inference, when the model is provided with a prompt, + ``input_pos`` would contain the positions of all of the tokens in the prompt + (eg: ``torch.arange(prompt_length)``). This is because we will need to compute the + KV values for each position. + + Returns: + Tensor: output tensor with shape ``[b x s x v]`` or a list of layer \ + output tensors defined by ``output_hidden_states`` with the \ + final output tensor appended to the list. + + Notation used for tensor shapes: + - b: batch size + - s: token sequence length + - s_e: encoder sequence length + - v: vocab size + - d: token embed dim + - d_e: encoder embed dim + - m_s: max seq len + """ + # During decoding, encoder_input will only be provided + # for new inputs. Previous encoder outputs are cached + # in the decoder cache. + encoder_embed = None + if encoder_input is not None: + encoder_embed = self.encoder(**encoder_input) + + output = self.decoder( + tokens=tokens, + mask=mask, + encoder_input=encoder_embed, + encoder_mask=encoder_mask, + input_pos=input_pos, + ) + return output diff --git a/torchtune/modules/model_fusion/_early_fusion.py b/torchtune/modules/model_fusion/_early_fusion.py new file mode 100644 index 0000000000..d20b2d119f --- /dev/null +++ b/torchtune/modules/model_fusion/_early_fusion.py @@ -0,0 +1,278 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch import nn +from torchtune.modules import TransformerDecoder +from torchtune.modules.model_fusion._fusion_utils import get_fusion_params +from torchtune.modules.peft._utils import set_trainable_params + + +class EarlyFusionModel(nn.Module): + """EarlyFusion is a type of fused model architecture where pretrained encoder(s) are combined + with a pretrained decoder (LLM) at the model input and not in internal layers. This is a popular architecture + for multimodal models, with a full overview available in `The Evolution of Multimodal Model Architectures + `_. This module works both for decoders in which the encoder tokens are + inside the vocab and outside the vocab. + + This module has the same methods and forward signature as :class:`~torchtune.modules.TransformerDecoder` and can be used + interchangeably where :class:`~torchtune.modules.TransformerDecoder` is. It combines the encoders with the decoder as a + single module for checkpointing and finetuning. It is expected that the encoders and decoder + are already defined with any extra learnable ``fusion_params``: learnable parameters to help + adapt the pre-trained encoders to the pre-trained decoder. + + You can pass in multiple encoders as a dictionary into ``encoders``. + + Note: Once the decoder is wrapped in this module, the decoder's ``tok_embeddings`` module is moved + to the parent EarlyFusionModel's ``tok_embeddings``. You should not forward pass the decoder individually. + Instead, use EarlyFusionModel's forward pass with ``encoder_input=None`` to get decoder-only outputs. + State dicts will automatically be updated on save and load to account for this change. + + Example: + >>> # decoder is a text-only TransformerDecoder (e.g. llama3_8b) with no modifications + >>> decoder = llama3_8b() + >>> + >>> # encoder is pre-trained encoder (e.g. clip_vit_224) with an added projection head + >>> projection_head = FeedForward(...) + >>> register_fusion_module(projection_head)) + >>> encoders = {"image": nn.Sequential(clip_vit_224(), projection_head)} + >>> + >>> # EarlyFusionModel combines the encoder and decoder + >>> model = EarlyFusionModel(decoder, encoders, encoder_tokens={"image": 128256}) + >>> + >>> # Load full fused checkpoints + >>> model.load_state_dict(...) + >>> + >>> # Forward pass + >>> encoder_input = {"image": {...}} + >>> output = model(tokens, mask=mask, encoder_input=encoder_input, encoder_mask=encoder_mask, input_pos=input_pos) + >>> + >>> # Forward pass decoder only + >>> output = model(tokens, mask=mask, input_pos=input_pos) + + Args: + decoder (TransformerDecoder): decoder module + encoders (Dict[str, nn.Module]): dictionary mapping encoder name as a string to the encoder module. + encoder_tokens (Dict[str, int]): dictionary mapping encoder name to special token ID indicating where + in the text sequence the encoder embedding outputs should be injected. + decoder_trainable (bool): whether to train or freeze the decoder. Default is False. + encoders_trainable (Union[bool, Dict[str, bool]]): whether to train or freeze the encoder. Use a single + boolean to set trainable for all encoders or a dictionary keyed by encoder names to specify trainable + for each encoder individually. Encoder names should match with ``encoders``. Default is False. + fusion_trainable (bool): whether to train the fusion parameters. Default is True. + + Raises: + ValueError: if ``encoders`` and ``encoders_trainable`` keys do not match + """ + + def __init__( + self, + decoder: TransformerDecoder, + encoders: Dict[str, nn.Module], + encoder_tokens: Dict[str, int], + decoder_trainable: bool = False, + encoders_trainable: Union[bool, Dict[str, bool]] = False, + fusion_trainable: bool = True, + ): + super().__init__() + if encoders.keys() != encoder_tokens.keys() or ( + not isinstance(encoders_trainable, bool) + and encoders.keys() != encoders_trainable.keys() + ): + raise ValueError( + f"Found mismatched keys in encoders, encoder_tokens, and/or encoders_trainable. Expected {encoders.keys()}" + ) + + self.decoder = decoder + self.encoders = nn.ModuleDict(encoders) + self.encoder_tokens = encoder_tokens + self.encoders_trainable = ( + {k: encoders_trainable for k in self.encoders.keys()} + if isinstance(encoders_trainable, bool) + else encoders_trainable + ) + + # A little surgery in the decoder to give the + # fusion module access to control the embeddings + # The alternative is to pass a special tok_embeddings + # module into TransformerDecoder builder that does the + # merging there + self.tok_embeddings = decoder.tok_embeddings + decoder.tok_embeddings = nn.Identity() + + self._register_state_dict_hook(self._state_dict_hook) + self.register_load_state_dict_pre_hook(self._load_state_dict_hook) + + trainable_params = set() + for encoder, trainable in self.encoders_trainable.items(): + if trainable: + trainable_params |= { + f"encoders.{encoder}.{n}" + for n, p in self.encoders[encoder].named_parameters() + } + if decoder_trainable: + trainable_params |= { + f"decoder.{n}" for n, p in self.decoder.named_parameters() + } + trainable_params |= { + f"tok_embeddings.{n}" for n, p in self.tok_embeddings.named_parameters() + } + if fusion_trainable: + trainable_params |= set(get_fusion_params(self)) + else: + trainable_params -= set(get_fusion_params(self)) + + set_trainable_params(self, trainable_params) + + @staticmethod + def _state_dict_hook(module, state_dict, *args, **kwargs): + """ + Keep tok_embeddings inside of decoder state_dict + + [!Note] This update changes the order of the OrderedDict + """ + for n, p in module.tok_embeddings.named_parameters(): + state_dict[f"decoder.tok_embeddings.{n}"] = p + del state_dict[f"tok_embeddings.{n}"] + + @staticmethod + def _load_state_dict_hook(module, state_dict, *args, **kwargs): + """Undo the change from _state_dict_hook""" + old_keys = list(state_dict.keys()) + for key in old_keys: + if key.startswith("decoder.tok_embeddings"): + state_dict[key[len("decoder.") :]] = state_dict[key] + del state_dict[key] + + def set_num_output_chunks(self, num_output_chunks: int) -> None: + """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. + This should be called before the first forward pass, in the recipe.""" + self.decoder.set_num_output_chunks(num_output_chunks) + + def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: + """Setup key value caches for attention calculation. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + """ + self.decoder.setup_caches(batch_size, dtype) + + def caches_are_setup(self) -> bool: + """ + Check if the key value caches are setup. This means ``setup_caches`` has been called, and + the relevant attention modules in the model have created their ``KVCache``. + """ + return self.decoder.caches_are_setup() + + def caches_are_enabled(self) -> bool: + """ + Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant + attention modules will be "enabled" and all forward passes will update the caches. This behaviour + can be disabled without altering the state of the KV-caches by "disabling" the KV-caches + using :func:`~torchtune.modules.common_utils.disable_kv_cache`, upon which ``caches_are_enabled`` would return False. + """ + return self.decoder.caches_are_enabled() + + def reset_caches(self): + """Reset the key value caches.""" + self.decoder.reset_caches() + + def _decoder_embed(self, tokens) -> Tuple[torch.Tensor, torch.Tensor]: + """Embed the text-only tokens with the decoder's tok_embeddings""" + encoder_token_ids = torch.tensor(list(self.encoder_tokens.values())) + # [bsz, seq_len], True indicates the token is not an encoder special token + is_text = ~torch.isin(tokens, encoder_token_ids) + text_tokens = torch.masked_select(tokens, is_text) + # [num_text, embed_dim] + text_embeds = self.tok_embeddings(text_tokens) + return is_text, text_embeds + + def forward( + self, + tokens: torch.Tensor, + *, + mask: Optional[torch.Tensor] = None, + encoder_input: Optional[Dict[str, Dict[str, Any]]] = None, + input_pos: Optional[torch.Tensor] = None, + **kwargs: Dict[str, Any], # no need for encoder_mask + ) -> torch.Tensor: + """ + Note: This module assumes that there will be enough encoder inputs (i.e., total number of images in the batch) + for the number of encoder tokens in the batch. + + Args: + tokens (torch.Tensor): input tensor with shape ``[b x s]`` + mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask + with shape ``[b x s x s]``. This is applied after the query-key multiplication and + before the softmax. A value of True in row i and column j means token i attends + to token j. A value of False means token i does not attend to token j. If no + mask is specified, a causal mask is used by default. Default is None. + encoder_input (Optional[Dict[str, Dict[str, Any]]]): Optional input kwargs for the encoders. Must be + keyed by encoder name and match the keys of ``encoders`` + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape ``[b x s]``. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + **kwargs (Dict[str, Any]): additional keyword arguments. This is solely used to match the + :class:`~torchtune.modules.TransformerDecoder` forward and does not have any effect. + + Note: At the very first step of inference, when the model is provided with a prompt, + ``input_pos`` would contain the positions of all of the tokens in the prompt + (eg: ``torch.arange(prompt_length)``). This is because we will need to compute the + KV values for each position. + + Returns: + torch.Tensor: output tensor with shape ``[b x s x v]`` or a list of layer \ + output tensors defined by ``output_hidden_states`` with the \ + final output tensor appended to the list. + + Raises: + ValueError: if ``encoder_input`` keys do not match ``encoders`` keys + + Notation used for tensor shapes: + - b: batch size + - s: token sequence length + - s_e: encoder sequence length + - v: vocab size + - d: token embed dim + - d_e: encoder embed dim + - m_s: max seq len + """ + if encoder_input is not None and encoder_input.keys() != self.encoders.keys(): + raise ValueError( + f"Found mismatched keys in encoder_input and instantiated encoders. " + f"Got {encoder_input.keys()}, expected {self.encoders.keys()}." + ) + + bsz, seq_len = tokens.shape + # is_text: [bsz, seq_len], text_embeds: [num_text, embed_dim] + is_text, text_embeds = self._decoder_embed(tokens) + embed_dim = text_embeds.shape[-1] + + # Holds the final embedding vector + fused_embeds = torch.empty( + bsz, seq_len, embed_dim, dtype=text_embeds.dtype, device=text_embeds.device + ) + # Place the text-only embeddings + fused_embeds = fused_embeds.masked_scatter(is_text.unsqueeze(-1), text_embeds) + + encoder_input = encoder_input or {} + for encoder, inp in encoder_input.items(): + # [bsz, num_encoder_tokens, embed_dim] + encoder_embeds = self.encoders[encoder](**inp) + # [bsz * num_encoder_tokens, embed_dim] + encoder_embeds = encoder_embeds.view(-1, embed_dim) + # [bsz, seq_len, 1] + encoder_mask = (tokens == self.encoder_tokens[encoder]).unsqueeze(-1) + # At locations where encoder token is found, replace with encoder embedding + fused_embeds = fused_embeds.masked_scatter(encoder_mask, encoder_embeds) + + output = self.decoder(fused_embeds, mask=mask, input_pos=input_pos) + return output diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion_layers.py similarity index 54% rename from torchtune/modules/model_fusion/_fusion.py rename to torchtune/modules/model_fusion/_fusion_layers.py index 907c7a2ed0..7fe3939ec4 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion_layers.py @@ -4,13 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, List, Optional, Union +from typing import Dict, List import torch from torch import nn -from torchtune.modules import TransformerDecoder -from torchtune.modules.model_fusion._fusion_utils import get_fusion_params -from torchtune.modules.peft._utils import set_trainable_params class FusionLayer(nn.Module): @@ -273,9 +270,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # num_fusion_tokens = (input >= vocab_size).sum() fusion_tokens = torch.masked_select(input, ~mask) - vocab_size - # [batch_size x num_tokens x embed_dim] + # [batch_size * num_tokens, embed_dim] embeds = self.embedding(tokens) - # [batch_size x num_fusion_tokens x embed_dim] + # [batch_size * num_fusion_tokens, embed_dim] fusion_embeds = self.fusion_embedding(fusion_tokens) # [batch_size x seq_length x embed_dim] @@ -284,197 +281,3 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: out = out.masked_scatter(mask, embeds) out = out.masked_scatter(~mask, fusion_embeds) return out - - -class DeepFusionModel(nn.Module): - """DeepFusion is a type of fused model architecture where a pretrained encoder is combined - with a pretrained decoder (LLM). This is a popular architecture for multimodal models, with - a full overview available in `The Evolution of Multimodal Model Architectures `_. - - This module has the same methods and forward signature as :class:`~torchtune.modules.TransformerDecoder` and can be used - interchangeably where :class:`~torchtune.modules.TransformerDecoder` is. It combines the encoder with the decoder as a - single module for checkpointing and finetuning. It is expected that the encoder and decoder - are already defined with any extra learnable ``fusion_params``: learnable parameters to help - adapt the pre-trained encoder to the pre-trained decoder. - - Example: - >>> # decoder is a TransformerDecoder (e.g. llama3_8b) with fused cross attention layers - >>> embed = FusionEmbedding(...) - >>> layer = FusionLayer( - ... layer=TransformerSelfAttentionLayer(...), - ... fusion_layer=TransformerCrossAttentionLayer(...), - ... ) - >>> decoder = TransformerDecoder(tok_embeddings=embed, layers=layer, num_layers=32, ...) - >>> - >>> # encoder is pre-trained encoder (e.g. clip_vit_224) with an added projection head - >>> projection_head = FeedForward(...) - >>> register_fusion_module(projection_head)) - >>> encoder = nn.Sequential(clip_vit_224(), projection_head) - >>> - >>> # DeepFusionModel combines the encoder and decoder - >>> model = DeepFusionModel(decoder, encoder) - >>> - >>> # Load full fused checkpoints (e.g. a Flamingo checkpoint) - >>> model.load_state_dict(...) - >>> - >>> # Or load pretrained individual models (fusion_params are not loaded) - >>> model.encoder.load_state_dict(..., strict=False) - >>> model.decoder.load_state_dict(..., strict=False) - >>> - >>> # Forward pass - >>> output = model(tokens, mask, encoder_input, encoder_mask, input_pos) - - Args: - decoder (TransformerDecoder): decoder module - encoder (nn.Module): encoder module - decoder_trainable (bool): whether to train or freeze the decoder. Default is False. - encoder_trainable (bool): whether to train or freeze the encoder. Default is False. - fusion_trainable (bool): whether to train the fusion parameters. Default is True. - - """ - - def __init__( - self, - decoder: TransformerDecoder, - encoder: nn.Module, - *, - decoder_trainable: bool = False, - encoder_trainable: bool = False, - fusion_trainable: bool = True, - ): - super().__init__() - self.decoder = decoder - self.encoder = encoder - - trainable_params = set() - if encoder_trainable: - trainable_params |= { - f"encoder.{n}" for n, p in self.encoder.named_parameters() - } - if decoder_trainable: - trainable_params |= { - f"decoder.{n}" for n, p in self.decoder.named_parameters() - } - if fusion_trainable: - trainable_params |= set(get_fusion_params(self)) - else: - trainable_params -= set(get_fusion_params(self)) - set_trainable_params(self, trainable_params) - - def set_num_output_chunks(self, num_output_chunks: int) -> None: - """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. - This should be called before the first forward pass, in the recipe.""" - self.decoder.set_num_output_chunks(num_output_chunks) - - def setup_caches( - self, - batch_size: int, - dtype: torch.dtype, - *, - encoder_max_seq_len: int = None, - decoder_max_seq_len: int = None, - ): - """ - Sets up key-value attention caches for inference for ``self.decoder``. - For each layer in ``self.decoder.layers``: - - :class:`torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``. - - :class:`torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``. - - :class:`torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``. - - Args: - batch_size (int): batch size for the caches. - dtype (torch.dtype): dtype for the caches. - encoder_max_seq_len (int): maximum encoder cache sequence length. - decoder_max_seq_len (int): maximum decoder cache sequence length. - """ - self.decoder.setup_caches( - batch_size, - dtype, - encoder_max_seq_len=encoder_max_seq_len, - decoder_max_seq_len=decoder_max_seq_len, - ) - - def caches_are_setup(self) -> bool: - """ - Check if the key value caches are setup. This means ``setup_caches`` has been called, and - the relevant attention modules in the model have created their ``KVCache``. - """ - return self.decoder.caches_are_setup() - - def caches_are_enabled(self) -> bool: - """ - Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant - attention modules will be "enabled" and all forward passes will update the caches. This behaviour - can be disabled without altering the state of the KV-caches by "disabling" the KV-caches - using :func:`~torchtune.modules.common_utils.disable_kv_cache`, upon which ``caches_are_enabled`` would return False. - """ - return self.decoder.caches_are_enabled() - - def reset_caches(self): - """ - Resets KV-cache buffers on relevant attention modules to zero, and reset cache positions to zero, - without deleting or reallocating cache tensors. - """ - self.decoder.reset_caches() - - def forward( - self, - tokens: torch.Tensor, - *, - mask: Optional[torch.Tensor] = None, - encoder_input: Optional[Dict] = None, - encoder_mask: Optional[torch.Tensor] = None, - input_pos: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """ - Args: - tokens (torch.Tensor): input tensor with shape ``[b x s]`` - mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask - with shape ``[b x s x s]``. This is applied after the query-key multiplication and - before the softmax. A value of True in row i and column j means token i attends - to token j. A value of False means token i does not attend to token j. If no - mask is specified, a causal mask is used by default. Default is None. - encoder_input (Optional[Dict]): Optional input for the encoder. - encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between - tokens and encoder embeddings. A True value at position i,j means token i can attend - to embedding j in the decoder. Mask has shape ``[b x s x s_e]``. Default is None. - input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids - of each token. During training, this is used to indicate the positions - of each token relative to its sample when packed, shape ``[b x s]``. - During inference, this indicates the position of the current token. - If none, assume the index of the token is its position id. Default is None. - - Note: At the very first step of inference, when the model is provided with a prompt, - ``input_pos`` would contain the positions of all of the tokens in the prompt - (eg: ``torch.arange(prompt_length)``). This is because we will need to compute the - KV values for each position. - - Returns: - Tensor: output tensor with shape ``[b x s x v]`` or a list of layer \ - output tensors defined by ``output_hidden_states`` with the \ - final output tensor appended to the list. - - Notation used for tensor shapes: - - b: batch size - - s: token sequence length - - s_e: encoder sequence length - - v: vocab size - - d: token embed dim - - d_e: encoder embed dim - - m_s: max seq len - """ - # During decoding, encoder_input will only be provided - # for new inputs. Previous encoder outputs are cached - # in the decoder cache. - encoder_embed = None - if encoder_input is not None: - encoder_embed = self.encoder(**encoder_input) - - output = self.decoder( - tokens=tokens, - mask=mask, - encoder_input=encoder_embed, - encoder_mask=encoder_mask, - input_pos=input_pos, - ) - return output diff --git a/torchtune/modules/model_fusion/_fusion_utils.py b/torchtune/modules/model_fusion/_fusion_utils.py index c22cc03549..e10bfcb3e5 100644 --- a/torchtune/modules/model_fusion/_fusion_utils.py +++ b/torchtune/modules/model_fusion/_fusion_utils.py @@ -65,5 +65,5 @@ def get_fusion_params(model: nn.Module) -> Dict[str, nn.Parameter]: current_fusion_params.remove(n) assert ( current_fusion_params == [] - ), f"Fusion params {current_adapter_params} not converted" + ), f"Fusion params {current_fusion_params} not converted" return fusion_params diff --git a/torchtune/modules/peft/_utils.py b/torchtune/modules/peft/_utils.py index 4768d77619..318ab4136a 100644 --- a/torchtune/modules/peft/_utils.py +++ b/torchtune/modules/peft/_utils.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import contextlib -from typing import Any, Dict, Generator, List, Literal, Optional, Protocol, Set +from typing import Any, Dict, Generator, List, Literal, Optional, Protocol, Set, Union import torch from torch import nn @@ -62,13 +62,15 @@ def get_adapter_params(model: nn.Module) -> Dict[str, nn.Parameter]: return adapter_params -def set_trainable_params(model: nn.Module, adapter_params: Dict[str, Any]) -> None: +def set_trainable_params( + model: nn.Module, adapter_params: Union[Dict[str, Any], Set] +) -> None: """ Set trainable parameters for an nn.Module based on a state dict of adapter parameters. Args: model (nn.Module): Instance of model class containing some adapter params. - adapter_params (Dict[str, Any]): State dict mapping adapter key names to their + adapter_params (Union[Dict[str, Any], Set]): State dict mapping adapter key names to their respective nn.Parameters (i.e. outputs of :func:`~torchtune.modules.peft.get_adapter_params`.) Returns: