Skip to content

Commit

Permalink
pass in kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim committed Dec 4, 2024
1 parent 049b755 commit 8ec6a2e
Showing 1 changed file with 25 additions and 11 deletions.
36 changes: 25 additions & 11 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,49 @@
# LICENSE file in the root directory of this source tree.

import pickle
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict, List, Optional

import torch
from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader

from torchtitan.datasets.tokenizer import Tokenizer
from torchtitan.logging import logger

from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node


def load_c4_dataset(dataset_path: str, **kwargs):
"""Load C4 dataset with specific configuration."""
logger.info("Loading C4 dataset...")
return load_dataset(dataset_path, name="en", split="train", streaming=True)

# Default settings for C4
default_config = {"name": "en", "split": "train", "streaming": True}

# kwargs override defaults
config = {**default_config, **kwargs}

return load_dataset(dataset_path, **config)


def load_wikipedia_dataset(dataset_path: str, **kwargs):
"""Load Wikipedia dataset with specific configuration."""
logger.info("Loading Wikipedia dataset...")
return load_dataset(
dataset_path,
name="20220301.en",
split="train",
streaming=True,
trust_remote_code=True,
)

# Default settings for Wikipedia
default_config = {
"name": "20220301.en",
"split": "train",
"streaming": True,
"trust_remote_code": True,
}

# kwargs override defaults
config = {**default_config, **kwargs}

return load_dataset(dataset_path, **config)


def process_c4_text(sample: Dict[str, Any]) -> str:
Expand Down

0 comments on commit 8ec6a2e

Please sign in to comment.