Skip to content

Commit

Permalink
[TextGeneration][Timer] text gen specific timings + improved timing t…
Browse files Browse the repository at this point in the history
…ooling
  • Loading branch information
Benjamin committed Jul 13, 2023
1 parent 8a26435 commit 83b1412
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 36 deletions.
77 changes: 45 additions & 32 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
__all__ = ["TextGenerationPipeline"]


PROMPT_PREFILL = "engine_prompt_prefill"
PROMPT_PREFILL_SINGLE = "engine_prompt_prefill_single"
TOKEN_GENERATION = "engine_token_generation"
TOKEN_GENERATION_SINGLE = "engine_token_generation_single"


class TextGenerationInput(BaseModel):
sequences: Union[str, List[str]] = Field(
description="The input sequences to generate the text from.",
Expand Down Expand Up @@ -304,35 +310,41 @@ def engine_forward(
sequence of generated tokens and a sequence
of logits for each generated token
"""
if not self.multitoken_engine.kv_cache_enabled:
tokens, prompt_logits = self.multitoken_engine(engine_inputs)
return numpy.array([tokens]), prompt_logits

else:
# run the prompt through
tokens, prompt_logits = self.prompt_inference(engine_inputs)

# create the generated output
max_tokens = (
self.max_generated_tokens
if self.max_generated_tokens and self.max_generated_tokens > 0
else 100 * self.sequence_length
) # set safety for absolute max generation

generated_tokens = [tokens[-1]]
generated_logits = prompt_logits

while len(generated_tokens) < max_tokens:
(
token,
logits,
) = self.autoregressive_inference(tokens)
tokens.append(token)
generated_tokens.append(token)
generated_logits.append(logits)

if token == self.tokenizer.eos_token_id and not self.force_max_tokens:
break
# engine_forward is always called in a threadpool due to batch splitting
# as such, a new context needs to be created since we are no longer in the
# main thread. That is why `engine_` is prepended to each of the timer phase
# names in this context
with self.timer_manager.new_timer_context(total_inference=False) as timer:
if not self.multitoken_engine.kv_cache_enabled:
tokens, prompt_logits = self.multitoken_engine(engine_inputs)
return numpy.array([tokens]), prompt_logits

else:
# run the prompt through
with timer.time(PROMPT_PREFILL):
tokens, prompt_logits = self.prompt_inference(engine_inputs)

# create the generated output
max_tokens = (
self.max_generated_tokens
if self.max_generated_tokens and self.max_generated_tokens > 0
else 100 * self.sequence_length
) # set safety for absolute max generation

generated_tokens = [tokens[-1]]
generated_logits = prompt_logits

timer.start(TOKEN_GENERATION)
while len(generated_tokens) < max_tokens:
with timer.time(TOKEN_GENERATION_SINGLE):
token, logits = self.autoregressive_inference(tokens)
tokens.append(token)
generated_tokens.append(token)
generated_logits.append(logits)

if token == self.tokenizer.eos_token_id and not self.force_max_tokens:
break
timer.stop(TOKEN_GENERATION)

return numpy.array([generated_tokens]), numpy.concatenate(
generated_logits, axis=1
Expand Down Expand Up @@ -388,9 +400,10 @@ def prompt_inference(

for token in tokens[num_tokens_processed:]:
run_tokens.append(token)
new_token, new_logits = self.autoregressive_inference(
run_tokens, shift_positions_by_one=not bool(num_tokens_processed)
)
with self.timer_manager.current.time(PROMPT_PREFILL_SINGLE):
new_token, new_logits = self.autoregressive_inference(
run_tokens, shift_positions_by_one=not bool(num_tokens_processed)
)
prompt_logits.append(new_logits)

tokens.append(new_token)
Expand Down
35 changes: 31 additions & 4 deletions src/deepsparse/utils/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,26 @@ def has_stage(self, stage: str) -> bool:
"""
return stage in self.stages

@contextmanager
def time(self, stage: str):
"""
Context Manager to record the time for a stage in the given context
example:
```
with timer.time(STAGE_NAME):
# do something...
```
:param stage: the name of the stage to time
"""
self.start(stage)

try:
yield
finally:
self.stop(stage)

def start(self, stage: str):
"""
Start the timer for a specific stage. If the stage doesn't exist,
Expand Down Expand Up @@ -322,16 +342,22 @@ def all_times(self) -> Dict[str, List[float]]:
return all_times

@contextmanager
def new_timer_context(self) -> StagedTimer:
def new_timer_context(self, total_inference: bool = True) -> StagedTimer:
"""
Create a new StagedTimer object and set it as the current context.
:param total_inference: if True, measures the entire context as total inference
automatically and assumes this is the main inference thread. if False,
assumes this is not the main inference thread and will not overwrite
any other timers in non-multi/benchmark mode. Default True
:return: the new StagedTimer object.
"""
timer = StagedTimer(enabled=self.enabled)
timer.start(InferenceStages.TOTAL_INFERENCE)

if self.multi:
if total_inference:
timer.start(InferenceStages.TOTAL_INFERENCE)

if self.multi or not total_inference:
self._timers.append(timer)
else:
self._timers = [timer]
Expand All @@ -341,4 +367,5 @@ def new_timer_context(self) -> StagedTimer:
try:
yield timer
finally:
timer.stop(InferenceStages.TOTAL_INFERENCE)
if total_inference:
timer.stop(InferenceStages.TOTAL_INFERENCE)

0 comments on commit 83b1412

Please sign in to comment.