Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Jul 17, 2023
1 parent b319fa3 commit 5da22c6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
12 changes: 10 additions & 2 deletions src/deepsparse/transformers/engines/nl_decoder_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine
from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache
from deepsparse.transformers.utils.helpers import generate_session_id, softmax
from deepsparse.utils.onnx import translate_onnx_type_to_numpy
from sparsezoo.utils.onnx import save_onnx


Expand Down Expand Up @@ -66,6 +67,8 @@ def __init__(
engine_context: Optional[Context] = None,
use_deepsparse_cache=False,
):
# flag to indicate if the model is quantized or not
self.kv_cache_data_type = None

onnx_file_path, output_indices_to_be_cached = self.overwrite_onnx_model_inputs(
onnx_file_path=onnx_file_path,
Expand Down Expand Up @@ -173,8 +176,8 @@ def transfer_cache_state(self, cache: DecoderKVCache):
"""
self.kv_cache = copy.deepcopy(cache)

@staticmethod
def overwrite_onnx_model_inputs(
self,
onnx_file_path: str,
sequence_length: int,
input_ids_length: int,
Expand Down Expand Up @@ -227,6 +230,11 @@ def overwrite_onnx_model_inputs(
1 if inp.name.startswith("present") else 0 for inp in model.graph.output
]

kv_cache_elem_type = next(
inp for inp in model.graph.input if inp.name.startswith(_CACHE_INPUT_NAME)
).type.tensor_type.elem_type
self.kv_cache_data_type = translate_onnx_type_to_numpy(kv_cache_elem_type)

return onnx_file_path, output_indices_to_be_cached

def generate_token(self, logits: numpy.ndarray) -> numpy.ndarray:
Expand Down Expand Up @@ -319,7 +327,7 @@ def _initialize_kv_cache_state(self, length: int) -> Dict[str, numpy.ndarray]:

empty_kv_cache_tensor = numpy.zeros(
(batch_size, num_attention_heads, length, hidden_dims),
dtype=numpy.float32,
dtype=self.kv_cache_data_type,
)

cache_keys = [
Expand Down
18 changes: 5 additions & 13 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
# limitations under the License.

import logging
import warnings
from typing import List, Optional, Tuple, Type, Union

import numpy
from pydantic import BaseModel, Field

from deepsparse import Pipeline
from deepsparse.cpu import cpu_avx512_compatible
from deepsparse.pipeline import DEEPSPARSE_ENGINE
from deepsparse.transformers.engines import NLDecoderEngine
from deepsparse.transformers.pipelines import TransformersPipeline
Expand Down Expand Up @@ -117,19 +115,9 @@ def __init__(
# TODO: Set this to 64 once we modify the OPT injection logic
prompt_processing_sequence_length: int = 128,
force_max_tokens: bool = False,
use_deepsparse_cache: bool = True,
use_deepsparse_cache: bool = False,
**kwargs,
):
print(cpu_avx512_compatible())
if not cpu_avx512_compatible() and kwargs["engine_type"] == DEEPSPARSE_ENGINE:
warnings.warn(
"Detected CPU is not AVX512 compatible. "
"The kv cache management will not be supported "
"by the optimized engine. The user may experience "
"non optimal performance."
)
use_deepsparse_cache = False

if use_deepsparse_cache:
if kwargs["engine_type"] != DEEPSPARSE_ENGINE:
raise ValueError(
Expand All @@ -138,6 +126,10 @@ def __init__(
f"is {kwargs['engine_type']}. "
f"Make sure to set `engine_type` to {DEEPSPARSE_ENGINE}"
)
raise NotImplementedError(
"The deepsparse kv cache is not yet "
"supported for text generation pipelines"
)

super().__init__(
**kwargs, _delay_engine_initialize=True, _delay_overwriting_inputs=True
Expand Down

0 comments on commit 5da22c6

Please sign in to comment.