diff --git a/setup.py b/setup.py index ff8269257f..d9c8dffd7d 100644 --- a/setup.py +++ b/setup.py @@ -149,6 +149,7 @@ def _parse_requirements_file(file_path): "datasets<2.16", "accelerate<0.26", "seqeval", + "evaluate", ] _sentence_transformers_integration_deps = ["optimum-deepsparse"] + _torch_deps diff --git a/src/deepsparse/evaluation/evaluator.py b/src/deepsparse/evaluation/evaluator.py index 3d18f8489f..3926b78a2a 100644 --- a/src/deepsparse/evaluation/evaluator.py +++ b/src/deepsparse/evaluation/evaluator.py @@ -16,6 +16,9 @@ from typing import List, Optional, Union from deepsparse import Pipeline +from deepsparse.evaluation.integrations.perplexity import ( # noqa + integration_eval as integration_eval_perplexity, +) from deepsparse.evaluation.registry import EvaluationRegistry from deepsparse.evaluation.results import Result from deepsparse.evaluation.utils import create_pipeline diff --git a/src/deepsparse/evaluation/integrations/__init__.py b/src/deepsparse/evaluation/integrations/__init__.py index 15eeee7d8d..f0871f135a 100644 --- a/src/deepsparse/evaluation/integrations/__init__.py +++ b/src/deepsparse/evaluation/integrations/__init__.py @@ -31,3 +31,4 @@ def try_import_lm_evaluation_harness(raise_error=True): if try_import_lm_evaluation_harness(raise_error=False): from .lm_evaluation_harness import * +from .perplexity import * diff --git a/src/deepsparse/evaluation/integrations/perplexity.py b/src/deepsparse/evaluation/integrations/perplexity.py new file mode 100644 index 0000000000..a9a3f3d8a3 --- /dev/null +++ b/src/deepsparse/evaluation/integrations/perplexity.py @@ -0,0 +1,278 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from collections import defaultdict +from typing import Any, Dict, List, Optional, Union + +import numpy +from tqdm import tqdm + +from datasets import load_dataset +from deepsparse import Pipeline +from deepsparse.evaluation.registry import EvaluationRegistry +from deepsparse.evaluation.results import Dataset, Evaluation, Metric, Result +from deepsparse.evaluation.utils import PERPLEXITY +from deepsparse.transformers.metrics import Perplexity +from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline +from deepsparse.transformers.pipelines.text_generation.pipeline_no_kv_cache import ( + TextGenerationPipelineNoCache, +) +from deepsparse.transformers.utils.eval_helpers import ( + HumanEvalIteratorWrapper, + process_concatenated_datasets, +) + + +""" +Integration for the evaluation module +that computes the perplexity of a model on a dataset +""" +_LOGGER = logging.getLogger(__name__) + + +@EvaluationRegistry.register(name=PERPLEXITY) +def integration_eval( + pipeline: Pipeline, + datasets: Union[List[str], str] = "openai_humaneval", + batch_size: int = 1, + limit: Optional[int] = None, + accumulate: Optional[bool] = None, + splits: Union[List[str], str, None] = "test", + metrics: Union[List[str], str, None] = None, + **kwargs, +) -> Result: + """ + A function that computes the perplexity of a pipeline given a set + of dataset names. + + :param pipeline: the pipeline to evaluate. The assumed pipeline + is a TextGenerationPipeline, either with or without the KV + cache support + :param datasets: the names of dataset(s) to evaluate on + :param batch_size: the batch size to use for evaluation + :param splits: the split of the dataset to evaluate on. Default is "test" + :param metrics: the metrics to compute. Default is None + :param limit: the number of batches to evaluate on. Default is None + (evaluates on entire dataset) + :param accumulate: whether to perplexity computation should + accumulate negative log-likelihood over samples. Defaults to + the default accumulate variable inferred from the dataset in + `datasets`. If not None, it will override the inferred accumulate + variable. + :return: a Result object containing the raw and formatted results + """ + metrics = metrics or PERPLEXITY + if metrics != PERPLEXITY: + raise ValueError(f"Invalid metric {metrics} for perplexity evaluation") + if splits is None: + splits = "test" + _LOGGER.info("Argument `splits` is None. Defaulting to `test` split.") + datasets = datasets if isinstance(datasets, list) else [datasets] + results_raw = defaultdict(str) + for dataset_name in datasets: + results_raw[dataset_name] = defaultdict() + dataset, _accumulate = load_perplexity_dataset( + dataset_name=dataset_name, splits=splits, pipeline=pipeline, **kwargs + ) + if accumulate is None: + accumulate = _accumulate + else: + _LOGGER.info( + f"Argument `accumulate` set to {accumulate}. " + "Overriding the inferred accumulate variable from the dataset." + ) + + perplexity = run_perplexity( + pipeline=pipeline, + dataset=dataset, + batch_size=batch_size, + accumulate=accumulate, + limit=limit, + ) + + results_raw[dataset_name] = defaultdict() + results_raw[dataset_name]["results"] = perplexity + results_raw[dataset_name]["split"] = splits + + results = Result( + # omit storing raw results. they can potentially + # contain numpy arrays that are not serializable. + # all the information is stored in the formatted results + raw=None, + formatted=format_raw_results(results_raw), + ) + + return results + + +def run_perplexity( + pipeline: Union[TextGenerationPipelineNoCache, TextGenerationPipeline], + dataset: "Dataset", + batch_size: int, + accumulate: bool, + limit: Optional[int] = None, +) -> Dict[str, Any]: + """ + Compute the perplexity of a pipeline given a dataset. + + :param pipeline: the pipeline to evaluate. The assumed pipeline + is a TextGenerationPipeline, either with or without the KV + cache support + :param dataset: the dataset to evaluate on + :param batch_size: the batch size to use for evaluation + :param accumulate: whether to perplexity computation should + accumulate negative log-likelihood over samples + :param limit: the number of batches to evaluate on. Default is None + (evaluates on entire dataset) + + :return: a dictionary containing the perplexity results + """ + + perplexity = Perplexity(accumulate=accumulate) + + batch = [] + for idx, sample in _enumerate_progress( + dataset, max_steps=None if limit is None else limit * batch_size + ): + + if limit is not None: + # stop if we have reached the #limit + # number of batches to be processed + if idx >= limit * batch_size: + break + + batch.append(sample) + + if len(batch) == batch_size: + if isinstance(pipeline, TextGenerationPipelineNoCache): + out = pipeline( + prompt=batch, + output_scores=True, + include_prompt_logits=True, + return_input_tokens=True, + ) + else: + out = pipeline( + prompt=batch, + output_scores=True, + max_new_tokens=0, + include_prompt_logits=True, + return_input_tokens=True, + ) + + for s in range(batch_size): + # Need to remove tokens that were masked + input_ids = out.input_tokens["input_ids"][s].flatten() + attention_mask = out.input_tokens["attention_mask"][s].flatten() + logits = out.generations[s].score + if batch_size > 1 and isinstance( + pipeline, TextGenerationPipelineNoCache + ): + logits = logits[-attention_mask.sum() :, :] + + logits = numpy.compress(attention_mask, logits, axis=0)[:-1, :] + input_ids = numpy.compress(attention_mask, input_ids)[1:] + + # Add predictions (logits) and targets (input_ids) to metric + perplexity.add_batch(logits, input_ids) + + batch.clear() + + return perplexity.compute() + + +def format_raw_results(results: Dict[str, Any]) -> List[Evaluation]: + """ + Format the raw perplexity results into a list of + Evaluation objects. + + :param results: the raw results from perplexity computation + :return: the formatted results as a list of Evaluation objects + """ + formatted_results = [] + for dataset_name, dataset_result in results.items(): + metrics = [] + for metric_name, metric_value in dataset_result["results"].items(): + if isinstance(metric_value, numpy.ndarray): + metric_value = metric_value.tolist() + metric = Metric(name=metric_name, value=metric_value) + metrics.append(metric) + dataset = Dataset(type=None, name=dataset_name, split=dataset_result["split"]) + evaluation = Evaluation( + task="perplexity", + dataset=dataset, + metrics=metrics, + samples=None, + ) + formatted_results.append(evaluation) + return formatted_results + + +def load_perplexity_dataset( + dataset_name: str, + splits: Union[List[str], str] = "test", + pipeline: Optional[Pipeline] = None, + **kwargs, +): + """ + Function to load the dataset for perplexity computation. + Eventually we want to load the dataset from the nm_utils + + :param dataset_name: the name of the dataset to load + :param splits: the splits to load from the dataset. Default is "test" + :param pipeline: the pipeline to use for loading the dataset. The pipeline + is used to infer the model path and sequence length to use for loading + the dataset. This argument can be omitted if the appropriate kwargs + are provided, or if the dataset does not require a process_concatenated_datasets + function to load the dataset. + :param kwargs: additional keyword arguments to pass to the dataset loading function + :return: the dataset and whether to accumulate perplexity over samples + """ + if isinstance(splits, list): + raise NotImplementedError("Evaluation on multiple splits not implemented") + + if dataset_name == "openai_humaneval": + dataset = load_dataset(dataset_name, split=splits) + dataset = HumanEvalIteratorWrapper(dataset) + accumulate = False + elif dataset_name in {"wikitext2", "c4"}: + # fetch max_sequence_length from pipeline if not provided + max_sequence_length = kwargs.pop("max_sequence_length", None) + if max_sequence_length is None and pipeline is not None: + max_sequence_length = pipeline.sequence_length + + # fetch model_path from pipeline if not provided + model_path = kwargs.pop("model_path", None) + if model_path is None and pipeline is not None: + model_path = os.path.dirname(pipeline.model_path) + + dataset = process_concatenated_datasets( + dataset_name, + model_path=model_path, + max_sequence_length=max_sequence_length, + split=splits, + **kwargs, + ) + accumulate = True + else: + raise NotImplementedError(f"Dataset {dataset_name} not implemented") + + return dataset, accumulate + + +def _enumerate_progress(dataset, max_steps): + progress_bar = tqdm(dataset, total=max_steps) if max_steps else tqdm(dataset) + return enumerate(progress_bar) diff --git a/src/deepsparse/evaluation/results.py b/src/deepsparse/evaluation/results.py index 00212d0a1e..78c4bbd501 100644 --- a/src/deepsparse/evaluation/results.py +++ b/src/deepsparse/evaluation/results.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional +from typing import Any, List, Optional, Union import yaml from pydantic import BaseModel, Field @@ -32,7 +32,7 @@ class Metric(BaseModel): name: str = Field(description="Name of the metric") - value: float = Field(description="Value of the metric") + value: Union[float, List[float]] = Field(description="Value of the metric") class Dataset(BaseModel): diff --git a/src/deepsparse/evaluation/utils.py b/src/deepsparse/evaluation/utils.py index ff2619315b..a5dc460596 100644 --- a/src/deepsparse/evaluation/utils.py +++ b/src/deepsparse/evaluation/utils.py @@ -27,7 +27,9 @@ "resolve_integration", ] _LOGGER = logging.getLogger(__name__) + LM_EVALUATION_HARNESS = "lm-evaluation-harness" +PERPLEXITY = "perplexity" def potentially_check_dependency_import(integration_name: str) -> bool: diff --git a/src/deepsparse/transformers/metrics.py b/src/deepsparse/transformers/metrics.py index b90c4dd744..0e7c24c8b6 100644 --- a/src/deepsparse/transformers/metrics.py +++ b/src/deepsparse/transformers/metrics.py @@ -20,7 +20,7 @@ import numpy -from deepsparse.utils import numpy_log_softmax +from deepsparse.utils.data import numpy_log_softmax __all__ = [ diff --git a/src/deepsparse/transformers/utils/eval_helpers.py b/src/deepsparse/transformers/utils/eval_helpers.py index 4c0e68b9de..012520b9b5 100644 --- a/src/deepsparse/transformers/utils/eval_helpers.py +++ b/src/deepsparse/transformers/utils/eval_helpers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Mapping, Union +from typing import List, Union import numpy from transformers import AutoTokenizer, PreTrainedTokenizerFast @@ -27,7 +27,8 @@ def process_concatenated_datasets( dataset_name: str, model_path: str, max_sequence_length: int, - kwargs: Mapping, + split: str = "test", + **kwargs, ) -> list: """ Concatenate text datasets and split them into chunks text that, after @@ -38,6 +39,8 @@ def process_concatenated_datasets( Options: "wikitext2" or "c4". model_path (str): The path to a pretrained transformer model for tokenization. max_sequence_length (int): The maximum number of tokens in each sequence. + split (str, optional): The split of the dataset to use. + Default is "test". kwargs (mapping): Additional keyword arguments. - eos (str, optional): The end-of-sentence token. Default is "\n\n" for wikitext2 and "" for c4. @@ -65,13 +68,13 @@ def process_concatenated_datasets( eos = kwargs.get("eos", "\n\n") bos = kwargs.get("bos", "") - raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split=split) raw_text = raw_dataset["text"] elif dataset_name == "c4": eos = kwargs.get("eos", "<|endoftext|>") bos = kwargs.get("bos", "") raw_samples = kwargs.get("raw_samples", None) - data_file = kwargs.get("data_file", 0) + data_file = kwargs.get("data_file", None) if data_file is not None: raw_dataset = load_dataset( "allenai/c4", @@ -79,13 +82,13 @@ def process_concatenated_datasets( data_files={ "validation": f"en/c4-validation.{data_file:05d}-of-00008.json.gz" }, - split="validation", + split=split, ) else: raw_dataset = load_dataset( "allenai/c4", "allenai--c4", - split="validation", + split=split, ) if raw_samples is not None: raw_dataset = raw_dataset[:raw_samples] @@ -181,3 +184,22 @@ def _split_text_by_tokens( ) return split_text + + +class HumanEvalIteratorWrapper: + """ + Wrapper around the `openai_humaneval` dataset, + that joins the prompt and the canonical solution + into a single string during iteration. + """ + + def __init__(self, dataset): + self.iterator = iter(dataset) + + def __iter__(self): + return self + + def __next__(self): + # Get the next sample from the original iterator + sample = next(self.iterator) + return sample["prompt"] + sample["canonical_solution"] diff --git a/tests/deepsparse/evaluation/integrations/test_perplexity.py b/tests/deepsparse/evaluation/integrations/test_perplexity.py new file mode 100644 index 0000000000..b156e5b9a4 --- /dev/null +++ b/tests/deepsparse/evaluation/integrations/test_perplexity.py @@ -0,0 +1,132 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import copy + +import numpy as np + +import pytest +from deepsparse.evaluation.integrations.perplexity import ( + integration_eval, + load_perplexity_dataset, +) +from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline +from evaluate import load + + +@pytest.fixture() +def model_path(): + return "hf:mgoin/TinyStories-1M-deepsparse" + + +@pytest.fixture() +def model_id(): + return "roneneldan/TinyStories-1M" + + +@pytest.mark.parametrize( + "datasets", + [ + "openai_humaneval", + "wikitext2", + ], +) +@pytest.mark.parametrize("batch_size", [1, 2]) +class TestPerplexity: + limit = 2 + + def test_perplexity_ground_truth_equal_pipeline( + self, model_path, model_id, datasets, batch_size + ): + # setting max_sequence_length to 16 to speed up the test + kwargs_ground_truth = ( + dict(max_sequence_length=16) if datasets in {"c4", "wikitext2"} else {} + ) + kwargs = copy(kwargs_ground_truth) + + result_gt = self._get_ground_truth( + datasets=datasets, + batch_size=batch_size, + limit=self.limit, + model_id=model_id, + kwargs=kwargs_ground_truth, + ) + + result = integration_eval( + pipeline=TextGenerationPipeline( + model_path="hf:mgoin/TinyStories-1M-deepsparse", + engine_type="onnxruntime", + ), + datasets=datasets, + batch_size=batch_size, + limit=self.limit, + # we are setting accumulate=False to compare + # with the torch ground truth apples to apples + accumulate=False, + **kwargs, + ) + perplexities = result.formatted[0].metrics[0].value + perplexities_gt = result_gt["perplexities"] + assert np.allclose(perplexities, perplexities_gt, rtol=0.1) + + def test_perplexity_kv_cache_pipeline_equal_no_kv_cache_pipeline( + self, model_path, model_id, datasets, batch_size + ): + + kwargs_ground_truth = ( + dict(max_sequence_length=16) if datasets in {"c4", "wikitext2"} else {} + ) + kwargs = copy(kwargs_ground_truth) + + result_kv_cache = integration_eval( + pipeline=TextGenerationPipeline( + model_path="hf:mgoin/TinyStories-1M-deepsparse", + engine_type="onnxruntime", + ), + datasets=datasets, + model_path=model_id, + batch_size=batch_size, + limit=self.limit, + **kwargs, + ) + + result_non_kv_cache = integration_eval( + pipeline=TextGenerationPipeline( + model_path="hf:mgoin/TinyStories-1M-deepsparse", + engine_type="onnxruntime", + onnx_model_name="model-orig.onnx", + ), + datasets=datasets, + batch_size=batch_size, + limit=self.limit, + **kwargs, + ) + + perplexities_kv_cache = result_kv_cache.formatted[0].metrics[0].value + perplexities_non_kv_cache = result_non_kv_cache.formatted[0].metrics[0].value + np.allclose(perplexities_kv_cache, perplexities_non_kv_cache, rtol=0.1) + + @staticmethod + def _get_ground_truth(datasets, batch_size, limit, model_id, kwargs={}): + perplexity = load("perplexity", module_type="metric") + kwargs["model_path"] = model_id + dataset, *_ = load_perplexity_dataset(dataset_name=datasets, **kwargs) + predictions = [] + for i, sample in enumerate(dataset): + if i == batch_size * limit: + break + predictions.append(sample) + return perplexity.compute( + predictions=predictions, add_start_token=False, model_id=model_id + )