Skip to content

Commit

Permalink
fixed the logic to assert correct multibatch inference
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Jul 7, 2023
1 parent e81c327 commit 042cb79
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/deepsparse/transformers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def add_batch(self, predictions: List[str]):
attention_mask = attention_masks[start_index:end_index]

out = self._pipeline(
sequences=predictions, return_logits=True, truncate=True
sequences=predictions, return_logits=True, fixed_sequences_length=True
)
logits = out.logits

Expand Down
33 changes: 19 additions & 14 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,13 @@ class TextGenerationInput(BaseModel):
"and the model is using kv cache, it "
"will be set to a random uuid.",
)
truncate: bool = Field(
fixed_sequences_length: bool = Field(
default=False,
description="A flag that indicates whether to truncate "
"the input text sequence. Useful, when a batch of "
"predictions needs to have consistent length so one"
description="A flag that indicates whether to modify "
"(pad or truncate) each input text sequence, so that "
"its tokenized length is equal to `sequence_length` "
"of tokens. Useful, when a batch of predictions needs "
"to have consistent length so one "
"can compute metric in a batched fashion. ",
)

Expand Down Expand Up @@ -235,17 +237,26 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
:return: the inputs for the engine
"""

if self._are_sequences_of_different_lengths(inputs.sequences):
padding = "longest"
else:
if inputs.fixed_sequences_length:
# to enforce a fixed sequence length, we need to
# truncate the input to the maximum sequence length
# or/and pad it to the maximum sequence length
truncate = True
padding = "max_length"
else:
# otherwise, we do not need to truncate the input
# and we shall can pad it to the longest sequence
# in the batch (so that the engine can process multiple inputs
# at once)
truncate = False
padding = "longest"

input_tokens = self.tokenizer(
inputs.sequences,
return_tensors="np",
max_length=self.sequence_length,
padding=padding,
truncation=inputs.truncate,
truncation=truncate,
)

attention_mask = input_tokens["attention_mask"]
Expand Down Expand Up @@ -429,9 +440,3 @@ def autoregressive_inference(
def _reset_engines_cache(self):
self.engine.reset_kv_cache()
self.multitoken_engine.reset_kv_cache()

@staticmethod
def _are_sequences_of_different_lengths(sequences: Union[str, List[str]]) -> bool:
if isinstance(sequences, str):
return False
return len(set([len(sequence) for sequence in sequences])) > 1

0 comments on commit 042cb79

Please sign in to comment.