Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Early fusion multimodal models #1904

Merged
merged 15 commits into from
Nov 9, 2024
Merged

Early fusion multimodal models #1904

merged 15 commits into from
Nov 9, 2024

Conversation

RdoubleA
Copy link
Contributor

@RdoubleA RdoubleA commented Oct 25, 2024

Context

This enables Early Fusion models based on @pbontrager 's excellent original RFC on multimodal fusion models #1283. Since the RFC, we have already landed Deep Fusion model components. This PR discusses and implements the EarlyFusionModel component, along with testing and some lint updates.

Early fusion is simply a decoder with 1 or more extra encoders that merges their outputs with the token embeddings tables. The challenge lies in how we merge the embeddings and pass it into the decoder.

Changelog

  • Added EarlyFusionModel and tests
  • Updated DeepFusionModel docstring, none of the code was touched except updated an incorrect typing
  • Split _fusion.py into _fusion_layers.py, _early_fusion.py, and _deep_fusion.py
  • Small update in peft utils

Design

There is one design consideration I am seeking feedback on, and that is the EarlyFusionModel's usage of self.decoder.tok_embeddings. It accesses the decoder's token embedding table outside of the decoder forward because we need to merge the image encoder and any other modality encoder's output embeddings with the text embeddings (in this case just concatenate in sequence dimension):

embeds = self.tok_embeddings(tokens)
bsz, seq_len, embed_dim = embeds.shape
for encoder, inp in (encoder_input or {}).items():
    encoder_embeds = self.encoders[encoder](**inp)
    encoder_mask = (tokens == self.encoder_tokens[encoder]).expand(bsz, seq_len, embed_dim)
    embeds[encoder_mask] = encoder_embeds
    
output = self.decoder(embeds, mask, input_pos)
return output

Now, instead of token ids, we are passing in the merged embeddings directly into the decoder. But since we already used the text-only tok_embeddings from the decoder, we need to skip it when passing in the merged embeddings for the final decoder output. There are two ways we can do this.

State dict surgery

In the current code changes and suggested by the original RFC, we can manually set self.decoder.tok_embeddings = nn.Identity() so that it becomes a no-op when you forward pass with merged embeddings.

  • This will require additional state dict hooks to make sure checkpoint saving and loading is still maintained despite the module change
  • If a user wants to use the decoder outside of the EarlyFusionModule in the same script, they will need to restore the original tok_embeddings module from nn.Identity

Additional input_embeds kwarg

We could add a new keyword argument in TransformerDecoder forward for input embeddings. If this is passed in, we automatically skip the token embeddings:

h = self.tok_embeddings(tokens) if input_embeds is None else input_embeds

This way we don't need any state dict hooks or decoder modifications. However, we are polluting the decoder model forward with more arguments.

Test plan

  • Add forward pass decoder + encoders test
  • Add forward pass decoder only test
  • Add forward pass encoders only test, check to see embeddings are merged correctly in correct sequence placement
  • Verify state dict hooks by loading a state dict with "decoder" labels then saving the same state dict
  • Verify that variable number of encoder special tokens in each sample still works as expected

Copy link

pytorch-bot bot commented Oct 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1904

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d0b1ab0 with merge base eb67cc5 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 25, 2024
@joecummings joecummings added the rfc Request for comments label Oct 25, 2024
Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for putting this up Rafi! I left some comments on the implementation, but I'll leave the state dict discussion to others as we've already chatted on this.

torchtune/modules/model_fusion/_fusion.py Outdated Show resolved Hide resolved
torchtune/modules/model_fusion/_fusion.py Outdated Show resolved Hide resolved
torchtune/modules/model_fusion/_fusion.py Outdated Show resolved Hide resolved
torchtune/modules/model_fusion/_fusion.py Outdated Show resolved Hide resolved
@acisseJZhong
Copy link

Thanks for the RFC, you made it very clear what's difference between early fusion and late fusion!

About the design choice, I personally prefer Option 2 for the same reason as you mentioned. I think it's fine to "polluting" the decoder model forward a bit with some optional arguments for each modality. We might need something like

h = self.tok_embeddings(tokens) 
if speech: 
   h[speech_mask] += speech_encoder(input) 
if image: 
   h[image_mask] += image_encoder(input)

@ebsmothers
Copy link
Contributor

11th hour comment on the open design question: in my mind there are nonzero UX costs to either approach. If we patch the decoder embeddings to nn.Identity we introduce additional indirection that is pretty trivial but also pretty non-obvious (I claim any state dict hook is non-obvious when first debugging the inevitable key mismatch error until you find the actual code pointer). On the plus side, we fully contain the blast radius to multimodal model code, and text-only users do not have to worry about it. Conversely, I know we don't want to just add a bunch of random arguments to TransformerDecoder forward, especially ones that are very specific to multimodal models.

