diff --git a/README.md b/README.md index 2468d4a0..89060cbe 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ You may want to see how the model is defined or how parallelism techniques are a 4. `torch.compile` support 5. [Float8](https://discuss.pytorch.org/t/distributed-w-torchtitan-enabling-float8-all-gather-in-fsdp2/209323) support ([how-to](docs/float8.md)) 6. DDP and HSDP -7. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) +7. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) and support for [custom datasets](docs/datasets.md) 8. Learning rate scheduler, meta-init, (optional) fused RMSNorm kernel 9. Loss, GPU memory, throughput (tokens/sec), and MFU displayed and logged via [Tensorboard or Weights & Biases](/docs/metrics.md) 10. Debugging tools including CPU/GPU profiling, [memory profiling](docs/memory_profiler.md), [Flight Recorder](#debugging), etc. diff --git a/docs/datasets.md b/docs/datasets.md index 75832101..e13da2dd 100644 --- a/docs/datasets.md +++ b/docs/datasets.md @@ -69,6 +69,6 @@ That's it! Your custom dataset is now ready to use with TorchTitan. - `text_processor`: Function to process individual samples - The loader function should return a HuggingFace dataset object - The processor function should return a string that combines the relevant fields from your dataset -- Use streaming=True for large datasets to manage memory efficiently +- Use `streaming=True` for large datasets to manage memory efficiently Now you can start training with your custom dataset! diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index d0cf7587..201f0b48 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -9,6 +9,9 @@ from typing import Any, Callable, Dict, List, Optional import torch + +from datasets import Dataset, load_dataset +from datasets.distributed import split_dataset_by_node from torch.distributed.checkpoint.stateful import Stateful from torch.utils.data import IterableDataset from torchdata.stateful_dataloader import StatefulDataLoader @@ -16,16 +19,13 @@ from torchtitan.datasets.tokenizer import Tokenizer from torchtitan.logging import logger -from datasets import Dataset, load_dataset -from datasets.distributed import split_dataset_by_node - -def load_c4_dataset(dataset_path: str): +def _load_c4_dataset(dataset_path: str): """Load C4 dataset with default configuration.""" return load_dataset(dataset_path, name="en", split="train", streaming=True) -def process_c4_text(sample: Dict[str, Any]) -> str: +def _process_c4_text(sample: Dict[str, Any]) -> str: """Process C4 dataset sample text.""" return sample["text"] @@ -41,18 +41,18 @@ class DatasetConfig: DATASETS = { "c4": DatasetConfig( path="allenai/c4", - loader=load_c4_dataset, - text_processor=process_c4_text, + loader=_load_c4_dataset, + text_processor=_process_c4_text, ), "c4_test": DatasetConfig( path="test/assets/c4_test", - loader=lambda path, **kwargs: load_dataset(path, split="train"), - text_processor=process_c4_text, + loader=lambda path: load_dataset(path, split="train"), + text_processor=_process_c4_text, ), } -def validate_dataset( +def _validate_dataset( dataset_name: str, dataset_path: str = None ) -> tuple[str, Callable, Callable]: """Validate dataset name and path.""" @@ -82,7 +82,7 @@ def __init__( # Force lowercase for consistent comparison dataset_name = dataset_name.lower() - path, dataset_loader, text_processor = validate_dataset( + path, dataset_loader, text_processor = _validate_dataset( dataset_name, dataset_path ) ds = dataset_loader(path) diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 9e83594f..3001ec74 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -42,10 +42,10 @@ context_parallel_degree = 1 pipeline_parallel_degree = 1 [checkpoint] -enable_checkpoint = true +enable_checkpoint = false folder = "checkpoint" interval_type = "steps" -interval = 10 +interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]