From ad998dfea683a46827dbc41374baedb04f4b046a Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Tue, 18 Jul 2023 15:42:01 +0200 Subject: [PATCH] [Text Generation] Detect dtype of kv cache (float32/uint8) for text generation models (#1123) * initial implementation * initial commit --- .../transformers/engines/nl_decoder_engine.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/deepsparse/transformers/engines/nl_decoder_engine.py b/src/deepsparse/transformers/engines/nl_decoder_engine.py index 6ca2b81dc7..d75d051e56 100644 --- a/src/deepsparse/transformers/engines/nl_decoder_engine.py +++ b/src/deepsparse/transformers/engines/nl_decoder_engine.py @@ -23,6 +23,7 @@ from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache from deepsparse.transformers.utils.helpers import generate_session_id, softmax +from deepsparse.utils.onnx import translate_onnx_type_to_numpy from sparsezoo.utils.onnx import save_onnx @@ -66,6 +67,8 @@ def __init__( engine_context: Optional[Context] = None, use_deepsparse_cache=False, ): + # flag to indicate if the model is quantized or not + self.kv_cache_data_type = None onnx_file_path, output_indices_to_be_cached = self.overwrite_onnx_model_inputs( onnx_file_path=onnx_file_path, @@ -173,8 +176,8 @@ def transfer_cache_state(self, cache: DecoderKVCache): """ self.kv_cache = copy.deepcopy(cache) - @staticmethod def overwrite_onnx_model_inputs( + self, onnx_file_path: str, sequence_length: int, input_ids_length: int, @@ -227,6 +230,11 @@ def overwrite_onnx_model_inputs( 1 if inp.name.startswith("present") else 0 for inp in model.graph.output ] + kv_cache_elem_type = next( + inp for inp in model.graph.input if inp.name.startswith(_CACHE_INPUT_NAME) + ).type.tensor_type.elem_type + self.kv_cache_data_type = translate_onnx_type_to_numpy(kv_cache_elem_type) + return onnx_file_path, output_indices_to_be_cached def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray: @@ -319,7 +327,7 @@ def _initialize_kv_cache_state(self, length: int) -> Dict[str, numpy.ndarray]: empty_kv_cache_tensor = numpy.zeros( (batch_size, num_attention_heads, length, hidden_dims), - dtype=numpy.float32, + dtype=self.kv_cache_data_type, ) cache_keys = [