Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use stateful dataloader to checkpoint data iteration order and token buffer #279

Merged
merged 5 commits into from
May 21, 2024

Conversation

gokulavasan
Copy link
Contributor

Summary:

Use the stateful_dataloader from torchdata (https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader) for storing the token buffer and iteration data order. It requires a dependency on the nightly build of torchdata >= 20240426.

Also make sure the dataloader state has a different key per rank.

Test Plan:

Tested locally by first running 30 steps (checkpointing every 5 steps) and capturing all the loss values. Then deleting the last 3 checkpoints and then re-run the training and the loss values from step 16-30 match with what we had earlier in the first run. Note that this requires changes in the train.py to enable a deterministic run.

Reviewers: @tianyu-l

Subscribers: @andrewkho

Tasks:

Tags:

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 26, 2024
@gokulavasan
Copy link
Contributor Author

Looks like adding the index url (for torchdata) is causing other dependencies to not get installed. Will figure out how to fix this

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First pass looks awesome already. Let me see if I can run some large scale experiments. What is the earliest nightly that's OK to run with this PR?

torchtitan/datasets/hf_datasets.py Show resolved Hide resolved
torchtitan/datasets/hf_datasets.py Show resolved Hide resolved
torchtitan/datasets/hf_datasets.py Show resolved Hide resolved
torchtitan/datasets/hf_datasets.py Outdated Show resolved Hide resolved
torchtitan/datasets/hf_datasets.py Outdated Show resolved Hide resolved
torchtitan/datasets/hf_datasets.py Show resolved Hide resolved
requirements.txt Outdated Show resolved Hide resolved
@tianyu-l tianyu-l linked an issue May 1, 2024 that may be closed by this pull request
@fegin
Copy link
Contributor

fegin commented May 1, 2024

@gokulavasan, @tianyu-l pytorch/pytorch#125335 and pytorch/pytorch#125334 should unblock this PR.

torchtitan/checkpoint.py Outdated Show resolved Hide resolved
@gokulavasan gokulavasan marked this pull request as draft May 9, 2024 21:43
@gokulavasan gokulavasan force-pushed the stateful_dataloader_integration branch from d09fcfb to 75cb1d9 Compare May 9, 2024 21:43
@rlrs
Copy link
Contributor

rlrs commented May 10, 2024

I've been testing this out, and ran into an issue with resuming from a checkpoint. I suspect it's because of how StatefulDataLoader handles the state dict: https://github.com/pytorch/data/blob/11e16da61d7f5f587627c75e99ea664efef3e0f8/torchdata/stateful_dataloader/stateful_dataloader.py#L249

That is, a freshly initialized StatefulDataLoader does not have a state dict to load into? I'm not very familiar with how DCP works, so please correct me if it's wrong.

Edit: investigated a bit further, and indeed I get that state_dict for the data loader in DCP.load() is for example '0': {}, which causes it to be discarded by DefaultLoadPlanner.set_up_planner.

@gokulavasan
Copy link
Contributor Author

gokulavasan commented May 10, 2024

@rlrs Would it be possible to test it after my latest commit (b9b045d)? I missed adding that part.

@gokulavasan gokulavasan marked this pull request as ready for review May 10, 2024 14:37
@rlrs
Copy link
Contributor

rlrs commented May 10, 2024

@rlrs Would it be possible to test it after my latest commit (b9b045d)? I missed adding that part.

I had already added that in my version. I can't get it to load the state_dict, unless I first call iter(dataloader) so that self._iterator is not None.

If I call iter before DCP.load, and then set self._first_iter = True in HuggingFaceDataset.load_state_dict, everything seems to work!

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Had some minor comments.
For next steps, we can discuss if we should add a unit test to guard the correctness of checkpointable data loading, and the plan to migrate to DTensor-based checkpointing.

requirements.txt Outdated
@@ -1,5 +1,5 @@
torch >= 2.2.0.dev
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to put the torchdata dependency here, maybe try

Suggested change
torch >= 2.2.0.dev
torch >= 2.2.0.dev
--find-links https://download.pytorch.org/whl/nightly/cpu/
torchdata >= 0.7.1.dev20240426+cpu

If works, we can remove the dependency in unit test workflow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I attempted thus but it fails to find the nightly, instead installs the latest released version. We do plan to do a torchdata release soon (which will contain the latest StatefulDataLoader) and once that happens, we can remove the explicit pip installs in unit test workflows and main README.

torchtitan/datasets/hf_datasets.py Outdated Show resolved Hide resolved
torchtitan/datasets/hf_datasets.py Outdated Show resolved Hide resolved
@gokulavasan gokulavasan force-pushed the stateful_dataloader_integration branch 5 times, most recently from ae5f139 to 4f7c08c Compare May 17, 2024 01:42
@gokulavasan
Copy link
Contributor Author

@tianyu-l Addressed PR comments (thank you!), added unit test, and made changes to the github workflows to allow running those unit tests. Let me know if the changes look okay. Regarding the move to DTensor, I think this requires analysis of what the benefits are (especially for storing unstructured state dict of dataloader).

If it is purely to reduce the replication of state across tensor/pipeline parallel groups, I think we can store the dataloader state just for the dp worker ranks (by using key as the dp_rank_id) and load it back instead of storing it for all global ranks. For now, with just the text tokens, this might not even be necessary as the state is not that big. Let me know how you would like to proceed.

@gokulavasan
Copy link
Contributor Author

