Skip to content

Commit

Permalink
initial implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Jul 17, 2023
1 parent 8a26435 commit b319fa3
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.

import logging
import warnings
from typing import List, Optional, Tuple, Type, Union

import numpy
from pydantic import BaseModel, Field

from deepsparse import Pipeline
from deepsparse.cpu import cpu_avx512_compatible
from deepsparse.pipeline import DEEPSPARSE_ENGINE
from deepsparse.transformers.engines import NLDecoderEngine
from deepsparse.transformers.pipelines import TransformersPipeline
Expand Down Expand Up @@ -115,9 +117,19 @@ def __init__(
# TODO: Set this to 64 once we modify the OPT injection logic
prompt_processing_sequence_length: int = 128,
force_max_tokens: bool = False,
use_deepsparse_cache: bool = False,
use_deepsparse_cache: bool = True,
**kwargs,
):
print(cpu_avx512_compatible())
if not cpu_avx512_compatible() and kwargs["engine_type"] == DEEPSPARSE_ENGINE:
warnings.warn(
"Detected CPU is not AVX512 compatible. "
"The kv cache management will not be supported "
"by the optimized engine. The user may experience "
"non optimal performance."
)
use_deepsparse_cache = False

if use_deepsparse_cache:
if kwargs["engine_type"] != DEEPSPARSE_ENGINE:
raise ValueError(
Expand All @@ -126,10 +138,6 @@ def __init__(
f"is {kwargs['engine_type']}. "
f"Make sure to set `engine_type` to {DEEPSPARSE_ENGINE}"
)
raise NotImplementedError(
"The deepsparse kv cache is not yet "
"supported for text generation pipelines"
)

super().__init__(
**kwargs, _delay_engine_initialize=True, _delay_overwriting_inputs=True
Expand Down

0 comments on commit b319fa3

Please sign in to comment.