From 83b14128c4bec1e93f854327280f2eb4527c972e Mon Sep 17 00:00:00 2001 From: Benjamin Date: Thu, 13 Jul 2023 17:04:34 -0400 Subject: [PATCH] [TextGeneration][Timer] text gen specific timings + improved timing tooling --- .../transformers/pipelines/text_generation.py | 77 +++++++++++-------- src/deepsparse/utils/timer.py | 35 ++++++++- 2 files changed, 76 insertions(+), 36 deletions(-) diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index f74696d37a3..e05fd15908e 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -30,6 +30,12 @@ __all__ = ["TextGenerationPipeline"] +PROMPT_PREFILL = "engine_prompt_prefill" +PROMPT_PREFILL_SINGLE = "engine_prompt_prefill_single" +TOKEN_GENERATION = "engine_token_generation" +TOKEN_GENERATION_SINGLE = "engine_token_generation_single" + + class TextGenerationInput(BaseModel): sequences: Union[str, List[str]] = Field( description="The input sequences to generate the text from.", @@ -304,35 +310,41 @@ def engine_forward( sequence of generated tokens and a sequence of logits for each generated token """ - if not self.multitoken_engine.kv_cache_enabled: - tokens, prompt_logits = self.multitoken_engine(engine_inputs) - return numpy.array([tokens]), prompt_logits - - else: - # run the prompt through - tokens, prompt_logits = self.prompt_inference(engine_inputs) - - # create the generated output - max_tokens = ( - self.max_generated_tokens - if self.max_generated_tokens and self.max_generated_tokens > 0 - else 100 * self.sequence_length - ) # set safety for absolute max generation - - generated_tokens = [tokens[-1]] - generated_logits = prompt_logits - - while len(generated_tokens) < max_tokens: - ( - token, - logits, - ) = self.autoregressive_inference(tokens) - tokens.append(token) - generated_tokens.append(token) - generated_logits.append(logits) - - if token == self.tokenizer.eos_token_id and not self.force_max_tokens: - break + # engine_forward is always called in a threadpool due to batch splitting + # as such, a new context needs to be created since we are no longer in the + # main thread. That is why `engine_` is prepended to each of the timer phase + # names in this context + with self.timer_manager.new_timer_context(total_inference=False) as timer: + if not self.multitoken_engine.kv_cache_enabled: + tokens, prompt_logits = self.multitoken_engine(engine_inputs) + return numpy.array([tokens]), prompt_logits + + else: + # run the prompt through + with timer.time(PROMPT_PREFILL): + tokens, prompt_logits = self.prompt_inference(engine_inputs) + + # create the generated output + max_tokens = ( + self.max_generated_tokens + if self.max_generated_tokens and self.max_generated_tokens > 0 + else 100 * self.sequence_length + ) # set safety for absolute max generation + + generated_tokens = [tokens[-1]] + generated_logits = prompt_logits + + timer.start(TOKEN_GENERATION) + while len(generated_tokens) < max_tokens: + with timer.time(TOKEN_GENERATION_SINGLE): + token, logits = self.autoregressive_inference(tokens) + tokens.append(token) + generated_tokens.append(token) + generated_logits.append(logits) + + if token == self.tokenizer.eos_token_id and not self.force_max_tokens: + break + timer.stop(TOKEN_GENERATION) return numpy.array([generated_tokens]), numpy.concatenate( generated_logits, axis=1 @@ -388,9 +400,10 @@ def prompt_inference( for token in tokens[num_tokens_processed:]: run_tokens.append(token) - new_token, new_logits = self.autoregressive_inference( - run_tokens, shift_positions_by_one=not bool(num_tokens_processed) - ) + with self.timer_manager.current.time(PROMPT_PREFILL_SINGLE): + new_token, new_logits = self.autoregressive_inference( + run_tokens, shift_positions_by_one=not bool(num_tokens_processed) + ) prompt_logits.append(new_logits) tokens.append(new_token) diff --git a/src/deepsparse/utils/timer.py b/src/deepsparse/utils/timer.py index b29bc350658..90b915f3bda 100644 --- a/src/deepsparse/utils/timer.py +++ b/src/deepsparse/utils/timer.py @@ -114,6 +114,26 @@ def has_stage(self, stage: str) -> bool: """ return stage in self.stages + @contextmanager + def time(self, stage: str): + """ + Context Manager to record the time for a stage in the given context + + example: + ``` + with timer.time(STAGE_NAME): + # do something... + ``` + + :param stage: the name of the stage to time + """ + self.start(stage) + + try: + yield + finally: + self.stop(stage) + def start(self, stage: str): """ Start the timer for a specific stage. If the stage doesn't exist, @@ -322,16 +342,22 @@ def all_times(self) -> Dict[str, List[float]]: return all_times @contextmanager - def new_timer_context(self) -> StagedTimer: + def new_timer_context(self, total_inference: bool = True) -> StagedTimer: """ Create a new StagedTimer object and set it as the current context. + :param total_inference: if True, measures the entire context as total inference + automatically and assumes this is the main inference thread. if False, + assumes this is not the main inference thread and will not overwrite + any other timers in non-multi/benchmark mode. Default True :return: the new StagedTimer object. """ timer = StagedTimer(enabled=self.enabled) - timer.start(InferenceStages.TOTAL_INFERENCE) - if self.multi: + if total_inference: + timer.start(InferenceStages.TOTAL_INFERENCE) + + if self.multi or not total_inference: self._timers.append(timer) else: self._timers = [timer] @@ -341,4 +367,5 @@ def new_timer_context(self) -> StagedTimer: try: yield timer finally: - timer.stop(InferenceStages.TOTAL_INFERENCE) + if total_inference: + timer.stop(InferenceStages.TOTAL_INFERENCE)