Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Text Generation] Detect dtype of kv cache (float32/uint8) for text generation models #1123

Merged
merged 3 commits into from
Jul 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine
from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache
from deepsparse.transformers.utils.helpers import generate_session_id, softmax
from deepsparse.utils.onnx import translate_onnx_type_to_numpy
from sparsezoo.utils.onnx import save_onnx


Expand Down Expand Up @@ -66,6 +67,8 @@ def __init__(
engine_context: Optional[Context] = None,
use_deepsparse_cache=False,
):
# flag to indicate if the model is quantized or not
self.kv_cache_data_type = None
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved

onnx_file_path, output_indices_to_be_cached = self.overwrite_onnx_model_inputs(
onnx_file_path=onnx_file_path,
Expand Down Expand Up @@ -173,8 +176,8 @@ def transfer_cache_state(self, cache: DecoderKVCache):
"""
self.kv_cache = copy.deepcopy(cache)

@staticmethod
def overwrite_onnx_model_inputs(
self,
onnx_file_path: str,
sequence_length: int,
input_ids_length: int,
Expand Down Expand Up @@ -227,6 +230,11 @@ def overwrite_onnx_model_inputs(
1 if inp.name.startswith("present") else 0 for inp in model.graph.output
]

kv_cache_elem_type = next(
inp for inp in model.graph.input if inp.name.startswith(_CACHE_INPUT_NAME)
).type.tensor_type.elem_type
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
self.kv_cache_data_type = translate_onnx_type_to_numpy(kv_cache_elem_type)

return onnx_file_path, output_indices_to_be_cached

def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray:
Expand Down Expand Up @@ -319,7 +327,7 @@ def _initialize_kv_cache_state(self, length: int) -> Dict[str, numpy.ndarray]:

empty_kv_cache_tensor = numpy.zeros(
(batch_size, num_attention_heads, length, hidden_dims),
dtype=numpy.float32,
dtype=self.kv_cache_data_type,
)

cache_keys = [
Expand Down
Loading