Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Passing in Tokenized Data to One-Shot #2202

Merged
merged 7 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/sparseml/transformers/finetune/data/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ class DataTrainingArguments(CustomDataTrainingArguments):
default=512,
metadata={"help": "Number of samples to use for one-shot calibration"},
)
shuffle_calibration_samples: Optional[bool] = field(
default=True,
metadata={
"help": "whether to shuffle the dataset before selecting calibration data"
},
)
streaming: Optional[bool] = field(
default=False,
metadata={"help": "True to stream data from a cloud dataset"},
Expand Down
16 changes: 11 additions & 5 deletions src/sparseml/transformers/finetune/data/data_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import torch
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.data import default_data_collator


Expand All @@ -36,6 +36,7 @@
def format_calibration_data(
tokenized_dataset: Dataset,
num_calibration_samples: Optional[int] = None,
do_shuffle: bool = True,
collate_fn: Callable = default_data_collator,
accelerator: Optional[Any] = None,
) -> List[torch.Tensor]:
Expand All @@ -45,6 +46,8 @@ def format_calibration_data(

:param tokenized_dataset: dataset to convert to dataloader
:param num_calibration_samples: number of data samples to convert
:param do_shuffle: whether to shuffle the dataset before selecting calibration
samples, true by default
:param collate_fn: optional custom collate function, or use default
:param accelerator: optional accelerator for if preparing in FSDP mode
:return: list of trimmed calibration data tensors
Expand All @@ -58,17 +61,20 @@ def format_calibration_data(
f"the provided dataset only has {safe_calibration_samples}. "
)

shuffled_calibration = tokenized_dataset.shuffle()
shuffled_calibration = shuffled_calibration.select(range(safe_calibration_samples))
if do_shuffle:
tokenized_dataset = tokenized_dataset.shuffle()
tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples))

dataloader_params = {
"batch_size": 1,
"sampler": RandomSampler(shuffled_calibration),
"sampler": RandomSampler(tokenized_calibration)
if do_shuffle
else SequentialSampler(tokenized_calibration),
"collate_fn": collate_fn,
"pin_memory": True,
}

calib_dataloader = DataLoader(shuffled_calibration, **dataloader_params)
calib_dataloader = DataLoader(tokenized_calibration, **dataloader_params)
if accelerator:
calib_dataloader = accelerator.prepare(calib_dataloader)

Expand Down
15 changes: 10 additions & 5 deletions src/sparseml/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing import List, Optional

import torch
from torch.nn import Module
from torch.utils.data import Dataset
from transformers import AutoTokenizer

Expand Down Expand Up @@ -72,7 +71,6 @@ def __init__(
data_args: "DataTrainingArguments",
model_args: "ModelArguments",
training_args: "TrainingArguments",
model: Module,
):
self._data_args = data_args
self._model_args = model_args
Expand Down Expand Up @@ -121,9 +119,15 @@ def _get_split_name(inp_str):
tokenizer=tokenizer,
)

raw_dataset = dataset_manager.get_raw_dataset(self._model_args.cache_dir)
tokenized_dataset = dataset_manager.tokenize_and_process(raw_dataset)
tokenized_datasets[split_name] = tokenized_dataset
dataset = self._data_args.dataset
if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names:
# dataset is already tokenized
tokenized_datasets[split_name] = dataset
else:
# dataset needs to be tokenized
raw_dataset = dataset_manager.get_raw_dataset()
tokenized_dataset = dataset_manager.tokenize_and_process(raw_dataset)
tokenized_datasets[split_name] = tokenized_dataset

self.datasets = make_dataset_splits(
tokenized_datasets,
Expand Down Expand Up @@ -154,6 +158,7 @@ def one_shot(self, stage: Optional[str] = None):
calib_data = format_calibration_data(
tokenized_dataset=self.get_dataset_split("calibration"),
num_calibration_samples=self._data_args.num_calibration_samples,
do_shuffle=self._data_args.shuffle_calibration_samples,
accelerator=self.trainer.accelerator,
)

Expand Down
5 changes: 1 addition & 4 deletions src/sparseml/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,7 @@ def main(

# Load datasets
stage_runner = StageRunner(
model_args=model_args,
data_args=data_args,
training_args=training_args,
model=model,
model_args=model_args, data_args=data_args, training_args=training_args
)
stage_runner.populate_datasets(tokenizer=tokenizer)
train_dataset = stage_runner.get_dataset_split("train")
Expand Down
53 changes: 48 additions & 5 deletions tests/sparseml/transformers/finetune/data/test_dataset_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# limitations under the License.

import pytest
from datasets import IterableDataset
import torch
from datasets import IterableDataset, load_dataset

from sparseml.transformers.finetune.data import TextGenerationDataset
from sparseml.transformers.finetune.data.data_args import DataTrainingArguments
from sparseml.transformers.finetune.data.data_helpers import format_calibration_data
from sparseml.transformers.finetune.model_args import ModelArguments
from sparseml.transformers.finetune.runner import StageRunner
from sparseml.transformers.finetune.training_args import TrainingArguments
Expand Down Expand Up @@ -229,13 +231,54 @@ def test_split_loading(split_def, tiny_llama_tokenizer):
training_args = TrainingArguments(do_train=True, output_dir="dummy")
model_args = ModelArguments(model=None)
stage_runner = StageRunner(
model_args=model_args,
data_args=data_args,
training_args=training_args,
model=None,
model_args=model_args, data_args=data_args, training_args=training_args
)
stage_runner.populate_datasets(tokenizer=tiny_llama_tokenizer)

train_dataset = stage_runner.get_dataset_split("train")
assert train_dataset is not None
assert isinstance(train_dataset[0], dict)


def test_load_tokenized_data(tiny_llama_tokenizer):
dataset = load_dataset("garage-bAInd/Open-Platypus")["train"]
NUM_CALIB_SAMPS = 256
MAX_SEQ_LEN = 512
dataset = dataset.shuffle(seed=42).select(range(NUM_CALIB_SAMPS))

def preprocess(sample):
concat_text = "INPUT: " + sample.get("input", "")
concat_text += "INSTRUCTIONS: " + sample.get("instruction", "")
concat_text += "OUTPUT: " + sample.get("output", "")

return tiny_llama_tokenizer(
concat_text, padding=False, max_length=MAX_SEQ_LEN, truncation=True
)

tokenized_dataset = dataset.map(
preprocess, remove_columns=["input", "output", "instruction", "data_source"]
)
stage_runner = StageRunner(
model_args=None,
data_args=DataTrainingArguments(
dataset=tokenized_dataset, shuffle_calibration_samples=False
),
training_args=TrainingArguments(do_oneshot=True),
)
stage_runner.populate_datasets(tokenizer=None)
calib_dataset = stage_runner.get_dataset_split("calibration")
assert len(calib_dataset) == NUM_CALIB_SAMPS
data_cols = calib_dataset.column_names
assert len(data_cols) == 2
assert "input_ids" in data_cols and "attention_mask" in data_cols

# confirm turning shuffle off works
calib_dataloader = format_calibration_data(
tokenized_dataset=calib_dataset,
num_calibration_samples=NUM_CALIB_SAMPS,
do_shuffle=stage_runner._data_args.shuffle_calibration_samples,
)
assert len(calib_dataloader) == NUM_CALIB_SAMPS
dataloader_sample = next(iter(calib_dataloader))["input_ids"]
diff = dataloader_sample - torch.Tensor(calib_dataset[0]["input_ids"])
assert torch.sum(diff) == 0
Loading