From e37a3e1b256276f0af9cd484dd777b55d2761912 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Thu, 24 Oct 2024 18:14:28 -0700 Subject: [PATCH 01/11] first --- .../model_fusion/test_fusion_models.py | 130 +++++++++++++++++- torchtune/modules/model_fusion/__init__.py | 3 +- torchtune/modules/model_fusion/_fusion.py | 116 +++++++++++++++- torchtune/modules/peft/_utils.py | 4 +- 4 files changed, 240 insertions(+), 13 deletions(-) diff --git a/tests/torchtune/modules/model_fusion/test_fusion_models.py b/tests/torchtune/modules/model_fusion/test_fusion_models.py index 322616276e..1ee9c0b6c7 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_models.py +++ b/tests/torchtune/modules/model_fusion/test_fusion_models.py @@ -9,7 +9,7 @@ import torch from tests.test_utils import assert_expected, fixed_init_model from torch import nn -from torchtune.modules.model_fusion import DeepFusionModel, register_fusion_module +from torchtune.modules.model_fusion import DeepFusionModel, register_fusion_module, EarlyFusionModel from torchtune.training.seed import set_seed @@ -22,7 +22,7 @@ class DummyModel(nn.Module): def __init__(self, dim, vocab_size): super().__init__() self.cache_enabled = False - self.embed = nn.Embedding(vocab_size, dim) + self.tok_embeddings = nn.Embedding(vocab_size, dim) self.q = nn.Linear(dim, dim) self.k = nn.Linear(dim, dim) self.v = nn.Linear(dim, dim) @@ -39,7 +39,7 @@ def reset_caches(self): self.cache_enabled = False def forward(self, tokens, mask, encoder_input, encoder_mask, input_pos): - x = self.embed(tokens) + x = self.tok_embeddings(tokens) if encoder_input is not None: q = self.q(x) k = self.k(encoder_input) @@ -85,7 +85,7 @@ def fused_model(self, encoder, decoder) -> DeepFusionModel: return model @pytest.fixture - def inputs(self, dim, vocab_size): + def inputs(self, vocab_size): batch_size = 2 seq_len = 10 tokens = torch.randint(0, vocab_size, (batch_size, seq_len)) @@ -185,3 +185,125 @@ def test_set_trainable_params(self, fused_model, encoder, decoder): "decoder.v.bias", "decoder.embed.weight", } + +class TestEarlyFusionModel: + @pytest.fixture + def vocab_size(self) -> int: + return 100 + + @pytest.fixture + def dim(self) -> int: + return 64 + + @pytest.fixture + def encoder(self, dim, vocab_size) -> nn.Module: + encoder = nn.Embedding(vocab_size, dim) + fixed_init_model(encoder) + return encoder + + @pytest.fixture + def decoder(self, dim, vocab_size) -> nn.Module: + decoder = DummyModel(dim, vocab_size) + fixed_init_model(decoder, max_val=0.1) + return decoder + + @pytest.fixture + def fused_model(self, encoder, decoder) -> EarlyFusionModel: + model = EarlyFusionModel( + encoders={"red": encoder, "green": encoder, "blue": encoder}, + decoder=decoder, + encoder_tokens={"red": 0, "green": 1, "blue": 2}, + decoder_trainable=True, + encoders_trainable={"red": False, "green": True, "blue": False}, + ) + return model + + @pytest.fixture + def inputs(self, vocab_size): + batch_size = 2 + seq_len = 10 + tokens = torch.randint(0, vocab_size, (batch_size, seq_len)) + red_seq_len, green_seq_len, blue_seq_len = 1, 2, 3 + tokens[:, 0] = 0 + tokens[:, 3:5] = 1 + tokens[:, 8:] = 2 + encoder_input = { + "red": {"input": torch.randint(0, vocab_size, (batch_size, red_seq_len))}, + "green": {"input": torch.randint(0, vocab_size, (batch_size, green_seq_len))}, + "blue": {"input": torch.randint(0, vocab_size, (batch_size, blue_seq_len))}, + } + encoder_mask = torch.randint(0, 2, (batch_size, seq_len, seq_len)).bool() + input_pos = torch.Tensor([1]).int() + return tokens, encoder_input, encoder_mask, input_pos + + @torch.no_grad() + def test_forward(self, fused_model, inputs, vocab_size): + """ + Test that the forward pass of the EarlyFusionModel works as expected. + """ + tokens, encoder_input, encoder_mask, _ = inputs + batch_size, seq_len = tokens.shape + out = fused_model( + tokens, encoder_input=encoder_input, encoder_mask=encoder_mask + ) + + assert out.shape == (batch_size, seq_len, vocab_size) + assert_expected(out.mean(), torch.tensor(8.5584), atol=1e-3, rtol=1e-3) + + @torch.no_grad() + def test_forward_no_encoding(self, fused_model, inputs, vocab_size): + """ + Test that the forward pass of the EarlyFusionModel with no encoder input. + """ + tokens, *_ = inputs + batch_size, seq_len = tokens.shape + out = fused_model(tokens) + + assert out.shape == (batch_size, seq_len, vocab_size) + assert_expected(out.mean(), torch.tensor(0.2271), atol=1e-3, rtol=1e-3) + + @torch.no_grad() + def test_decoding_forward(self, fused_model, inputs, vocab_size): + """ + Test that the forward pass of the EarlyFusionModel works during decoding. + """ + tokens, encoder_input, encoder_mask, input_pos = inputs + tokens = tokens[:, input_pos] + encoder_mask = encoder_mask[:, input_pos] + batch_size, seq_len = tokens.shape + out = fused_model( + tokens, + encoder_input=encoder_input, + encoder_mask=encoder_mask, + input_pos=input_pos, + ) + + assert out.shape == (batch_size, seq_len, vocab_size) + assert_expected(out.mean(), torch.tensor(9.0072), atol=1e-3, rtol=1e-3) + + def test_setup_cache(self, fused_model): + """ + Test that the cache methods works as expected. + """ + fused_model.setup_caches(2, torch.float32) + assert fused_model.caches_are_setup() + fused_model.reset_caches() + assert not fused_model.caches_are_setup() + + def test_set_trainable_params(self, fused_model): + """ + Test that the trainable parameters are set correctly. + """ + trainable_params = { + n for n, p in fused_model.named_parameters() if p.requires_grad + } + assert trainable_params == { + "decoder.q.weight", + "decoder.q.bias", + "decoder.k.weight", + "decoder.k.bias", + "decoder.v.weight", + "decoder.v.bias", + "decoder.embed.weight", + "encoders.green.weight", + } diff --git a/torchtune/modules/model_fusion/__init__.py b/torchtune/modules/model_fusion/__init__.py index 7ad788bd57..6c4452e3f0 100644 --- a/torchtune/modules/model_fusion/__init__.py +++ b/torchtune/modules/model_fusion/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from ._fusion import DeepFusionModel, FusionEmbedding, FusionLayer +from ._fusion import DeepFusionModel, FusionEmbedding, FusionLayer, EarlyFusionModel from ._fusion_utils import get_fusion_params, register_fusion_module __all__ = [ @@ -13,4 +13,5 @@ "FusionEmbedding", "register_fusion_module", "get_fusion_params", + "EarlyFusionModel", ] diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index 1a5452daae..0ca8c3534a 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -4,10 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, Any import torch -from torch import nn +from torch import nn, Tensor from torchtune.modules import TransformerDecoder from torchtune.modules.model_fusion._fusion_utils import get_fusion_params from torchtune.modules.peft._utils import set_trainable_params @@ -370,8 +370,8 @@ def setup_caches( batch_size: int, dtype: torch.dtype, *, - encoder_max_seq_len: int = None, - decoder_max_seq_len: int = None, + encoder_max_seq_len: Optional[int] = None, + decoder_max_seq_len: Optional[int] = None, ): """ Sets up key-value attention caches for inference for ``self.decoder``. @@ -421,7 +421,7 @@ def forward( tokens: torch.Tensor, *, mask: Optional[torch.Tensor] = None, - encoder_input: Optional[Dict] = None, + encoder_input: Optional[Dict[str, Dict[str, Any]]] = None, encoder_mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: @@ -433,7 +433,7 @@ def forward( before the softmax. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no mask is specified, a causal mask is used by default. Default is None. - encoder_input (Optional[Dict]): Optional input for the encoder. + encoder_input (Optional[Dict[str, Any]]): Optional input for the encoder. encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between tokens and encoder embeddings. A True value at position i,j means token i can attend to embedding j in the decoder. Mask has shape ``[b x s x s_e]``. Default is None. @@ -477,3 +477,107 @@ def forward( input_pos=input_pos, ) return output + +class EarlyFusionModel(nn.Module): + def __init__( + self, + decoder: TransformerDecoder, + encoders: nn.ModuleDict, + encoder_tokens: Dict[str, int], + decoder_trainable: bool, + encoders_trainable: Dict[str, bool], + ): + super().__init__() + self.decoder = decoder + self.encoders = encoders + self.encoder_tokens = encoder_tokens + + # A little surgery in the decoder to give the + # fusion module access to control the embeddings + # The alternative is to pass a special tok_embeddings + # module into TransformerDecoder builder that does the + # merging there + self.tok_embeddings = decoder.tok_embeddings + decoder.tok_embeddings = nn.Identity() + + self.register_state_dict_post_hook(self._state_dict_hook) + self.register_load_state_dict_pre_hook( + self._load_state_dict_hook, + with_module=True + ) + + trainable_params = set() + for encoder, trainable in encoders_trainable.items(): + if trainable: + trainable_params |= { + f"encoders.{encoder}.{n}" for n, p in self.encoders[encoder].named_parameters() + } + if decoder_trainable: + trainable_params |= { + f"decoder.{n}" for n, p in self.decoder.named_parameters() + } + set_trainable_params(self, trainable_params) + + def _state_dict_hook(self, destination, prefix, keep_vars): + """ + Keep tok_embeddings inside of decoder state_dict + + [!Note] This update changes the order of the OrderedDict + """ + key = "tok_embeddings" + decoder_key = "decoder.tok_embeddings" + destination[decoder_key] = destination[key] + del destination[key] + + def _load_state_dict_hook(self, state_dict, *args, **kwargs): + """ Undo the change from _state_dict_hook """ + key = "tok_embeddings" + decoder_key = "decoder.tok_embeddings" + state_dict[key] = state_dict[decoder_key] + del state_dict[decoder_key] + + def set_num_output_chunks(self, num_output_chunks: int) -> None: + """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. + This should be called before the first forward pass, in the recipe.""" + self.decoder.set_num_output_chunks(num_output_chunks) + + def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: + """Setup key value caches for attention calculation. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + """ + self.decoder.setup_caches(batch_size, dtype) + + def caches_are_enabled(self) -> bool: + """Check if the key value caches are setup.""" + return self.decoder.caches_are_enabled() + + def reset_caches(self): + """Reset the key value caches.""" + self.decoder.reset_caches() + + def forward( + self, + tokens: Tensor, + *, + mask: Optional[Tensor] = None, + encoder_input: Optional[Dict[str, Dict[str, Any]]] = None, + input_pos: Optional[Tensor] = None, + **kwargs: Dict[str, Any], # no need for encoder_mask + ) -> Tensor: + """ + For the token IDs associated with each encoder, we are assuming that the number of tokens have already + been expanded to the number of tokens encoded for the given media. For example, if an image is tiled/patched + and tokenized to 100 tokens, we assume the text sequence already has 100 "image" tokens as placeholders. + """ + embeds = self.tok_embeddings(tokens) + bsz, seq_len, embed_dim = embeds.shape + for encoder, inp in (encoder_input or {}).items(): + encoder_embeds = self.encoders[encoder](**inp) + encoder_mask = (tokens == self.encoder_tokens[encoder]).expand(bsz, seq_len, embed_dim) + embeds[encoder_mask] = encoder_embeds + + output = self.decoder(embeds, mask, input_pos) + return output diff --git a/torchtune/modules/peft/_utils.py b/torchtune/modules/peft/_utils.py index 4768d77619..8f338750e8 100644 --- a/torchtune/modules/peft/_utils.py +++ b/torchtune/modules/peft/_utils.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import contextlib -from typing import Any, Dict, Generator, List, Literal, Optional, Protocol, Set +from typing import Any, Dict, Generator, List, Literal, Optional, Protocol, Set, Union import torch from torch import nn @@ -62,7 +62,7 @@ def get_adapter_params(model: nn.Module) -> Dict[str, nn.Parameter]: return adapter_params -def set_trainable_params(model: nn.Module, adapter_params: Dict[str, Any]) -> None: +def set_trainable_params(model: nn.Module, adapter_params: Union[Dict[str, Any], Set]) -> None: """ Set trainable parameters for an nn.Module based on a state dict of adapter parameters. From 282adacf73c2e9d9faec49dc0d6f3fbdea49629b Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Mon, 28 Oct 2024 08:44:29 -0700 Subject: [PATCH 02/11] second --- .../modules/model_fusion/test_fusion_models.py | 11 +++++++++-- torchtune/modules/model_fusion/__init__.py | 2 +- torchtune/modules/model_fusion/_fusion.py | 15 +++++++++------ torchtune/modules/peft/_utils.py | 4 +++- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/tests/torchtune/modules/model_fusion/test_fusion_models.py b/tests/torchtune/modules/model_fusion/test_fusion_models.py index 1ee9c0b6c7..97d7bbca4b 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_models.py +++ b/tests/torchtune/modules/model_fusion/test_fusion_models.py @@ -9,7 +9,11 @@ import torch from tests.test_utils import assert_expected, fixed_init_model from torch import nn -from torchtune.modules.model_fusion import DeepFusionModel, register_fusion_module, EarlyFusionModel +from torchtune.modules.model_fusion import ( + DeepFusionModel, + EarlyFusionModel, + register_fusion_module, +) from torchtune.training.seed import set_seed @@ -186,6 +190,7 @@ def test_set_trainable_params(self, fused_model, encoder, decoder): "decoder.embed.weight", } + class TestEarlyFusionModel: @pytest.fixture def vocab_size(self) -> int: @@ -229,7 +234,9 @@ def inputs(self, vocab_size): tokens[:, 8:] = 2 encoder_input = { "red": {"input": torch.randint(0, vocab_size, (batch_size, red_seq_len))}, - "green": {"input": torch.randint(0, vocab_size, (batch_size, green_seq_len))}, + "green": { + "input": torch.randint(0, vocab_size, (batch_size, green_seq_len)) + }, "blue": {"input": torch.randint(0, vocab_size, (batch_size, blue_seq_len))}, } encoder_mask = torch.randint(0, 2, (batch_size, seq_len, seq_len)).bool() diff --git a/torchtune/modules/model_fusion/__init__.py b/torchtune/modules/model_fusion/__init__.py index 6c4452e3f0..be12ac16c5 100644 --- a/torchtune/modules/model_fusion/__init__.py +++ b/torchtune/modules/model_fusion/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from ._fusion import DeepFusionModel, FusionEmbedding, FusionLayer, EarlyFusionModel +from ._fusion import DeepFusionModel, EarlyFusionModel, FusionEmbedding, FusionLayer from ._fusion_utils import get_fusion_params, register_fusion_module __all__ = [ diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index 0ca8c3534a..b73b87448a 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, List, Optional, Union, Any +from typing import Any, Dict, List, Optional, Union import torch from torch import nn, Tensor @@ -478,6 +478,7 @@ def forward( ) return output + class EarlyFusionModel(nn.Module): def __init__( self, @@ -502,15 +503,15 @@ def __init__( self.register_state_dict_post_hook(self._state_dict_hook) self.register_load_state_dict_pre_hook( - self._load_state_dict_hook, - with_module=True + self._load_state_dict_hook, with_module=True ) trainable_params = set() for encoder, trainable in encoders_trainable.items(): if trainable: trainable_params |= { - f"encoders.{encoder}.{n}" for n, p in self.encoders[encoder].named_parameters() + f"encoders.{encoder}.{n}" + for n, p in self.encoders[encoder].named_parameters() } if decoder_trainable: trainable_params |= { @@ -530,7 +531,7 @@ def _state_dict_hook(self, destination, prefix, keep_vars): del destination[key] def _load_state_dict_hook(self, state_dict, *args, **kwargs): - """ Undo the change from _state_dict_hook """ + """Undo the change from _state_dict_hook""" key = "tok_embeddings" decoder_key = "decoder.tok_embeddings" state_dict[key] = state_dict[decoder_key] @@ -576,7 +577,9 @@ def forward( bsz, seq_len, embed_dim = embeds.shape for encoder, inp in (encoder_input or {}).items(): encoder_embeds = self.encoders[encoder](**inp) - encoder_mask = (tokens == self.encoder_tokens[encoder]).expand(bsz, seq_len, embed_dim) + encoder_mask = (tokens == self.encoder_tokens[encoder]).expand( + bsz, seq_len, embed_dim + ) embeds[encoder_mask] = encoder_embeds output = self.decoder(embeds, mask, input_pos) diff --git a/torchtune/modules/peft/_utils.py b/torchtune/modules/peft/_utils.py index 8f338750e8..8e360eaf5f 100644 --- a/torchtune/modules/peft/_utils.py +++ b/torchtune/modules/peft/_utils.py @@ -62,7 +62,9 @@ def get_adapter_params(model: nn.Module) -> Dict[str, nn.Parameter]: return adapter_params -def set_trainable_params(model: nn.Module, adapter_params: Union[Dict[str, Any], Set]) -> None: +def set_trainable_params( + model: nn.Module, adapter_params: Union[Dict[str, Any], Set] +) -> None: """ Set trainable parameters for an nn.Module based on a state dict of adapter parameters. From a8666feb2dce08cf3fe15d5218fe7f1817c0783f Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Mon, 28 Oct 2024 16:40:30 -0700 Subject: [PATCH 03/11] support multiple encoders, update DeepFusion, docstrings --- torchtune/modules/model_fusion/_fusion.py | 222 +++++++++++++++++++--- 1 file changed, 193 insertions(+), 29 deletions(-) diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index 7402ab69f5..38ec8d73aa 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Union import torch -from torch import nn, Tensor +from torch import nn from torchtune.modules import TransformerDecoder from torchtune.modules.model_fusion._fusion_utils import get_fusion_params from torchtune.modules.peft._utils import set_trainable_params @@ -288,8 +288,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class DeepFusionModel(nn.Module): """DeepFusion is a type of fused model architecture where a pretrained encoder is combined - with a pretrained decoder (LLM). This is a popular architecture for multimodal models, with + with a pretrained decoder (LLM) in the internal decoder layers. This is a popular architecture for multimodal models, with a full overview available in `The Evolution of Multimodal Model Architectures `_. + A common deep fusion architecture is to fuse the encoder input into the decoder with interspersed cross-attention + layers. This module makes no assumptions on how the encoder and decoder are fused; it simply + passes in the encoder embeddings to the decoder and lets the decoder handle any fusion. This module has the same methods and forward signature as :class:`~torchtune.modules.TransformerDecoder` and can be used interchangeably where :class:`~torchtune.modules.TransformerDecoder` is. It combines the encoder with the decoder as a @@ -297,6 +300,8 @@ class DeepFusionModel(nn.Module): are already defined with any extra learnable ``fusion_params``: learnable parameters to help adapt the pre-trained encoder to the pre-trained decoder. + DeepFusionModel currently only supports a single encoder. + Example: >>> # decoder is a TransformerDecoder (e.g. llama3_8b) with fused cross attention layers >>> embed = FusionEmbedding(...) @@ -309,10 +314,10 @@ class DeepFusionModel(nn.Module): >>> # encoder is pre-trained encoder (e.g. clip_vit_224) with an added projection head >>> projection_head = FeedForward(...) >>> register_fusion_module(projection_head)) - >>> encoder = nn.Sequential(clip_vit_224(), projection_head) + >>> encoders = {"image": nn.Sequential(clip_vit_224(), projection_head)} >>> >>> # DeepFusionModel combines the encoder and decoder - >>> model = DeepFusionModel(decoder, encoder) + >>> model = DeepFusionModel(decoder, encoders) >>> >>> # Load full fused checkpoints (e.g. a Flamingo checkpoint) >>> model.load_state_dict(...) @@ -322,35 +327,63 @@ class DeepFusionModel(nn.Module): >>> model.decoder.load_state_dict(..., strict=False) >>> >>> # Forward pass + >>> encoder_input = {"image": {...}} >>> output = model(tokens, mask, encoder_input, encoder_mask, input_pos) Args: decoder (TransformerDecoder): decoder module - encoder (nn.Module): encoder module + encoders (Dict[str, nn.Module]): dictionary mapping encoder name as a string to the encoder module. decoder_trainable (bool): whether to train or freeze the decoder. Default is False. - encoder_trainable (bool): whether to train or freeze the encoder. Default is False. + encoders_trainable (Union[bool, Dict[str, bool]]): whether to train or freeze the encoder. Use a single + boolean to set trainable for all encoders or a dictionary keyed by encoder names to specify trainable + for each encoder individually. Encoder names should match with ``encoders``. Default is False. fusion_trainable (bool): whether to train the fusion parameters. Default is True. + Raises: + ValueError: if ``encoders`` and ``encoders_trainable`` keys do not match + ValueError: if ``len(encoders) != 1`` """ def __init__( self, decoder: TransformerDecoder, - encoder: nn.Module, + encoders: Dict[str, nn.Module], *, decoder_trainable: bool = False, - encoder_trainable: bool = False, + encoders_trainable: Union[bool, Dict[str, bool]] = False, fusion_trainable: bool = True, ): super().__init__() + if ( + not isinstance(encoders_trainable, bool) + and encoders.keys() != encoders_trainable.keys() + ): + raise ValueError( + f"Found mismatched keys in encoders and encoders_trainable. Got {encoders.keys()} and {encoders_trainable.keys()}." + ) + # Currently, only a single encoder is supported, so user can only + # pass in a single key. When multiple encoders are + # supported, this can be removed. + if len(encoders.keys()) != 1: + raise ValueError( + f"DeepFusionModel only supports a single encoder. Got {len(encoders.keys())} encoders." + ) + self.decoder = decoder - self.encoder = encoder + self.encoders = nn.ModuleDict(encoders) + self.encoders_trainable = ( + {k: encoders_trainable for k in self.encoders.keys()} + if isinstance(encoders_trainable, bool) + else encoders_trainable + ) trainable_params = set() - if encoder_trainable: - trainable_params |= { - f"encoder.{n}" for n, p in self.encoder.named_parameters() - } + for encoder, trainable in self.encoders_trainable.items(): + if trainable: + trainable_params |= { + f"encoders.{encoder}.{n}" + for n, p in self.encoders[encoder].named_parameters() + } if decoder_trainable: trainable_params |= { f"decoder.{n}" for n, p in self.decoder.named_parameters() @@ -384,8 +417,8 @@ def setup_caches( Args: batch_size (int): batch size for the caches. dtype (torch.dtype): dtype for the caches. - encoder_max_seq_len (int): maximum encoder cache sequence length. - decoder_max_seq_len (int): maximum decoder cache sequence length. + encoder_max_seq_len (Optional[int]): maximum encoder cache sequence length. + decoder_max_seq_len (Optional[int]): maximum decoder cache sequence length. """ self.decoder.setup_caches( batch_size, @@ -434,7 +467,8 @@ def forward( before the softmax. A value of True in row i and column j means token i attends to token j. A value of False means token i does not attend to token j. If no mask is specified, a causal mask is used by default. Default is None. - encoder_input (Optional[Dict[str, Any]]): Optional input for the encoder. + encoder_input (Optional[Dict[str, Dict[str, Any]]]): Optional input kwargs for the encoders. Must be + keyed by encoder name and match the keys of ``encoders`` encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between tokens and encoder embeddings. A True value at position i,j means token i can attend to embedding j in the decoder. Mask has shape ``[b x s x s_e]``. Default is None. @@ -454,6 +488,9 @@ def forward( output tensors defined by ``output_hidden_states`` with the \ final output tensor appended to the list. + Raises: + ValueError: if ``encoder_input`` keys do not match ``encoders`` keys + Notation used for tensor shapes: - b: batch size - s: token sequence length @@ -463,17 +500,32 @@ def forward( - d_e: encoder embed dim - m_s: max seq len """ + if encoder_input is not None and encoder_input.keys() != self.encoders.keys(): + raise ValueError( + f"Found mismatched keys in encoder_input and instantiated encoders. " + f"Got {encoder_input.keys()}, expected {self.encoders.keys()}." + ) # During decoding, encoder_input will only be provided # for new inputs. Previous encoder outputs are cached # in the decoder cache. encoder_embed = None if encoder_input is not None: - encoder_embed = self.encoder(**encoder_input) + encoder_embed = { + encoder: self.encoders[encoder](**encoder_input[encoder]) + for encoder in encoder_input + } + + # Currently, only a single encoder is supported, so we need + # to get the encoder key manually. When multiple encoders are + # supported, this can be removed. + decoder_encoder_input = ( + list(encoder_embed.values())[0] if encoder_embed is not None else None + ) output = self.decoder( tokens=tokens, mask=mask, - encoder_input=encoder_embed, + encoder_input=decoder_encoder_input, encoder_mask=encoder_mask, input_pos=input_pos, ) @@ -481,18 +533,84 @@ def forward( class EarlyFusionModel(nn.Module): + """EarlyFusion is a type of fused model architecture where pretrained encoder(s) are combined + with a pretrained decoder (LLM) at the model input and not in internal layers. This is a popular architecture + for multimodal models, with a full overview available in `The Evolution of Multimodal Model Architectures + `_. This module assumes the decoder is trained to recognize tokens specific + to each encoder, which are then replaced in the sequence with the encoder embedding outputs. + + This module has the same methods and forward signature as :class:`~torchtune.modules.TransformerDecoder` and can be used + interchangeably where :class:`~torchtune.modules.TransformerDecoder` is. It combines the encoders with the decoder as a + single module for checkpointing and finetuning. It is expected that the encoders and decoder + are already defined with any extra learnable ``fusion_params``: learnable parameters to help + adapt the pre-trained encoders to the pre-trained decoder. + + You can pass in multiple encoders as a dictionary into ``encoders``. + + Example: + >>> # decoder is a text-only TransformerDecoder (e.g. llama3_8b) with no modifications + >>> decoder = llama3_8b() + >>> + >>> # encoder is pre-trained encoder (e.g. clip_vit_224) with an added projection head + >>> projection_head = FeedForward(...) + >>> register_fusion_module(projection_head)) + >>> encoders = {"image": nn.Sequential(clip_vit_224(), projection_head)} + >>> + >>> # EarlyFusionModel combines the encoder and decoder + >>> model = DeepFusionModel(decoder, encoders) + >>> + >>> # Load full fused checkpoints + >>> model.load_state_dict(...) + >>> + >>> # Or load pretrained individual models (fusion_params are not loaded) + >>> model.encoder.load_state_dict(..., strict=False) + >>> model.decoder.load_state_dict(..., strict=False) + >>> + >>> # Forward pass + >>> encoder_input = {"image": {...}} + >>> output = model(tokens, mask, encoder_input, encoder_mask, input_pos) + + Args: + decoder (TransformerDecoder): decoder module + encoders (Dict[str, nn.Module]): dictionary mapping encoder name as a string to the encoder module. + encoder_tokens (Dict[str, int]): dictionary mapping encoder name to special token ID indicating where + in the text sequence the encoder embedding outputs should be injected. + decoder_trainable (bool): whether to train or freeze the decoder. Default is False. + encoders_trainable (Union[bool, Dict[str, bool]]): whether to train or freeze the encoder. Use a single + boolean to set trainable for all encoders or a dictionary keyed by encoder names to specify trainable + for each encoder individually. Encoder names should match with ``encoders``. Default is False. + fusion_trainable (bool): whether to train the fusion parameters. Default is True. + + Raises: + ValueError: if ``encoders`` and ``encoders_trainable`` keys do not match + """ + def __init__( self, decoder: TransformerDecoder, - encoders: nn.ModuleDict, + encoders: Dict[str, nn.Module], encoder_tokens: Dict[str, int], decoder_trainable: bool, - encoders_trainable: Dict[str, bool], + encoders_trainable: Union[bool, Dict[str, bool]] = False, + fusion_trainable: bool = True, ): super().__init__() + if encoders.keys() != encoder_tokens.keys() or ( + not isinstance(encoders_trainable, bool) + and encoders.keys() != encoders_trainable.keys() + ): + raise ValueError( + f"Found mismatched keys in encoders and encoders_trainable. Got {encoders.keys()} and {encoders_trainable.keys()}." + ) + self.decoder = decoder - self.encoders = encoders + self.encoders = nn.ModuleDict(encoders) self.encoder_tokens = encoder_tokens + self.encoders_trainable = ( + {k: encoders_trainable for k in self.encoders.keys()} + if isinstance(encoders_trainable, bool) + else encoders_trainable + ) # A little surgery in the decoder to give the # fusion module access to control the embeddings @@ -508,7 +626,7 @@ def __init__( ) trainable_params = set() - for encoder, trainable in encoders_trainable.items(): + for encoder, trainable in self.encoders_trainable.items(): if trainable: trainable_params |= { f"encoders.{encoder}.{n}" @@ -518,6 +636,11 @@ def __init__( trainable_params |= { f"decoder.{n}" for n, p in self.decoder.named_parameters() } + if fusion_trainable: + trainable_params |= set(get_fusion_params(self)) + else: + trainable_params -= set(get_fusion_params(self)) + set_trainable_params(self, trainable_params) def _state_dict_hook(self, destination, prefix, keep_vars): @@ -562,18 +685,59 @@ def reset_caches(self): def forward( self, - tokens: Tensor, + tokens: torch.Tensor, *, - mask: Optional[Tensor] = None, + mask: Optional[torch.Tensor] = None, encoder_input: Optional[Dict[str, Dict[str, Any]]] = None, - input_pos: Optional[Tensor] = None, + input_pos: Optional[torch.Tensor] = None, **kwargs: Dict[str, Any], # no need for encoder_mask - ) -> Tensor: + ) -> torch.Tensor: """ - For the token IDs associated with each encoder, we are assuming that the number of tokens have already - been expanded to the number of tokens encoded for the given media. For example, if an image is tiled/patched - and tokenized to 100 tokens, we assume the text sequence already has 100 "image" tokens as placeholders. + Args: + tokens (torch.Tensor): input tensor with shape ``[b x s]`` + mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask + with shape ``[b x s x s]``. This is applied after the query-key multiplication and + before the softmax. A value of True in row i and column j means token i attends + to token j. A value of False means token i does not attend to token j. If no + mask is specified, a causal mask is used by default. Default is None. + encoder_input (Optional[Dict[str, Dict[str, Any]]]): Optional input kwargs for the encoders. Must be + keyed by encoder name and match the keys of ``encoders`` + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape ``[b x s]``. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + **kwargs (Dict[str, Any]): additional keyword arguments. This is solely used to match the + :class:`~torchtune.modules.TransformerDecoder` forward and does not have any effect. + + Note: At the very first step of inference, when the model is provided with a prompt, + ``input_pos`` would contain the positions of all of the tokens in the prompt + (eg: ``torch.arange(prompt_length)``). This is because we will need to compute the + KV values for each position. + + Returns: + torch.Tensor: output tensor with shape ``[b x s x v]`` or a list of layer \ + output tensors defined by ``output_hidden_states`` with the \ + final output tensor appended to the list. + + Raises: + ValueError: if ``encoder_input`` keys do not match ``encoders`` keys + + Notation used for tensor shapes: + - b: batch size + - s: token sequence length + - s_e: encoder sequence length + - v: vocab size + - d: token embed dim + - d_e: encoder embed dim + - m_s: max seq len """ + if encoder_input is not None and encoder_input.keys() != self.encoders.keys(): + raise ValueError( + f"Found mismatched keys in encoder_input and instantiated encoders. " + f"Got {encoder_input.keys()}, expected {self.encoders.keys()}." + ) + embeds = self.tok_embeddings(tokens) bsz, seq_len, embed_dim = embeds.shape for encoder, inp in (encoder_input or {}).items(): From 024bfc7fe1911dd3de74e02e82fa79c5550d14cd Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Mon, 28 Oct 2024 17:20:39 -0700 Subject: [PATCH 04/11] update tests part 1 --- .../model_fusion/test_fusion_models.py | 61 +++++++++++++++++-- torchtune/modules/model_fusion/_fusion.py | 2 +- 2 files changed, 57 insertions(+), 6 deletions(-) diff --git a/tests/torchtune/modules/model_fusion/test_fusion_models.py b/tests/torchtune/modules/model_fusion/test_fusion_models.py index 97d7bbca4b..278fdd8ca7 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_models.py +++ b/tests/torchtune/modules/model_fusion/test_fusion_models.py @@ -83,7 +83,7 @@ def decoder(self, dim, vocab_size) -> nn.Module: @pytest.fixture def fused_model(self, encoder, decoder) -> DeepFusionModel: model = DeepFusionModel( - encoder=encoder, + encoders={"encoder": encoder}, decoder=decoder, ) return model @@ -93,7 +93,9 @@ def inputs(self, vocab_size): batch_size = 2 seq_len = 10 tokens = torch.randint(0, vocab_size, (batch_size, seq_len)) - encoder_input = {"input": torch.randint(0, vocab_size, (batch_size, seq_len))} + encoder_input = { + "encoder": {"input": torch.randint(0, vocab_size, (batch_size, seq_len))} + } encoder_mask = torch.randint(0, 2, (batch_size, seq_len, seq_len)).bool() input_pos = torch.Tensor([1]).int() return tokens, encoder_input, encoder_mask, input_pos @@ -164,9 +166,9 @@ def test_set_trainable_params(self, fused_model, encoder, decoder): # Test encoder only model = DeepFusionModel( - encoder=encoder, + encoders={"encoder": encoder}, decoder=decoder, - encoder_trainable=True, + encoders_trainable=True, fusion_trainable=False, ) trainable_params = {n for n, p in model.named_parameters() if p.requires_grad} @@ -174,7 +176,7 @@ def test_set_trainable_params(self, fused_model, encoder, decoder): # Test decoder only, and confirm fusion layers are removed independently model = DeepFusionModel( - encoder=encoder, + encoders={"encoder": encoder}, decoder=decoder, decoder_trainable=True, fusion_trainable=False, @@ -190,6 +192,29 @@ def test_set_trainable_params(self, fused_model, encoder, decoder): "decoder.embed.weight", } + def test_incorrect_number_of_encoders(self, decoder): + with pytest.raises(ValueError): + _ = DeepFusionModel( + encoders={"encoder": nn.Identity(), "encoder2": nn.Identity()}, + decoder=decoder, + ) + + def test_mismatched_encoder_keys(self, decoder): + with pytest.raises(ValueError): + _ = DeepFusionModel( + encoders={"encoder": nn.Identity()}, + decoder=decoder, + encoders_trainable={"encoder2": True}, + ) + + def test_mismatched_encoder_input(self, fused_model, inputs): + tokens, _, _, _ = inputs + with pytest.raises(ValueError): + _ = fused_model( + tokens, + encoder_input={"encoder2": {"input": torch.tensor([1])}}, + ) + class TestEarlyFusionModel: @pytest.fixture @@ -314,3 +339,29 @@ def test_set_trainable_params(self, fused_model): "decoder.embed.weight", "encoders.green.weight", } + + def test_mismatched_encoder_tokens(self, decoder): + with pytest.raises(ValueError): + _ = EarlyFusionModel( + encoders={"encoder": nn.Identity(), "encoder2": nn.Identity()}, + decoder=decoder, + encoder_tokens={"encoder": 0, "encoder3": 1}, + encoders_trainable=False, + ) + + def test_mismatched_encoder_trainable(self, decoder): + with pytest.raises(ValueError): + _ = EarlyFusionModel( + encoders={"encoder": nn.Identity(), "encoder2": nn.Identity()}, + decoder=decoder, + encoder_tokens={"encoder": 0, "encoder2": 1}, + encoders_trainable={"encoder": True, "encoder3": False}, + ) + + def test_mismatched_encoder_input(self, fused_model, inputs): + tokens, _, _, _ = inputs + with pytest.raises(ValueError): + _ = fused_model( + tokens, + encoder_input={"encoder": {"input": torch.tensor([1])}}, + ) diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index 38ec8d73aa..f2c311728f 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -590,7 +590,7 @@ def __init__( decoder: TransformerDecoder, encoders: Dict[str, nn.Module], encoder_tokens: Dict[str, int], - decoder_trainable: bool, + decoder_trainable: bool = False, encoders_trainable: Union[bool, Dict[str, bool]] = False, fusion_trainable: bool = True, ): From 347d602981c59992f6860fc792ffae0741dee2e6 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Thu, 31 Oct 2024 10:55:23 -0700 Subject: [PATCH 05/11] fix tests parts 2 --- .../modules/model_fusion/test_fusion_models.py | 4 ++-- torchtune/modules/model_fusion/_fusion.py | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/torchtune/modules/model_fusion/test_fusion_models.py b/tests/torchtune/modules/model_fusion/test_fusion_models.py index 278fdd8ca7..f16f7b0929 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_models.py +++ b/tests/torchtune/modules/model_fusion/test_fusion_models.py @@ -172,7 +172,7 @@ def test_set_trainable_params(self, fused_model, encoder, decoder): fusion_trainable=False, ) trainable_params = {n for n, p in model.named_parameters() if p.requires_grad} - assert trainable_params == {"encoder.weight"} + assert trainable_params == {"encoders.encoder.weight"} # Test decoder only, and confirm fusion layers are removed independently model = DeepFusionModel( @@ -189,7 +189,7 @@ def test_set_trainable_params(self, fused_model, encoder, decoder): "decoder.k.bias", "decoder.v.weight", "decoder.v.bias", - "decoder.embed.weight", + "decoder.tok_embeddings.weight", } def test_incorrect_number_of_encoders(self, decoder): diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index f2c311728f..b6ad775c68 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -600,7 +600,7 @@ def __init__( and encoders.keys() != encoders_trainable.keys() ): raise ValueError( - f"Found mismatched keys in encoders and encoders_trainable. Got {encoders.keys()} and {encoders_trainable.keys()}." + f"Found mismatched keys in encoders, encoder_tokens, and/or encoders_trainable. Expected {encoders.keys()}" ) self.decoder = decoder @@ -620,10 +620,8 @@ def __init__( self.tok_embeddings = decoder.tok_embeddings decoder.tok_embeddings = nn.Identity() - self.register_state_dict_post_hook(self._state_dict_hook) - self.register_load_state_dict_pre_hook( - self._load_state_dict_hook, with_module=True - ) + self._register_state_dict_hook(self._state_dict_hook) + self.register_load_state_dict_pre_hook(self._load_state_dict_hook) trainable_params = set() for encoder, trainable in self.encoders_trainable.items(): @@ -636,6 +634,9 @@ def __init__( trainable_params |= { f"decoder.{n}" for n, p in self.decoder.named_parameters() } + trainable_params |= { + f"tok_embeddings.{n}" for n, p in self.tok_embeddings.named_parameters() + } if fusion_trainable: trainable_params |= set(get_fusion_params(self)) else: @@ -643,7 +644,7 @@ def __init__( set_trainable_params(self, trainable_params) - def _state_dict_hook(self, destination, prefix, keep_vars): + def _state_dict_hook(self, destination, *args, **kwargs): """ Keep tok_embeddings inside of decoder state_dict From 3033728298d38692a0f2065262423b9edf5675d6 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Fri, 1 Nov 2024 17:26:58 -0700 Subject: [PATCH 06/11] finally fix tests, but they shouldnt be passing --- .../model_fusion/test_fusion_models.py | 31 ++++++++++++----- torchtune/modules/model_fusion/_fusion.py | 34 ++++++++++++++----- 2 files changed, 47 insertions(+), 18 deletions(-) diff --git a/tests/torchtune/modules/model_fusion/test_fusion_models.py b/tests/torchtune/modules/model_fusion/test_fusion_models.py index f16f7b0929..23bf8437d4 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_models.py +++ b/tests/torchtune/modules/model_fusion/test_fusion_models.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import pdb + import pytest import torch @@ -42,14 +44,22 @@ def caches_are_setup(self): def reset_caches(self): self.cache_enabled = False - def forward(self, tokens, mask, encoder_input, encoder_mask, input_pos): + def forward( + self, + tokens, + *, + mask=None, + encoder_input=None, + encoder_mask=None, + input_pos=None + ): x = self.tok_embeddings(tokens) if encoder_input is not None: q = self.q(x) - k = self.k(encoder_input) - v = self.v(encoder_input) + k = self.k(encoder_input) if encoder_input is not None else self.k(x) + v = self.v(encoder_input) if encoder_input is not None else self.v(x) x += nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=encoder_mask + q, k, v, attn_mask=encoder_mask if encoder_mask is not None else mask ) x = self.output(x) return x @@ -245,6 +255,7 @@ def fused_model(self, encoder, decoder) -> EarlyFusionModel: encoder_tokens={"red": 0, "green": 1, "blue": 2}, decoder_trainable=True, encoders_trainable={"red": False, "green": True, "blue": False}, + fusion_trainable=False, ) return model @@ -273,17 +284,19 @@ def test_forward(self, fused_model, inputs, vocab_size): """ Test that the forward pass of the EarlyFusionModel works as expected. """ - tokens, encoder_input, encoder_mask, _ = inputs + tokens, encoder_input, *_ = inputs batch_size, seq_len = tokens.shape + pdb.set_trace() out = fused_model( - tokens, encoder_input=encoder_input, encoder_mask=encoder_mask + tokens, + encoder_input=encoder_input, ) assert out.shape == (batch_size, seq_len, vocab_size) assert_expected(out.mean(), torch.tensor(8.5584), atol=1e-3, rtol=1e-3) @torch.no_grad() - def test_forward_no_encoding(self, fused_model, inputs, vocab_size): + def test_forward_no_encoder(self, fused_model, inputs, vocab_size): """ Test that the forward pass of the EarlyFusionModel with no encoder input. """ @@ -295,7 +308,7 @@ def test_forward_no_encoding(self, fused_model, inputs, vocab_size): assert_expected(out.mean(), torch.tensor(0.2271), atol=1e-3, rtol=1e-3) @torch.no_grad() - def test_decoding_forward(self, fused_model, inputs, vocab_size): + def test_decoder_forward(self, fused_model, inputs, vocab_size): """ Test that the forward pass of the EarlyFusionModel works during decoding. """ @@ -336,7 +349,7 @@ def test_set_trainable_params(self, fused_model): "decoder.k.bias", "decoder.v.weight", "decoder.v.bias", - "decoder.embed.weight", + "tok_embeddings.weight", "encoders.green.weight", } diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index b6ad775c68..571bbc508a 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -644,7 +644,8 @@ def __init__( set_trainable_params(self, trainable_params) - def _state_dict_hook(self, destination, *args, **kwargs): + @staticmethod + def _state_dict_hook(module, state_dict, *args, **kwargs): """ Keep tok_embeddings inside of decoder state_dict @@ -652,10 +653,11 @@ def _state_dict_hook(self, destination, *args, **kwargs): """ key = "tok_embeddings" decoder_key = "decoder.tok_embeddings" - destination[decoder_key] = destination[key] - del destination[key] + state_dict[decoder_key] = state_dict[key] + del state_dict[key] - def _load_state_dict_hook(self, state_dict, *args, **kwargs): + @staticmethod + def _load_state_dict_hook(module, state_dict, *args, **kwargs): """Undo the change from _state_dict_hook""" key = "tok_embeddings" decoder_key = "decoder.tok_embeddings" @@ -676,8 +678,20 @@ def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: """ self.decoder.setup_caches(batch_size, dtype) + def caches_are_setup(self) -> bool: + """ + Check if the key value caches are setup. This means ``setup_caches`` has been called, and + the relevant attention modules in the model have created their ``KVCache``. + """ + return self.decoder.caches_are_setup() + def caches_are_enabled(self) -> bool: - """Check if the key value caches are setup.""" + """ + Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant + attention modules will be "enabled" and all forward passes will update the caches. This behaviour + can be disabled without altering the state of the KV-caches by "disabling" the KV-caches + using :func:`~torchtune.modules.common_utils.disable_kv_cache`, upon which ``caches_are_enabled`` would return False. + """ return self.decoder.caches_are_enabled() def reset_caches(self): @@ -743,10 +757,12 @@ def forward( bsz, seq_len, embed_dim = embeds.shape for encoder, inp in (encoder_input or {}).items(): encoder_embeds = self.encoders[encoder](**inp) - encoder_mask = (tokens == self.encoder_tokens[encoder]).expand( - bsz, seq_len, embed_dim + encoder_mask = ( + (tokens == self.encoder_tokens[encoder]) + .unsqueeze(-1) + .expand(bsz, seq_len, embed_dim) ) - embeds[encoder_mask] = encoder_embeds + embeds = torch.where(encoder_mask, encoder_embeds, embeds) - output = self.decoder(embeds, mask, input_pos) + output = self.decoder(embeds, mask=mask, input_pos=input_pos) return output From 342613bc2737c8b692d0991b28289c93d2496eea Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Sun, 3 Nov 2024 14:30:53 -0800 Subject: [PATCH 07/11] final test fix --- .../model_fusion/test_fusion_models.py | 107 ++++++++++++++---- torchtune/modules/model_fusion/_fusion.py | 43 +++---- 2 files changed, 106 insertions(+), 44 deletions(-) diff --git a/tests/torchtune/modules/model_fusion/test_fusion_models.py b/tests/torchtune/modules/model_fusion/test_fusion_models.py index 23bf8437d4..d1d7b89533 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_models.py +++ b/tests/torchtune/modules/model_fusion/test_fusion_models.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import pdb +from collections import OrderedDict import pytest @@ -51,7 +51,7 @@ def forward( mask=None, encoder_input=None, encoder_mask=None, - input_pos=None + input_pos=None, ): x = self.tok_embeddings(tokens) if encoder_input is not None: @@ -235,24 +235,29 @@ def vocab_size(self) -> int: def dim(self) -> int: return 64 - @pytest.fixture - def encoder(self, dim, vocab_size) -> nn.Module: - encoder = nn.Embedding(vocab_size, dim) - fixed_init_model(encoder) - return encoder - @pytest.fixture def decoder(self, dim, vocab_size) -> nn.Module: - decoder = DummyModel(dim, vocab_size) + decoder = DummyModel(dim, vocab_size + 3) fixed_init_model(decoder, max_val=0.1) return decoder @pytest.fixture - def fused_model(self, encoder, decoder) -> EarlyFusionModel: + def fused_model(self, vocab_size, dim, decoder) -> EarlyFusionModel: + red = nn.Embedding(vocab_size, dim) + fixed_init_model(red) + green = nn.Embedding(vocab_size, dim) + fixed_init_model(green) + blue = nn.Embedding(vocab_size, dim) + fixed_init_model(blue) + model = EarlyFusionModel( - encoders={"red": encoder, "green": encoder, "blue": encoder}, + encoders={"red": red, "green": green, "blue": blue}, decoder=decoder, - encoder_tokens={"red": 0, "green": 1, "blue": 2}, + encoder_tokens={ + "red": vocab_size, + "green": vocab_size + 1, + "blue": vocab_size + 2, + }, decoder_trainable=True, encoders_trainable={"red": False, "green": True, "blue": False}, fusion_trainable=False, @@ -265,9 +270,9 @@ def inputs(self, vocab_size): seq_len = 10 tokens = torch.randint(0, vocab_size, (batch_size, seq_len)) red_seq_len, green_seq_len, blue_seq_len = 1, 2, 3 - tokens[:, 0] = 0 - tokens[:, 3:5] = 1 - tokens[:, 8:] = 2 + tokens[:, 0] = vocab_size + tokens[:, 3:5] = vocab_size + 1 + tokens[:, 7:] = vocab_size + 2 encoder_input = { "red": {"input": torch.randint(0, vocab_size, (batch_size, red_seq_len))}, "green": { @@ -279,6 +284,25 @@ def inputs(self, vocab_size): input_pos = torch.Tensor([1]).int() return tokens, encoder_input, encoder_mask, input_pos + @pytest.fixture + def state_dict(self, dim, vocab_size): + return OrderedDict( + { + "decoder.q.weight": torch.randn((dim, dim)), + "decoder.q.bias": torch.randn((dim,)), + "decoder.k.weight": torch.randn((dim, dim)), + "decoder.k.bias": torch.randn((dim,)), + "decoder.v.weight": torch.randn((dim, dim)), + "decoder.v.bias": torch.randn((dim,)), + "decoder.output.weight": torch.randn((vocab_size + 3, dim)), + "decoder.output.bias": torch.randn((vocab_size + 3,)), + "decoder.tok_embeddings.weight": torch.randn((vocab_size + 3, dim)), + "encoders.red.weight": torch.randn((vocab_size, dim)), + "encoders.green.weight": torch.randn((vocab_size, dim)), + "encoders.blue.weight": torch.randn((vocab_size, dim)), + } + ) + @torch.no_grad() def test_forward(self, fused_model, inputs, vocab_size): """ @@ -286,26 +310,53 @@ def test_forward(self, fused_model, inputs, vocab_size): """ tokens, encoder_input, *_ = inputs batch_size, seq_len = tokens.shape - pdb.set_trace() + out = fused_model( tokens, encoder_input=encoder_input, ) - assert out.shape == (batch_size, seq_len, vocab_size) - assert_expected(out.mean(), torch.tensor(8.5584), atol=1e-3, rtol=1e-3) + assert out.shape == (batch_size, seq_len, vocab_size + 3) + assert_expected(out.mean(), torch.tensor(1.1828), atol=1e-3, rtol=1e-3) @torch.no_grad() - def test_forward_no_encoder(self, fused_model, inputs, vocab_size): + def test_forward_no_decoder(self, fused_model, inputs, dim): + """ + Test that the forward pass of the EarlyFusionModel works as expected. + """ + tokens, encoder_input, *_ = inputs + batch_size, seq_len = tokens.shape + + class DummyModule(nn.Module): + def forward(self, x, **kwargs): + return x + + fused_model.decoder = DummyModule() + + out = fused_model( + tokens, + encoder_input=encoder_input, + ) + + assert out.shape == (batch_size, seq_len, dim) + # Check that each encoder output is placed correctly in the fused output + red = fused_model.encoders["red"](**encoder_input["red"]) + assert_expected(out[:, :1, :], red, atol=1e-3, rtol=1e-3) + green = fused_model.encoders["green"](**encoder_input["green"]) + assert_expected(out[:, 3:5, :], green, atol=1e-3, rtol=1e-3) + blue = fused_model.encoders["blue"](**encoder_input["blue"]) + assert_expected(out[:, 7:, :], blue, atol=1e-3, rtol=1e-3) + + @torch.no_grad() + def test_forward_no_encoder(self, fused_model, inputs): """ Test that the forward pass of the EarlyFusionModel with no encoder input. """ tokens, *_ = inputs - batch_size, seq_len = tokens.shape - out = fused_model(tokens) + actual = fused_model(tokens) + expected = fused_model.decoder(fused_model.tok_embeddings(tokens)) - assert out.shape == (batch_size, seq_len, vocab_size) - assert_expected(out.mean(), torch.tensor(0.2271), atol=1e-3, rtol=1e-3) + assert_expected(actual, expected, atol=1e-3, rtol=1e-3) @torch.no_grad() def test_decoder_forward(self, fused_model, inputs, vocab_size): @@ -323,8 +374,8 @@ def test_decoder_forward(self, fused_model, inputs, vocab_size): input_pos=input_pos, ) - assert out.shape == (batch_size, seq_len, vocab_size) - assert_expected(out.mean(), torch.tensor(9.0072), atol=1e-3, rtol=1e-3) + assert out.shape == (batch_size, seq_len, vocab_size + 3) + assert_expected(out.mean(), torch.tensor(0.200152), atol=1e-3, rtol=1e-3) def test_setup_cache(self, fused_model): """ @@ -378,3 +429,9 @@ def test_mismatched_encoder_input(self, fused_model, inputs): tokens, encoder_input={"encoder": {"input": torch.tensor([1])}}, ) + + def test_state_dict_hooks(self, fused_model, state_dict): + fused_model.load_state_dict(state_dict) + actual = fused_model.state_dict() + expected = state_dict + assert_expected(actual, expected) diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index 571bbc508a..e32db97928 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -547,6 +547,11 @@ class EarlyFusionModel(nn.Module): You can pass in multiple encoders as a dictionary into ``encoders``. + Note: Once the decoder is wrapped in this module, the decoder's ``tok_embeddings`` module is moved + to the parent EarlyFusionModel's ``tok_embeddings``. You should not forward pass the decoder individually. + Instead, use EarlyFusionModel's forward pass with ``encoder_input=None`` to get decoder-only outputs. + State dicts will automatically be updated on save and load to account for this change. + Example: >>> # decoder is a text-only TransformerDecoder (e.g. llama3_8b) with no modifications >>> decoder = llama3_8b() @@ -557,18 +562,17 @@ class EarlyFusionModel(nn.Module): >>> encoders = {"image": nn.Sequential(clip_vit_224(), projection_head)} >>> >>> # EarlyFusionModel combines the encoder and decoder - >>> model = DeepFusionModel(decoder, encoders) + >>> model = EarlyFusionModel(decoder, encoders, encoder_tokens={"image": 128256}) >>> >>> # Load full fused checkpoints >>> model.load_state_dict(...) >>> - >>> # Or load pretrained individual models (fusion_params are not loaded) - >>> model.encoder.load_state_dict(..., strict=False) - >>> model.decoder.load_state_dict(..., strict=False) - >>> >>> # Forward pass >>> encoder_input = {"image": {...}} - >>> output = model(tokens, mask, encoder_input, encoder_mask, input_pos) + >>> output = model(tokens, mask=mask, encoder_input=encoder_input, encoder_mask=encoder_mask, input_pos=input_pos) + >>> + >>> # Forward pass decoder only + >>> output = model(tokens, mask=mask, input_pos=input_pos) Args: decoder (TransformerDecoder): decoder module @@ -651,18 +655,18 @@ def _state_dict_hook(module, state_dict, *args, **kwargs): [!Note] This update changes the order of the OrderedDict """ - key = "tok_embeddings" - decoder_key = "decoder.tok_embeddings" - state_dict[decoder_key] = state_dict[key] - del state_dict[key] + for n, p in module.tok_embeddings.named_parameters(): + state_dict[f"decoder.tok_embeddings.{n}"] = p + del state_dict[f"tok_embeddings.{n}"] @staticmethod def _load_state_dict_hook(module, state_dict, *args, **kwargs): """Undo the change from _state_dict_hook""" - key = "tok_embeddings" - decoder_key = "decoder.tok_embeddings" - state_dict[key] = state_dict[decoder_key] - del state_dict[decoder_key] + old_keys = list(state_dict.keys()) + for key in old_keys: + if key.startswith("decoder.tok_embeddings"): + state_dict[key[len("decoder.") :]] = state_dict[key] + del state_dict[key] def set_num_output_chunks(self, num_output_chunks: int) -> None: """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. @@ -754,15 +758,16 @@ def forward( ) embeds = self.tok_embeddings(tokens) - bsz, seq_len, embed_dim = embeds.shape + bsz, _, embed_dim = embeds.shape for encoder, inp in (encoder_input or {}).items(): encoder_embeds = self.encoders[encoder](**inp) encoder_mask = ( - (tokens == self.encoder_tokens[encoder]) - .unsqueeze(-1) - .expand(bsz, seq_len, embed_dim) + torch.where(tokens == self.encoder_tokens[encoder])[1] + .view(bsz, -1, 1) + .expand(bsz, -1, embed_dim) # shape: [bsz, num_values, embed_dim] ) - embeds = torch.where(encoder_mask, encoder_embeds, embeds) + # At locations where encoder token is found, replace with encoder embedding + embeds.scatter_(1, encoder_mask, encoder_embeds) output = self.decoder(embeds, mask=mask, input_pos=input_pos) return output From 80b0c83fb11103d29410d5713b94a5276dae0667 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Sun, 3 Nov 2024 14:36:17 -0800 Subject: [PATCH 08/11] fix lint --- torchtune/modules/peft/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/modules/peft/_utils.py b/torchtune/modules/peft/_utils.py index 8e360eaf5f..318ab4136a 100644 --- a/torchtune/modules/peft/_utils.py +++ b/torchtune/modules/peft/_utils.py @@ -70,7 +70,7 @@ def set_trainable_params( Args: model (nn.Module): Instance of model class containing some adapter params. - adapter_params (Dict[str, Any]): State dict mapping adapter key names to their + adapter_params (Union[Dict[str, Any], Set]): State dict mapping adapter key names to their respective nn.Parameters (i.e. outputs of :func:`~torchtune.modules.peft.get_adapter_params`.) Returns: From 7fc42bba77f6e005c4ab8f84cf49640e3b9fe09c Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Sun, 3 Nov 2024 14:42:24 -0800 Subject: [PATCH 09/11] update DeepFusion callsites --- tests/torchtune/modules/test_common_utils.py | 4 ++-- .../models/llama3_2_vision/_model_builders.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/torchtune/modules/test_common_utils.py b/tests/torchtune/modules/test_common_utils.py index 41dc472f00..a83da5c743 100644 --- a/tests/torchtune/modules/test_common_utils.py +++ b/tests/torchtune/modules/test_common_utils.py @@ -45,9 +45,9 @@ def llama_vision_model(): fixed_init_model(vision_encoder, min_val=-1, max_val=1) fixed_init_model(vision_decoder, min_val=-1, max_val=1) model = DeepFusionModel( - encoder=vision_encoder, + encoders={"image": vision_encoder}, decoder=vision_decoder, - encoder_trainable=False, + encoders_trainable=False, decoder_trainable=False, fusion_trainable=False, ) diff --git a/torchtune/models/llama3_2_vision/_model_builders.py b/torchtune/models/llama3_2_vision/_model_builders.py index ed83760004..6b83721960 100644 --- a/torchtune/models/llama3_2_vision/_model_builders.py +++ b/torchtune/models/llama3_2_vision/_model_builders.py @@ -115,9 +115,9 @@ def llama3_2_vision_11b( intermediate_dim=14336, ) return DeepFusionModel( - encoder=encoder, + encoders={"image": encoder}, decoder=decoder, - encoder_trainable=encoder_trainable, + encoders_trainable=encoder_trainable, decoder_trainable=decoder_trainable, fusion_trainable=fusion_trainable, ) @@ -218,9 +218,9 @@ def lora_llama3_2_vision_11b( quantize_base=quantize_base, ) return DeepFusionModel( - encoder=encoder, + encoders={"image": encoder}, decoder=decoder, - encoder_trainable=encoder_type != LoRATrainable.FROZEN, + encoders_trainable=encoder_type != LoRATrainable.FROZEN, decoder_trainable=decoder_type != LoRATrainable.FROZEN, fusion_trainable=fusion_type != LoRATrainable.FROZEN, ) @@ -270,9 +270,9 @@ def llama3_2_vision_90b( intermediate_dim=28672, ) return DeepFusionModel( - encoder=encoder, + encoders={"image": encoder}, decoder=decoder, - encoder_trainable=encoder_trainable, + encoders_trainable=encoder_trainable, decoder_trainable=decoder_trainable, fusion_trainable=fusion_trainable, ) @@ -373,9 +373,9 @@ def lora_llama3_2_vision_90b( quantize_base=quantize_base, ) return DeepFusionModel( - encoder=encoder, + encoders={"image": encoder}, decoder=decoder, - encoder_trainable=encoder_type != LoRATrainable.FROZEN, + encoders_trainable=encoder_type != LoRATrainable.FROZEN, decoder_trainable=decoder_type != LoRATrainable.FROZEN, fusion_trainable=fusion_type != LoRATrainable.FROZEN, ) From 67fbc721d0f182ab814e859b9add41b0db7680e6 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Mon, 4 Nov 2024 12:29:29 -0800 Subject: [PATCH 10/11] revamp forward to support uneven encoder tokens in batch --- .../model_fusion/test_fusion_models.py | 92 ++++++++++++++++--- torchtune/modules/model_fusion/_fusion.py | 48 +++++++--- 2 files changed, 112 insertions(+), 28 deletions(-) diff --git a/tests/torchtune/modules/model_fusion/test_fusion_models.py b/tests/torchtune/modules/model_fusion/test_fusion_models.py index d1d7b89533..6e8b287169 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_models.py +++ b/tests/torchtune/modules/model_fusion/test_fusion_models.py @@ -235,9 +235,17 @@ def vocab_size(self) -> int: def dim(self) -> int: return 64 + @pytest.fixture + def batch_size(self) -> int: + return 2 + + @pytest.fixture + def seq_len(self) -> int: + return 10 + @pytest.fixture def decoder(self, dim, vocab_size) -> nn.Module: - decoder = DummyModel(dim, vocab_size + 3) + decoder = DummyModel(dim, vocab_size) fixed_init_model(decoder, max_val=0.1) return decoder @@ -253,6 +261,7 @@ def fused_model(self, vocab_size, dim, decoder) -> EarlyFusionModel: model = EarlyFusionModel( encoders={"red": red, "green": green, "blue": blue}, decoder=decoder, + # These are IDs that are out of vocab in the decoder encoder_tokens={ "red": vocab_size, "green": vocab_size + 1, @@ -265,9 +274,7 @@ def fused_model(self, vocab_size, dim, decoder) -> EarlyFusionModel: return model @pytest.fixture - def inputs(self, vocab_size): - batch_size = 2 - seq_len = 10 + def inputs(self, batch_size, seq_len, vocab_size): tokens = torch.randint(0, vocab_size, (batch_size, seq_len)) red_seq_len, green_seq_len, blue_seq_len = 1, 2, 3 tokens[:, 0] = vocab_size @@ -294,9 +301,9 @@ def state_dict(self, dim, vocab_size): "decoder.k.bias": torch.randn((dim,)), "decoder.v.weight": torch.randn((dim, dim)), "decoder.v.bias": torch.randn((dim,)), - "decoder.output.weight": torch.randn((vocab_size + 3, dim)), - "decoder.output.bias": torch.randn((vocab_size + 3,)), - "decoder.tok_embeddings.weight": torch.randn((vocab_size + 3, dim)), + "decoder.output.weight": torch.randn((vocab_size, dim)), + "decoder.output.bias": torch.randn((vocab_size,)), + "decoder.tok_embeddings.weight": torch.randn((vocab_size, dim)), "encoders.red.weight": torch.randn((vocab_size, dim)), "encoders.green.weight": torch.randn((vocab_size, dim)), "encoders.blue.weight": torch.randn((vocab_size, dim)), @@ -316,8 +323,8 @@ def test_forward(self, fused_model, inputs, vocab_size): encoder_input=encoder_input, ) - assert out.shape == (batch_size, seq_len, vocab_size + 3) - assert_expected(out.mean(), torch.tensor(1.1828), atol=1e-3, rtol=1e-3) + assert out.shape == (batch_size, seq_len, vocab_size) + assert_expected(out.mean(), torch.tensor(0.5647), atol=1e-3, rtol=1e-3) @torch.no_grad() def test_forward_no_decoder(self, fused_model, inputs, dim): @@ -327,6 +334,7 @@ def test_forward_no_decoder(self, fused_model, inputs, dim): tokens, encoder_input, *_ = inputs batch_size, seq_len = tokens.shape + # No-op for the decoder class DummyModule(nn.Module): def forward(self, x, **kwargs): return x @@ -348,16 +356,72 @@ def forward(self, x, **kwargs): assert_expected(out[:, 7:, :], blue, atol=1e-3, rtol=1e-3) @torch.no_grad() - def test_forward_no_encoder(self, fused_model, inputs): + def test_forward_no_encoder(self, fused_model, batch_size, seq_len, vocab_size): """ - Test that the forward pass of the EarlyFusionModel with no encoder input. + Test the forward pass of the EarlyFusionModel with no encoder input. """ - tokens, *_ = inputs + tokens = torch.randint(0, vocab_size, (batch_size, seq_len)) + actual = fused_model(tokens) expected = fused_model.decoder(fused_model.tok_embeddings(tokens)) assert_expected(actual, expected, atol=1e-3, rtol=1e-3) + @torch.no_grad() + def test_forward_no_decoder_uneven_encoder_tokens( + self, fused_model, dim, batch_size, seq_len, vocab_size + ): + """ + If each sample has a different number of encoder tokens in the sequence, test that mask scatter + of embeds still works as expected: + + This is a dog. + My dog is better than yours. + """ + red_seq_len, green_seq_len, blue_seq_len = 1, 2, 3 + # In a real encoder input, it would be padded to max number of media in the batch, so we don't + # make these test inputs uneven. The forward pass should still be able to take the number of embeddings + # it needs and ignore the rest, which would be pad embeddings. + encoder_input = { + "red": {"input": torch.randint(0, vocab_size, (batch_size, red_seq_len))}, + "green": { + "input": torch.randint(0, vocab_size, (batch_size, green_seq_len)) + }, + "blue": {"input": torch.randint(0, vocab_size, (batch_size, blue_seq_len))}, + } + tokens = torch.randint(0, vocab_size, (batch_size, seq_len)) + # For red encoder, only the first sample has a token + tokens[0, 0] = vocab_size + # For green encoder, first sample has 2 tokens, second sample has 1 token + tokens[0, 3:5] = vocab_size + 1 + tokens[1, 4] = vocab_size + 1 + # For blue encoder, first sample has 3 tokens, second sample has 2 tokens + tokens[0, 7:] = vocab_size + 2 + tokens[1, 8:] = vocab_size + 2 + + # No-op for the decoder + class DummyModule(nn.Module): + def forward(self, x, **kwargs): + return x + + fused_model.decoder = DummyModule() + + out = fused_model( + tokens, + encoder_input=encoder_input, + ) + + assert out.shape == (batch_size, seq_len, dim) + # Check that each encoder output is placed correctly in the fused output + red = fused_model.encoders["red"](**encoder_input["red"]) + assert_expected(out[0, 0, :], red[0, 0, :], atol=1e-3, rtol=1e-3) + green = fused_model.encoders["green"](**encoder_input["green"]) + assert_expected(out[0, 3:5, :], green[0, :, :], atol=1e-3, rtol=1e-3) + assert_expected(out[1, 4, :], green[1, 0, :], atol=1e-3, rtol=1e-3) + blue = fused_model.encoders["blue"](**encoder_input["blue"]) + assert_expected(out[0, 7:, :], blue[0, :, :], atol=1e-3, rtol=1e-3) + assert_expected(out[1, 8:, :], blue[1, :2, :], atol=1e-3, rtol=1e-3) + @torch.no_grad() def test_decoder_forward(self, fused_model, inputs, vocab_size): """ @@ -374,8 +438,8 @@ def test_decoder_forward(self, fused_model, inputs, vocab_size): input_pos=input_pos, ) - assert out.shape == (batch_size, seq_len, vocab_size + 3) - assert_expected(out.mean(), torch.tensor(0.200152), atol=1e-3, rtol=1e-3) + assert out.shape == (batch_size, seq_len, vocab_size) + assert_expected(out.mean(), torch.tensor(0.2383), atol=1e-3, rtol=1e-3) def test_setup_cache(self, fused_model): """ diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index e32db97928..a4056329e8 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn @@ -273,9 +273,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # num_fusion_tokens = (input >= vocab_size).sum() fusion_tokens = torch.masked_select(input, ~mask) - vocab_size - # [batch_size x num_tokens x embed_dim] + # [batch_size * num_tokens, embed_dim] embeds = self.embedding(tokens) - # [batch_size x num_fusion_tokens x embed_dim] + # [batch_size * num_fusion_tokens, embed_dim] fusion_embeds = self.fusion_embedding(fusion_tokens) # [batch_size x seq_length x embed_dim] @@ -536,8 +536,8 @@ class EarlyFusionModel(nn.Module): """EarlyFusion is a type of fused model architecture where pretrained encoder(s) are combined with a pretrained decoder (LLM) at the model input and not in internal layers. This is a popular architecture for multimodal models, with a full overview available in `The Evolution of Multimodal Model Architectures - `_. This module assumes the decoder is trained to recognize tokens specific - to each encoder, which are then replaced in the sequence with the encoder embedding outputs. + `_. This module works both for decoders in which the encoder tokens are + inside the vocab and outside the vocab. This module has the same methods and forward signature as :class:`~torchtune.modules.TransformerDecoder` and can be used interchangeably where :class:`~torchtune.modules.TransformerDecoder` is. It combines the encoders with the decoder as a @@ -702,6 +702,16 @@ def reset_caches(self): """Reset the key value caches.""" self.decoder.reset_caches() + def _decoder_embed(self, tokens) -> Tuple[torch.Tensor, torch.Tensor]: + """Embed the text-only tokens with the decoder's tok_embeddings""" + encoder_token_ids = torch.tensor(list(self.encoder_tokens.values())) + # [bsz, seq_len], True indicates the token is not an encoder special token + is_text = ~torch.isin(tokens, encoder_token_ids) + text_tokens = torch.masked_select(tokens, is_text) + # [num_text, embed_dim] + text_embeds = self.tok_embeddings(text_tokens) + return is_text, text_embeds + def forward( self, tokens: torch.Tensor, @@ -757,17 +767,27 @@ def forward( f"Got {encoder_input.keys()}, expected {self.encoders.keys()}." ) - embeds = self.tok_embeddings(tokens) - bsz, _, embed_dim = embeds.shape + bsz, seq_len = tokens.shape + # is_text: [bsz, seq_len], text_embeds: [num_text, embed_dim] + is_text, text_embeds = self._decoder_embed(tokens) + embed_dim = text_embeds.shape[-1] + + # Holds the final embedding vector + fused_embeds = torch.empty( + bsz, seq_len, embed_dim, dtype=text_embeds.dtype, device=text_embeds.device + ) + # Place the text-only embeddings + fused_embeds = fused_embeds.masked_scatter(is_text.unsqueeze(-1), text_embeds) + for encoder, inp in (encoder_input or {}).items(): + # [bsz, num_encoder_tokens, embed_dim] encoder_embeds = self.encoders[encoder](**inp) - encoder_mask = ( - torch.where(tokens == self.encoder_tokens[encoder])[1] - .view(bsz, -1, 1) - .expand(bsz, -1, embed_dim) # shape: [bsz, num_values, embed_dim] - ) + # [bsz * num_encoder_tokens, embed_dim] + encoder_embeds = encoder_embeds.view(-1, embed_dim) + # [bsz, seq_len, 1] + encoder_mask = (tokens == self.encoder_tokens[encoder]).unsqueeze(-1) # At locations where encoder token is found, replace with encoder embedding - embeds.scatter_(1, encoder_mask, encoder_embeds) + fused_embeds = fused_embeds.masked_scatter(encoder_mask, encoder_embeds) - output = self.decoder(embeds, mask=mask, input_pos=input_pos) + output = self.decoder(fused_embeds, mask=mask, input_pos=input_pos) return output From d0b1ab07486de8ae68b510610e2cd8f77be1a930 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Fri, 8 Nov 2024 15:55:26 -0800 Subject: [PATCH 11/11] address comments --- .../modules/model_fusion/test_deep_fusion.py | 195 +++++ ..._fusion_models.py => test_early_fusion.py} | 167 +--- .../modules/model_fusion/test_fusion_embed.py | 90 -- ..._fusion_layer.py => test_fusion_layers.py} | 75 +- tests/torchtune/modules/test_common_utils.py | 4 +- .../models/llama3_2_vision/_model_builders.py | 16 +- torchtune/modules/model_fusion/__init__.py | 4 +- .../modules/model_fusion/_deep_fusion.py | 212 +++++ .../modules/model_fusion/_early_fusion.py | 278 ++++++ torchtune/modules/model_fusion/_fusion.py | 793 ------------------ .../modules/model_fusion/_fusion_layers.py | 283 +++++++ .../modules/model_fusion/_fusion_utils.py | 2 +- 12 files changed, 1057 insertions(+), 1062 deletions(-) create mode 100644 tests/torchtune/modules/model_fusion/test_deep_fusion.py rename tests/torchtune/modules/model_fusion/{test_fusion_models.py => test_early_fusion.py} (68%) delete mode 100644 tests/torchtune/modules/model_fusion/test_fusion_embed.py rename tests/torchtune/modules/model_fusion/{test_fusion_layer.py => test_fusion_layers.py} (67%) create mode 100644 torchtune/modules/model_fusion/_deep_fusion.py create mode 100644 torchtune/modules/model_fusion/_early_fusion.py delete mode 100644 torchtune/modules/model_fusion/_fusion.py create mode 100644 torchtune/modules/model_fusion/_fusion_layers.py diff --git a/tests/torchtune/modules/model_fusion/test_deep_fusion.py b/tests/torchtune/modules/model_fusion/test_deep_fusion.py new file mode 100644 index 0000000000..79b2f9ab3d --- /dev/null +++ b/tests/torchtune/modules/model_fusion/test_deep_fusion.py @@ -0,0 +1,195 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +import torch +from tests.test_utils import assert_expected, fixed_init_model +from torch import nn +from torchtune.modules.model_fusion import DeepFusionModel, register_fusion_module +from torchtune.training.seed import set_seed + + +@pytest.fixture(autouse=True) +def random(): + set_seed(1) + + +class DummyModel(nn.Module): + def __init__(self, dim, vocab_size): + super().__init__() + self.cache_enabled = False + self.tok_embeddings = nn.Embedding(vocab_size, dim) + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.output = nn.Linear(dim, vocab_size) + register_fusion_module(self.output) + + def setup_caches(self, batch_size, dtype, *args, **kwargs): + self.cache_enabled = True + + def caches_are_setup(self): + return self.cache_enabled + + def reset_caches(self): + self.cache_enabled = False + + def forward( + self, + tokens, + *, + mask=None, + encoder_input=None, + encoder_mask=None, + input_pos=None, + ): + x = self.tok_embeddings(tokens) + if encoder_input is not None: + q = self.q(x) + k = self.k(encoder_input) if encoder_input is not None else self.k(x) + v = self.v(encoder_input) if encoder_input is not None else self.v(x) + x += nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=encoder_mask if encoder_mask is not None else mask + ) + x = self.output(x) + return x + + +class TestDeepFusionModel: + """ + Class for testing our DeepFusionModel wrapper. + """ + + @pytest.fixture + def vocab_size(self) -> int: + return 100 + + @pytest.fixture + def dim(self) -> int: + return 64 + + @pytest.fixture + def encoder(self, dim, vocab_size) -> nn.Module: + encoder = nn.Embedding(vocab_size, dim) + fixed_init_model(encoder) + return encoder + + @pytest.fixture + def decoder(self, dim, vocab_size) -> nn.Module: + decoder = DummyModel(dim, vocab_size) + fixed_init_model(decoder, max_val=0.1) + return decoder + + @pytest.fixture + def fused_model(self, encoder, decoder) -> DeepFusionModel: + model = DeepFusionModel( + encoder=encoder, + decoder=decoder, + ) + return model + + @pytest.fixture + def inputs(self, vocab_size): + batch_size = 2 + seq_len = 10 + tokens = torch.randint(0, vocab_size, (batch_size, seq_len)) + encoder_input = {"input": torch.randint(0, vocab_size, (batch_size, seq_len))} + encoder_mask = torch.randint(0, 2, (batch_size, seq_len, seq_len)).bool() + input_pos = torch.Tensor([1]).int() + return tokens, encoder_input, encoder_mask, input_pos + + @torch.no_grad() + def test_forward(self, fused_model, inputs, vocab_size): + """ + Test that the forward pass of the DeepFusionModel works as expected. + """ + tokens, encoder_input, encoder_mask, _ = inputs + batch_size, seq_len = tokens.shape + out = fused_model( + tokens, encoder_input=encoder_input, encoder_mask=encoder_mask + ) + + assert out.shape == (batch_size, seq_len, vocab_size) + assert_expected(out.mean(), torch.tensor(8.5584), atol=1e-3, rtol=1e-3) + + @torch.no_grad() + def test_forward_no_encoding(self, fused_model, inputs, vocab_size): + """ + Test that the forward pass of the DeepFusionModel with no encoder input. + """ + tokens, *_ = inputs + batch_size, seq_len = tokens.shape + out = fused_model(tokens) + + assert out.shape == (batch_size, seq_len, vocab_size) + assert_expected(out.mean(), torch.tensor(0.2271), atol=1e-3, rtol=1e-3) + + @torch.no_grad() + def test_decoding_forward(self, fused_model, inputs, vocab_size): + """ + Test that the forward pass of the DeepFusionModel works during decoding. + """ + tokens, encoder_input, encoder_mask, input_pos = inputs + tokens = tokens[:, input_pos] + encoder_mask = encoder_mask[:, input_pos] + batch_size, seq_len = tokens.shape + out = fused_model( + tokens, + encoder_input=encoder_input, + encoder_mask=encoder_mask, + input_pos=input_pos, + ) + + assert out.shape == (batch_size, seq_len, vocab_size) + assert_expected(out.mean(), torch.tensor(9.0072), atol=1e-3, rtol=1e-3) + + def test_setup_cache(self, fused_model): + """ + Test that the cache methods works as expected. + """ + fused_model.setup_caches(2, torch.float32) + assert fused_model.caches_are_setup() + fused_model.reset_caches() + assert not fused_model.caches_are_setup() + + def test_set_trainable_params(self, fused_model, encoder, decoder): + """ + Test that the trainable parameters are set correctly. + """ + # Test default case + trainable_params = { + n for n, p in fused_model.named_parameters() if p.requires_grad + } + assert trainable_params == {"decoder.output.weight", "decoder.output.bias"} + + # Test encoder only + model = DeepFusionModel( + encoder=encoder, + decoder=decoder, + encoder_trainable=True, + fusion_trainable=False, + ) + trainable_params = {n for n, p in model.named_parameters() if p.requires_grad} + assert trainable_params == {"encoder.weight"} + + # Test decoder only, and confirm fusion layers are removed independently + model = DeepFusionModel( + encoder=encoder, + decoder=decoder, + decoder_trainable=True, + fusion_trainable=False, + ) + trainable_params = {n for n, p in model.named_parameters() if p.requires_grad} + assert trainable_params == { + "decoder.q.weight", + "decoder.q.bias", + "decoder.k.weight", + "decoder.k.bias", + "decoder.v.weight", + "decoder.v.bias", + "decoder.tok_embeddings.weight", + } diff --git a/tests/torchtune/modules/model_fusion/test_fusion_models.py b/tests/torchtune/modules/model_fusion/test_early_fusion.py similarity index 68% rename from tests/torchtune/modules/model_fusion/test_fusion_models.py rename to tests/torchtune/modules/model_fusion/test_early_fusion.py index 6e8b287169..d7ff407289 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_models.py +++ b/tests/torchtune/modules/model_fusion/test_early_fusion.py @@ -11,11 +11,7 @@ import torch from tests.test_utils import assert_expected, fixed_init_model from torch import nn -from torchtune.modules.model_fusion import ( - DeepFusionModel, - EarlyFusionModel, - register_fusion_module, -) +from torchtune.modules.model_fusion import EarlyFusionModel, register_fusion_module from torchtune.training.seed import set_seed @@ -65,167 +61,6 @@ def forward( return x -class TestDeepFusionModel: - """ - Class for testing our DeepFusionModel wrapper. - """ - - @pytest.fixture - def vocab_size(self) -> int: - return 100 - - @pytest.fixture - def dim(self) -> int: - return 64 - - @pytest.fixture - def encoder(self, dim, vocab_size) -> nn.Module: - encoder = nn.Embedding(vocab_size, dim) - fixed_init_model(encoder) - return encoder - - @pytest.fixture - def decoder(self, dim, vocab_size) -> nn.Module: - decoder = DummyModel(dim, vocab_size) - fixed_init_model(decoder, max_val=0.1) - return decoder - - @pytest.fixture - def fused_model(self, encoder, decoder) -> DeepFusionModel: - model = DeepFusionModel( - encoders={"encoder": encoder}, - decoder=decoder, - ) - return model - - @pytest.fixture - def inputs(self, vocab_size): - batch_size = 2 - seq_len = 10 - tokens = torch.randint(0, vocab_size, (batch_size, seq_len)) - encoder_input = { - "encoder": {"input": torch.randint(0, vocab_size, (batch_size, seq_len))} - } - encoder_mask = torch.randint(0, 2, (batch_size, seq_len, seq_len)).bool() - input_pos = torch.Tensor([1]).int() - return tokens, encoder_input, encoder_mask, input_pos - - @torch.no_grad() - def test_forward(self, fused_model, inputs, vocab_size): - """ - Test that the forward pass of the DeepFusionModel works as expected. - """ - tokens, encoder_input, encoder_mask, _ = inputs - batch_size, seq_len = tokens.shape - out = fused_model( - tokens, encoder_input=encoder_input, encoder_mask=encoder_mask - ) - - assert out.shape == (batch_size, seq_len, vocab_size) - assert_expected(out.mean(), torch.tensor(8.5584), atol=1e-3, rtol=1e-3) - - @torch.no_grad() - def test_forward_no_encoding(self, fused_model, inputs, vocab_size): - """ - Test that the forward pass of the DeepFusionModel with no encoder input. - """ - tokens, *_ = inputs - batch_size, seq_len = tokens.shape - out = fused_model(tokens) - - assert out.shape == (batch_size, seq_len, vocab_size) - assert_expected(out.mean(), torch.tensor(0.2271), atol=1e-3, rtol=1e-3) - - @torch.no_grad() - def test_decoding_forward(self, fused_model, inputs, vocab_size): - """ - Test that the forward pass of the DeepFusionModel works during decoding. - """ - tokens, encoder_input, encoder_mask, input_pos = inputs - tokens = tokens[:, input_pos] - encoder_mask = encoder_mask[:, input_pos] - batch_size, seq_len = tokens.shape - out = fused_model( - tokens, - encoder_input=encoder_input, - encoder_mask=encoder_mask, - input_pos=input_pos, - ) - - assert out.shape == (batch_size, seq_len, vocab_size) - assert_expected(out.mean(), torch.tensor(9.0072), atol=1e-3, rtol=1e-3) - - def test_setup_cache(self, fused_model): - """ - Test that the cache methods works as expected. - """ - fused_model.setup_caches(2, torch.float32) - assert fused_model.caches_are_setup() - fused_model.reset_caches() - assert not fused_model.caches_are_setup() - - def test_set_trainable_params(self, fused_model, encoder, decoder): - """ - Test that the trainable parameters are set correctly. - """ - # Test default case - trainable_params = { - n for n, p in fused_model.named_parameters() if p.requires_grad - } - assert trainable_params == {"decoder.output.weight", "decoder.output.bias"} - - # Test encoder only - model = DeepFusionModel( - encoders={"encoder": encoder}, - decoder=decoder, - encoders_trainable=True, - fusion_trainable=False, - ) - trainable_params = {n for n, p in model.named_parameters() if p.requires_grad} - assert trainable_params == {"encoders.encoder.weight"} - - # Test decoder only, and confirm fusion layers are removed independently - model = DeepFusionModel( - encoders={"encoder": encoder}, - decoder=decoder, - decoder_trainable=True, - fusion_trainable=False, - ) - trainable_params = {n for n, p in model.named_parameters() if p.requires_grad} - assert trainable_params == { - "decoder.q.weight", - "decoder.q.bias", - "decoder.k.weight", - "decoder.k.bias", - "decoder.v.weight", - "decoder.v.bias", - "decoder.tok_embeddings.weight", - } - - def test_incorrect_number_of_encoders(self, decoder): - with pytest.raises(ValueError): - _ = DeepFusionModel( - encoders={"encoder": nn.Identity(), "encoder2": nn.Identity()}, - decoder=decoder, - ) - - def test_mismatched_encoder_keys(self, decoder): - with pytest.raises(ValueError): - _ = DeepFusionModel( - encoders={"encoder": nn.Identity()}, - decoder=decoder, - encoders_trainable={"encoder2": True}, - ) - - def test_mismatched_encoder_input(self, fused_model, inputs): - tokens, _, _, _ = inputs - with pytest.raises(ValueError): - _ = fused_model( - tokens, - encoder_input={"encoder2": {"input": torch.tensor([1])}}, - ) - - class TestEarlyFusionModel: @pytest.fixture def vocab_size(self) -> int: diff --git a/tests/torchtune/modules/model_fusion/test_fusion_embed.py b/tests/torchtune/modules/model_fusion/test_fusion_embed.py deleted file mode 100644 index 35ef5c0e87..0000000000 --- a/tests/torchtune/modules/model_fusion/test_fusion_embed.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import pytest - -import torch -from tests.test_utils import assert_expected, fixed_init_model -from torchtune.modules.model_fusion import FusionEmbedding -from torchtune.training.seed import set_seed - - -@pytest.fixture(autouse=True) -def random(): - set_seed(1) - - -class TestFusionEmbedding: - """ - Class for testing our FusionEmbedding. - """ - - @pytest.fixture - def dim(self) -> int: - return 2 - - @pytest.fixture - def vocab_size(self) -> int: - return 10 - - @pytest.fixture - def fusion_vocab_size(self) -> int: - return 5 - - @pytest.fixture - def embed(self, dim, vocab_size, fusion_vocab_size) -> FusionEmbedding: - embeds = FusionEmbedding( - vocab_size=vocab_size, fusion_vocab_size=fusion_vocab_size, embed_dim=dim - ) - fixed_init_model(embeds.embedding, min_val=0, max_val=0.5) - fixed_init_model(embeds.fusion_embedding, min_val=0.51, max_val=1) - return embeds - - @torch.no_grad() - def test_forward(self, embed, vocab_size, fusion_vocab_size, dim): - """ - Test that the forward pass of the FusionEmbedding works as expected. - """ - tokens = torch.randint(0, vocab_size + fusion_vocab_size, (2, 10)) - out = embed(tokens) - - assert out.shape == (2, 10, dim) - assert_expected(out.mean(), torch.tensor(0.3409), atol=1e-3, rtol=1e-3) - - # Only new tokens, embeddings should be > 0.5 - tokens = torch.randint(vocab_size, vocab_size + fusion_vocab_size, (2, 10)) - out = embed(tokens) - - assert out.min() > 0.5 - - # Only old tokens, embeddings should be < 0.5 - tokens = torch.randint(0, vocab_size, (2, 10)) - out = embed(tokens) - - assert out.max() < 0.5 - - def test_fusion_params(self, embed): - """ - Test that the currect fusion params are returned. - """ - fusion_params = set(embed.fusion_params()) - - assert fusion_params == {"fusion_embedding.weight"} - - def test_get_and_load_state_dict(self, embed): - """ - Test that the state dict hooks work in removing the "layer" variable - """ - state_dict = embed.state_dict() - state_keys = set(state_dict.keys()) - - assert state_keys == { - "weight", - "fusion_embedding.weight", - } - - # Check that the state_dict can be loaded back into the model - embed.load_state_dict(state_dict) diff --git a/tests/torchtune/modules/model_fusion/test_fusion_layer.py b/tests/torchtune/modules/model_fusion/test_fusion_layers.py similarity index 67% rename from tests/torchtune/modules/model_fusion/test_fusion_layer.py rename to tests/torchtune/modules/model_fusion/test_fusion_layers.py index a2fc0715eb..da8fdb4b1f 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_layer.py +++ b/tests/torchtune/modules/model_fusion/test_fusion_layers.py @@ -9,7 +9,7 @@ import torch from tests.test_utils import assert_expected, fixed_init_model from torch import nn -from torchtune.modules.model_fusion import FusionLayer +from torchtune.modules.model_fusion import FusionEmbedding, FusionLayer from torchtune.training.seed import set_seed @@ -60,6 +60,79 @@ def forward(self, x): return self.linear(x) +class TestFusionEmbedding: + """ + Class for testing our FusionEmbedding. + """ + + @pytest.fixture + def dim(self) -> int: + return 2 + + @pytest.fixture + def vocab_size(self) -> int: + return 10 + + @pytest.fixture + def fusion_vocab_size(self) -> int: + return 5 + + @pytest.fixture + def embed(self, dim, vocab_size, fusion_vocab_size) -> FusionEmbedding: + embeds = FusionEmbedding( + vocab_size=vocab_size, fusion_vocab_size=fusion_vocab_size, embed_dim=dim + ) + fixed_init_model(embeds.embedding, min_val=0, max_val=0.5) + fixed_init_model(embeds.fusion_embedding, min_val=0.51, max_val=1) + return embeds + + @torch.no_grad() + def test_forward(self, embed, vocab_size, fusion_vocab_size, dim): + """ + Test that the forward pass of the FusionEmbedding works as expected. + """ + tokens = torch.randint(0, vocab_size + fusion_vocab_size, (2, 10)) + out = embed(tokens) + + assert out.shape == (2, 10, dim) + assert_expected(out.mean(), torch.tensor(0.3409), atol=1e-3, rtol=1e-3) + + # Only new tokens, embeddings should be > 0.5 + tokens = torch.randint(vocab_size, vocab_size + fusion_vocab_size, (2, 10)) + out = embed(tokens) + + assert out.min() > 0.5 + + # Only old tokens, embeddings should be < 0.5 + tokens = torch.randint(0, vocab_size, (2, 10)) + out = embed(tokens) + + assert out.max() < 0.5 + + def test_fusion_params(self, embed): + """ + Test that the currect fusion params are returned. + """ + fusion_params = set(embed.fusion_params()) + + assert fusion_params == {"fusion_embedding.weight"} + + def test_get_and_load_state_dict(self, embed): + """ + Test that the state dict hooks work in removing the "layer" variable + """ + state_dict = embed.state_dict() + state_keys = set(state_dict.keys()) + + assert state_keys == { + "weight", + "fusion_embedding.weight", + } + + # Check that the state_dict can be loaded back into the model + embed.load_state_dict(state_dict) + + class TestFusionLayer: """ Class for testing our FusionLayer wrapper. diff --git a/tests/torchtune/modules/test_common_utils.py b/tests/torchtune/modules/test_common_utils.py index a83da5c743..41dc472f00 100644 --- a/tests/torchtune/modules/test_common_utils.py +++ b/tests/torchtune/modules/test_common_utils.py @@ -45,9 +45,9 @@ def llama_vision_model(): fixed_init_model(vision_encoder, min_val=-1, max_val=1) fixed_init_model(vision_decoder, min_val=-1, max_val=1) model = DeepFusionModel( - encoders={"image": vision_encoder}, + encoder=vision_encoder, decoder=vision_decoder, - encoders_trainable=False, + encoder_trainable=False, decoder_trainable=False, fusion_trainable=False, ) diff --git a/torchtune/models/llama3_2_vision/_model_builders.py b/torchtune/models/llama3_2_vision/_model_builders.py index 6b83721960..ed83760004 100644 --- a/torchtune/models/llama3_2_vision/_model_builders.py +++ b/torchtune/models/llama3_2_vision/_model_builders.py @@ -115,9 +115,9 @@ def llama3_2_vision_11b( intermediate_dim=14336, ) return DeepFusionModel( - encoders={"image": encoder}, + encoder=encoder, decoder=decoder, - encoders_trainable=encoder_trainable, + encoder_trainable=encoder_trainable, decoder_trainable=decoder_trainable, fusion_trainable=fusion_trainable, ) @@ -218,9 +218,9 @@ def lora_llama3_2_vision_11b( quantize_base=quantize_base, ) return DeepFusionModel( - encoders={"image": encoder}, + encoder=encoder, decoder=decoder, - encoders_trainable=encoder_type != LoRATrainable.FROZEN, + encoder_trainable=encoder_type != LoRATrainable.FROZEN, decoder_trainable=decoder_type != LoRATrainable.FROZEN, fusion_trainable=fusion_type != LoRATrainable.FROZEN, ) @@ -270,9 +270,9 @@ def llama3_2_vision_90b( intermediate_dim=28672, ) return DeepFusionModel( - encoders={"image": encoder}, + encoder=encoder, decoder=decoder, - encoders_trainable=encoder_trainable, + encoder_trainable=encoder_trainable, decoder_trainable=decoder_trainable, fusion_trainable=fusion_trainable, ) @@ -373,9 +373,9 @@ def lora_llama3_2_vision_90b( quantize_base=quantize_base, ) return DeepFusionModel( - encoders={"image": encoder}, + encoder=encoder, decoder=decoder, - encoders_trainable=encoder_type != LoRATrainable.FROZEN, + encoder_trainable=encoder_type != LoRATrainable.FROZEN, decoder_trainable=decoder_type != LoRATrainable.FROZEN, fusion_trainable=fusion_type != LoRATrainable.FROZEN, ) diff --git a/torchtune/modules/model_fusion/__init__.py b/torchtune/modules/model_fusion/__init__.py index be12ac16c5..21a3c1c063 100644 --- a/torchtune/modules/model_fusion/__init__.py +++ b/torchtune/modules/model_fusion/__init__.py @@ -4,7 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from ._fusion import DeepFusionModel, EarlyFusionModel, FusionEmbedding, FusionLayer +from ._deep_fusion import DeepFusionModel +from ._early_fusion import EarlyFusionModel +from ._fusion_layers import FusionEmbedding, FusionLayer from ._fusion_utils import get_fusion_params, register_fusion_module __all__ = [ diff --git a/torchtune/modules/model_fusion/_deep_fusion.py b/torchtune/modules/model_fusion/_deep_fusion.py new file mode 100644 index 0000000000..6a61c43744 --- /dev/null +++ b/torchtune/modules/model_fusion/_deep_fusion.py @@ -0,0 +1,212 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional, Union + +import torch +from torch import nn +from torchtune.modules import TransformerDecoder +from torchtune.modules.model_fusion._fusion_utils import get_fusion_params +from torchtune.modules.peft._utils import set_trainable_params + + +class DeepFusionModel(nn.Module): + """DeepFusion is a type of fused model architecture where a pretrained encoder is combined + with a pretrained decoder (LLM) in the internal decoder layers. This is a popular architecture for multimodal models, with + a full overview available in `The Evolution of Multimodal Model Architectures `_. + A common deep fusion architecture is to fuse the encoder input into the decoder with interspersed cross-attention + layers. This module makes no assumptions on how the encoder and decoder are fused; it simply + passes in the encoder embeddings to the decoder and lets the decoder handle any fusion. + + This module has the same methods and forward signature as :class:`~torchtune.modules.TransformerDecoder` and can be used + interchangeably where :class:`~torchtune.modules.TransformerDecoder` is. It combines the encoder with the decoder as a + single module for checkpointing and finetuning. It is expected that the encoder and decoder + are already defined with any extra learnable ``fusion_params``: learnable parameters to help + adapt the pre-trained encoder to the pre-trained decoder. + + DeepFusionModel currently only supports a single encoder. + + Example: + >>> # decoder is a TransformerDecoder (e.g. llama3_8b) with fused cross attention layers + >>> embed = FusionEmbedding(...) + >>> layer = FusionLayer( + ... layer=TransformerSelfAttentionLayer(...), + ... fusion_layer=TransformerCrossAttentionLayer(...), + ... ) + >>> decoder = TransformerDecoder(tok_embeddings=embed, layers=layer, num_layers=32, ...) + >>> + >>> # encoder is pre-trained encoder (e.g. clip_vit_224) with an added projection head + >>> projection_head = FeedForward(...) + >>> register_fusion_module(projection_head)) + >>> encoder = nn.Sequential(clip_vit_224(), projection_head) + >>> + >>> # DeepFusionModel combines the encoder and decoder + >>> model = DeepFusionModel(decoder, encoder) + >>> + >>> # Load full fused checkpoints (e.g. a Flamingo checkpoint) + >>> model.load_state_dict(...) + >>> + >>> # Or load pretrained individual models (fusion_params are not loaded) + >>> model.encoder.load_state_dict(..., strict=False) + >>> model.decoder.load_state_dict(..., strict=False) + >>> + >>> # Forward pass + >>> output = model(tokens, mask, encoder_input, encoder_mask, input_pos) + + Args: + decoder (TransformerDecoder): decoder module + encoder (nn.Module): encoder module + decoder_trainable (bool): whether to train or freeze the decoder. Default is False. + encoder_trainable (bool): whether to train or freeze the encoder. Default is False. + fusion_trainable (bool): whether to train the fusion parameters. Default is True. + + """ + + def __init__( + self, + decoder: TransformerDecoder, + encoder: nn.Module, + *, + decoder_trainable: bool = False, + encoder_trainable: bool = False, + fusion_trainable: bool = True, + ): + super().__init__() + self.decoder = decoder + self.encoder = encoder + + trainable_params = set() + if encoder_trainable: + trainable_params |= { + f"encoder.{n}" for n, p in self.encoder.named_parameters() + } + if decoder_trainable: + trainable_params |= { + f"decoder.{n}" for n, p in self.decoder.named_parameters() + } + if fusion_trainable: + trainable_params |= set(get_fusion_params(self)) + else: + trainable_params -= set(get_fusion_params(self)) + set_trainable_params(self, trainable_params) + + def set_num_output_chunks(self, num_output_chunks: int) -> None: + """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. + This should be called before the first forward pass, in the recipe.""" + self.decoder.set_num_output_chunks(num_output_chunks) + + def setup_caches( + self, + batch_size: int, + dtype: torch.dtype, + *, + encoder_max_seq_len: Optional[int] = None, + decoder_max_seq_len: Optional[int] = None, + ): + """ + Sets up key-value attention caches for inference for ``self.decoder``. + For each layer in ``self.decoder.layers``: + - :class:`torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``. + - :class:`torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``. + - :class:`torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + encoder_max_seq_len (Optional[int]): maximum encoder cache sequence length. + decoder_max_seq_len (Optional[int]): maximum decoder cache sequence length. + """ + self.decoder.setup_caches( + batch_size, + dtype, + encoder_max_seq_len=encoder_max_seq_len, + decoder_max_seq_len=decoder_max_seq_len, + ) + + def caches_are_setup(self) -> bool: + """ + Check if the key value caches are setup. This means ``setup_caches`` has been called, and + the relevant attention modules in the model have created their ``KVCache``. + """ + return self.decoder.caches_are_setup() + + def caches_are_enabled(self) -> bool: + """ + Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant + attention modules will be "enabled" and all forward passes will update the caches. This behaviour + can be disabled without altering the state of the KV-caches by "disabling" the KV-caches + using :func:`~torchtune.modules.common_utils.disable_kv_cache`, upon which ``caches_are_enabled`` would return False. + """ + return self.decoder.caches_are_enabled() + + def reset_caches(self): + """ + Resets KV-cache buffers on relevant attention modules to zero, and reset cache positions to zero, + without deleting or reallocating cache tensors. + """ + self.decoder.reset_caches() + + def forward( + self, + tokens: torch.Tensor, + *, + mask: Optional[torch.Tensor] = None, + encoder_input: Optional[Dict] = None, + encoder_mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """ + Args: + tokens (torch.Tensor): input tensor with shape ``[b x s]`` + mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask + with shape ``[b x s x s]``. This is applied after the query-key multiplication and + before the softmax. A value of True in row i and column j means token i attends + to token j. A value of False means token i does not attend to token j. If no + mask is specified, a causal mask is used by default. Default is None. + encoder_input (Optional[Dict]): Optional input for the encoder. + encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between + tokens and encoder embeddings. A True value at position i,j means token i can attend + to embedding j in the decoder. Mask has shape ``[b x s x s_e]``. Default is None. + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape ``[b x s]``. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Note: At the very first step of inference, when the model is provided with a prompt, + ``input_pos`` would contain the positions of all of the tokens in the prompt + (eg: ``torch.arange(prompt_length)``). This is because we will need to compute the + KV values for each position. + + Returns: + Tensor: output tensor with shape ``[b x s x v]`` or a list of layer \ + output tensors defined by ``output_hidden_states`` with the \ + final output tensor appended to the list. + + Notation used for tensor shapes: + - b: batch size + - s: token sequence length + - s_e: encoder sequence length + - v: vocab size + - d: token embed dim + - d_e: encoder embed dim + - m_s: max seq len + """ + # During decoding, encoder_input will only be provided + # for new inputs. Previous encoder outputs are cached + # in the decoder cache. + encoder_embed = None + if encoder_input is not None: + encoder_embed = self.encoder(**encoder_input) + + output = self.decoder( + tokens=tokens, + mask=mask, + encoder_input=encoder_embed, + encoder_mask=encoder_mask, + input_pos=input_pos, + ) + return output diff --git a/torchtune/modules/model_fusion/_early_fusion.py b/torchtune/modules/model_fusion/_early_fusion.py new file mode 100644 index 0000000000..d20b2d119f --- /dev/null +++ b/torchtune/modules/model_fusion/_early_fusion.py @@ -0,0 +1,278 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch import nn +from torchtune.modules import TransformerDecoder +from torchtune.modules.model_fusion._fusion_utils import get_fusion_params +from torchtune.modules.peft._utils import set_trainable_params + + +class EarlyFusionModel(nn.Module): + """EarlyFusion is a type of fused model architecture where pretrained encoder(s) are combined + with a pretrained decoder (LLM) at the model input and not in internal layers. This is a popular architecture + for multimodal models, with a full overview available in `The Evolution of Multimodal Model Architectures + `_. This module works both for decoders in which the encoder tokens are + inside the vocab and outside the vocab. + + This module has the same methods and forward signature as :class:`~torchtune.modules.TransformerDecoder` and can be used + interchangeably where :class:`~torchtune.modules.TransformerDecoder` is. It combines the encoders with the decoder as a + single module for checkpointing and finetuning. It is expected that the encoders and decoder + are already defined with any extra learnable ``fusion_params``: learnable parameters to help + adapt the pre-trained encoders to the pre-trained decoder. + + You can pass in multiple encoders as a dictionary into ``encoders``. + + Note: Once the decoder is wrapped in this module, the decoder's ``tok_embeddings`` module is moved + to the parent EarlyFusionModel's ``tok_embeddings``. You should not forward pass the decoder individually. + Instead, use EarlyFusionModel's forward pass with ``encoder_input=None`` to get decoder-only outputs. + State dicts will automatically be updated on save and load to account for this change. + + Example: + >>> # decoder is a text-only TransformerDecoder (e.g. llama3_8b) with no modifications + >>> decoder = llama3_8b() + >>> + >>> # encoder is pre-trained encoder (e.g. clip_vit_224) with an added projection head + >>> projection_head = FeedForward(...) + >>> register_fusion_module(projection_head)) + >>> encoders = {"image": nn.Sequential(clip_vit_224(), projection_head)} + >>> + >>> # EarlyFusionModel combines the encoder and decoder + >>> model = EarlyFusionModel(decoder, encoders, encoder_tokens={"image": 128256}) + >>> + >>> # Load full fused checkpoints + >>> model.load_state_dict(...) + >>> + >>> # Forward pass + >>> encoder_input = {"image": {...}} + >>> output = model(tokens, mask=mask, encoder_input=encoder_input, encoder_mask=encoder_mask, input_pos=input_pos) + >>> + >>> # Forward pass decoder only + >>> output = model(tokens, mask=mask, input_pos=input_pos) + + Args: + decoder (TransformerDecoder): decoder module + encoders (Dict[str, nn.Module]): dictionary mapping encoder name as a string to the encoder module. + encoder_tokens (Dict[str, int]): dictionary mapping encoder name to special token ID indicating where + in the text sequence the encoder embedding outputs should be injected. + decoder_trainable (bool): whether to train or freeze the decoder. Default is False. + encoders_trainable (Union[bool, Dict[str, bool]]): whether to train or freeze the encoder. Use a single + boolean to set trainable for all encoders or a dictionary keyed by encoder names to specify trainable + for each encoder individually. Encoder names should match with ``encoders``. Default is False. + fusion_trainable (bool): whether to train the fusion parameters. Default is True. + + Raises: + ValueError: if ``encoders`` and ``encoders_trainable`` keys do not match + """ + + def __init__( + self, + decoder: TransformerDecoder, + encoders: Dict[str, nn.Module], + encoder_tokens: Dict[str, int], + decoder_trainable: bool = False, + encoders_trainable: Union[bool, Dict[str, bool]] = False, + fusion_trainable: bool = True, + ): + super().__init__() + if encoders.keys() != encoder_tokens.keys() or ( + not isinstance(encoders_trainable, bool) + and encoders.keys() != encoders_trainable.keys() + ): + raise ValueError( + f"Found mismatched keys in encoders, encoder_tokens, and/or encoders_trainable. Expected {encoders.keys()}" + ) + + self.decoder = decoder + self.encoders = nn.ModuleDict(encoders) + self.encoder_tokens = encoder_tokens + self.encoders_trainable = ( + {k: encoders_trainable for k in self.encoders.keys()} + if isinstance(encoders_trainable, bool) + else encoders_trainable + ) + + # A little surgery in the decoder to give the + # fusion module access to control the embeddings + # The alternative is to pass a special tok_embeddings + # module into TransformerDecoder builder that does the + # merging there + self.tok_embeddings = decoder.tok_embeddings + decoder.tok_embeddings = nn.Identity() + + self._register_state_dict_hook(self._state_dict_hook) + self.register_load_state_dict_pre_hook(self._load_state_dict_hook) + + trainable_params = set() + for encoder, trainable in self.encoders_trainable.items(): + if trainable: + trainable_params |= { + f"encoders.{encoder}.{n}" + for n, p in self.encoders[encoder].named_parameters() + } + if decoder_trainable: + trainable_params |= { + f"decoder.{n}" for n, p in self.decoder.named_parameters() + } + trainable_params |= { + f"tok_embeddings.{n}" for n, p in self.tok_embeddings.named_parameters() + } + if fusion_trainable: + trainable_params |= set(get_fusion_params(self)) + else: + trainable_params -= set(get_fusion_params(self)) + + set_trainable_params(self, trainable_params) + + @staticmethod + def _state_dict_hook(module, state_dict, *args, **kwargs): + """ + Keep tok_embeddings inside of decoder state_dict + + [!Note] This update changes the order of the OrderedDict + """ + for n, p in module.tok_embeddings.named_parameters(): + state_dict[f"decoder.tok_embeddings.{n}"] = p + del state_dict[f"tok_embeddings.{n}"] + + @staticmethod + def _load_state_dict_hook(module, state_dict, *args, **kwargs): + """Undo the change from _state_dict_hook""" + old_keys = list(state_dict.keys()) + for key in old_keys: + if key.startswith("decoder.tok_embeddings"): + state_dict[key[len("decoder.") :]] = state_dict[key] + del state_dict[key] + + def set_num_output_chunks(self, num_output_chunks: int) -> None: + """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. + This should be called before the first forward pass, in the recipe.""" + self.decoder.set_num_output_chunks(num_output_chunks) + + def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: + """Setup key value caches for attention calculation. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + """ + self.decoder.setup_caches(batch_size, dtype) + + def caches_are_setup(self) -> bool: + """ + Check if the key value caches are setup. This means ``setup_caches`` has been called, and + the relevant attention modules in the model have created their ``KVCache``. + """ + return self.decoder.caches_are_setup() + + def caches_are_enabled(self) -> bool: + """ + Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant + attention modules will be "enabled" and all forward passes will update the caches. This behaviour + can be disabled without altering the state of the KV-caches by "disabling" the KV-caches + using :func:`~torchtune.modules.common_utils.disable_kv_cache`, upon which ``caches_are_enabled`` would return False. + """ + return self.decoder.caches_are_enabled() + + def reset_caches(self): + """Reset the key value caches.""" + self.decoder.reset_caches() + + def _decoder_embed(self, tokens) -> Tuple[torch.Tensor, torch.Tensor]: + """Embed the text-only tokens with the decoder's tok_embeddings""" + encoder_token_ids = torch.tensor(list(self.encoder_tokens.values())) + # [bsz, seq_len], True indicates the token is not an encoder special token + is_text = ~torch.isin(tokens, encoder_token_ids) + text_tokens = torch.masked_select(tokens, is_text) + # [num_text, embed_dim] + text_embeds = self.tok_embeddings(text_tokens) + return is_text, text_embeds + + def forward( + self, + tokens: torch.Tensor, + *, + mask: Optional[torch.Tensor] = None, + encoder_input: Optional[Dict[str, Dict[str, Any]]] = None, + input_pos: Optional[torch.Tensor] = None, + **kwargs: Dict[str, Any], # no need for encoder_mask + ) -> torch.Tensor: + """ + Note: This module assumes that there will be enough encoder inputs (i.e., total number of images in the batch) + for the number of encoder tokens in the batch. + + Args: + tokens (torch.Tensor): input tensor with shape ``[b x s]`` + mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask + with shape ``[b x s x s]``. This is applied after the query-key multiplication and + before the softmax. A value of True in row i and column j means token i attends + to token j. A value of False means token i does not attend to token j. If no + mask is specified, a causal mask is used by default. Default is None. + encoder_input (Optional[Dict[str, Dict[str, Any]]]): Optional input kwargs for the encoders. Must be + keyed by encoder name and match the keys of ``encoders`` + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape ``[b x s]``. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + **kwargs (Dict[str, Any]): additional keyword arguments. This is solely used to match the + :class:`~torchtune.modules.TransformerDecoder` forward and does not have any effect. + + Note: At the very first step of inference, when the model is provided with a prompt, + ``input_pos`` would contain the positions of all of the tokens in the prompt + (eg: ``torch.arange(prompt_length)``). This is because we will need to compute the + KV values for each position. + + Returns: + torch.Tensor: output tensor with shape ``[b x s x v]`` or a list of layer \ + output tensors defined by ``output_hidden_states`` with the \ + final output tensor appended to the list. + + Raises: + ValueError: if ``encoder_input`` keys do not match ``encoders`` keys + + Notation used for tensor shapes: + - b: batch size + - s: token sequence length + - s_e: encoder sequence length + - v: vocab size + - d: token embed dim + - d_e: encoder embed dim + - m_s: max seq len + """ + if encoder_input is not None and encoder_input.keys() != self.encoders.keys(): + raise ValueError( + f"Found mismatched keys in encoder_input and instantiated encoders. " + f"Got {encoder_input.keys()}, expected {self.encoders.keys()}." + ) + + bsz, seq_len = tokens.shape + # is_text: [bsz, seq_len], text_embeds: [num_text, embed_dim] + is_text, text_embeds = self._decoder_embed(tokens) + embed_dim = text_embeds.shape[-1] + + # Holds the final embedding vector + fused_embeds = torch.empty( + bsz, seq_len, embed_dim, dtype=text_embeds.dtype, device=text_embeds.device + ) + # Place the text-only embeddings + fused_embeds = fused_embeds.masked_scatter(is_text.unsqueeze(-1), text_embeds) + + encoder_input = encoder_input or {} + for encoder, inp in encoder_input.items(): + # [bsz, num_encoder_tokens, embed_dim] + encoder_embeds = self.encoders[encoder](**inp) + # [bsz * num_encoder_tokens, embed_dim] + encoder_embeds = encoder_embeds.view(-1, embed_dim) + # [bsz, seq_len, 1] + encoder_mask = (tokens == self.encoder_tokens[encoder]).unsqueeze(-1) + # At locations where encoder token is found, replace with encoder embedding + fused_embeds = fused_embeds.masked_scatter(encoder_mask, encoder_embeds) + + output = self.decoder(fused_embeds, mask=mask, input_pos=input_pos) + return output diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py deleted file mode 100644 index a4056329e8..0000000000 --- a/torchtune/modules/model_fusion/_fusion.py +++ /dev/null @@ -1,793 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -from torch import nn -from torchtune.modules import TransformerDecoder -from torchtune.modules.model_fusion._fusion_utils import get_fusion_params -from torchtune.modules.peft._utils import set_trainable_params - - -class FusionLayer(nn.Module): - """Fusion layer as introduced in `Flamingo: a Visual Language Model for Few-Shot Learning `_. - - Deep Fusion model architectures combine pretrained encoder models with pretrained - language models by infusing the encoder outputs into the middle layers of the LLM. - This allows the language model to interpret the enocder outputs as text and - "understand" any modality for which you can train an encoder. To enable the language model - to adapt to the encoder outputs, the FusionLayer fuses a new learnable layer to an existing - decoder (language model) layer. This additional layer can take the encoder embeddings and - learn to combine them with the token embeddings from the decoder. The module supports fusing - the new layer before or after the original, in Flamingo the new layer is fused before the original. - - The original layer is wrapped in FusionLayer such that it maintains its original state_dict - key and the pre-trained checkpoint isn't broken. The new layer parameters are available - through ``fusion_params`` to separately control if they're trainable or not. - - Example: - >>> # Original decoder style transformer - >>> layer = nn.TransformerSelfAttentionLayer(...) - >>> model = TransformerDecoder(layers=layer, num_layers=32, ...) - >>> - >>> # Fuse a cross attention layer to each self attention layer to adapt for the encoder - >>> fusion_layer = nn.TransformerCrossAttentionLayer(...) - >>> fused_layer = FusionLayer(layer, fusion_layer) - >>> model = TransformerDecoder(layers=fused_layer, num_layers=32, ...) - >>> - >>> # Original decoder state_dict still works - >>> model.load_state_dict(..., strict=False) - - Args: - layer (nn.Module): original decoder layer - fusion_layer (nn.Module): new fusion layer - fusion_first (bool): boolean to insert fusion layer before or after the decoder layer. - """ - - def __init__( - self, layer: nn.Module, fusion_layer: nn.Module, fusion_first: bool = True - ): - super().__init__() - self.layer = layer - self.fusion_layer = fusion_layer - self.fusion_first = fusion_first - - # Keep FusionLayer wrappings out of the state_dict - self._register_state_dict_hook(FusionLayer._state_dict_hook) - self._register_load_state_dict_pre_hook( - FusionLayer._load_state_dict_hook, with_module=True - ) - # TODO: Switch to register_load_state_dict_pre_hook and - # register_state_dict_pre_hook after PyTorch v2.5 - - def _state_dict_hook(self, state_dict, prefix, *args, **kwargs): - """Remove "layer" from the original layer in the state_dict - name. This keeps the orginal state dict name for the layer - from before fusing with the FusionLayer. - - [!Note] This update changes the order of the OrderedDict - """ - keys = list(state_dict.keys()) - for key in keys: - local_key = key[len(prefix) :] - if local_key.startswith("layer"): - new_key = prefix + local_key.replace("layer.", "") - state_dict[new_key] = state_dict[key] - del state_dict[key] - - def _load_state_dict_hook(self, state_dict, prefix, *args, **kwargs): - """Apply extra "layer" prefix to the state_dict key to - account for the FusionLayer wrapping. - """ - keys = list(state_dict.keys()) - for key in keys: - local_key = key[len(prefix) :] - if not local_key.startswith("fusion_layer"): - new_key = prefix + "layer." + local_key - state_dict[new_key] = state_dict[key] - del state_dict[key] - - def setup_caches( - self, - batch_size: int, - dtype: torch.dtype, - *, - encoder_max_seq_len: int, - decoder_max_seq_len: int, - ) -> None: - """Setup key value cache for both layers. - - Args: - batch_size (int): batch size for the caches. - dtype (torch.dtype): dtype for the caches. - encoder_max_seq_len (int): maximum cache sequence length for cross-attention layer. - decoder_max_seq_len (int): maximum cache sequence length for self-attention layer. - """ - self.layer.setup_caches( - batch_size, - dtype, - encoder_max_seq_len=encoder_max_seq_len, - decoder_max_seq_len=decoder_max_seq_len, - ) - - self.fusion_layer.setup_caches( - batch_size, - dtype, - encoder_max_seq_len=encoder_max_seq_len, - decoder_max_seq_len=decoder_max_seq_len, - ) - - def caches_are_setup(self) -> bool: - """ - Check if the key value caches are setup on ``self.layer``. - See :func:~torchtune.modules.TransformerDecoder.caches_are_setup`. - """ - return self.layer.caches_are_setup() - - def caches_are_enabled(self) -> bool: - """ - Checks if the key value caches on ``self.layer`` are enabled. - See :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`. - """ - return self.layer.caches_are_enabled() - - def reset_cache(self): - """Reset both layers' key value caches.""" - self.layer.reset_cache() - self.fusion_layer.reset_cache() - - def fusion_params(self) -> List[str]: - """ - Return parameters of fusion layer. - """ - fusion_params = [ - f"fusion_layer.{k}" for k, v in self.fusion_layer.named_parameters() - ] - return fusion_params - - def forward(self, x: torch.Tensor, **kwargs: Dict) -> torch.Tensor: - """ - Args: - x (torch.Tensor): input tensor with shape - [batch_size x seq_length x embed_dim] - **kwargs (Dict): all additional layer args - - Returns: - Tensor: output tensor with same shape as input - [batch_size x seq_length x embed_dim]` - - """ - if self.fusion_first: - x = self.fusion_layer(x, **kwargs) - x = self.layer(x, **kwargs) - else: - x = self.layer(x, **kwargs) - x = self.fusion_layer(x, **kwargs) - return x - - -class FusionEmbedding(nn.Module): - """Fusion embedding supports training additional special tokens while keeping - the original embedding frozen. When fusing new models with a language model, - there may be some additional tokens needed to support the fused language model. For - example, adding a vision encoder might necessitate additional tokens like ``<|image|>`` - to indicate an images position in text and require learning an embedding for this token. - The FusionEmbedding keeps the original embeddings frozen while learning a much smaller - second embedding for the additional tokens. During forward this module routes - the tokens to the appropriate embedding table. - - Use this as a drop-in replacement for :class:`torch.nn.Embedding` in your model. - - Example: - >>> embedding = FusionEmbedding(vocab_size=100, fusion_vocab_size=10, embed_dim=128) - >>> model = TransformerDecoder(tok_embeddings=embedding, ...) - >>> - >>> # Original model state_dict still works - >>> model.load_state_dict(..., strict=False) - - .. note:: - This module assumes all tokens in the range [0, vocab_size) are part of the - original embedding table and all new tokens in the range - [vocab_size, vocab_size + fusion_vocab_size) - - Args: - vocab_size (int): language model vocab size - fusion_vocab_size (int): additional tokens for the fused model - embed_dim (int): embedding dimension of the two embedding tables - """ - - def __init__(self, vocab_size: int, fusion_vocab_size: int, embed_dim: int) -> None: - super().__init__() - self.embedding = nn.Embedding(vocab_size, embed_dim) - self.fusion_embedding = nn.Embedding(fusion_vocab_size, embed_dim) - self.dim = embed_dim - self.num_embeddings = vocab_size + fusion_vocab_size - # TODO: Support merging the embeddings after finetuning - - # Keep FusionLayer wrappings out of the state_dict - self._register_state_dict_hook(FusionEmbedding._state_dict_hook) - self._register_load_state_dict_pre_hook( - FusionEmbedding._load_state_dict_hook, with_module=True - ) - # TODO: Switch to register_load_state_dict_pre_hook and - # register_state_dict_pre_hook after PyTorch v2.5 - - def _state_dict_hook(self, destination, prefix, keep_vars): - """Remove "embedding" from the original embedding in the state_dict - name. This keeps the orginal state dict name for the embedding - from before fusing with the FusionEmbedding. - - [!Note] This update changes the order of the OrderedDict - """ - key = prefix + "embedding.weight" - new_key = prefix + "weight" - destination[new_key] = destination[key] - del destination[key] - - def _load_state_dict_hook(self, state_dict, prefix, *args, **kwargs): - """Apply extra "embedding" prefix to the state_dict key to - account for the FusionEmbedding wrapping. - """ - if state_dict: - key = prefix + "weight" - new_key = prefix + "embedding.weight" - state_dict[new_key] = state_dict[key] - del state_dict[key] - - def fusion_params(self) -> List[str]: - """ - Return fusion embedding parameters. - """ - fusion_params = ["fusion_embedding.weight"] - return fusion_params - - def _fused_embed(self, bs, seq_len): - """ - Return an empty tensor the shape of the combined embedding. - """ - device = self.embedding.weight.device - dtype = self.embedding.weight.dtype - return torch.empty(bs, seq_len, self.dim, device=device, dtype=dtype) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - Args: - input (torch.Tensor): input integer tensor with shape - [batch_size x seq_length] - - Returns: - Tensor: output tensor embedding with shape - [batch_size x seq_length x embed_dim]` - - """ - bs, seq_len = input.size() - vocab_size = self.embedding.num_embeddings - - mask = input < vocab_size - # num_tokens = (input < vocab_size).sum() - tokens = torch.masked_select(input, mask) - # num_fusion_tokens = (input >= vocab_size).sum() - fusion_tokens = torch.masked_select(input, ~mask) - vocab_size - - # [batch_size * num_tokens, embed_dim] - embeds = self.embedding(tokens) - # [batch_size * num_fusion_tokens, embed_dim] - fusion_embeds = self.fusion_embedding(fusion_tokens) - - # [batch_size x seq_length x embed_dim] - out = self._fused_embed(bs, seq_len) - mask = mask.unsqueeze(-1).expand(bs, seq_len, self.dim) - out = out.masked_scatter(mask, embeds) - out = out.masked_scatter(~mask, fusion_embeds) - return out - - -class DeepFusionModel(nn.Module): - """DeepFusion is a type of fused model architecture where a pretrained encoder is combined - with a pretrained decoder (LLM) in the internal decoder layers. This is a popular architecture for multimodal models, with - a full overview available in `The Evolution of Multimodal Model Architectures `_. - A common deep fusion architecture is to fuse the encoder input into the decoder with interspersed cross-attention - layers. This module makes no assumptions on how the encoder and decoder are fused; it simply - passes in the encoder embeddings to the decoder and lets the decoder handle any fusion. - - This module has the same methods and forward signature as :class:`~torchtune.modules.TransformerDecoder` and can be used - interchangeably where :class:`~torchtune.modules.TransformerDecoder` is. It combines the encoder with the decoder as a - single module for checkpointing and finetuning. It is expected that the encoder and decoder - are already defined with any extra learnable ``fusion_params``: learnable parameters to help - adapt the pre-trained encoder to the pre-trained decoder. - - DeepFusionModel currently only supports a single encoder. - - Example: - >>> # decoder is a TransformerDecoder (e.g. llama3_8b) with fused cross attention layers - >>> embed = FusionEmbedding(...) - >>> layer = FusionLayer( - ... layer=TransformerSelfAttentionLayer(...), - ... fusion_layer=TransformerCrossAttentionLayer(...), - ... ) - >>> decoder = TransformerDecoder(tok_embeddings=embed, layers=layer, num_layers=32, ...) - >>> - >>> # encoder is pre-trained encoder (e.g. clip_vit_224) with an added projection head - >>> projection_head = FeedForward(...) - >>> register_fusion_module(projection_head)) - >>> encoders = {"image": nn.Sequential(clip_vit_224(), projection_head)} - >>> - >>> # DeepFusionModel combines the encoder and decoder - >>> model = DeepFusionModel(decoder, encoders) - >>> - >>> # Load full fused checkpoints (e.g. a Flamingo checkpoint) - >>> model.load_state_dict(...) - >>> - >>> # Or load pretrained individual models (fusion_params are not loaded) - >>> model.encoder.load_state_dict(..., strict=False) - >>> model.decoder.load_state_dict(..., strict=False) - >>> - >>> # Forward pass - >>> encoder_input = {"image": {...}} - >>> output = model(tokens, mask, encoder_input, encoder_mask, input_pos) - - Args: - decoder (TransformerDecoder): decoder module - encoders (Dict[str, nn.Module]): dictionary mapping encoder name as a string to the encoder module. - decoder_trainable (bool): whether to train or freeze the decoder. Default is False. - encoders_trainable (Union[bool, Dict[str, bool]]): whether to train or freeze the encoder. Use a single - boolean to set trainable for all encoders or a dictionary keyed by encoder names to specify trainable - for each encoder individually. Encoder names should match with ``encoders``. Default is False. - fusion_trainable (bool): whether to train the fusion parameters. Default is True. - - Raises: - ValueError: if ``encoders`` and ``encoders_trainable`` keys do not match - ValueError: if ``len(encoders) != 1`` - """ - - def __init__( - self, - decoder: TransformerDecoder, - encoders: Dict[str, nn.Module], - *, - decoder_trainable: bool = False, - encoders_trainable: Union[bool, Dict[str, bool]] = False, - fusion_trainable: bool = True, - ): - super().__init__() - if ( - not isinstance(encoders_trainable, bool) - and encoders.keys() != encoders_trainable.keys() - ): - raise ValueError( - f"Found mismatched keys in encoders and encoders_trainable. Got {encoders.keys()} and {encoders_trainable.keys()}." - ) - # Currently, only a single encoder is supported, so user can only - # pass in a single key. When multiple encoders are - # supported, this can be removed. - if len(encoders.keys()) != 1: - raise ValueError( - f"DeepFusionModel only supports a single encoder. Got {len(encoders.keys())} encoders." - ) - - self.decoder = decoder - self.encoders = nn.ModuleDict(encoders) - self.encoders_trainable = ( - {k: encoders_trainable for k in self.encoders.keys()} - if isinstance(encoders_trainable, bool) - else encoders_trainable - ) - - trainable_params = set() - for encoder, trainable in self.encoders_trainable.items(): - if trainable: - trainable_params |= { - f"encoders.{encoder}.{n}" - for n, p in self.encoders[encoder].named_parameters() - } - if decoder_trainable: - trainable_params |= { - f"decoder.{n}" for n, p in self.decoder.named_parameters() - } - if fusion_trainable: - trainable_params |= set(get_fusion_params(self)) - else: - trainable_params -= set(get_fusion_params(self)) - set_trainable_params(self, trainable_params) - - def set_num_output_chunks(self, num_output_chunks: int) -> None: - """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. - This should be called before the first forward pass, in the recipe.""" - self.decoder.set_num_output_chunks(num_output_chunks) - - def setup_caches( - self, - batch_size: int, - dtype: torch.dtype, - *, - encoder_max_seq_len: Optional[int] = None, - decoder_max_seq_len: Optional[int] = None, - ): - """ - Sets up key-value attention caches for inference for ``self.decoder``. - For each layer in ``self.decoder.layers``: - - :class:`torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``. - - :class:`torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``. - - :class:`torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``. - - Args: - batch_size (int): batch size for the caches. - dtype (torch.dtype): dtype for the caches. - encoder_max_seq_len (Optional[int]): maximum encoder cache sequence length. - decoder_max_seq_len (Optional[int]): maximum decoder cache sequence length. - """ - self.decoder.setup_caches( - batch_size, - dtype, - encoder_max_seq_len=encoder_max_seq_len, - decoder_max_seq_len=decoder_max_seq_len, - ) - - def caches_are_setup(self) -> bool: - """ - Check if the key value caches are setup. This means ``setup_caches`` has been called, and - the relevant attention modules in the model have created their ``KVCache``. - """ - return self.decoder.caches_are_setup() - - def caches_are_enabled(self) -> bool: - """ - Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant - attention modules will be "enabled" and all forward passes will update the caches. This behaviour - can be disabled without altering the state of the KV-caches by "disabling" the KV-caches - using :func:`~torchtune.modules.common_utils.disable_kv_cache`, upon which ``caches_are_enabled`` would return False. - """ - return self.decoder.caches_are_enabled() - - def reset_caches(self): - """ - Resets KV-cache buffers on relevant attention modules to zero, and reset cache positions to zero, - without deleting or reallocating cache tensors. - """ - self.decoder.reset_caches() - - def forward( - self, - tokens: torch.Tensor, - *, - mask: Optional[torch.Tensor] = None, - encoder_input: Optional[Dict[str, Dict[str, Any]]] = None, - encoder_mask: Optional[torch.Tensor] = None, - input_pos: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """ - Args: - tokens (torch.Tensor): input tensor with shape ``[b x s]`` - mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask - with shape ``[b x s x s]``. This is applied after the query-key multiplication and - before the softmax. A value of True in row i and column j means token i attends - to token j. A value of False means token i does not attend to token j. If no - mask is specified, a causal mask is used by default. Default is None. - encoder_input (Optional[Dict[str, Dict[str, Any]]]): Optional input kwargs for the encoders. Must be - keyed by encoder name and match the keys of ``encoders`` - encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between - tokens and encoder embeddings. A True value at position i,j means token i can attend - to embedding j in the decoder. Mask has shape ``[b x s x s_e]``. Default is None. - input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids - of each token. During training, this is used to indicate the positions - of each token relative to its sample when packed, shape ``[b x s]``. - During inference, this indicates the position of the current token. - If none, assume the index of the token is its position id. Default is None. - - Note: At the very first step of inference, when the model is provided with a prompt, - ``input_pos`` would contain the positions of all of the tokens in the prompt - (eg: ``torch.arange(prompt_length)``). This is because we will need to compute the - KV values for each position. - - Returns: - Tensor: output tensor with shape ``[b x s x v]`` or a list of layer \ - output tensors defined by ``output_hidden_states`` with the \ - final output tensor appended to the list. - - Raises: - ValueError: if ``encoder_input`` keys do not match ``encoders`` keys - - Notation used for tensor shapes: - - b: batch size - - s: token sequence length - - s_e: encoder sequence length - - v: vocab size - - d: token embed dim - - d_e: encoder embed dim - - m_s: max seq len - """ - if encoder_input is not None and encoder_input.keys() != self.encoders.keys(): - raise ValueError( - f"Found mismatched keys in encoder_input and instantiated encoders. " - f"Got {encoder_input.keys()}, expected {self.encoders.keys()}." - ) - # During decoding, encoder_input will only be provided - # for new inputs. Previous encoder outputs are cached - # in the decoder cache. - encoder_embed = None - if encoder_input is not None: - encoder_embed = { - encoder: self.encoders[encoder](**encoder_input[encoder]) - for encoder in encoder_input - } - - # Currently, only a single encoder is supported, so we need - # to get the encoder key manually. When multiple encoders are - # supported, this can be removed. - decoder_encoder_input = ( - list(encoder_embed.values())[0] if encoder_embed is not None else None - ) - - output = self.decoder( - tokens=tokens, - mask=mask, - encoder_input=decoder_encoder_input, - encoder_mask=encoder_mask, - input_pos=input_pos, - ) - return output - - -class EarlyFusionModel(nn.Module): - """EarlyFusion is a type of fused model architecture where pretrained encoder(s) are combined - with a pretrained decoder (LLM) at the model input and not in internal layers. This is a popular architecture - for multimodal models, with a full overview available in `The Evolution of Multimodal Model Architectures - `_. This module works both for decoders in which the encoder tokens are - inside the vocab and outside the vocab. - - This module has the same methods and forward signature as :class:`~torchtune.modules.TransformerDecoder` and can be used - interchangeably where :class:`~torchtune.modules.TransformerDecoder` is. It combines the encoders with the decoder as a - single module for checkpointing and finetuning. It is expected that the encoders and decoder - are already defined with any extra learnable ``fusion_params``: learnable parameters to help - adapt the pre-trained encoders to the pre-trained decoder. - - You can pass in multiple encoders as a dictionary into ``encoders``. - - Note: Once the decoder is wrapped in this module, the decoder's ``tok_embeddings`` module is moved - to the parent EarlyFusionModel's ``tok_embeddings``. You should not forward pass the decoder individually. - Instead, use EarlyFusionModel's forward pass with ``encoder_input=None`` to get decoder-only outputs. - State dicts will automatically be updated on save and load to account for this change. - - Example: - >>> # decoder is a text-only TransformerDecoder (e.g. llama3_8b) with no modifications - >>> decoder = llama3_8b() - >>> - >>> # encoder is pre-trained encoder (e.g. clip_vit_224) with an added projection head - >>> projection_head = FeedForward(...) - >>> register_fusion_module(projection_head)) - >>> encoders = {"image": nn.Sequential(clip_vit_224(), projection_head)} - >>> - >>> # EarlyFusionModel combines the encoder and decoder - >>> model = EarlyFusionModel(decoder, encoders, encoder_tokens={"image": 128256}) - >>> - >>> # Load full fused checkpoints - >>> model.load_state_dict(...) - >>> - >>> # Forward pass - >>> encoder_input = {"image": {...}} - >>> output = model(tokens, mask=mask, encoder_input=encoder_input, encoder_mask=encoder_mask, input_pos=input_pos) - >>> - >>> # Forward pass decoder only - >>> output = model(tokens, mask=mask, input_pos=input_pos) - - Args: - decoder (TransformerDecoder): decoder module - encoders (Dict[str, nn.Module]): dictionary mapping encoder name as a string to the encoder module. - encoder_tokens (Dict[str, int]): dictionary mapping encoder name to special token ID indicating where - in the text sequence the encoder embedding outputs should be injected. - decoder_trainable (bool): whether to train or freeze the decoder. Default is False. - encoders_trainable (Union[bool, Dict[str, bool]]): whether to train or freeze the encoder. Use a single - boolean to set trainable for all encoders or a dictionary keyed by encoder names to specify trainable - for each encoder individually. Encoder names should match with ``encoders``. Default is False. - fusion_trainable (bool): whether to train the fusion parameters. Default is True. - - Raises: - ValueError: if ``encoders`` and ``encoders_trainable`` keys do not match - """ - - def __init__( - self, - decoder: TransformerDecoder, - encoders: Dict[str, nn.Module], - encoder_tokens: Dict[str, int], - decoder_trainable: bool = False, - encoders_trainable: Union[bool, Dict[str, bool]] = False, - fusion_trainable: bool = True, - ): - super().__init__() - if encoders.keys() != encoder_tokens.keys() or ( - not isinstance(encoders_trainable, bool) - and encoders.keys() != encoders_trainable.keys() - ): - raise ValueError( - f"Found mismatched keys in encoders, encoder_tokens, and/or encoders_trainable. Expected {encoders.keys()}" - ) - - self.decoder = decoder - self.encoders = nn.ModuleDict(encoders) - self.encoder_tokens = encoder_tokens - self.encoders_trainable = ( - {k: encoders_trainable for k in self.encoders.keys()} - if isinstance(encoders_trainable, bool) - else encoders_trainable - ) - - # A little surgery in the decoder to give the - # fusion module access to control the embeddings - # The alternative is to pass a special tok_embeddings - # module into TransformerDecoder builder that does the - # merging there - self.tok_embeddings = decoder.tok_embeddings - decoder.tok_embeddings = nn.Identity() - - self._register_state_dict_hook(self._state_dict_hook) - self.register_load_state_dict_pre_hook(self._load_state_dict_hook) - - trainable_params = set() - for encoder, trainable in self.encoders_trainable.items(): - if trainable: - trainable_params |= { - f"encoders.{encoder}.{n}" - for n, p in self.encoders[encoder].named_parameters() - } - if decoder_trainable: - trainable_params |= { - f"decoder.{n}" for n, p in self.decoder.named_parameters() - } - trainable_params |= { - f"tok_embeddings.{n}" for n, p in self.tok_embeddings.named_parameters() - } - if fusion_trainable: - trainable_params |= set(get_fusion_params(self)) - else: - trainable_params -= set(get_fusion_params(self)) - - set_trainable_params(self, trainable_params) - - @staticmethod - def _state_dict_hook(module, state_dict, *args, **kwargs): - """ - Keep tok_embeddings inside of decoder state_dict - - [!Note] This update changes the order of the OrderedDict - """ - for n, p in module.tok_embeddings.named_parameters(): - state_dict[f"decoder.tok_embeddings.{n}"] = p - del state_dict[f"tok_embeddings.{n}"] - - @staticmethod - def _load_state_dict_hook(module, state_dict, *args, **kwargs): - """Undo the change from _state_dict_hook""" - old_keys = list(state_dict.keys()) - for key in old_keys: - if key.startswith("decoder.tok_embeddings"): - state_dict[key[len("decoder.") :]] = state_dict[key] - del state_dict[key] - - def set_num_output_chunks(self, num_output_chunks: int) -> None: - """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. - This should be called before the first forward pass, in the recipe.""" - self.decoder.set_num_output_chunks(num_output_chunks) - - def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: - """Setup key value caches for attention calculation. - - Args: - batch_size (int): batch size for the caches. - dtype (torch.dtype): dtype for the caches. - """ - self.decoder.setup_caches(batch_size, dtype) - - def caches_are_setup(self) -> bool: - """ - Check if the key value caches are setup. This means ``setup_caches`` has been called, and - the relevant attention modules in the model have created their ``KVCache``. - """ - return self.decoder.caches_are_setup() - - def caches_are_enabled(self) -> bool: - """ - Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant - attention modules will be "enabled" and all forward passes will update the caches. This behaviour - can be disabled without altering the state of the KV-caches by "disabling" the KV-caches - using :func:`~torchtune.modules.common_utils.disable_kv_cache`, upon which ``caches_are_enabled`` would return False. - """ - return self.decoder.caches_are_enabled() - - def reset_caches(self): - """Reset the key value caches.""" - self.decoder.reset_caches() - - def _decoder_embed(self, tokens) -> Tuple[torch.Tensor, torch.Tensor]: - """Embed the text-only tokens with the decoder's tok_embeddings""" - encoder_token_ids = torch.tensor(list(self.encoder_tokens.values())) - # [bsz, seq_len], True indicates the token is not an encoder special token - is_text = ~torch.isin(tokens, encoder_token_ids) - text_tokens = torch.masked_select(tokens, is_text) - # [num_text, embed_dim] - text_embeds = self.tok_embeddings(text_tokens) - return is_text, text_embeds - - def forward( - self, - tokens: torch.Tensor, - *, - mask: Optional[torch.Tensor] = None, - encoder_input: Optional[Dict[str, Dict[str, Any]]] = None, - input_pos: Optional[torch.Tensor] = None, - **kwargs: Dict[str, Any], # no need for encoder_mask - ) -> torch.Tensor: - """ - Args: - tokens (torch.Tensor): input tensor with shape ``[b x s]`` - mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask - with shape ``[b x s x s]``. This is applied after the query-key multiplication and - before the softmax. A value of True in row i and column j means token i attends - to token j. A value of False means token i does not attend to token j. If no - mask is specified, a causal mask is used by default. Default is None. - encoder_input (Optional[Dict[str, Dict[str, Any]]]): Optional input kwargs for the encoders. Must be - keyed by encoder name and match the keys of ``encoders`` - input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids - of each token. During training, this is used to indicate the positions - of each token relative to its sample when packed, shape ``[b x s]``. - During inference, this indicates the position of the current token. - If none, assume the index of the token is its position id. Default is None. - **kwargs (Dict[str, Any]): additional keyword arguments. This is solely used to match the - :class:`~torchtune.modules.TransformerDecoder` forward and does not have any effect. - - Note: At the very first step of inference, when the model is provided with a prompt, - ``input_pos`` would contain the positions of all of the tokens in the prompt - (eg: ``torch.arange(prompt_length)``). This is because we will need to compute the - KV values for each position. - - Returns: - torch.Tensor: output tensor with shape ``[b x s x v]`` or a list of layer \ - output tensors defined by ``output_hidden_states`` with the \ - final output tensor appended to the list. - - Raises: - ValueError: if ``encoder_input`` keys do not match ``encoders`` keys - - Notation used for tensor shapes: - - b: batch size - - s: token sequence length - - s_e: encoder sequence length - - v: vocab size - - d: token embed dim - - d_e: encoder embed dim - - m_s: max seq len - """ - if encoder_input is not None and encoder_input.keys() != self.encoders.keys(): - raise ValueError( - f"Found mismatched keys in encoder_input and instantiated encoders. " - f"Got {encoder_input.keys()}, expected {self.encoders.keys()}." - ) - - bsz, seq_len = tokens.shape - # is_text: [bsz, seq_len], text_embeds: [num_text, embed_dim] - is_text, text_embeds = self._decoder_embed(tokens) - embed_dim = text_embeds.shape[-1] - - # Holds the final embedding vector - fused_embeds = torch.empty( - bsz, seq_len, embed_dim, dtype=text_embeds.dtype, device=text_embeds.device - ) - # Place the text-only embeddings - fused_embeds = fused_embeds.masked_scatter(is_text.unsqueeze(-1), text_embeds) - - for encoder, inp in (encoder_input or {}).items(): - # [bsz, num_encoder_tokens, embed_dim] - encoder_embeds = self.encoders[encoder](**inp) - # [bsz * num_encoder_tokens, embed_dim] - encoder_embeds = encoder_embeds.view(-1, embed_dim) - # [bsz, seq_len, 1] - encoder_mask = (tokens == self.encoder_tokens[encoder]).unsqueeze(-1) - # At locations where encoder token is found, replace with encoder embedding - fused_embeds = fused_embeds.masked_scatter(encoder_mask, encoder_embeds) - - output = self.decoder(fused_embeds, mask=mask, input_pos=input_pos) - return output diff --git a/torchtune/modules/model_fusion/_fusion_layers.py b/torchtune/modules/model_fusion/_fusion_layers.py new file mode 100644 index 0000000000..7fe3939ec4 --- /dev/null +++ b/torchtune/modules/model_fusion/_fusion_layers.py @@ -0,0 +1,283 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List + +import torch +from torch import nn + + +class FusionLayer(nn.Module): + """Fusion layer as introduced in `Flamingo: a Visual Language Model for Few-Shot Learning `_. + + Deep Fusion model architectures combine pretrained encoder models with pretrained + language models by infusing the encoder outputs into the middle layers of the LLM. + This allows the language model to interpret the enocder outputs as text and + "understand" any modality for which you can train an encoder. To enable the language model + to adapt to the encoder outputs, the FusionLayer fuses a new learnable layer to an existing + decoder (language model) layer. This additional layer can take the encoder embeddings and + learn to combine them with the token embeddings from the decoder. The module supports fusing + the new layer before or after the original, in Flamingo the new layer is fused before the original. + + The original layer is wrapped in FusionLayer such that it maintains its original state_dict + key and the pre-trained checkpoint isn't broken. The new layer parameters are available + through ``fusion_params`` to separately control if they're trainable or not. + + Example: + >>> # Original decoder style transformer + >>> layer = nn.TransformerSelfAttentionLayer(...) + >>> model = TransformerDecoder(layers=layer, num_layers=32, ...) + >>> + >>> # Fuse a cross attention layer to each self attention layer to adapt for the encoder + >>> fusion_layer = nn.TransformerCrossAttentionLayer(...) + >>> fused_layer = FusionLayer(layer, fusion_layer) + >>> model = TransformerDecoder(layers=fused_layer, num_layers=32, ...) + >>> + >>> # Original decoder state_dict still works + >>> model.load_state_dict(..., strict=False) + + Args: + layer (nn.Module): original decoder layer + fusion_layer (nn.Module): new fusion layer + fusion_first (bool): boolean to insert fusion layer before or after the decoder layer. + """ + + def __init__( + self, layer: nn.Module, fusion_layer: nn.Module, fusion_first: bool = True + ): + super().__init__() + self.layer = layer + self.fusion_layer = fusion_layer + self.fusion_first = fusion_first + + # Keep FusionLayer wrappings out of the state_dict + self._register_state_dict_hook(FusionLayer._state_dict_hook) + self._register_load_state_dict_pre_hook( + FusionLayer._load_state_dict_hook, with_module=True + ) + # TODO: Switch to register_load_state_dict_pre_hook and + # register_state_dict_pre_hook after PyTorch v2.5 + + def _state_dict_hook(self, state_dict, prefix, *args, **kwargs): + """Remove "layer" from the original layer in the state_dict + name. This keeps the orginal state dict name for the layer + from before fusing with the FusionLayer. + + [!Note] This update changes the order of the OrderedDict + """ + keys = list(state_dict.keys()) + for key in keys: + local_key = key[len(prefix) :] + if local_key.startswith("layer"): + new_key = prefix + local_key.replace("layer.", "") + state_dict[new_key] = state_dict[key] + del state_dict[key] + + def _load_state_dict_hook(self, state_dict, prefix, *args, **kwargs): + """Apply extra "layer" prefix to the state_dict key to + account for the FusionLayer wrapping. + """ + keys = list(state_dict.keys()) + for key in keys: + local_key = key[len(prefix) :] + if not local_key.startswith("fusion_layer"): + new_key = prefix + "layer." + local_key + state_dict[new_key] = state_dict[key] + del state_dict[key] + + def setup_caches( + self, + batch_size: int, + dtype: torch.dtype, + *, + encoder_max_seq_len: int, + decoder_max_seq_len: int, + ) -> None: + """Setup key value cache for both layers. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + encoder_max_seq_len (int): maximum cache sequence length for cross-attention layer. + decoder_max_seq_len (int): maximum cache sequence length for self-attention layer. + """ + self.layer.setup_caches( + batch_size, + dtype, + encoder_max_seq_len=encoder_max_seq_len, + decoder_max_seq_len=decoder_max_seq_len, + ) + + self.fusion_layer.setup_caches( + batch_size, + dtype, + encoder_max_seq_len=encoder_max_seq_len, + decoder_max_seq_len=decoder_max_seq_len, + ) + + def caches_are_setup(self) -> bool: + """ + Check if the key value caches are setup on ``self.layer``. + See :func:~torchtune.modules.TransformerDecoder.caches_are_setup`. + """ + return self.layer.caches_are_setup() + + def caches_are_enabled(self) -> bool: + """ + Checks if the key value caches on ``self.layer`` are enabled. + See :func:~torchtune.modules.TransformerDecoder.caches_are_enabled`. + """ + return self.layer.caches_are_enabled() + + def reset_cache(self): + """Reset both layers' key value caches.""" + self.layer.reset_cache() + self.fusion_layer.reset_cache() + + def fusion_params(self) -> List[str]: + """ + Return parameters of fusion layer. + """ + fusion_params = [ + f"fusion_layer.{k}" for k, v in self.fusion_layer.named_parameters() + ] + return fusion_params + + def forward(self, x: torch.Tensor, **kwargs: Dict) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor with shape + [batch_size x seq_length x embed_dim] + **kwargs (Dict): all additional layer args + + Returns: + Tensor: output tensor with same shape as input + [batch_size x seq_length x embed_dim]` + + """ + if self.fusion_first: + x = self.fusion_layer(x, **kwargs) + x = self.layer(x, **kwargs) + else: + x = self.layer(x, **kwargs) + x = self.fusion_layer(x, **kwargs) + return x + + +class FusionEmbedding(nn.Module): + """Fusion embedding supports training additional special tokens while keeping + the original embedding frozen. When fusing new models with a language model, + there may be some additional tokens needed to support the fused language model. For + example, adding a vision encoder might necessitate additional tokens like ``<|image|>`` + to indicate an images position in text and require learning an embedding for this token. + The FusionEmbedding keeps the original embeddings frozen while learning a much smaller + second embedding for the additional tokens. During forward this module routes + the tokens to the appropriate embedding table. + + Use this as a drop-in replacement for :class:`torch.nn.Embedding` in your model. + + Example: + >>> embedding = FusionEmbedding(vocab_size=100, fusion_vocab_size=10, embed_dim=128) + >>> model = TransformerDecoder(tok_embeddings=embedding, ...) + >>> + >>> # Original model state_dict still works + >>> model.load_state_dict(..., strict=False) + + .. note:: + This module assumes all tokens in the range [0, vocab_size) are part of the + original embedding table and all new tokens in the range + [vocab_size, vocab_size + fusion_vocab_size) + + Args: + vocab_size (int): language model vocab size + fusion_vocab_size (int): additional tokens for the fused model + embed_dim (int): embedding dimension of the two embedding tables + """ + + def __init__(self, vocab_size: int, fusion_vocab_size: int, embed_dim: int) -> None: + super().__init__() + self.embedding = nn.Embedding(vocab_size, embed_dim) + self.fusion_embedding = nn.Embedding(fusion_vocab_size, embed_dim) + self.dim = embed_dim + self.num_embeddings = vocab_size + fusion_vocab_size + # TODO: Support merging the embeddings after finetuning + + # Keep FusionLayer wrappings out of the state_dict + self._register_state_dict_hook(FusionEmbedding._state_dict_hook) + self._register_load_state_dict_pre_hook( + FusionEmbedding._load_state_dict_hook, with_module=True + ) + # TODO: Switch to register_load_state_dict_pre_hook and + # register_state_dict_pre_hook after PyTorch v2.5 + + def _state_dict_hook(self, destination, prefix, keep_vars): + """Remove "embedding" from the original embedding in the state_dict + name. This keeps the orginal state dict name for the embedding + from before fusing with the FusionEmbedding. + + [!Note] This update changes the order of the OrderedDict + """ + key = prefix + "embedding.weight" + new_key = prefix + "weight" + destination[new_key] = destination[key] + del destination[key] + + def _load_state_dict_hook(self, state_dict, prefix, *args, **kwargs): + """Apply extra "embedding" prefix to the state_dict key to + account for the FusionEmbedding wrapping. + """ + if state_dict: + key = prefix + "weight" + new_key = prefix + "embedding.weight" + state_dict[new_key] = state_dict[key] + del state_dict[key] + + def fusion_params(self) -> List[str]: + """ + Return fusion embedding parameters. + """ + fusion_params = ["fusion_embedding.weight"] + return fusion_params + + def _fused_embed(self, bs, seq_len): + """ + Return an empty tensor the shape of the combined embedding. + """ + device = self.embedding.weight.device + dtype = self.embedding.weight.dtype + return torch.empty(bs, seq_len, self.dim, device=device, dtype=dtype) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """ + Args: + input (torch.Tensor): input integer tensor with shape + [batch_size x seq_length] + + Returns: + Tensor: output tensor embedding with shape + [batch_size x seq_length x embed_dim]` + + """ + bs, seq_len = input.size() + vocab_size = self.embedding.num_embeddings + + mask = input < vocab_size + # num_tokens = (input < vocab_size).sum() + tokens = torch.masked_select(input, mask) + # num_fusion_tokens = (input >= vocab_size).sum() + fusion_tokens = torch.masked_select(input, ~mask) - vocab_size + + # [batch_size * num_tokens, embed_dim] + embeds = self.embedding(tokens) + # [batch_size * num_fusion_tokens, embed_dim] + fusion_embeds = self.fusion_embedding(fusion_tokens) + + # [batch_size x seq_length x embed_dim] + out = self._fused_embed(bs, seq_len) + mask = mask.unsqueeze(-1).expand(bs, seq_len, self.dim) + out = out.masked_scatter(mask, embeds) + out = out.masked_scatter(~mask, fusion_embeds) + return out diff --git a/torchtune/modules/model_fusion/_fusion_utils.py b/torchtune/modules/model_fusion/_fusion_utils.py index c22cc03549..e10bfcb3e5 100644 --- a/torchtune/modules/model_fusion/_fusion_utils.py +++ b/torchtune/modules/model_fusion/_fusion_utils.py @@ -65,5 +65,5 @@ def get_fusion_params(model: nn.Module) -> Dict[str, nn.Parameter]: current_fusion_params.remove(n) assert ( current_fusion_params == [] - ), f"Fusion params {current_adapter_params} not converted" + ), f"Fusion params {current_fusion_params} not converted" return fusion_params