diff --git a/setup.py b/setup.py index 31edd30b01..d37bb1eb42 100644 --- a/setup.py +++ b/setup.py @@ -150,6 +150,7 @@ def _parse_requirements_file(file_path): "accelerate<0.26", "scikit-learn", "seqeval", + "evaluate", ] _sentence_transformers_integration_deps = ["optimum-deepsparse"] + _torch_deps @@ -310,7 +311,7 @@ def _setup_entry_points() -> Dict: f"deepsparse.image_classification.eval={ic_eval}", "deepsparse.license=deepsparse.license:main", "deepsparse.validate_license=deepsparse.license:validate_license_cli", - "deepsparse.eval=deepsparse.evaluation.cli:main", + "deepsparse.evaluate=deepsparse.evaluation.cli:main", ] } diff --git a/src/deepsparse/evaluation/cli.py b/src/deepsparse/evaluation/cli.py index b68d32d4e5..d192dd67a1 100644 --- a/src/deepsparse/evaluation/cli.py +++ b/src/deepsparse/evaluation/cli.py @@ -20,7 +20,7 @@ Module for evaluating models on the various evaluation integrations OPTIONS: - --model_path MODEL_PATH + MODEL_PATH A path to an ONNX model, local directory containing ONNX model (including all the auxiliary files) or a SparseZoo stub -d DATASET, --dataset DATASET @@ -72,7 +72,7 @@ from deepsparse.evaluation.evaluator import evaluate from deepsparse.evaluation.results import Result, save_result -from deepsparse.evaluation.utils import args_to_dict, get_save_path +from deepsparse.evaluation.utils import get_save_path, parse_kwarg_tuples from deepsparse.operators.engine_operator import ( DEEPSPARSE_ENGINE, ORT_ENGINE, @@ -88,12 +88,10 @@ ignore_unknown_options=True, ) ) -@click.option( - "--model_path", +@click.argument( + "model_path", type=click.Path(dir_okay=True, file_okay=True), required=True, - help="A path to an ONNX model, local directory containing ONNX model" - "(including all the auxiliary files) or a SparseZoo stub", ) @click.option( "-d", @@ -178,7 +176,7 @@ def main( # join datasets to a list if multiple datasets are passed datasets = list(dataset) if not isinstance(dataset, str) else dataset # format kwargs to a dict - integration_args = args_to_dict(integration_args) + integration_args = parse_kwarg_tuples(integration_args) _LOGGER.info( f"Creating {engine_type} pipeline to evaluate from model path: {model_path}" @@ -203,7 +201,7 @@ def main( **integration_args, ) - _LOGGER.info(f"Evaluation done. Results:\n{result}") + _LOGGER.info(f"Evaluation done. Results:\n{result.formatted}") save_path = get_save_path( save_path=save_path, diff --git a/src/deepsparse/evaluation/evaluator.py b/src/deepsparse/evaluation/evaluator.py index b513f07563..3d18f8489f 100644 --- a/src/deepsparse/evaluation/evaluator.py +++ b/src/deepsparse/evaluation/evaluator.py @@ -65,7 +65,6 @@ def evaluate( return eval_integration( pipeline=pipeline, datasets=datasets, - engine_type=engine_type, batch_size=batch_size, splits=splits, metrics=metrics, diff --git a/src/deepsparse/evaluation/integrations/__init__.py b/src/deepsparse/evaluation/integrations/__init__.py index 1cc3bfacf0..f0871f135a 100644 --- a/src/deepsparse/evaluation/integrations/__init__.py +++ b/src/deepsparse/evaluation/integrations/__init__.py @@ -15,7 +15,7 @@ # flake8: noqa: F401 -def try_import_lm_evaluation_harness(raise_error=False): +def try_import_lm_evaluation_harness(raise_error=True): try: import lm_eval @@ -24,11 +24,11 @@ def try_import_lm_evaluation_harness(raise_error=False): if raise_error: raise ImportError( "Unable to import lm_eval. " - "To install run 'pip install " - "git+https://github.com/EleutherAI/lm-evaluation-harness@b018a7d51'" + "To install run 'pip install lm-eval==0.4.0'" ) return False if try_import_lm_evaluation_harness(raise_error=False): from .lm_evaluation_harness import * +from .perplexity import * diff --git a/src/deepsparse/evaluation/integrations/lm_evaluation_harness.py b/src/deepsparse/evaluation/integrations/lm_evaluation_harness.py index 2f8c7b8cef..69934af37a 100644 --- a/src/deepsparse/evaluation/integrations/lm_evaluation_harness.py +++ b/src/deepsparse/evaluation/integrations/lm_evaluation_harness.py @@ -13,35 +13,39 @@ # limitations under the License. """ -Integration of the `lm_evaluation_harness`: +Integration of the `lm-evaluation-harness`: https://github.com/EleutherAI/lm-evaluation-harness """ - -import json import logging from typing import Any, Dict, List, Optional, Tuple, Union import numpy -from pydantic import BaseModel, Field from tqdm import tqdm -import torch from deepsparse import Pipeline from deepsparse.evaluation.registry import EvaluationRegistry from deepsparse.evaluation.results import Dataset, Evaluation, Metric, Result -from lm_eval import base, evaluator, tasks, utils +from deepsparse.evaluation.utils import LM_EVALUATION_HARNESS +from deepsparse.utils.data import numpy_log_softmax +from lm_eval import evaluator, tasks, utils +from lm_eval.api.instance import Instance +from lm_eval.api.model import LM + +tasks.initialize_tasks("INFO") _LOGGER = logging.getLogger(__name__) __all__ = ["integration_eval"] -@EvaluationRegistry.register(name="lm-evaluation-harness") +@EvaluationRegistry.register(name=LM_EVALUATION_HARNESS, alias="lm-eval-harness") def integration_eval( - model: Any, + pipeline: Pipeline, datasets: Union[List[str], str], - batch_size: int, + batch_size: int = 1, + splits: Union[List[str], str, None] = None, + metrics: Union[List[str], str, None] = None, **kwargs, ) -> Result: """ @@ -49,101 +53,53 @@ def integration_eval( https://github.com/EleutherAI/lm-evaluation-harness/blob/master/main.py that is compatible with deepsparse.evaluator.py - :param model: the model/pipeline to evaluate + :param pipeline: the model/pipeline to evaluate :param datasets: the datasets to evaluate on :param batch_size: the batch size to use for evaluation :param kwargs: additional arguments to alter the behavior of the evaluation :return the evaluation results """ - # [START] - # The code that sets up the interface between deepsparse and lm_evaluation_harness - if isinstance(model, Pipeline): - # If the model is a Pipeline, we need to wrap - # it in a DeepSparseLM object - model = DeepSparseLM( - pipeline=model, - batch_size=batch_size, - max_gen_toks=kwargs.get("max_gen_toks"), - ) + pipeline = DeepSparseLM(pipeline=pipeline, batch_size=batch_size) datasets = (",").join(datasets) if isinstance(datasets, list) else datasets - # [END] - - # [START] - # The code below is being adapted from: - # https://github.com/EleutherAI/lm-evaluation-harness/blob/master/main.py - if kwargs.get("limit"): - _LOGGER.warning( - "WARNING: --limit SHOULD ONLY BE USED FOR TESTING. " - "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." - ) - - if datasets is None: - task_names = tasks.ALL_TASKS - else: - task_names = utils.pattern_match(datasets.split(","), tasks.ALL_TASKS) + task_names = utils.pattern_match(datasets.split(","), tasks.ALL_TASKS) _LOGGER.info(f"Selected Tasks: {task_names}") - description_dict = {} - if kwargs.get("description_dict_path"): - with open(kwargs.get("description_dict_path"), "r") as f: - description_dict = json.load(f) - - evaluator_input = EvaluatorInputSchema( - model=model, - tasks=task_names, - description_dict=description_dict, - batch_size=batch_size, - **kwargs, + results_raw = evaluator.simple_evaluate( + model=pipeline, tasks=task_names, batch_size=batch_size, **kwargs ) - results_raw = evaluator.simple_evaluate(**evaluator_input.dict()) - results = Result( - raw=dict(output=results_raw, input=filter_evaluator_input(evaluator_input)), + raw=results_raw, formatted=format_raw_results(results_raw), ) return results -def filter_evaluator_input( - evaluator_input: "EvaluatorInputSchema", -) -> Dict[str, Any]: # noqa: F821 - """ - Filter the evaluator input to remove the model field. - The model field is a complex object that cannot be serialized. - - :param evaluator_input: the evaluator input to filter - :return: the filtered evaluator input - """ - evaluator = evaluator_input.dict() - del evaluator["model"] - - return evaluator - - def format_raw_results(results: Dict[str, Any]) -> List[Evaluation]: """ Format the raw results from lm_evaluation_harness into a list of Evaluation objects. - :param results: the raw results from lm_evaluation_harness + :param results: the raw results from lm-evaluation-harness :return: the formatted results as a list of Evaluation objects """ formatted_results = [] for dataset_name, dataset_result in results["results"].items(): metrics = [] for metric_name, metric_value in dataset_result.items(): + if isinstance(metric_value, str): + continue metric = Metric(name=metric_name, value=metric_value) metrics.append(metric) dataset = Dataset( type=None, name=dataset_name, config=results["config"], split=None ) evaluation = Evaluation( - task="lm_evaluation_harness", + task=LM_EVALUATION_HARNESS, dataset=dataset, metrics=metrics, samples=None, @@ -152,177 +108,241 @@ def format_raw_results(results: Dict[str, Any]) -> List[Evaluation]: return formatted_results -class EvaluatorInputSchema(BaseModel): - model: Any = Field(description="The name of the model.") - tasks: List[str] = Field( - description="The task (or multiple tasks) to evaluate the target on." - ) - description_dict: Optional[Dict[str, Any]] = Field( - None, description="Description dict." - ) - batch_size: int = Field(description="The batch size to use for evaluation.") - model_args: str = Field( - "", description="Additional arguments for the evaluated model." - ) - num_fewshot: int = Field(0, description="The number of few shots to use.") - max_batch_size: Optional[int] = Field( - None, description="Maximal batch size to try with --batch_size auto." - ) - device: Optional[str] = Field(None, description="Device to use for evaluation.") - no_cache: bool = Field(False, description="Include this flag to prevent caching.") - limit: Optional[float] = Field( - None, - description="Limit the number of examples per task. If <1, " - "limit is a percentage of the total number of " - "examples.", - ) - decontamination_ngrams_path: Optional[str] = Field( - None, description="Specify the path for decontamination n-grams." - ) - check_integrity: bool = Field( - False, description="Include this flag to check integrity." - ) - write_out: bool = Field(False, description="Include this flag to write out.") - output_base_path: Optional[str] = Field( - None, description="Specify the output base path." - ) - - -class DeepSparseLM(base.BaseLM): +class DeepSparseLM(LM): def __init__( self, pipeline: Pipeline, - tokenizer: Optional[str] = None, batch_size: int = 1, - max_gen_toks: Optional[int] = None, + max_gen_toks: int = 256, + tokenizer: Optional["AutoTokenizer"] = None, # noqa: F821 ): """ Wrapper around the DeepSparse pipeline to make it compatible with the llm-evaluation-harness. + + :param pipeline: the pipeline object to wrap + :param batch_size: the batch size to use for evaluation + :param max_gen_toks: the maximum number of tokens to generate + when using the model for generation (see: greed_until method) + :param tokenizer: the tokenizer to use for encoding and decoding + strings and tokens. By default, the tokenizer from the pipeline """ super().__init__() - # Initialize new model and tokenizer instances - self.model = pipeline - self.tokenizer = tokenizer if tokenizer else self.model.tokenizer - - self._batch_size = batch_size + self.pipeline = pipeline + self.batch_size = batch_size + self.tokenizer = tokenizer or pipeline.tokenizer self._max_length = pipeline.sequence_length - self._max_gen_toks = max_gen_toks or 256 + self._max_gen_toks = max_gen_toks + self.batch_sizes = {} - self.vocab_size = self.tokenizer.vocab_size + def tok_encode(self, string: str) -> List[int]: + return self.tokenizer.encode(string) - def _model_call(self, inps) -> torch.Tensor: + def tok_decode(self, tokens: List[int]) -> str: + return self.tokenizer.decode(tokens) + + @property + def max_length(self) -> int: + return self._max_length + + @property + def max_gen_toks(self) -> int: + return self._max_gen_toks + + def loglikelihood(self, requests) -> List[Tuple[float, bool]]: + """ + Copied directly from + https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/huggingface.py + """ + new_reqs = [] + for context, continuation in [req.args for req in requests]: + if context == "": + raise NotImplementedError( + "Implementing empty context is not supported yet" + ) + context_enc, continuation_enc = self._encode_pair(context, continuation) + + new_reqs.append(((context, continuation), context_enc, continuation_enc)) + + return self._loglikelihood_tokens(new_reqs) + + def _loglikelihood_tokens( + self, + requests: List[Tuple[Tuple[str, str], List[int], List[int]]], + disable_tqdm: bool = False, + ) -> List[Tuple[float, bool]]: """ - Override the _model_call method to use the DeepSparse pipeline for - logits generation. + The function to compute the loglikelihood of the continuation + tokens given the context tokens. - inps: a torch tensor of shape [batch, sequence] - the size of sequence may vary from call to call - returns: a torch tensor of shape [batch, sequence, vocab] with the - logits returned from the model + This function is an adapted version of the original function from + https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/huggingface.py """ - # Encode the tokens to strings - prompt = self.model.tokenizer.batch_decode(inps.numpy()) - - # Run the model to map the prompt to logits - out = self.model( - prompt=prompt, - max_new_tokens=0, - include_prompt_logits=True, - output_scores=True, - ) - logits_numpy = numpy.stack([generation.score for generation in out.generations]) - return torch.from_numpy(logits_numpy) + res = [] - def greedy_until( - self, requests: List[Tuple[str, Union[List[str], str]]] - ) -> List[str]: def _collate(x): - tokens = self.tok_encode(x[0]) - return len(tokens), x[0] + """Defines the key for the sorted method""" + toks = x[1] + x[2] + return -len(toks), tuple(toks) - results = [] - reorder = utils.Reorderer(requests, _collate) + re_ord = utils.Reorderer(requests, _collate) - for chunk in utils.chunks( - tqdm(reorder.get_reordered(), disable=False), - self.batch_size, + for chunk in tqdm( + list(utils.chunks(re_ord.get_reordered(), self.batch_size)), + disable=disable_tqdm, ): - context = [c[0] for c in chunk] - request_args = chunk[0][1] - stop = request_args.get("until", None) - stop_sequences = stop if isinstance(stop, list) else [stop] - max_generation_length = request_args.get("max_length", None) - - assert ( - isinstance(max_generation_length, int) or max_generation_length is None - ) - assert isinstance(stop_sequences, list) or stop_sequences is None - - # TODO: Find a better way to handle stop sequences for 0-shot. - if stop_sequences is None: - until = [self.eot_token] - else: - until = stop_sequences + [self.eot_token] - - if max_generation_length is None: - max_tokens = self.max_gen_toks - else: - max_tokens = max_generation_length - - responses = self.model( - sequences=context, - max_new_tokens=max_tokens, - stop=until, - do_sample=False, + batch_inp = [] + batch_cache_key = [] + batch_continuation_enc = [] + # len(chunk) is the batch_size + for cache_key, context_enc, continuation_enc in chunk: + # how this all works (illustrated on a causal decoder-only setup): + # CTX CONT + # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] + # model \ \ + # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the + # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice # noqa: E501 + + inp = (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1] + + batch_inp.append(self.tokenizer.decode(inp)) + batch_cache_key.append(cache_key) + batch_continuation_enc.append(continuation_enc) + + response = self.pipeline( + prompt=batch_inp, + max_new_tokens=0, + output_scores=True, + include_prompt_logits=True, ) - responses = responses if type(responses) is list else [responses] + for resp, continuation_enc, cache_key in zip( + response.generations, batch_continuation_enc, batch_cache_key + ): + # (seq_len, vocab_size) + multi_scores = resp.score + # (seq_len, vocab_size) but with softmax applied + multi_logits = numpy_log_softmax(multi_scores, axis=1) + # toss out the context half of the sequence + # (cont_len, vocab_size) + continuation_multi_logits = multi_logits[-len(continuation_enc) :] + + # pick out the logits for the continuation tokens + # (cont_len,) + continuation_logits = continuation_multi_logits[ + numpy.arange(len(continuation_enc)), continuation_enc + ] + # check if the tokens generated greedly are the same + # as the expected continuation + greedy_tokens = continuation_multi_logits.argmax(axis=1) + max_equal = greedy_tokens.tolist() == continuation_enc + + # Answer: (log prob, is-exact-match) + answer = (float(continuation_logits.sum()), bool(max_equal)) + + res.append(answer) + + if cache_key is not None: + self.cache_hook.add_partial("loglikelihood", cache_key, answer) + + return re_ord.get_original(res) + + def loglikelihood_rolling( + self, requests: list[Instance] + ) -> list[tuple[float, bool]]: + raise NotImplementedError( + "The method not required by any of our " "current task integrations so far" + ) + + def generate_until(self, requests: list[Instance]) -> list[str]: + """ + The function to generate a certain number of new tokens + given a context. + + This function is an adapted version of the original function from + https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/openai_completions.py + """ + if not requests: + return [] + res = [] + requests = [req.args for req in requests] - for response in responses: - response = response.generations[0].text - # Ensure the generated responses do not contain the stop sequences. - for term in until: - response = response.split(term)[0] - # partial caching - self.cache_hook.add_partial("greedy_until", (context, until), response) - results.append(response) + def _collate(x): + toks = self.tok_encode(x[0]) + return len(toks), x[0] + + re_ord = utils.Reorderer(requests, _collate) + + def sameuntil_chunks(xs, size): + ret = [] + lastuntil = xs[0][1] + for x in xs: + if len(ret) >= size or x[1] != lastuntil: + yield ret, lastuntil + ret = [] + lastuntil = x[1] + ret.append(x) + + if ret: + yield ret, lastuntil + + pbar = tqdm(total=len(requests)) + for chunk, request_args in tqdm( + list(sameuntil_chunks(re_ord.get_reordered(), self.batch_size)) + ): + inps = [] - return reorder.get_original(results) + self._max_gen_toks = request_args.pop("max_gen_toks", self.max_gen_toks) - def _model_generate(self, context, max_length, eos_token_id): - # Isn't used because we override greedy_until - raise NotImplementedError() + for context, _ in chunk: + # add context (prompts) to the list + inps.append(context) - @property - def eot_token(self) -> str: - return self.tokenizer.eos_token + until = request_args.pop("until", ["<|endoftext|>"]) + request_args.pop("do_sample", None) + request_args["temperature"] = request_args.get("temperature", 0) - @property - def eot_token_id(self) -> int: - return self.tokenizer.eos_token_id + # run inference (generate max_gen_toks tokens) + out = self.pipeline( + sequences=inps, + max_new_tokens=self.max_gen_toks - 1, + stop=until, + **request_args, + ) - @property - def max_length(self): - return self._max_length + for resp, (context, args_) in zip(out.generations, chunk): + text = resp.text + until_ = until + # split the text at the first occurrence of any of the until tokens + for term in until_: + if len(term) > 0: + text = text.split(term)[0] - @property - def max_gen_toks(self): - return self._max_gen_toks + res.append(text) - @property - def batch_size(self): - # should return self._batch_size but the - # TextGeneration model does not support batch_size > 1 - return 1 + self.cache_hook.add_partial( + "generate_until", (context, {"until": until_}), text + ) + pbar.update(1) - @property - def device(self): - pass + pbar.close() - def tok_encode(self, string: str): - return self.tokenizer.encode(string, add_special_tokens=False) + return re_ord.get_original(res) - def tok_decode(self, tokens): - return self.tokenizer.decode(tokens) + def _encode_pair( + self, context: str, continuation: str + ) -> Tuple[List[int], List[int]]: + """ + Copied directly from + https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/huggingface.py + """ + n_spaces = len(context) - len(context.rstrip()) + if n_spaces > 0: + continuation = context[-n_spaces:] + continuation + context = context[:-n_spaces] + whole_enc = self.tok_encode(context + continuation) + context_enc = self.tok_encode(context) + context_enc_len = len(context_enc) + continuation_enc = whole_enc[context_enc_len:] + return context_enc, continuation_enc diff --git a/src/deepsparse/evaluation/integrations/perplexity.py b/src/deepsparse/evaluation/integrations/perplexity.py new file mode 100644 index 0000000000..a9a3f3d8a3 --- /dev/null +++ b/src/deepsparse/evaluation/integrations/perplexity.py @@ -0,0 +1,278 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from collections import defaultdict +from typing import Any, Dict, List, Optional, Union + +import numpy +from tqdm import tqdm + +from datasets import load_dataset +from deepsparse import Pipeline +from deepsparse.evaluation.registry import EvaluationRegistry +from deepsparse.evaluation.results import Dataset, Evaluation, Metric, Result +from deepsparse.evaluation.utils import PERPLEXITY +from deepsparse.transformers.metrics import Perplexity +from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline +from deepsparse.transformers.pipelines.text_generation.pipeline_no_kv_cache import ( + TextGenerationPipelineNoCache, +) +from deepsparse.transformers.utils.eval_helpers import ( + HumanEvalIteratorWrapper, + process_concatenated_datasets, +) + + +""" +Integration for the evaluation module +that computes the perplexity of a model on a dataset +""" +_LOGGER = logging.getLogger(__name__) + + +@EvaluationRegistry.register(name=PERPLEXITY) +def integration_eval( + pipeline: Pipeline, + datasets: Union[List[str], str] = "openai_humaneval", + batch_size: int = 1, + limit: Optional[int] = None, + accumulate: Optional[bool] = None, + splits: Union[List[str], str, None] = "test", + metrics: Union[List[str], str, None] = None, + **kwargs, +) -> Result: + """ + A function that computes the perplexity of a pipeline given a set + of dataset names. + + :param pipeline: the pipeline to evaluate. The assumed pipeline + is a TextGenerationPipeline, either with or without the KV + cache support + :param datasets: the names of dataset(s) to evaluate on + :param batch_size: the batch size to use for evaluation + :param splits: the split of the dataset to evaluate on. Default is "test" + :param metrics: the metrics to compute. Default is None + :param limit: the number of batches to evaluate on. Default is None + (evaluates on entire dataset) + :param accumulate: whether to perplexity computation should + accumulate negative log-likelihood over samples. Defaults to + the default accumulate variable inferred from the dataset in + `datasets`. If not None, it will override the inferred accumulate + variable. + :return: a Result object containing the raw and formatted results + """ + metrics = metrics or PERPLEXITY + if metrics != PERPLEXITY: + raise ValueError(f"Invalid metric {metrics} for perplexity evaluation") + if splits is None: + splits = "test" + _LOGGER.info("Argument `splits` is None. Defaulting to `test` split.") + datasets = datasets if isinstance(datasets, list) else [datasets] + results_raw = defaultdict(str) + for dataset_name in datasets: + results_raw[dataset_name] = defaultdict() + dataset, _accumulate = load_perplexity_dataset( + dataset_name=dataset_name, splits=splits, pipeline=pipeline, **kwargs + ) + if accumulate is None: + accumulate = _accumulate + else: + _LOGGER.info( + f"Argument `accumulate` set to {accumulate}. " + "Overriding the inferred accumulate variable from the dataset." + ) + + perplexity = run_perplexity( + pipeline=pipeline, + dataset=dataset, + batch_size=batch_size, + accumulate=accumulate, + limit=limit, + ) + + results_raw[dataset_name] = defaultdict() + results_raw[dataset_name]["results"] = perplexity + results_raw[dataset_name]["split"] = splits + + results = Result( + # omit storing raw results. they can potentially + # contain numpy arrays that are not serializable. + # all the information is stored in the formatted results + raw=None, + formatted=format_raw_results(results_raw), + ) + + return results + + +def run_perplexity( + pipeline: Union[TextGenerationPipelineNoCache, TextGenerationPipeline], + dataset: "Dataset", + batch_size: int, + accumulate: bool, + limit: Optional[int] = None, +) -> Dict[str, Any]: + """ + Compute the perplexity of a pipeline given a dataset. + + :param pipeline: the pipeline to evaluate. The assumed pipeline + is a TextGenerationPipeline, either with or without the KV + cache support + :param dataset: the dataset to evaluate on + :param batch_size: the batch size to use for evaluation + :param accumulate: whether to perplexity computation should + accumulate negative log-likelihood over samples + :param limit: the number of batches to evaluate on. Default is None + (evaluates on entire dataset) + + :return: a dictionary containing the perplexity results + """ + + perplexity = Perplexity(accumulate=accumulate) + + batch = [] + for idx, sample in _enumerate_progress( + dataset, max_steps=None if limit is None else limit * batch_size + ): + + if limit is not None: + # stop if we have reached the #limit + # number of batches to be processed + if idx >= limit * batch_size: + break + + batch.append(sample) + + if len(batch) == batch_size: + if isinstance(pipeline, TextGenerationPipelineNoCache): + out = pipeline( + prompt=batch, + output_scores=True, + include_prompt_logits=True, + return_input_tokens=True, + ) + else: + out = pipeline( + prompt=batch, + output_scores=True, + max_new_tokens=0, + include_prompt_logits=True, + return_input_tokens=True, + ) + + for s in range(batch_size): + # Need to remove tokens that were masked + input_ids = out.input_tokens["input_ids"][s].flatten() + attention_mask = out.input_tokens["attention_mask"][s].flatten() + logits = out.generations[s].score + if batch_size > 1 and isinstance( + pipeline, TextGenerationPipelineNoCache + ): + logits = logits[-attention_mask.sum() :, :] + + logits = numpy.compress(attention_mask, logits, axis=0)[:-1, :] + input_ids = numpy.compress(attention_mask, input_ids)[1:] + + # Add predictions (logits) and targets (input_ids) to metric + perplexity.add_batch(logits, input_ids) + + batch.clear() + + return perplexity.compute() + + +def format_raw_results(results: Dict[str, Any]) -> List[Evaluation]: + """ + Format the raw perplexity results into a list of + Evaluation objects. + + :param results: the raw results from perplexity computation + :return: the formatted results as a list of Evaluation objects + """ + formatted_results = [] + for dataset_name, dataset_result in results.items(): + metrics = [] + for metric_name, metric_value in dataset_result["results"].items(): + if isinstance(metric_value, numpy.ndarray): + metric_value = metric_value.tolist() + metric = Metric(name=metric_name, value=metric_value) + metrics.append(metric) + dataset = Dataset(type=None, name=dataset_name, split=dataset_result["split"]) + evaluation = Evaluation( + task="perplexity", + dataset=dataset, + metrics=metrics, + samples=None, + ) + formatted_results.append(evaluation) + return formatted_results + + +def load_perplexity_dataset( + dataset_name: str, + splits: Union[List[str], str] = "test", + pipeline: Optional[Pipeline] = None, + **kwargs, +): + """ + Function to load the dataset for perplexity computation. + Eventually we want to load the dataset from the nm_utils + + :param dataset_name: the name of the dataset to load + :param splits: the splits to load from the dataset. Default is "test" + :param pipeline: the pipeline to use for loading the dataset. The pipeline + is used to infer the model path and sequence length to use for loading + the dataset. This argument can be omitted if the appropriate kwargs + are provided, or if the dataset does not require a process_concatenated_datasets + function to load the dataset. + :param kwargs: additional keyword arguments to pass to the dataset loading function + :return: the dataset and whether to accumulate perplexity over samples + """ + if isinstance(splits, list): + raise NotImplementedError("Evaluation on multiple splits not implemented") + + if dataset_name == "openai_humaneval": + dataset = load_dataset(dataset_name, split=splits) + dataset = HumanEvalIteratorWrapper(dataset) + accumulate = False + elif dataset_name in {"wikitext2", "c4"}: + # fetch max_sequence_length from pipeline if not provided + max_sequence_length = kwargs.pop("max_sequence_length", None) + if max_sequence_length is None and pipeline is not None: + max_sequence_length = pipeline.sequence_length + + # fetch model_path from pipeline if not provided + model_path = kwargs.pop("model_path", None) + if model_path is None and pipeline is not None: + model_path = os.path.dirname(pipeline.model_path) + + dataset = process_concatenated_datasets( + dataset_name, + model_path=model_path, + max_sequence_length=max_sequence_length, + split=splits, + **kwargs, + ) + accumulate = True + else: + raise NotImplementedError(f"Dataset {dataset_name} not implemented") + + return dataset, accumulate + + +def _enumerate_progress(dataset, max_steps): + progress_bar = tqdm(dataset, total=max_steps) if max_steps else tqdm(dataset) + return enumerate(progress_bar) diff --git a/src/deepsparse/evaluation/registry.py b/src/deepsparse/evaluation/registry.py index 2daabb69cc..343cd9786c 100644 --- a/src/deepsparse/evaluation/registry.py +++ b/src/deepsparse/evaluation/registry.py @@ -57,7 +57,7 @@ def resolve( if integration is None: _LOGGER.info( - "No integration specified, inferring the evaluation" + "No integration specified, inferring the evaluation " "function from the input arguments..." ) integration = resolve_integration(pipeline, datasets) diff --git a/src/deepsparse/evaluation/results.py b/src/deepsparse/evaluation/results.py index 00212d0a1e..78c4bbd501 100644 --- a/src/deepsparse/evaluation/results.py +++ b/src/deepsparse/evaluation/results.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional +from typing import Any, List, Optional, Union import yaml from pydantic import BaseModel, Field @@ -32,7 +32,7 @@ class Metric(BaseModel): name: str = Field(description="Name of the metric") - value: float = Field(description="Value of the metric") + value: Union[float, List[float]] = Field(description="Value of the metric") class Dataset(BaseModel): diff --git a/src/deepsparse/evaluation/utils.py b/src/deepsparse/evaluation/utils.py index 87475dd5d2..6e5ade9344 100644 --- a/src/deepsparse/evaluation/utils.py +++ b/src/deepsparse/evaluation/utils.py @@ -11,21 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import ast +import logging import os -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union from deepsparse import Pipeline +from deepsparse.operators.engine_operator import DEEPSPARSE_ENGINE __all__ = [ "create_pipeline", "get_save_path", - "args_to_dict", + "parse_kwarg_tuples", "resolve_integration", ] +_LOGGER = logging.getLogger(__name__) LM_EVALUATION_HARNESS = "lm-evaluation-harness" +PERPLEXITY = "perplexity" def potentially_check_dependency_import(integration_name: str) -> bool: @@ -38,10 +42,14 @@ def potentially_check_dependency_import(integration_name: str) -> bool: :return: True if the dependency is installed, False otherwise """ - if integration_name.replace("_", "-") == LM_EVALUATION_HARNESS: + if integration_name == LM_EVALUATION_HARNESS: from deepsparse.evaluation.integrations import try_import_lm_evaluation_harness - try_import_lm_evaluation_harness(raise_error=True) + try_import_lm_evaluation_harness() + if integration_name == PERPLEXITY: + from deepsparse.evaluation.integrations.perplexity import ( # noqa F401 + integration_eval, + ) return True @@ -79,24 +87,66 @@ def if_generative_language_model(pipeline: Pipeline) -> bool: return False -def args_to_dict(args: Tuple[Any, ...]) -> Dict[str, Any]: +def parse_kwarg_tuples(kwargs: tuple) -> Dict: """ - Convert a tuple of args to a dict of args. - - :param args: The args to convert. Should be a tuple of alternating - arg names and arg values e.g.('--arg1', 1, 'arg2', 2, -arg3', 3). + Convert a tuple of kwargs to a dict of kwargs. + This function is used to enable the click parsing of kwargs. + + Example use: + ``` + @click.command( + context_settings=dict( + ignore_unknown_options=True) + ) + @click.argument(...) + @click.option(...) + ... + @click.argument("kwargs", nargs=-1, type=click.UNPROCESSED) + def main(..., kwargs): + ... + kwargs: Dict[str, Any] = parse_kwarg_tuples(kwargs: Tuple) + ``` + + Example inputs, outputs: + ``` + input = ('--arg1', 1, 'arg2', 2, '-arg3', 3) + output = parse_kwarg_tuples(input) + output = {'arg1': 1, 'arg2': 2, 'arg3': 3} + ``` + + :param kwargs: The kwargs to convert. Should be a tuple of alternating + kwargs names and kwargs values e.g.('--arg1', 1, 'arg2', 2, -arg3', 3). The names can optionally have a '-' or `--` in front of them. - :return: The converted args as a dict. + :return: The converted kwargs as a dict. """ - if len(args) == 0: + if len(kwargs) == 0: return {} + if len(kwargs) % 2 != 0: + raise ValueError( + "kwargs must be a tuple of alternating names and values " + "i.e. the length of kwargs tuple must be even. Received " + f"kwargs: {kwargs}" + ) # names are uneven indices, values are even indices - args_names = args[0::2] - args_values = args[1::2] + kwargs_names = kwargs[0::2] + kwargs_values = kwargs[1::2] + # by default kwargs values are strings, so convert them + # to the appropriate type if possible + kwargs_values = list(kwargs_values) + for i, value in enumerate(kwargs_values): + try: + kwargs_values[i] = ast.literal_eval(value) + except Exception as e: # noqa E841 + _LOGGER.debug( + f"Failed to infer non-string type" + f"from kwarg value: {value}. It will" + f"be left as a string." + ) + # remove any '-' or '--' from the names - args_names = [name.lstrip("-") for name in args_names] + kwargs_names = [name.lstrip("-") for name in kwargs_names] - return dict(zip(args_names, args_values)) + return dict(zip(kwargs_names, kwargs_values)) def get_save_path( @@ -143,6 +193,7 @@ def create_pipeline( :param engine_type: The engine type to initialize the model with. :return: The initialized pipeline """ + engine_type = engine_type or DEEPSPARSE_ENGINE return Pipeline.create( task=kwargs.pop("task", "text-generation"), model_path=model_path, diff --git a/src/deepsparse/transformers/metrics.py b/src/deepsparse/transformers/metrics.py index 1952ec2155..acfe2e846b 100644 --- a/src/deepsparse/transformers/metrics.py +++ b/src/deepsparse/transformers/metrics.py @@ -20,6 +20,7 @@ import numpy +from deepsparse.utils.data import numpy_log_softmax from scipy.special import log_softmax from sklearn.metrics import precision_recall_fscore_support diff --git a/src/deepsparse/transformers/utils/eval_helpers.py b/src/deepsparse/transformers/utils/eval_helpers.py index 4c0e68b9de..012520b9b5 100644 --- a/src/deepsparse/transformers/utils/eval_helpers.py +++ b/src/deepsparse/transformers/utils/eval_helpers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Mapping, Union +from typing import List, Union import numpy from transformers import AutoTokenizer, PreTrainedTokenizerFast @@ -27,7 +27,8 @@ def process_concatenated_datasets( dataset_name: str, model_path: str, max_sequence_length: int, - kwargs: Mapping, + split: str = "test", + **kwargs, ) -> list: """ Concatenate text datasets and split them into chunks text that, after @@ -38,6 +39,8 @@ def process_concatenated_datasets( Options: "wikitext2" or "c4". model_path (str): The path to a pretrained transformer model for tokenization. max_sequence_length (int): The maximum number of tokens in each sequence. + split (str, optional): The split of the dataset to use. + Default is "test". kwargs (mapping): Additional keyword arguments. - eos (str, optional): The end-of-sentence token. Default is "\n\n" for wikitext2 and "" for c4. @@ -65,13 +68,13 @@ def process_concatenated_datasets( eos = kwargs.get("eos", "\n\n") bos = kwargs.get("bos", "") - raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split=split) raw_text = raw_dataset["text"] elif dataset_name == "c4": eos = kwargs.get("eos", "<|endoftext|>") bos = kwargs.get("bos", "") raw_samples = kwargs.get("raw_samples", None) - data_file = kwargs.get("data_file", 0) + data_file = kwargs.get("data_file", None) if data_file is not None: raw_dataset = load_dataset( "allenai/c4", @@ -79,13 +82,13 @@ def process_concatenated_datasets( data_files={ "validation": f"en/c4-validation.{data_file:05d}-of-00008.json.gz" }, - split="validation", + split=split, ) else: raw_dataset = load_dataset( "allenai/c4", "allenai--c4", - split="validation", + split=split, ) if raw_samples is not None: raw_dataset = raw_dataset[:raw_samples] @@ -181,3 +184,22 @@ def _split_text_by_tokens( ) return split_text + + +class HumanEvalIteratorWrapper: + """ + Wrapper around the `openai_humaneval` dataset, + that joins the prompt and the canonical solution + into a single string during iteration. + """ + + def __init__(self, dataset): + self.iterator = iter(dataset) + + def __iter__(self): + return self + + def __next__(self): + # Get the next sample from the original iterator + sample = next(self.iterator) + return sample["prompt"] + sample["canonical_solution"] diff --git a/tests/deepsparse/evaluation/integrations/test_lm_evaluation_harness.py b/tests/deepsparse/evaluation/integrations/test_lm_evaluation_harness.py index 3b9016294f..8d8b343dd5 100644 --- a/tests/deepsparse/evaluation/integrations/test_lm_evaluation_harness.py +++ b/tests/deepsparse/evaluation/integrations/test_lm_evaluation_harness.py @@ -12,64 +12,118 @@ # See the License for the specific language governing permissions and # limitations under the License. -from transformers import AutoModelForCausalLM - import pytest from deepsparse.evaluation.integrations import try_import_lm_evaluation_harness from deepsparse.evaluation.utils import create_pipeline -@pytest.mark.parametrize( - "pipeline, model_torch", - [ - ( - create_pipeline( - "hf:mgoin/TinyStories-1M-deepsparse", engine_type="onnxruntime" - ), - AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1M"), - ) - ], -) -@pytest.mark.parametrize( - "datasets", - [ - ["hellaswag"], - ["hellaswag", "gsm8k"], - "gsm8k", - "arc_challenge", - ], -) @pytest.mark.parametrize( "batch_size", [1, 3], ) -class TestLMEvaluationHarness: - @pytest.mark.skipif( - not try_import_lm_evaluation_harness(raise_error=False), - reason="lm_evaluation_harness not installed", - ) - def test_integration_eval_onnx_matches_torch( - self, pipeline, model_torch, datasets, batch_size - ): +@pytest.mark.skipif( + not try_import_lm_evaluation_harness(raise_error=False), + reason="lm_evaluation_harness not installed", +) +class TestLMEval: + @pytest.fixture() + def integration_eval(self): from deepsparse.evaluation.integrations.lm_evaluation_harness import ( - integration_eval, + integration_eval as eval_fn, ) - out_torch = integration_eval( - model=model_torch, + return eval_fn + + @pytest.mark.parametrize( + "datasets", + [ + "hellaswag", + ["arc_challenge"], + ["hellaswag", "arc_challenge"], + ], + ) + def test_likelihood_scenario(self, batch_size, datasets, integration_eval): + + model_path_ds = "hf:mgoin/TinyStories-1M-ds" + model_path_hf = "roneneldan/TinyStories-1M" + limit = 2 + + out_onnx = integration_eval( + create_pipeline( + model_path_ds, + engine_type="onnxruntime", + ), datasets=datasets, batch_size=batch_size, - limit=5, - no_cache=True, # avoid saving files when running tests + limit=limit, + use_cache=None, # avoid saving files when running tests + ) + + from lm_eval import evaluator, tasks, utils + + datasets_ = (",").join(datasets) if isinstance(datasets, list) else datasets + out_torch = evaluator.simple_evaluate( + model="hf", + model_args=f"pretrained={model_path_hf}", + tasks=utils.pattern_match(datasets_.split(","), tasks.ALL_TASKS), + batch_size=batch_size, + limit=limit, + use_cache=None, # avoid saving files when running tests ) + self._test_same(out_onnx.raw, out_torch, datasets) + + @pytest.mark.parametrize( + "datasets", + [ + "gsm8k", + ], + ) + def test_greedy_until_scenario(self, batch_size, datasets, integration_eval): + model_path_ds = "hf:mgoin/TinyLlama-1.1B-step-50K-105b-ONNX" + model_path_hf = "TinyLlama/TinyLlama-1.1B-step-50K-105b" + limit = 2 + # compute until 16 new tokens + # so that tests are faster + gen_kwargs = "max_gen_toks=16" + out_onnx = integration_eval( - model=pipeline, + create_pipeline(model_path_ds, engine_type="onnxruntime"), datasets=datasets, batch_size=batch_size, - limit=5, - no_cache=True, # avoid saving files when running tests + limit=limit, + gen_kwargs=gen_kwargs, + use_cache=None, # avoid saving files when running tests + ) + + from lm_eval import evaluator, tasks, utils + + datasets_ = (",").join(datasets) if isinstance(datasets, list) else datasets + out_torch = evaluator.simple_evaluate( + model="hf", + model_args=f"pretrained={model_path_hf}", + tasks=utils.pattern_match(datasets_.split(","), tasks.ALL_TASKS), + batch_size=batch_size, + limit=limit, + gen_kwargs=gen_kwargs, + use_cache=None, # avoid saving files when running tests ) - out_onnx = out_onnx.raw["output"] - out_torch = out_torch.raw["output"] + self._test_same(out_onnx.raw, out_torch, datasets) - assert out_onnx["results"] == out_torch["results"] + @staticmethod + def _test_same(out_onnx, out_torch, datasets, greedy=False): + datasets = datasets if isinstance(datasets, list) else [datasets] + for dataset in datasets: + torch_samples = out_torch["samples"][dataset] + onnx_samples = out_onnx["samples"][dataset] + for torch_sample, onnx_sample in zip(torch_samples, onnx_samples): + if greedy: + # for datasets that validate greedy generation + # make sure that generated sequences are the same + assert torch_sample["resps"] == onnx_sample["resps"] + else: + # for datasets that validate likelihood + # make sure that likelihoods are the same + assert ( + pytest.approx(torch_sample["resps"][0][0], 0.0001) + == onnx_sample["resps"][0][0] + ) diff --git a/tests/deepsparse/evaluation/integrations/test_perplexity.py b/tests/deepsparse/evaluation/integrations/test_perplexity.py new file mode 100644 index 0000000000..b156e5b9a4 --- /dev/null +++ b/tests/deepsparse/evaluation/integrations/test_perplexity.py @@ -0,0 +1,132 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import copy + +import numpy as np + +import pytest +from deepsparse.evaluation.integrations.perplexity import ( + integration_eval, + load_perplexity_dataset, +) +from deepsparse.transformers.pipelines.text_generation import TextGenerationPipeline +from evaluate import load + + +@pytest.fixture() +def model_path(): + return "hf:mgoin/TinyStories-1M-deepsparse" + + +@pytest.fixture() +def model_id(): + return "roneneldan/TinyStories-1M" + + +@pytest.mark.parametrize( + "datasets", + [ + "openai_humaneval", + "wikitext2", + ], +) +@pytest.mark.parametrize("batch_size", [1, 2]) +class TestPerplexity: + limit = 2 + + def test_perplexity_ground_truth_equal_pipeline( + self, model_path, model_id, datasets, batch_size + ): + # setting max_sequence_length to 16 to speed up the test + kwargs_ground_truth = ( + dict(max_sequence_length=16) if datasets in {"c4", "wikitext2"} else {} + ) + kwargs = copy(kwargs_ground_truth) + + result_gt = self._get_ground_truth( + datasets=datasets, + batch_size=batch_size, + limit=self.limit, + model_id=model_id, + kwargs=kwargs_ground_truth, + ) + + result = integration_eval( + pipeline=TextGenerationPipeline( + model_path="hf:mgoin/TinyStories-1M-deepsparse", + engine_type="onnxruntime", + ), + datasets=datasets, + batch_size=batch_size, + limit=self.limit, + # we are setting accumulate=False to compare + # with the torch ground truth apples to apples + accumulate=False, + **kwargs, + ) + perplexities = result.formatted[0].metrics[0].value + perplexities_gt = result_gt["perplexities"] + assert np.allclose(perplexities, perplexities_gt, rtol=0.1) + + def test_perplexity_kv_cache_pipeline_equal_no_kv_cache_pipeline( + self, model_path, model_id, datasets, batch_size + ): + + kwargs_ground_truth = ( + dict(max_sequence_length=16) if datasets in {"c4", "wikitext2"} else {} + ) + kwargs = copy(kwargs_ground_truth) + + result_kv_cache = integration_eval( + pipeline=TextGenerationPipeline( + model_path="hf:mgoin/TinyStories-1M-deepsparse", + engine_type="onnxruntime", + ), + datasets=datasets, + model_path=model_id, + batch_size=batch_size, + limit=self.limit, + **kwargs, + ) + + result_non_kv_cache = integration_eval( + pipeline=TextGenerationPipeline( + model_path="hf:mgoin/TinyStories-1M-deepsparse", + engine_type="onnxruntime", + onnx_model_name="model-orig.onnx", + ), + datasets=datasets, + batch_size=batch_size, + limit=self.limit, + **kwargs, + ) + + perplexities_kv_cache = result_kv_cache.formatted[0].metrics[0].value + perplexities_non_kv_cache = result_non_kv_cache.formatted[0].metrics[0].value + np.allclose(perplexities_kv_cache, perplexities_non_kv_cache, rtol=0.1) + + @staticmethod + def _get_ground_truth(datasets, batch_size, limit, model_id, kwargs={}): + perplexity = load("perplexity", module_type="metric") + kwargs["model_path"] = model_id + dataset, *_ = load_perplexity_dataset(dataset_name=datasets, **kwargs) + predictions = [] + for i, sample in enumerate(dataset): + if i == batch_size * limit: + break + predictions.append(sample) + return perplexity.compute( + predictions=predictions, add_start_token=False, model_id=model_id + ) diff --git a/tests/deepsparse/evaluation/test_evaluator.py b/tests/deepsparse/evaluation/test_evaluator.py index 816ad075e0..58eedff836 100644 --- a/tests/deepsparse/evaluation/test_evaluator.py +++ b/tests/deepsparse/evaluation/test_evaluator.py @@ -115,19 +115,25 @@ def test_evaluate_pipeline_without_kv_cache( not try_import_lm_evaluation_harness(raise_error=False), reason="lm_evaluation_harness not installed", ) -def test_evaluation_llm_evaluation_harness_integration_name( +def test_evaluation_llm_evaluation_harness( model_path, - datasets, ): assert evaluate( model=model_path, - datasets=datasets, - limit=2, - no_cache=True, + # testing only on hellaswag dataset + # to avoid long running time + datasets="hellaswag", + limit=1, integration="lm_evaluation_harness", ) +def test_evaluation_perplexity(model_path): + assert evaluate( + model=model_path, datasets="openai_humaneval", limit=1, integration="perplexity" + ) + + @pytest.mark.parametrize("type_serialization", ["json", "yaml"]) @pytest.mark.skipif( tuple(map(int, sys.version.split(".")[:2])) < (3, 10), @@ -144,7 +150,6 @@ def test_cli( runner.invoke( main, [ - "--model_path", model_path, "--dataset", datasets[0],