Skip to content

Commit

Permalink
[DeepSparse Evaluation API] Perplexity (#1555)
Browse files Browse the repository at this point in the history
* initial commit

* Update src/deepsparse/evaluation/integrations/__init__.py

* design ready, time to define additional features

* split prep_for_generation operator

* fix logits

* update non-kv cache pipeline and tests

* add tests to address edge cases

* add condition to check of kv_cache full during prompt inference, add test to cover this case, revert debugging changes

* fix typing

* remove commented code

* remove irrelevant condition

* perplexity for non-kv cache pipelines works!

* logic is working

* ready for review

* [DeepSparse Evaluation API] Perplexity eval support for `openai_humaneval`, `c4`, `wikitext2` (#1586)

* fix tests 2

* initial commit

* add return to a function

* make script more robust

---------

Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
  • Loading branch information
dbogunowicz and dsikka authored Feb 9, 2024
1 parent e0b4f36 commit b82b49b
Show file tree
Hide file tree
Showing 9 changed files with 448 additions and 9 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def _parse_requirements_file(file_path):
"datasets<2.16",
"accelerate<0.26",
"seqeval",
"evaluate",
]
_sentence_transformers_integration_deps = ["optimum-deepsparse"] + _torch_deps

Expand Down
3 changes: 3 additions & 0 deletions src/deepsparse/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from typing import List, Optional, Union

from deepsparse import Pipeline
from deepsparse.evaluation.integrations.perplexity import ( # noqa
integration_eval as integration_eval_perplexity,
)
from deepsparse.evaluation.registry import EvaluationRegistry
from deepsparse.evaluation.results import Result
from deepsparse.evaluation.utils import create_pipeline
Expand Down
1 change: 1 addition & 0 deletions src/deepsparse/evaluation/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ def try_import_lm_evaluation_harness(raise_error=True):

if try_import_lm_evaluation_harness(raise_error=False):
from .lm_evaluation_harness import *
from .perplexity import *
278 changes: 278 additions & 0 deletions src/deepsparse/evaluation/integrations/perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from collections import defaultdict
from typing import Any, Dict, List, Optional, Union

import numpy
from tqdm import tqdm

from datasets import load_dataset
from deepsparse import Pipeline
from deepsparse.evaluation.registry import EvaluationRegistry
from deepsparse.evaluation.results import Dataset, Evaluation, Metric, Result
from deepsparse.evaluation.utils import PERPLEXITY
from deepsparse.transformers.metrics import Perplexity
from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline
from deepsparse.transformers.pipelines.text_generation.pipeline_no_kv_cache import (
TextGenerationPipelineNoCache,
)
from deepsparse.transformers.utils.eval_helpers import (
HumanEvalIteratorWrapper,
process_concatenated_datasets,
)


"""
Integration for the evaluation module
that computes the perplexity of a model on a dataset
"""
_LOGGER = logging.getLogger(__name__)


@EvaluationRegistry.register(name=PERPLEXITY)
def integration_eval(
pipeline: Pipeline,
datasets: Union[List[str], str] = "openai_humaneval",
batch_size: int = 1,
limit: Optional[int] = None,
accumulate: Optional[bool] = None,
splits: Union[List[str], str, None] = "test",
metrics: Union[List[str], str, None] = None,
**kwargs,
) -> Result:
"""
A function that computes the perplexity of a pipeline given a set
of dataset names.
:param pipeline: the pipeline to evaluate. The assumed pipeline
is a TextGenerationPipeline, either with or without the KV
cache support
:param datasets: the names of dataset(s) to evaluate on
:param batch_size: the batch size to use for evaluation
:param splits: the split of the dataset to evaluate on. Default is "test"
:param metrics: the metrics to compute. Default is None
:param limit: the number of batches to evaluate on. Default is None
(evaluates on entire dataset)
:param accumulate: whether to perplexity computation should
accumulate negative log-likelihood over samples. Defaults to
the default accumulate variable inferred from the dataset in
`datasets`. If not None, it will override the inferred accumulate
variable.
:return: a Result object containing the raw and formatted results
"""
metrics = metrics or PERPLEXITY
if metrics != PERPLEXITY:
raise ValueError(f"Invalid metric {metrics} for perplexity evaluation")
if splits is None:
splits = "test"
_LOGGER.info("Argument `splits` is None. Defaulting to `test` split.")
datasets = datasets if isinstance(datasets, list) else [datasets]
results_raw = defaultdict(str)
for dataset_name in datasets:
results_raw[dataset_name] = defaultdict()
dataset, _accumulate = load_perplexity_dataset(
dataset_name=dataset_name, splits=splits, pipeline=pipeline, **kwargs
)
if accumulate is None:
accumulate = _accumulate
else:
_LOGGER.info(
f"Argument `accumulate` set to {accumulate}. "
"Overriding the inferred accumulate variable from the dataset."
)

perplexity = run_perplexity(
pipeline=pipeline,
dataset=dataset,
batch_size=batch_size,
accumulate=accumulate,
limit=limit,
)

results_raw[dataset_name] = defaultdict()
results_raw[dataset_name]["results"] = perplexity
results_raw[dataset_name]["split"] = splits

results = Result(
# omit storing raw results. they can potentially
# contain numpy arrays that are not serializable.
# all the information is stored in the formatted results
raw=None,
formatted=format_raw_results(results_raw),
)

return results


def run_perplexity(
pipeline: Union[TextGenerationPipelineNoCache, TextGenerationPipeline],
dataset: "Dataset",
batch_size: int,
accumulate: bool,
limit: Optional[int] = None,
) -> Dict[str, Any]:
"""
Compute the perplexity of a pipeline given a dataset.
:param pipeline: the pipeline to evaluate. The assumed pipeline
is a TextGenerationPipeline, either with or without the KV
cache support
:param dataset: the dataset to evaluate on
:param batch_size: the batch size to use for evaluation
:param accumulate: whether to perplexity computation should
accumulate negative log-likelihood over samples
:param limit: the number of batches to evaluate on. Default is None
(evaluates on entire dataset)
:return: a dictionary containing the perplexity results
"""

perplexity = Perplexity(accumulate=accumulate)

batch = []
for idx, sample in _enumerate_progress(
dataset, max_steps=None if limit is None else limit * batch_size
):

if limit is not None:
# stop if we have reached the #limit
# number of batches to be processed
if idx >= limit * batch_size:
break

batch.append(sample)

if len(batch) == batch_size:
if isinstance(pipeline, TextGenerationPipelineNoCache):
out = pipeline(
prompt=batch,
output_scores=True,
include_prompt_logits=True,
return_input_tokens=True,
)
else:
out = pipeline(
prompt=batch,
output_scores=True,
max_new_tokens=0,
include_prompt_logits=True,
return_input_tokens=True,
)

for s in range(batch_size):
# Need to remove tokens that were masked
input_ids = out.input_tokens["input_ids"][s].flatten()
attention_mask = out.input_tokens["attention_mask"][s].flatten()
logits = out.generations[s].score
if batch_size > 1 and isinstance(
pipeline, TextGenerationPipelineNoCache
):
logits = logits[-attention_mask.sum() :, :]

logits = numpy.compress(attention_mask, logits, axis=0)[:-1, :]
input_ids = numpy.compress(attention_mask, input_ids)[1:]

# Add predictions (logits) and targets (input_ids) to metric
perplexity.add_batch(logits, input_ids)

batch.clear()

return perplexity.compute()


def format_raw_results(results: Dict[str, Any]) -> List[Evaluation]:
"""
Format the raw perplexity results into a list of
Evaluation objects.
:param results: the raw results from perplexity computation
:return: the formatted results as a list of Evaluation objects
"""
formatted_results = []
for dataset_name, dataset_result in results.items():
metrics = []
for metric_name, metric_value in dataset_result["results"].items():
if isinstance(metric_value, numpy.ndarray):
metric_value = metric_value.tolist()
metric = Metric(name=metric_name, value=metric_value)
metrics.append(metric)
dataset = Dataset(type=None, name=dataset_name, split=dataset_result["split"])
evaluation = Evaluation(
task="perplexity",
dataset=dataset,
metrics=metrics,
samples=None,
)
formatted_results.append(evaluation)
return formatted_results


def load_perplexity_dataset(
dataset_name: str,
splits: Union[List[str], str] = "test",
pipeline: Optional[Pipeline] = None,
**kwargs,
):
"""
Function to load the dataset for perplexity computation.
Eventually we want to load the dataset from the nm_utils
:param dataset_name: the name of the dataset to load
:param splits: the splits to load from the dataset. Default is "test"
:param pipeline: the pipeline to use for loading the dataset. The pipeline
is used to infer the model path and sequence length to use for loading
the dataset. This argument can be omitted if the appropriate kwargs
are provided, or if the dataset does not require a process_concatenated_datasets
function to load the dataset.
:param kwargs: additional keyword arguments to pass to the dataset loading function
:return: the dataset and whether to accumulate perplexity over samples
"""
if isinstance(splits, list):
raise NotImplementedError("Evaluation on multiple splits not implemented")

if dataset_name == "openai_humaneval":
dataset = load_dataset(dataset_name, split=splits)
dataset = HumanEvalIteratorWrapper(dataset)
accumulate = False
elif dataset_name in {"wikitext2", "c4"}:
# fetch max_sequence_length from pipeline if not provided
max_sequence_length = kwargs.pop("max_sequence_length", None)
if max_sequence_length is None and pipeline is not None:
max_sequence_length = pipeline.sequence_length

# fetch model_path from pipeline if not provided
model_path = kwargs.pop("model_path", None)
if model_path is None and pipeline is not None:
model_path = os.path.dirname(pipeline.model_path)

dataset = process_concatenated_datasets(
dataset_name,
model_path=model_path,
max_sequence_length=max_sequence_length,
split=splits,
**kwargs,
)
accumulate = True
else:
raise NotImplementedError(f"Dataset {dataset_name} not implemented")

return dataset, accumulate


def _enumerate_progress(dataset, max_steps):
progress_bar = tqdm(dataset, total=max_steps) if max_steps else tqdm(dataset)
return enumerate(progress_bar)
4 changes: 2 additions & 2 deletions src/deepsparse/evaluation/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List, Optional
from typing import Any, List, Optional, Union

import yaml
from pydantic import BaseModel, Field
Expand All @@ -32,7 +32,7 @@

class Metric(BaseModel):
name: str = Field(description="Name of the metric")
value: float = Field(description="Value of the metric")
value: Union[float, List[float]] = Field(description="Value of the metric")


class Dataset(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions src/deepsparse/evaluation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
"resolve_integration",
]
_LOGGER = logging.getLogger(__name__)

LM_EVALUATION_HARNESS = "lm-evaluation-harness"
PERPLEXITY = "perplexity"


def potentially_check_dependency_import(integration_name: str) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion src/deepsparse/transformers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import numpy

from deepsparse.utils import numpy_log_softmax
from deepsparse.utils.data import numpy_log_softmax


__all__ = [
Expand Down
Loading

0 comments on commit b82b49b

Please sign in to comment.