-
Notifications
You must be signed in to change notification settings - Fork 466
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Recipe State Dict Code #1964
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1964
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 5a18094 with merge base 24d3579 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really awesome overall.
Main concerns are around any differences in memory and/or speed this could cause? Can we confirm?
"Please call get_full_model_state_dict(..., device=self._device)," | ||
" so DTensor can communicate over NCCL." | ||
for param_name, sharded_param in sharded_sd.items(): | ||
if sharded_param.is_cpu: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's happening here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this is for CPU offload or something?
# skip non-trainable params when trainable_only is True | ||
continue | ||
if isinstance(sharded_param._local_tensor, NF4Tensor): | ||
# NF4Tensor does not support all_gather from DTensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lol why does it not support all_gather? Can't we ask AO to support that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah let's open an issue there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want AO to support it, but it would take too long to get on stable so we have to do it ourselves in the meantime.
cpu_state_dict[param_name] = full_param.cpu() | ||
else: | ||
del full_param | ||
torch.distributed.barrier() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yay
@@ -38,30 +39,30 @@ | |||
class DummyAdapterModule(nn.Module, AdapterModule): | |||
def __init__(self, in_dim, out_dim): | |||
super().__init__() | |||
self.adapter = nn.Linear(in_dim, out_dim, bias=False) | |||
self.lora = nn.Linear(in_dim, out_dim, bias=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So as of now what is the actual value of AdapterModule
? Is it just for setting trainable params?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think AdapterModule is the right way to go, but since we are already ignoring it for ckpt merging, I'm not really changing anything by not using it for get_adapter_state_dict. I think we should move to using AdapterModule for all of these functions but that doesn't need to be solved in this PR.
torchtune/modules/peft/_utils.py
Outdated
|
||
""" | ||
adapter_key_filter = lambda x: "lora" in x or "magnitude" in x | ||
return {k: v.cpu() for k, v in state_dict.items() if adapter_key_filter(k)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to make the move to CPU optional here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's already on cpu this is a no op
d0, *dn = quant_param.shape | ||
shape = (d0 * mesh.get_group().size(), *dn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this means sharding is always along the first dimension? If so might just leave a comment to that effect
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wei said FSDP always shards on dim 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks you for wading through this minefield and fixing this! And thanks for adding the new test case as well
Context
What is the purpose of this PR? Is it to
In our recipes we have 4 different ways that we collect model/adapter state_dicts:
Where this causes issues is when models use state_dict_hooks which will change the expected names. This has come up as an issue with activation checkpointing in the past but was side stepped. Now with fusion models, this is an issue again as it relies on state dict hooks to operate. To address these known issues and make support easier in the future, this PR consolidates all of our checkpoint code to use the state_dict api so we'll always have a consistent set of names. This PR takes two primary approaches:
Changelog
What are the changes made in this PR?
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
I will update here with an overview of all the updated recipes showing that memory and save time doesn't change with the checkpoint update.