Skip to content

Commit

Permalink
[Text Generation] Detect dtype of kv cache (float32/uint8) for text g…
Browse files Browse the repository at this point in the history
…eneration models (#1123)

* initial implementation

* initial commit
  • Loading branch information
dbogunowicz authored Jul 18, 2023
1 parent dc788db commit ad998df
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine
from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache
from deepsparse.transformers.utils.helpers import generate_session_id, softmax
from deepsparse.utils.onnx import translate_onnx_type_to_numpy
from sparsezoo.utils.onnx import save_onnx


Expand Down Expand Up @@ -66,6 +67,8 @@ def __init__(
engine_context: Optional[Context] = None,
use_deepsparse_cache=False,
):
# flag to indicate if the model is quantized or not
self.kv_cache_data_type = None

onnx_file_path, output_indices_to_be_cached = self.overwrite_onnx_model_inputs(
onnx_file_path=onnx_file_path,
Expand Down Expand Up @@ -173,8 +176,8 @@ def transfer_cache_state(self, cache: DecoderKVCache):
"""
self.kv_cache = copy.deepcopy(cache)

@staticmethod
def overwrite_onnx_model_inputs(
self,
onnx_file_path: str,
sequence_length: int,
input_ids_length: int,
Expand Down Expand Up @@ -227,6 +230,11 @@ def overwrite_onnx_model_inputs(
1 if inp.name.startswith("present") else 0 for inp in model.graph.output
]

kv_cache_elem_type = next(
inp for inp in model.graph.input if inp.name.startswith(_CACHE_INPUT_NAME)
).type.tensor_type.elem_type
self.kv_cache_data_type = translate_onnx_type_to_numpy(kv_cache_elem_type)

return onnx_file_path, output_indices_to_be_cached

def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray:
Expand Down Expand Up @@ -319,7 +327,7 @@ def _initialize_kv_cache_state(self, length: int) -> Dict[str, numpy.ndarray]:

empty_kv_cache_tensor = numpy.zeros(
(batch_size, num_attention_heads, length, hidden_dims),
dtype=numpy.float32,
dtype=self.kv_cache_data_type,
)

cache_keys = [
Expand Down

0 comments on commit ad998df

Please sign in to comment.