Skip to content

Commit

Permalink
Integrate stateful dataloader to torchtitan
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
gokulavasan committed May 17, 2024
1 parent 847189d commit 4f7c08c
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 9 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ git clone https://github.com/pytorch/torchtitan
cd torchtitan
pip install -r requirements.txt
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118
pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly
```

### Downloading a tokenizer
Expand Down
5 changes: 5 additions & 0 deletions test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
56 changes: 56 additions & 0 deletions test/datasets/test_dataset_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torchtitan.checkpoint import DataLoaderWrapper
from torchtitan.datasets.hf_datasets import build_hf_data_loader
from torchtitan.datasets.tokenizer import create_tokenizer


class TestDatasetCheckpoint:
def test_c4_resumption(self):
dataset_name = "c4_mini"
dataset_path = "./torchtitan/datasets/c4_mini"
batch_size = 1
seq_len = 1024
world_size = 4
rank = 0

dl_wrapper = self._create_dataloader_wrapper(
dataset_name, dataset_path, batch_size, seq_len, world_size, rank
)

it = iter(dl_wrapper.dataloader)
for _ in range(250):
next(it)
state = dl_wrapper.state_dict()
expected_input_ids, expected_labels = next(it)

# Create new dataloader, restore checkpoint, and check if next data yielded is the same as above
dl_wrapper = self._create_dataloader_wrapper(
dataset_name, dataset_path, batch_size, seq_len, world_size, rank
)
dl_wrapper.load_state_dict(state)
input_ids, labels = next(iter(dl_wrapper.dataloader))

assert torch.equal(input_ids, expected_input_ids)
assert torch.equal(labels, expected_labels)

def _create_dataloader_wrapper(
self, dataset_name, dataset_path, batch_size, seq_len, world_size, rank
):
tokenizer_type = "tiktoken"
tokenizer = create_tokenizer("tiktoken", "./test/assets/test_tiktoken.model")
dataloader = build_hf_data_loader(
dataset_name=dataset_name,
dataset_path=dataset_path,
tokenizer=tokenizer,
batch_size=1,
seq_len=1024,
world_size=4,
rank=0,
)
return DataLoaderWrapper(dataloader)
32 changes: 32 additions & 0 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import enum
import os
import pickle
import re
import time
from multiprocessing import get_context
Expand All @@ -22,6 +23,8 @@
set_optimizer_state_dict,
)
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import DataLoader
from torchdata.stateful_dataloader import StatefulDataLoader
from torchtitan.config_manager import JobConfig
from torchtitan.logging_utils import init_logger, logger

Expand Down Expand Up @@ -67,6 +70,33 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
set_optimizer_state_dict(self.model, self.optim, optim_state_dict=state_dict)


class DataLoaderWrapper(Stateful):
def __init__(self, dataloader: DataLoader) -> None:
self.dataloader = dataloader
# Use global rank for now even though dataloader state could be same across dp groups
self.rank_id = str(
dist.get_rank() if (dist.is_available() and dist.is_initialized()) else 0
)

def state_dict(self) -> Dict[str, Any]:
if isinstance(self.dataloader, StatefulDataLoader):
return {self.rank_id: pickle.dumps(self.dataloader.state_dict())}
return {}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
if isinstance(self.dataloader, StatefulDataLoader):
# State is empty
if not state_dict:
return

if self.rank_id not in state_dict:
logger.warning(f"DataLoader state is empty for rank {self.rank_id}. ")
return

# Load state for the current rank
self.dataloader.load_state_dict(pickle.loads(state_dict[self.rank_id]))


class Terminate:
pass

Expand Down Expand Up @@ -110,6 +140,7 @@ def __init__(
model: nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
dataloader: DataLoader,
states: Dict[str, Any],
job_config: JobConfig,
) -> None:
Expand All @@ -125,6 +156,7 @@ def __init__(
"model": ModelWrapper(model),
"optimizer": OptimizerWrapper(model, optimizer),
"lr_scheduler": lr_scheduler,
"dataloader": DataLoaderWrapper(dataloader),
}
)

Expand Down
51 changes: 42 additions & 9 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from typing import List, Optional

import torch
from torch.utils.data import DataLoader, IterableDataset
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader

from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging_utils import logger
Expand All @@ -23,7 +25,7 @@
}


class HuggingFaceDataset(IterableDataset):
class HuggingFaceDataset(IterableDataset, Stateful):
"""PyTorch Representation of the HuggingFace Dataset.
Args:
Expand Down Expand Up @@ -99,32 +101,63 @@ def __init__(
self.seq_len = seq_len
self.infinite = infinite

# variables for checkpointing
self._sample_idx = 0
self._all_tokens: List[int] = []

def __iter__(self):
max_buffer_token_len = 1 + self.seq_len
all_tokens: List[int] = []

while True:
for sample in iter(self._data):
for sample in self._get_data_iter():
sample_text = sample["text"]
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
all_tokens.extend(sample_tokens)
self._all_tokens.extend(sample_tokens)
self._sample_idx += 1

while len(all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(all_tokens[:max_buffer_token_len])
while len(self._all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(self._all_tokens[:max_buffer_token_len])
# update tokens to the remaining tokens
all_tokens = all_tokens[max_buffer_token_len:]
self._all_tokens = self._all_tokens[max_buffer_token_len:]
input = x[:-1]
label = x[1:]
yield input, label

if not self.infinite:
logger.warning(f"Dataset {self.dataset_name} has run out of data.")
break
else:
# Reset offset for the next iteration
self._sample_idx = 0
logger.warning(
f"Dataset {self.dataset_name} is being re-looped. "
"Loss related metrics might be misleading."
)

def _get_data_iter(self):
if self._sample_idx == 0:
return iter(self._data)

# Skip samples
if isinstance(self._data, IterableDataset):
it = iter(self._data)
# Naively iterate through the samples as skip may not be supported
for _ in range(self._sample_idx):
next(it)
return it

# As skipping to the end throws an error in case of map-style dataset, return an empty iterator
if self._sample_idx == len(self._data):
return iter([])
return iter(self._data.skip(self._sample_idx))

def load_state_dict(self, state_dict):
self._sample_idx = state_dict["sample_idx"]
self._all_tokens = state_dict["token_buffer"]

def state_dict(self):
return {"token_buffer": self._all_tokens, "sample_idx": self._sample_idx}


def build_hf_data_loader(
dataset_name: str,
Expand All @@ -140,4 +173,4 @@ def build_hf_data_loader(
dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite
)

return DataLoader(hf_ds, batch_size=batch_size)
return StatefulDataLoader(hf_ds, batch_size=batch_size)
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def loss_fn(pred, labels):
model=model,
optimizer=optimizer,
lr_scheduler=scheduler,
dataloader=data_loader,
states={"train_state": train_state},
job_config=job_config,
)
Expand Down

0 comments on commit 4f7c08c

Please sign in to comment.