From 042cb79fd78cbcb3d565b535ccbde39d34d2c1b1 Mon Sep 17 00:00:00 2001 From: Damian Date: Fri, 7 Jul 2023 13:49:31 +0000 Subject: [PATCH] fixed the logic to assert correct multibatch inference --- src/deepsparse/transformers/metrics.py | 2 +- .../transformers/pipelines/text_generation.py | 33 +++++++++++-------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/deepsparse/transformers/metrics.py b/src/deepsparse/transformers/metrics.py index ef5dd521eb..76fd00b471 100644 --- a/src/deepsparse/transformers/metrics.py +++ b/src/deepsparse/transformers/metrics.py @@ -83,7 +83,7 @@ def add_batch(self, predictions: List[str]): attention_mask = attention_masks[start_index:end_index] out = self._pipeline( - sequences=predictions, return_logits=True, truncate=True + sequences=predictions, return_logits=True, fixed_sequences_length=True ) logits = out.logits diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index 6e6f42d625..b0547992d7 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -46,11 +46,13 @@ class TextGenerationInput(BaseModel): "and the model is using kv cache, it " "will be set to a random uuid.", ) - truncate: bool = Field( + fixed_sequences_length: bool = Field( default=False, - description="A flag that indicates whether to truncate " - "the input text sequence. Useful, when a batch of " - "predictions needs to have consistent length so one" + description="A flag that indicates whether to modify " + "(pad or truncate) each input text sequence, so that " + "its tokenized length is equal to `sequence_length` " + "of tokens. Useful, when a batch of predictions needs " + "to have consistent length so one " "can compute metric in a batched fashion. ", ) @@ -235,17 +237,26 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]: :return: the inputs for the engine """ - if self._are_sequences_of_different_lengths(inputs.sequences): - padding = "longest" - else: + if inputs.fixed_sequences_length: + # to enforce a fixed sequence length, we need to + # truncate the input to the maximum sequence length + # or/and pad it to the maximum sequence length + truncate = True padding = "max_length" + else: + # otherwise, we do not need to truncate the input + # and we shall can pad it to the longest sequence + # in the batch (so that the engine can process multiple inputs + # at once) + truncate = False + padding = "longest" input_tokens = self.tokenizer( inputs.sequences, return_tensors="np", max_length=self.sequence_length, padding=padding, - truncation=inputs.truncate, + truncation=truncate, ) attention_mask = input_tokens["attention_mask"] @@ -429,9 +440,3 @@ def autoregressive_inference( def _reset_engines_cache(self): self.engine.reset_kv_cache() self.multitoken_engine.reset_kv_cache() - - @staticmethod - def _are_sequences_of_different_lengths(sequences: Union[str, List[str]]) -> bool: - if isinstance(sequences, str): - return False - return len(set([len(sequence) for sequence in sequences])) > 1