From 13e531317ffccd03d2b1e1e635f6e279017044b5 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Thu, 5 Dec 2024 10:46:06 -0800 Subject: [PATCH] updates --- torchtitan/datasets/hf_datasets.py | 46 +++++++----------------------- train_configs/debug_model.toml | 2 +- 2 files changed, 12 insertions(+), 36 deletions(-) diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index 6a83c1c0..8bebab13 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -8,6 +8,9 @@ from typing import Any, Dict, List, Optional import torch + +from datasets import Dataset, load_dataset +from datasets.distributed import split_dataset_by_node from torch.distributed.checkpoint.stateful import Stateful from torch.utils.data import IterableDataset from torchdata.stateful_dataloader import StatefulDataLoader @@ -15,39 +18,13 @@ from torchtitan.datasets.tokenizer import Tokenizer from torchtitan.logging import logger -from datasets import Dataset, load_dataset -from datasets.distributed import split_dataset_by_node +# To create your own custom dataset, add -def load_c4_dataset(dataset_path: str, **kwargs): - """Load C4 dataset with specific configuration.""" +def load_c4_dataset(dataset_path: str): + """Load C4 dataset with default configuration.""" logger.info("Loading C4 dataset...") - - # Default settings for C4 - default_config = {"name": "en", "split": "train", "streaming": True} - - # kwargs override defaults - config = {**default_config, **kwargs} - - return load_dataset(dataset_path, **config) - - -def load_wikipedia_dataset(dataset_path: str, **kwargs): - """Load Wikipedia dataset with specific configuration.""" - logger.info("Loading Wikipedia dataset...") - - # Default settings for Wikipedia - default_config = { - "name": "20220301.en", - "split": "train", - "streaming": True, - "trust_remote_code": True, - } - - # kwargs override defaults - config = {**default_config, **kwargs} - - return load_dataset(dataset_path, **config) + return load_dataset(dataset_path, name="en", split="train", streaming=True) def process_c4_text(sample: Dict[str, Any]) -> str: @@ -55,11 +32,6 @@ def process_c4_text(sample: Dict[str, Any]) -> str: return sample["text"] -def process_wikipedia_text(sample: Dict[str, Any]) -> str: - """Process Wikipedia dataset sample text.""" - return f"{sample['title']}\n\n{sample['text']}" - - # Map from dataset name to a local directory or dataset repository _supported_datasets = { "c4_test": "test/assets/c4_test", @@ -148,6 +120,7 @@ def __iter__(self): while len(self._all_tokens) >= max_buffer_token_len: x = torch.LongTensor(self._all_tokens[:max_buffer_token_len]) + # update tokens to the remaining tokens self._all_tokens = self._all_tokens[max_buffer_token_len:] input = x[:-1] label = x[1:] @@ -157,6 +130,7 @@ def __iter__(self): logger.warning(f"Dataset {self.dataset_name} has run out of data") break else: + # Reset offset for the next iteration self._sample_idx = 0 logger.warning(f"Dataset {self.dataset_name} is being re-looped") @@ -179,9 +153,11 @@ def __init__(self, dp_rank: int, hf_ds: IterableDataset, batch_size: int): self._rank_id = f"dp_rank_{dp_rank}" def state_dict(self) -> Dict[str, Any]: + # Store state only for dp rank to avoid replicating the same state across other dimensions return {self._rank_id: pickle.dumps(super().state_dict())} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + # State being empty is valid if not state_dict: return diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 8a2f8268..f681cdba 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -40,7 +40,7 @@ data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 tensor_parallel_degree = 1 compile = false -dataset = "wikipedia" # supported datasets: c4_test (2K), c4 (177M) +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [experimental] context_parallel_degree = 1