diff --git a/src/deepsparse/benchmark/benchmark_model.py b/src/deepsparse/benchmark/benchmark_model.py index 8c978ce87c..aa350fb474 100644 --- a/src/deepsparse/benchmark/benchmark_model.py +++ b/src/deepsparse/benchmark/benchmark_model.py @@ -81,6 +81,14 @@ zoo:nlp/question_answering/bert-base/pytorch/huggingface/squad/base-none \ --input_shapes "[1,512],[1,512],[1,512]" +########## +Example on a CodeGen (model with KV cache support) +from SparseZoo with input_ids_length 10 and sequence length 256: +deepsparse.benchmark \ + zoo:nlg/text_generation/codegen_mono-350m/pytorch/ + huggingface/bigpython_bigquery_thepile/pruned50-none + --input_ids_length 10 --sequence_length 256 + ########## Example on local ONNX model: deepsparse.benchmark /PATH/TO/model.onnx @@ -110,8 +118,10 @@ from deepsparse.log import set_logging_level from deepsparse.utils import ( generate_random_inputs, + has_model_kv_cache, model_to_path, override_onnx_input_shapes, + overwrite_cache_model_inputs, parse_input_shapes, ) @@ -143,6 +153,26 @@ def parse_args(): default=1, help="The batch size to run the analysis for. Must be greater than 0", ) + + parser.add_argument( + "-seq_len", + "--sequence_length", + type=int, + default=2048, + help="The sequence length to run the " + "KV cache supported model benchmarks for. " + "Must be greater than 0, default is 2048", + ) + + parser.add_argument( + "-input_ids_len", + "--input_ids_length", + type=int, + default=1, + help="The input ids length to run the " + "KV cache supported model benchmarks for. " + "Must be greater than 0, default is 1", + ) parser.add_argument( "-i", "-shapes", @@ -265,6 +295,8 @@ def load_custom_engine(custom_engine_identifier: str): def benchmark_model( model_path: str, batch_size: int = 1, + sequence_length: int = 2048, + input_ids_length: int = 1, input_shapes: str = "", num_cores: int = None, scenario: str = "sync", @@ -290,6 +322,28 @@ def benchmark_model( orig_model_path = model_path model_path = model_to_path(model_path) + + if has_model_kv_cache(model_path): + if batch_size != 1: + raise ValueError( + "Unable to run models with KV cache support " + "for batch size different than one." + "Please set batch size to 1 and try again" + ) + + _LOGGER.info( + "Found model with KV cache support. " + "Benchmarking the autoregressive model with " + f"input_ids_length: {input_ids_length} and " + f"sequence length: {sequence_length}." + ) + + model_path, _, _ = overwrite_cache_model_inputs( + model_path=model_path, + input_ids_length=input_ids_length, + sequence_length=sequence_length, + ) + num_streams = parse_num_streams(num_streams, num_cores, scenario) # Compile the ONNX into a runnable model @@ -351,6 +405,8 @@ def benchmark_model( "orig_model_path": orig_model_path, "model_path": model_path, "batch_size": batch_size, + "sequence_length": sequence_length, + "input_ids_length": input_ids_length, "input_shapes": input_shapes, "num_cores": num_cores, "scenario": scenario, @@ -376,6 +432,8 @@ def main(): result = benchmark_model( model_path=args.model_path, + sequence_length=args.sequence_length, + input_ids_length=args.input_ids_length, batch_size=args.batch_size, input_shapes=args.input_shapes, num_cores=args.num_cores, @@ -392,6 +450,10 @@ def main(): # Results summary print("Original Model Path: {}".format(args.model_path)) print("Batch Size: {}".format(args.batch_size)) + if args.sequence_length is not None: + print("Sequence Length: {}".format(args.sequence_length)) + if args.input_ids_length is not None: + print("Input IDs Length: {}".format(args.input_ids_length)) print("Scenario: {}".format(args.scenario)) print( "Throughput (items/sec): {:.4f}".format( diff --git a/src/deepsparse/transformers/engines/nl_decoder_engine.py b/src/deepsparse/transformers/engines/nl_decoder_engine.py index d76c7fb8d7..30176b3b10 100644 --- a/src/deepsparse/transformers/engines/nl_decoder_engine.py +++ b/src/deepsparse/transformers/engines/nl_decoder_engine.py @@ -22,17 +22,16 @@ from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache from deepsparse.transformers.utils.helpers import ( generate_session_id, - overwrite_onnx_model_inputs, + overwrite_onnx_model_inputs_for_kv_cache_models, ) from deepsparse.utils.data import numpy_softmax +from deepsparse.utils.onnx import CACHE_INPUT_PREFIX, CACHE_OUTPUT_PREFIX _LOGGER = logging.getLogger(__name__) __all__ = ["NLDecoderEngine"] -_CACHE_INPUT_NAME = "past_key_values" - class NLDecoderEngine: """ @@ -69,17 +68,17 @@ def __init__( ): # flag to indicate if the model is quantized or not self.kv_cache_data_type = None - ( onnx_file_path, output_indices_to_be_cached, kv_cache_data_type, - ) = overwrite_onnx_model_inputs( + ) = overwrite_onnx_model_inputs_for_kv_cache_models( onnx_file_path=onnx_file_path, batch_size=engine_args.get("batch_size", 1), sequence_length=sequence_length, input_ids_length=input_ids_length, ) + kv_cache_enabled = False if sum(output_indices_to_be_cached): kv_cache_enabled = True @@ -129,7 +128,7 @@ def onnx_input_names_no_cache(self) -> List[str]: return [ name for name in self.engine.input_names - if not name.startswith(_CACHE_INPUT_NAME) + if not name.startswith(CACHE_INPUT_PREFIX) ] @property @@ -284,7 +283,7 @@ def update_kv_cache( cache_onnx_names = [ name for name in self.engine.input_names - if name.startswith(_CACHE_INPUT_NAME) + if name.startswith(CACHE_INPUT_PREFIX) ] kv_cache_state = { name: array for name, array in zip(cache_onnx_names, kv_cache_state) @@ -302,7 +301,7 @@ def _initialize_kv_cache_state(self, length: int) -> Dict[str, numpy.ndarray]: cache_engine_input_index = next( i for i, name in enumerate(self.engine.input_names) - if _CACHE_INPUT_NAME in name + if CACHE_INPUT_PREFIX in name ) batch_size, num_attention_heads, _, hidden_dims = self.engine.input_shapes[ cache_engine_input_index @@ -314,9 +313,9 @@ def _initialize_kv_cache_state(self, length: int) -> Dict[str, numpy.ndarray]: ) cache_keys = [ - output_name.replace("present", _CACHE_INPUT_NAME) + output_name.replace(CACHE_OUTPUT_PREFIX, CACHE_INPUT_PREFIX) for output_name in self.engine.output_names - if output_name.startswith("present") + if output_name.startswith(CACHE_OUTPUT_PREFIX) ] return {key: empty_kv_cache_tensor for key in cache_keys} diff --git a/src/deepsparse/transformers/utils/helpers.py b/src/deepsparse/transformers/utils/helpers.py index f6a6e02155..5fb0f3c1c5 100644 --- a/src/deepsparse/transformers/utils/helpers.py +++ b/src/deepsparse/transformers/utils/helpers.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import uuid -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy import onnx @@ -24,21 +23,21 @@ __all__ = [ + "overwrite_onnx_model_inputs_for_kv_cache_models", "generate_session_id", "pad_to_fixed_length", "create_causal_mask", - "overwrite_onnx_model_inputs", ] _LOGGER = logging.getLogger(__name__) -def overwrite_onnx_model_inputs( +def overwrite_onnx_model_inputs_for_kv_cache_models( onnx_file_path: str, sequence_length: int, input_ids_length: int, batch_size: int = 1, -) -> Tuple[str, List[int]]: +) -> Tuple[str, List[int], Optional[int]]: """ Enforces the appropriate input shapes for the onnx model, as well as checks whether kv cache is enabled or not. diff --git a/src/deepsparse/utils/onnx.py b/src/deepsparse/utils/onnx.py index f442b45dce..24d2734d73 100644 --- a/src/deepsparse/utils/onnx.py +++ b/src/deepsparse/utils/onnx.py @@ -50,6 +50,8 @@ "truncate_onnx_model", "truncate_onnx_embedding_model", "default_cached_outputs", + "has_model_kv_cache", + "overwrite_cache_model_inputs", "CACHE_INPUT_PREFIX", "CACHE_OUTPUT_PREFIX", ] @@ -494,3 +496,56 @@ def default_cached_outputs(model_path: str) -> List[bool]: assert len(output_names) > 0 return [name.startswith(CACHE_OUTPUT_PREFIX) for name in output_names] + + +def has_model_kv_cache(model: Union[str, ModelProto]) -> bool: + """ + Check whether a model has a KV cache support. + + :param model_path: Path to a model or a model proto. + :return True if the model has a KV cache support, False otherwise. + """ + return bool(any(default_cached_outputs(model))) + + +def overwrite_cache_model_inputs( + model_path: str, + input_ids_length: int, + sequence_length: int, +) -> Tuple[str, List[int], Optional[int]]: + """ + Takes a path to an onnx model and enforces that it has + static input dimensions. + + :param model_path: Path to a model. + :param input_ids_length: The input_ids length to overwrite the model with. + :param sequence_length: The sequence length to overwrite the model with. + :return: A tuple that contains: + - the path to the onnx model file that has been overwritten + with the new input shapes + - boolean list, where elements are set to True if the + corresponding model output should be cached or False + if not. + - the data type of the kv cache. If the model does not + use kv cache, then the data type is None + """ + from deepsparse.transformers.utils.helpers import ( + overwrite_onnx_model_inputs_for_kv_cache_models, + ) + + assert input_ids_length < sequence_length, ( + f"input_ids_length {input_ids_length} " + f"must be less than sequence_length {sequence_length}" + ) + + ( + onnx_file_path, + output_indices_to_be_cached, + kv_cache_data_type, + ) = overwrite_onnx_model_inputs_for_kv_cache_models( + onnx_file_path=model_path, + sequence_length=sequence_length, + input_ids_length=input_ids_length, + ) + + return onnx_file_path, output_indices_to_be_cached, kv_cache_data_type diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index b9569d9f0d..1be380542a 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -22,7 +22,7 @@ from deepsparse import Pipeline from deepsparse.transformers.utils.helpers import ( create_causal_mask, - overwrite_onnx_model_inputs, + overwrite_onnx_model_inputs_for_kv_cache_models, ) from deepsparse.utils.onnx import CACHE_INPUT_PREFIX from sparsezoo import Model @@ -216,7 +216,7 @@ def _get_cache_state_ort_kv_cache(model_onnx_path, sequence, model_name): # setup model and session # (run full sequence inference) - overwrite_onnx_model_inputs( + overwrite_onnx_model_inputs_for_kv_cache_models( model_onnx_path, sequence_length=128, input_ids_length=128 ) sess = onnxruntime.InferenceSession(model_onnx_path)