@rlrs Thank you for your great analysis here (#279 (comment)).

Helped us narrow down the issue which basically boiled down to in-place loading of checkpoint of DCP. StatefulDataLoader doesn't currently return no state if dataloader iterator is not created while DCP expects the module it let it know what the keys the module is expecting.

In order to get around this, I serialized the state of the dataloader and in this case there is only one key to load that is communicated by the DataLoaderWrapper to DCP - "<rank_id>".

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can store the dataloader state just for the dp worker ranks (by using key as the dp_rank_id) and load it back instead of storing it for all global ranks.

This sounds quite good! @fegin would checkpointing behave in this expected way? I.e. if we use the same key for the same TP ranks, but different keys for different DP ranks, would it avoid saving extra copies and load correctly? If that's the case I agree we don't have to use DTensor for now.

torchtitan/datasets/hf_datasets.py Show resolved Hide resolved
.github/workflows/unit_test_4gpu.yaml Outdated Show resolved Hide resolved
.github/workflows/unit_test_cpu.yaml Outdated Show resolved Hide resolved
torchtitan/checkpoint.py Outdated Show resolved Hide resolved
test/datasets/test_dataset_checkpoint.py Outdated Show resolved Hide resolved
@gokulavasan gokulavasan force-pushed the stateful_dataloader_integration branch from 344a48d to 8a217b6 Compare May 21, 2024 17:45
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@gokulavasan gokulavasan force-pushed the stateful_dataloader_integration branch from 8a217b6 to 9b07bb9 Compare May 21, 2024 17:47
@gokulavasan gokulavasan force-pushed the stateful_dataloader_integration branch from 80cefc0 to 5f825c7 Compare May 21, 2024 19:56
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@gokulavasan gokulavasan force-pushed the stateful_dataloader_integration branch from 5f825c7 to c1a49fb Compare May 21, 2024 20:00
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks awesome! Thanks for the beautiful work!

Please address inline comments before merging.

@@ -31,5 +31,6 @@ jobs:
pip config --user set global.progress_bar off

python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
python -m pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need this in .github/workflows/integration_test_periodic.yaml as well.

Let's create an issue, tracking that we need to put torchdata in requirements.txt and pyproject.toml after the needed changes ship in an official release.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l Should pyproject.toml datasets dependency also enforce >= 2.19.0 version requirement?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes I think so

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created #351 to remove torchdata nightly pip install

torchtitan/datasets/hf_datasets.py Outdated Show resolved Hide resolved
@gokulavasan gokulavasan merged commit 99a73dd into main May 21, 2024
5 checks passed
@gokulavasan gokulavasan deleted the stateful_dataloader_integration branch May 21, 2024 23:05
tianyu-l pushed a commit that referenced this pull request May 28, 2024
…buffer (#279)

Summary: 

Use the stateful_dataloader from torchdata
(https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader)
for storing the token buffer and iteration data order. It requires a
dependency on the nightly build of torchdata >= 20240426.

Also make sure the dataloader state has a different key per rank.

Test Plan:

Tested locally by first running 30 steps (checkpointing every 5 steps)
and capturing all the loss values. Then deleting the last 3 checkpoints
and then re-run the training and the loss values from step 16-30 match
with what we had earlier in the first run. Note that this requires
changes in the train.py to enable a deterministic run.

Reviewers: @tianyu-l 

Subscribers: @andrewkho 

Tasks:

Tags:
@lhoestq
Copy link

lhoestq commented Jul 13, 2024

Hi ! I'm Quentin from HF :)
FYI we just added state_dict() and load_state_dict() in datasets.IterableDataset, which can resume iteration faster than just skipping samples !

tianyu-l pushed a commit to tianyu-l/torchtitan_intern24 that referenced this pull request Aug 16, 2024
…buffer (pytorch#279)

Summary: 

Use the stateful_dataloader from torchdata
(https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader)
for storing the token buffer and iteration data order. It requires a
dependency on the nightly build of torchdata >= 20240426.

Also make sure the dataloader state has a different key per rank.

Test Plan:

Tested locally by first running 30 steps (checkpointing every 5 steps)
and capturing all the loss values. Then deleting the last 3 checkpoints
and then re-run the training and the loss values from step 16-30 match
with what we had earlier in the first run. Note that this requires
changes in the train.py to enable a deterministic run.

Reviewers: @tianyu-l 

Subscribers: @andrewkho 

Tasks:

Tags:
tianyu-l pushed a commit that referenced this pull request Aug 16, 2024
…buffer (#279)

Summary: 

Use the stateful_dataloader from torchdata
(https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader)
for storing the token buffer and iteration data order. It requires a
dependency on the nightly build of torchdata >= 20240426.

Also make sure the dataloader state has a different key per rank.

Test Plan:

Tested locally by first running 30 steps (checkpointing every 5 steps)
and capturing all the loss values. Then deleting the last 3 checkpoints
and then re-run the training and the loss values from step 16-30 match
with what we had earlier in the first run. Note that this requires
changes in the train.py to enable a deterministic run.

Reviewers: @tianyu-l 

Subscribers: @andrewkho 

Tasks:

Tags:
philippguevorguian pushed a commit to YerevaNN/YNNtitan that referenced this pull request Aug 17, 2024
…buffer (pytorch#279)

Summary: 

Use the stateful_dataloader from torchdata
(https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader)
for storing the token buffer and iteration data order. It requires a
dependency on the nightly build of torchdata >= 20240426.

Also make sure the dataloader state has a different key per rank.

Test Plan:

Tested locally by first running 30 steps (checkpointing every 5 steps)
and capturing all the loss values. Then deleting the last 3 checkpoints
and then re-run the training and the loss values from step 16-30 match
with what we had earlier in the first run. Note that this requires
changes in the train.py to enable a deterministic run.

Reviewers: @tianyu-l 

Subscribers: @andrewkho 

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Make dataloader stateful?
7 participants