Skip to content

Commit

Permalink
feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim committed Dec 5, 2024
1 parent c86a20c commit 65348c6
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ You may want to see how the model is defined or how parallelism techniques are a
4. `torch.compile` support
5. [Float8](https://discuss.pytorch.org/t/distributed-w-torchtitan-enabling-float8-all-gather-in-fsdp2/209323) support ([how-to](docs/float8.md))
6. DDP and HSDP
7. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries)
7. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) and support for [custom datasets](docs/datasets.md)
8. Learning rate scheduler, meta-init, (optional) fused RMSNorm kernel
9. Loss, GPU memory, throughput (tokens/sec), and MFU displayed and logged via [Tensorboard or Weights & Biases](/docs/metrics.md)
10. Debugging tools including CPU/GPU profiling, [memory profiling](docs/memory_profiler.md), [Flight Recorder](#debugging), etc.
Expand Down
2 changes: 1 addition & 1 deletion docs/datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,6 @@ That's it! Your custom dataset is now ready to use with TorchTitan.
- `text_processor`: Function to process individual samples
- The loader function should return a HuggingFace dataset object
- The processor function should return a string that combines the relevant fields from your dataset
- Use streaming=True for large datasets to manage memory efficiently
- Use `streaming=True` for large datasets to manage memory efficiently

Now you can start training with your custom dataset!
22 changes: 11 additions & 11 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,23 @@
from typing import Any, Callable, Dict, List, Optional

import torch

from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader

from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging import logger

from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node


def load_c4_dataset(dataset_path: str):
def _load_c4_dataset(dataset_path: str):
"""Load C4 dataset with default configuration."""
return load_dataset(dataset_path, name="en", split="train", streaming=True)


def process_c4_text(sample: Dict[str, Any]) -> str:
def _process_c4_text(sample: Dict[str, Any]) -> str:
"""Process C4 dataset sample text."""
return sample["text"]

Expand All @@ -41,18 +41,18 @@ class DatasetConfig:
DATASETS = {
"c4": DatasetConfig(
path="allenai/c4",
loader=load_c4_dataset,
text_processor=process_c4_text,
loader=_load_c4_dataset,
text_processor=_process_c4_text,
),
"c4_test": DatasetConfig(
path="test/assets/c4_test",
loader=lambda path, **kwargs: load_dataset(path, split="train"),
text_processor=process_c4_text,
loader=lambda path: load_dataset(path, split="train"),
text_processor=_process_c4_text,
),
}


def validate_dataset(
def _validate_dataset(
dataset_name: str, dataset_path: str = None
) -> tuple[str, Callable, Callable]:
"""Validate dataset name and path."""
Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(
# Force lowercase for consistent comparison
dataset_name = dataset_name.lower()

path, dataset_loader, text_processor = validate_dataset(
path, dataset_loader, text_processor = _validate_dataset(
dataset_name, dataset_path
)
ds = dataset_loader(path)
Expand Down
4 changes: 2 additions & 2 deletions train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ context_parallel_degree = 1
pipeline_parallel_degree = 1

[checkpoint]
enable_checkpoint = true
enable_checkpoint = false
folder = "checkpoint"
interval_type = "steps"
interval = 10
interval = 500
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
Expand Down

0 comments on commit 65348c6

Please sign in to comment.