Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Recipe State Dict Code #1964

Merged
merged 3 commits into from
Nov 9, 2024

Conversation

pbontrager
Copy link
Contributor

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

In our recipes we have 4 different ways that we collect model/adapter state_dicts:

  1. single device: model.state_dict
  2. lora single device: {k: v.cpu() for k,v in self.adapter_params}
  3. distributed: get_full_model_state_dict using sharded_sd
  4. distributed qlora: get_full_model_state_dict using model.named_modules

Where this causes issues is when models use state_dict_hooks which will change the expected names. This has come up as an issue with activation checkpointing in the past but was side stepped. Now with fusion models, this is an issue again as it relies on state dict hooks to operate. To address these known issues and make support easier in the future, this PR consolidates all of our checkpoint code to use the state_dict api so we'll always have a consistent set of names. This PR takes two primary approaches:

  1. Introduces get_adapter_state_dict to filter a full model state_dict using the same pattern matching logic as get_merged_lora_ckpt uses. This incurs no extra cost even if save_adapter_only=True since calling model.state_dict doesn't copy params and is almost free.
  2. replaces get_full_model_state_dict with gather_cpu_state_dict for distributed recipes. This takes a sharded state_dict as input and gathers each param and copies in to cpu. Notably, when there are NF4Tensors it still uses the state_dict (unlike the old function) but does a manual all_gather instead of calling full_tensor.

Changelog

What are the changes made in this PR?

  • added get_merged_lora_ckpt
  • replaces get_full_model_state_dict with gather_cpu_state_dict
  • Updated every recipe that uses peft or get_full_model_state_dict
  • Updated docs and tests

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

I will update here with an overview of all the updated recipes showing that memory and save time doesn't change with the checkpoint update.

Copy link

pytorch-bot bot commented Nov 7, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1964

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5a18094 with merge base 24d3579 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 7, 2024
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really awesome overall.

Main concerns are around any differences in memory and/or speed this could cause? Can we confirm?

"Please call get_full_model_state_dict(..., device=self._device),"
" so DTensor can communicate over NCCL."
for param_name, sharded_param in sharded_sd.items():
if sharded_param.is_cpu:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's happening here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is for CPU offload or something?

# skip non-trainable params when trainable_only is True
continue
if isinstance(sharded_param._local_tensor, NF4Tensor):
# NF4Tensor does not support all_gather from DTensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lol why does it not support all_gather? Can't we ask AO to support that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah let's open an issue there

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want AO to support it, but it would take too long to get on stable so we have to do it ourselves in the meantime.

cpu_state_dict[param_name] = full_param.cpu()
else:
del full_param
torch.distributed.barrier()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yay

@@ -38,30 +39,30 @@
class DummyAdapterModule(nn.Module, AdapterModule):
def __init__(self, in_dim, out_dim):
super().__init__()
self.adapter = nn.Linear(in_dim, out_dim, bias=False)
self.lora = nn.Linear(in_dim, out_dim, bias=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So as of now what is the actual value of AdapterModule? Is it just for setting trainable params?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think AdapterModule is the right way to go, but since we are already ignoring it for ckpt merging, I'm not really changing anything by not using it for get_adapter_state_dict. I think we should move to using AdapterModule for all of these functions but that doesn't need to be solved in this PR.


"""
adapter_key_filter = lambda x: "lora" in x or "magnitude" in x
return {k: v.cpu() for k, v in state_dict.items() if adapter_key_filter(k)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to make the move to CPU optional here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's already on cpu this is a no op

Comment on lines +380 to +381
d0, *dn = quant_param.shape
shape = (d0 * mesh.get_group().size(), *dn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this means sharding is always along the first dimension? If so might just leave a comment to that effect

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wei said FSDP always shards on dim 0

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks you for wading through this minefield and fixing this! And thanks for adding the new test case as well

@pbontrager pbontrager merged commit 08efaed into pytorch:main Nov 9, 2024
17 checks passed
joecummings pushed a commit that referenced this pull request Nov 11, 2024
@ebsmothers ebsmothers mentioned this pull request Nov 26, 2024
44 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants