Skip to content

Commit

Permalink
Custom Dataset docs
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim committed Dec 4, 2024
1 parent 93ba30f commit 049b755
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 80 deletions.
80 changes: 80 additions & 0 deletions docs/datasets.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Custom Datasets in TorchTitan

TorchTitan is designed to work seamlessly with most HuggingFace datasets. While we provide the C4 dataset for numerics and convergence testing, you can easily add support for your own datasets. Here's how to do it using Wikipedia as an example.

## Quick Start

1. Install TorchTitan from source:
```bash
pip install -e .
```

2. Locate the dataset configuration file:
```
torchtitan/datasets/hf_datasets/hf_datasets.py
```

## Adding Your Dataset

You'll need to add two main components:

1. A dataset loader function
2. A sample processor function

### 1. Define Dataset Loader

Add a function that specifies how to load your dataset:

```python
def load_wikipedia_dataset(dataset_path: str, **kwargs):
"""Load Wikipedia dataset with specific configuration."""
logger.info("Loading Wikipedia dataset...")
return load_dataset(
dataset_path,
name="20220301.en",
split="train",
streaming=True,
trust_remote_code=True,
)

# Register your loader in DATASET_LOADERS
DATASET_LOADERS = {
# ... existing loaders ...
"wikipedia": load_wikipedia_dataset,
}
```

### 2. Define Sample Processor

Add a function that processes individual samples from your dataset:

```python
def process_wikipedia_text(sample: Dict[str, Any]) -> str:
"""Process Wikipedia dataset sample text."""
return f"{sample['title']}\n\n{sample['text']}"

# Register your processor in DATASET_TEXT_PROCESSORS
DATASET_TEXT_PROCESSORS = {
# ... existing processors ...
"wikipedia": process_wikipedia_text,
}
```

### 3. Configure Your Training

In your training configuration file (`.toml`), set your dataset:

```toml
dataset = "wikipedia"
```

That's it! Your custom dataset is now ready to use with TorchTitan.

## Key Points

- The loader function should return a HuggingFace dataset object
- The processor function should return a string that combines the relevant fields from your dataset
- Make sure your dataset name matches exactly in both the loader and processor registrations
- Use streaming=True for large datasets to manage memory efficiently

Now you can start training with your custom dataset!
157 changes: 79 additions & 78 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,63 +5,69 @@
# LICENSE file in the root directory of this source tree.

import pickle
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

import torch
from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset

from torchdata.stateful_dataloader import StatefulDataLoader

from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging import logger

from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node

# map from dataset name to a local directory, or
# a dataset repository on the HF hub
def load_c4_dataset(dataset_path: str, **kwargs):
"""Load C4 dataset with specific configuration."""
logger.info("Loading C4 dataset...")
return load_dataset(dataset_path, name="en", split="train", streaming=True)


def load_wikipedia_dataset(dataset_path: str, **kwargs):
"""Load Wikipedia dataset with specific configuration."""
logger.info("Loading Wikipedia dataset...")
return load_dataset(
dataset_path,
name="20220301.en",
split="train",
streaming=True,
trust_remote_code=True,
)


def process_c4_text(sample: Dict[str, Any]) -> str:
"""Process C4 dataset sample text."""
return sample["text"]


def process_wikipedia_text(sample: Dict[str, Any]) -> str:
"""Process Wikipedia dataset sample text."""
return f"{sample['title']}\n\n{sample['text']}"


# Map from dataset name to a local directory or dataset repository
_supported_datasets = {
"c4_test": "test/assets/c4_test",
"c4": "allenai/c4",
"wikipedia": "wikipedia",
}

DATASET_LOADERS = {
"c4": load_c4_dataset,
"c4_test": lambda path, **kwargs: load_dataset(path, split="train"),
"wikipedia": load_wikipedia_dataset,
}

class HuggingFaceDataset(IterableDataset, Stateful):
"""PyTorch Representation of the HuggingFace Dataset.
Args:
dataset_name (str): name of the dataset to load
dataset_path (Optional[str]):
Path to the dataset in the file system. If provided, data will be loaded
from this path instead of downloaded.
tokenizer (Tokenizer):
Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
seq_len (int): max sequence length
world_size (int): number of data parallel processes participating in training
rank (int): rank of the current data parallel process
infinite (bool): whether to loop infinitely over the dataset
We currently support the c4 dataset, and a subset of it for testing purposes:
c4_test (2K training entries)
c4 (177M training entries - this dataset is streamed due to the size)
>> c4 (EN) <<:
c4 cleaned, English version
Data input format (c4):
{
'url': 'https://klyq.com/beginners-bbq-class-taking-place-in-missoula/',
'text': 'Beginners BBQ Class Taking Place in Missoula!\nDo you want to get better at ...',
'timestamp': '2019-04-25T12:57:54Z'
}
Example use (c4):
>>> ds = HuggingFaceDataset(dataset_name="c4", dataset_path=None, tokenizer=tokenizer)
>>> for batch in Dataloader(ds, batch_size=8):
print(f"Batch size: {len(batch)}")
Batch size: 8
"""

DATASET_TEXT_PROCESSORS = {
"c4": process_c4_text,
"c4_test": process_c4_text,
"wikipedia": process_wikipedia_text,
}


class HuggingFaceDataset(IterableDataset, Stateful):
def __init__(
self,
dataset_name: str,
Expand All @@ -72,54 +78,62 @@ def __init__(
rank: int = 0,
infinite: bool = False,
) -> None:
# allow user to pass in a (local or HF hub) path to use unsupported datasets
# Force lowercase for consistent comparison
dataset_name = dataset_name.lower()

if dataset_name not in _supported_datasets:
if dataset_path:
logger.warning(
f"Dataset {dataset_name} is not tested or verfied. "
f"Recommended datasets are: {list(_supported_datasets.keys())}"
)
else:
raise ValueError(
f"Dataset {dataset_name} is not supported. "
f"Supported datasets are: {list(_supported_datasets.keys())}"
)
raise ValueError(
f"Dataset {dataset_name} is not supported. "
f"Supported datasets are: {list(_supported_datasets.keys())}"
)

if not dataset_path:
dataset_path = _supported_datasets[dataset_name]
logger.info(f"Preparing {dataset_name} dataset from {dataset_path}")

if dataset_name == "c4":
# c4 is huge, and requires both streaming and language selection
# (we default to en)
ds = load_dataset(dataset_path, name="en", split="train", streaming=True)
else:
ds = load_dataset(dataset_path, split="train")
if dataset_name not in DATASET_LOADERS:
raise ValueError(f"No loader found for dataset {dataset_name}")

dataset_loader = DATASET_LOADERS[dataset_name]
logger.info(f"Using dataset loader for {dataset_name}")
ds = dataset_loader(dataset_path)

# TODO: support shuffling
self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, rank, world_size)
self._tokenizer = tokenizer
self.seq_len = seq_len
self.infinite = infinite

# variables for checkpointing
if dataset_name not in DATASET_TEXT_PROCESSORS:
raise ValueError(f"No text processor found for dataset {dataset_name}")

self._text_processor = DATASET_TEXT_PROCESSORS[dataset_name]

# Variables for checkpointing
self._sample_idx = 0
self._all_tokens: List[int] = []

def _get_data_iter(self):
if self._sample_idx == 0:
return iter(self._data)

if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
return iter([])

return iter(self._data.skip(self._sample_idx))

def __iter__(self):
max_buffer_token_len = 1 + self.seq_len

while True:
for sample in self._get_data_iter():
sample_text = sample["text"]
# Use the dataset-specific text processor
sample_text = self._text_processor(sample)
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
self._all_tokens.extend(sample_tokens)
self._sample_idx += 1

while len(self._all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(self._all_tokens[:max_buffer_token_len])
# update tokens to the remaining tokens
self._all_tokens = self._all_tokens[max_buffer_token_len:]
input = x[:-1]
label = x[1:]
Expand All @@ -129,20 +143,9 @@ def __iter__(self):
logger.warning(f"Dataset {self.dataset_name} has run out of data")
break
else:
# Reset offset for the next iteration
self._sample_idx = 0
logger.warning(f"Dataset {self.dataset_name} is being re-looped")

def _get_data_iter(self):
if self._sample_idx == 0:
return iter(self._data)

# As skipping to the end throws an error in case of map-style dataset, return an empty iterator
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
return iter([])

return iter(self._data.skip(self._sample_idx))

def load_state_dict(self, state_dict):
self._sample_idx = state_dict["sample_idx"]
self._all_tokens = state_dict["token_buffer"]
Expand All @@ -162,11 +165,9 @@ def __init__(self, dp_rank: int, hf_ds: IterableDataset, batch_size: int):
self._rank_id = f"dp_rank_{dp_rank}"

def state_dict(self) -> Dict[str, Any]:
# Store state only for dp rank to avoid replicating the same state across other dimensions
return {self._rank_id: pickle.dumps(super().state_dict())}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# State being empty is valid
if not state_dict:
return

Expand All @@ -184,12 +185,12 @@ def build_hf_data_loader(
tokenizer: Tokenizer,
batch_size: int,
seq_len: int,
world_size,
rank,
world_size: int,
rank: int,
infinite: bool = True,
):
"""Build a data loader for HuggingFace datasets."""
hf_ds = HuggingFaceDataset(
dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite
)

return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size)
1 change: 0 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def main(job_config: JobConfig):
# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path)

# build dataloader
data_loader = build_hf_data_loader(
job_config.training.dataset,
Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
compile = false
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
dataset = "wikipedia" # supported datasets: c4_test (2K), c4 (177M)

[experimental]
context_parallel_degree = 1
Expand Down

0 comments on commit 049b755

Please sign in to comment.