Personally I really don't like state dict hooks for the reason I described above. As soon as something (inevitably) goes wrong, it will take a lot more debugging and head-banging-against-the-wall before the user realizes that things are being swapped out under the hood. So perhaps it's no surprise, but I vote for the simple and dumb thing: just add an extra parameter to TransformerDecoder forward. I know that may be controversial, but I like doing the obvious thing, and I like to think our users would appreciate that as well.

@RdoubleA RdoubleA changed the title [RFC] Early fusion multimodal models Early fusion multimodal models Nov 3, 2024
@RdoubleA
Copy link
Contributor Author

RdoubleA commented Nov 3, 2024

After extensive discussion offline, we decided to move ahead with the state dict hook approach. All current changes reflect this.

@SalmanMohammadi
Copy link
Collaborator

After extensive discussion offline, we decided to move ahead with the state dict hook approach. All current changes reflect this.

What were the high level reasons for this, if I may ask?

@RdoubleA
Copy link
Contributor Author

RdoubleA commented Nov 4, 2024

What were the high level reasons for this, if I may ask?

  1. Other fusion layers already use state dict hooks extensively, so there is some precedent
  2. Reluctance to add more bloat to the TransformerDecoder forward
  3. Other solutions would require decoders to be built differently to account for a fused embedding (i.e., a FusedEmbedding layer instead of nn.Embedding for tok_embeddings) and somewhat defeats the purpose of a separate EarlyFusionModel class

cc @ebsmothers @pbontrager

@RdoubleA RdoubleA mentioned this pull request Nov 6, 2024
Comment on lines 367 to 370
if len(encoders.keys()) != 1:
raise ValueError(
f"DeepFusionModel only supports a single encoder. Got {len(encoders.keys())} encoders."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering: why do we generalize encoder -> encoders now if we aren't ready to support multiple encoders yet anyways? Seems to me it'd be better to just make that move all at once in a separate PR. I would think we're not strictly required to have matching signatures for DeepFusion and EarlyFusion classes, is that incorrect?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was mainly to maintain a consistent API between the two. but I don't have strong opinions here. I don't see any hard requirement to make the signatures match

Comment on lines 567 to 568
>>> # Load full fused checkpoints
>>> model.load_state_dict(...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'm not sure this is especially helpful (maybe I'm missing the point though)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the checkpoint is a single file you can load the entire model with encoders all at once. but I'm not sure if this will be the case for a model with multiple encoders or what the checkpoint UX would look like. I'm ok to remove this until we know for sure.

torchtune/modules/model_fusion/_fusion.py Outdated Show resolved Hide resolved
# [bsz, seq_len, 1]
encoder_mask = (tokens == self.encoder_tokens[encoder]).unsqueeze(-1)
# At locations where encoder token is found, replace with encoder embedding
fused_embeds = fused_embeds.masked_scatter(encoder_mask, encoder_embeds)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I might be missing the point here.. is this changing the shape of fused_embeds?

Copy link
Contributor Author

@RdoubleA RdoubleA Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not, it is just placing the encoder embedding vectors in each instance of that encoder token in the fused_embeds. The embeddings should have the same hidden dim but num encoder embeddings < num fused embeds

# [bsz * num_encoder_tokens, embed_dim]
encoder_embeds = encoder_embeds.view(-1, embed_dim)
# [bsz, seq_len, 1]
encoder_mask = (tokens == self.encoder_tokens[encoder]).unsqueeze(-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do any validation on encoder_tokens in the model? E.g. what if we have image embeddings but there is no image token in the token sequence? Do we expect that to be handled in the dataset? If so, we should probably call it out in the documentation somewhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, need to think about where this should be asserted. I would say probably in the transform or the dataset. We wouldn't want to forward pass the encoder if there's nowhere to use it. Although, within a batch you can have variable number of images per sample, so one sample may have zero images and another may have two.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a ValueError just in case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually.... sometimes the number of encoder embeddings and num encoder tokens will not match, because there could be padding images. so in the dataset is probably the best place to assert this. will update the docstring here though to call this out

torchtune/modules/model_fusion/_fusion.py Outdated Show resolved Hide resolved
@RdoubleA RdoubleA merged commit 550163b into pytorch:main Nov 9, 2024
17 checks passed
@RdoubleA RdoubleA deleted the early_fusion branch November 9, 2024 02:28
@ebsmothers ebsmothers mentioned this pull request Nov 26, 2024
44 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. rfc Request for comments
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants