Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim committed Dec 5, 2024
1 parent 8ec6a2e commit 13e5313
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 36 deletions.
46 changes: 11 additions & 35 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,58 +8,30 @@
from typing import Any, Dict, List, Optional

import torch

from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader

from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging import logger

from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
# To create your own custom dataset, add


def load_c4_dataset(dataset_path: str, **kwargs):
"""Load C4 dataset with specific configuration."""
def load_c4_dataset(dataset_path: str):
"""Load C4 dataset with default configuration."""
logger.info("Loading C4 dataset...")

# Default settings for C4
default_config = {"name": "en", "split": "train", "streaming": True}

# kwargs override defaults
config = {**default_config, **kwargs}

return load_dataset(dataset_path, **config)


def load_wikipedia_dataset(dataset_path: str, **kwargs):
"""Load Wikipedia dataset with specific configuration."""
logger.info("Loading Wikipedia dataset...")

# Default settings for Wikipedia
default_config = {
"name": "20220301.en",
"split": "train",
"streaming": True,
"trust_remote_code": True,
}

# kwargs override defaults
config = {**default_config, **kwargs}

return load_dataset(dataset_path, **config)
return load_dataset(dataset_path, name="en", split="train", streaming=True)


def process_c4_text(sample: Dict[str, Any]) -> str:
"""Process C4 dataset sample text."""
return sample["text"]


def process_wikipedia_text(sample: Dict[str, Any]) -> str:
"""Process Wikipedia dataset sample text."""
return f"{sample['title']}\n\n{sample['text']}"


# Map from dataset name to a local directory or dataset repository
_supported_datasets = {
"c4_test": "test/assets/c4_test",
Expand Down Expand Up @@ -148,6 +120,7 @@ def __iter__(self):

while len(self._all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(self._all_tokens[:max_buffer_token_len])
# update tokens to the remaining tokens
self._all_tokens = self._all_tokens[max_buffer_token_len:]
input = x[:-1]
label = x[1:]
Expand All @@ -157,6 +130,7 @@ def __iter__(self):
logger.warning(f"Dataset {self.dataset_name} has run out of data")
break
else:
# Reset offset for the next iteration
self._sample_idx = 0
logger.warning(f"Dataset {self.dataset_name} is being re-looped")

Expand All @@ -179,9 +153,11 @@ def __init__(self, dp_rank: int, hf_ds: IterableDataset, batch_size: int):
self._rank_id = f"dp_rank_{dp_rank}"

def state_dict(self) -> Dict[str, Any]:
# Store state only for dp rank to avoid replicating the same state across other dimensions
return {self._rank_id: pickle.dumps(super().state_dict())}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# State being empty is valid
if not state_dict:
return

Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
compile = false
dataset = "wikipedia" # supported datasets: c4_test (2K), c4 (177M)
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)

[experimental]
context_parallel_degree = 1
Expand Down

0 comments on commit 13e5313

Please sign in to comment.