-
Notifications
You must be signed in to change notification settings - Fork 228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use stateful dataloader to checkpoint data iteration order and token buffer #279
Conversation
Looks like adding the index url (for torchdata) is causing other dependencies to not get installed. Will figure out how to fix this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First pass looks awesome already. Let me see if I can run some large scale experiments. What is the earliest nightly that's OK to run with this PR?
@gokulavasan, @tianyu-l pytorch/pytorch#125335 and pytorch/pytorch#125334 should unblock this PR. |
d09fcfb
to
75cb1d9
Compare
I've been testing this out, and ran into an issue with resuming from a checkpoint. I suspect it's because of how That is, a freshly initialized Edit: investigated a bit further, and indeed I get that |
I had already added that in my version. I can't get it to load the state_dict, unless I first call If I call |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Had some minor comments.
For next steps, we can discuss if we should add a unit test to guard the correctness of checkpointable data loading, and the plan to migrate to DTensor-based checkpointing.
requirements.txt
Outdated
@@ -1,5 +1,5 @@ | |||
torch >= 2.2.0.dev |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to put the torchdata
dependency here, maybe try
torch >= 2.2.0.dev | |
torch >= 2.2.0.dev | |
--find-links https://download.pytorch.org/whl/nightly/cpu/ | |
torchdata >= 0.7.1.dev20240426+cpu |
If works, we can remove the dependency in unit test workflow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I attempted thus but it fails to find the nightly, instead installs the latest released version. We do plan to do a torchdata release soon (which will contain the latest StatefulDataLoader) and once that happens, we can remove the explicit pip installs in unit test workflows and main README.
ae5f139
to
4f7c08c
Compare
@tianyu-l Addressed PR comments (thank you!), added unit test, and made changes to the github workflows to allow running those unit tests. Let me know if the changes look okay. Regarding the move to DTensor, I think this requires analysis of what the benefits are (especially for storing unstructured state dict of dataloader). If it is purely to reduce the replication of state across tensor/pipeline parallel groups, I think we can store the dataloader state just for the dp worker ranks (by using key as the dp_rank_id) and load it back instead of storing it for all global ranks. For now, with just the text tokens, this might not even be necessary as the state is not that big. Let me know how you would like to proceed. |
@rlrs Thank you for your great analysis here (#279 (comment)). Helped us narrow down the issue which basically boiled down to in-place loading of checkpoint of DCP. StatefulDataLoader doesn't currently return no state if dataloader iterator is not created while DCP expects the module it let it know what the keys the module is expecting. In order to get around this, I serialized the state of the dataloader and in this case there is only one key to load that is communicated by the DataLoaderWrapper to DCP - "<rank_id>". |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can store the dataloader state just for the dp worker ranks (by using key as the dp_rank_id) and load it back instead of storing it for all global ranks.
This sounds quite good! @fegin would checkpointing behave in this expected way? I.e. if we use the same key for the same TP ranks, but different keys for different DP ranks, would it avoid saving extra copies and load correctly? If that's the case I agree we don't have to use DTensor for now.
344a48d
to
8a217b6
Compare
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
8a217b6
to
9b07bb9
Compare
80cefc0
to
5f825c7
Compare
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
5f825c7
to
c1a49fb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks awesome! Thanks for the beautiful work!
Please address inline comments before merging.
@@ -31,5 +31,6 @@ jobs: | |||
pip config --user set global.progress_bar off | |||
|
|||
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 | |||
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need this in .github/workflows/integration_test_periodic.yaml
as well.
Let's create an issue, tracking that we need to put torchdata in requirements.txt
and pyproject.toml
after the needed changes ship in an official release.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tianyu-l Should pyproject.toml datasets dependency also enforce >= 2.19.0 version requirement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yes I think so
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Created #351 to remove torchdata nightly pip install
…buffer (#279) Summary: Use the stateful_dataloader from torchdata (https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader) for storing the token buffer and iteration data order. It requires a dependency on the nightly build of torchdata >= 20240426. Also make sure the dataloader state has a different key per rank. Test Plan: Tested locally by first running 30 steps (checkpointing every 5 steps) and capturing all the loss values. Then deleting the last 3 checkpoints and then re-run the training and the loss values from step 16-30 match with what we had earlier in the first run. Note that this requires changes in the train.py to enable a deterministic run. Reviewers: @tianyu-l Subscribers: @andrewkho Tasks: Tags:
Hi ! I'm Quentin from HF :) |
…buffer (pytorch#279) Summary: Use the stateful_dataloader from torchdata (https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader) for storing the token buffer and iteration data order. It requires a dependency on the nightly build of torchdata >= 20240426. Also make sure the dataloader state has a different key per rank. Test Plan: Tested locally by first running 30 steps (checkpointing every 5 steps) and capturing all the loss values. Then deleting the last 3 checkpoints and then re-run the training and the loss values from step 16-30 match with what we had earlier in the first run. Note that this requires changes in the train.py to enable a deterministic run. Reviewers: @tianyu-l Subscribers: @andrewkho Tasks: Tags:
…buffer (#279) Summary: Use the stateful_dataloader from torchdata (https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader) for storing the token buffer and iteration data order. It requires a dependency on the nightly build of torchdata >= 20240426. Also make sure the dataloader state has a different key per rank. Test Plan: Tested locally by first running 30 steps (checkpointing every 5 steps) and capturing all the loss values. Then deleting the last 3 checkpoints and then re-run the training and the loss values from step 16-30 match with what we had earlier in the first run. Note that this requires changes in the train.py to enable a deterministic run. Reviewers: @tianyu-l Subscribers: @andrewkho Tasks: Tags:
…buffer (pytorch#279) Summary: Use the stateful_dataloader from torchdata (https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader) for storing the token buffer and iteration data order. It requires a dependency on the nightly build of torchdata >= 20240426. Also make sure the dataloader state has a different key per rank. Test Plan: Tested locally by first running 30 steps (checkpointing every 5 steps) and capturing all the loss values. Then deleting the last 3 checkpoints and then re-run the training and the loss values from step 16-30 match with what we had earlier in the first run. Note that this requires changes in the train.py to enable a deterministic run. Reviewers: @tianyu-l Subscribers: @andrewkho Tasks: Tags:
Summary:
Use the stateful_dataloader from torchdata (https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader) for storing the token buffer and iteration data order. It requires a dependency on the nightly build of torchdata >= 20240426.
Also make sure the dataloader state has a different key per rank.
Test Plan:
Tested locally by first running 30 steps (checkpointing every 5 steps) and capturing all the loss values. Then deleting the last 3 checkpoints and then re-run the training and the loss values from step 16-30 match with what we had earlier in the first run. Note that this requires changes in the train.py to enable a deterministic run.
Reviewers: @tianyu-l
Subscribers: @andrewkho
Tasks:
Tags: