From 3959eaccba53c444c5705d600d333cc3d47bc06c Mon Sep 17 00:00:00 2001 From: Abhi Venigalla <77638579+abhi-mosaic@users.noreply.github.com> Date: Fri, 5 May 2023 18:00:45 -0700 Subject: [PATCH] Update inference benchmarking script (#55) --- scripts/inference/benchmarking/README.md | 4 +- scripts/inference/benchmarking/benchmark.py | 119 ++++++++----------- scripts/inference/benchmarking/yamls/1b.yaml | 10 +- scripts/inference/benchmarking/yamls/7b.yaml | 12 +- 4 files changed, 60 insertions(+), 85 deletions(-) diff --git a/scripts/inference/benchmarking/README.md b/scripts/inference/benchmarking/README.md index a33290ded3..ad49697d21 100644 --- a/scripts/inference/benchmarking/README.md +++ b/scripts/inference/benchmarking/README.md @@ -2,13 +2,13 @@ This folder provides scripts for benchmarking the inference performance of deep learning models. Currently, we support benchmarking with Deepspeed and Huggingface generate. ## Scripts -The repository includes the benchmark.py script, along with associated `.yaml files,` to run benchmarking. The script takes a `.yaml` file as input and outputs the latency (in seconds) and tokens per second for each run. We average over `num_runs=5`, which is defined in the `.yaml` file. Additionally, we iterate over various `batch_sizes`, `input_lengths`, and `output_lengths` to produce varying throughput metrics. +The repository includes the benchmark.py script, along with associated `.yaml files,` to run benchmarking. The script takes a `.yaml` file as input and outputs the latency (in seconds) and tokens per second for each run. We average over `num_batches=5`, which is defined in the `.yaml` file. Additionally, we iterate over various `batch_sizes`, `input_lengths`, and `output_lengths` to produce varying throughput metrics. ## Usage To use the `benchmark.py` script, you need to provide a `.yaml` file that specifies the model configuration and other parameters such as the path to the model checkpoint and the input data. You can modify the default `.yaml` files provided in the repository or create your own `.yaml` file. To run the benchmarking script, use the following command: -`python benchmark.py config.yaml` +`python benchmark.py yamls/1b.yaml` To run the scripts on [The MosaicML platform](https://www.mosaicml.com/blog/mosaicml-cloud-demo) we've also included scripts and associated `.yaml files` in the `mcloud` folder. diff --git a/scripts/inference/benchmarking/benchmark.py b/scripts/inference/benchmarking/benchmark.py index d94a6c1f0d..8209fe21bd 100644 --- a/scripts/inference/benchmarking/benchmark.py +++ b/scripts/inference/benchmarking/benchmark.py @@ -1,11 +1,10 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import contextlib import sys import time +from contextlib import nullcontext -import numpy as np import torch # You can use this to load the model weights from omegaconf import OmegaConf as om @@ -13,31 +12,41 @@ from llmfoundry import COMPOSER_MODEL_REGISTRY -def get_precision(precision): - if precision == 'fp32': +def get_dtype(dtype): + if dtype == 'fp32': return torch.float32 - elif precision == 'fp16': + elif dtype == 'fp16': return torch.float16 - elif precision == 'bf16': + elif dtype == 'bf16': return torch.bfloat16 else: raise NotImplementedError( - f'Precision of type {precision} is not supported. ' - f'We only support fp32, amp_fp16, and amp_bf16 currently') + f'dtype {dtype} is not supported. ' + f'We only support fp32, fp16, and bf16 currently') -def compare_precision(precision, param_dtype): - if precision != param_dtype: +def compare_dtype(dtype, param_dtype): + if dtype != param_dtype: raise ValueError( - f'Precision type is: {precision} but model dtype is: {param_dtype}. ' - f"The expected precision and model precision don't match.") + f'dtype type is: {dtype} but model dtype is: {param_dtype}. ' + f"The expected dtype and model dtype don't match.") def main(config): - model_dtype = get_precision(config.model_dtype) - autocast_precision = None - if config.autocast_precision is not None: - autocast_precision = get_precision(config.autocast_precision) + if config.device is not None: + device = config.device + else: + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + model_dtype = get_dtype(config.model_dtype) + print(f'Using device={device} and dtype={model_dtype}...') + + if config.autocast_dtype is not None: + autocast_dtype = get_dtype(config.autocast_dtype) + autocast_context = torch.autocast(device, autocast_dtype) + print(f'Using autocast with dtype={autocast_dtype}...') + else: + autocast_context = nullcontext() + print('NOT using autocast...') inference_config = { 'replace_with_kernel_inject': True, @@ -51,9 +60,7 @@ def main(config): composer_model = COMPOSER_MODEL_REGISTRY[config.model.name]( config.model, config.tokenizer) - model = composer_model.model - model.eval() if config.use_deepspeed: @@ -62,90 +69,58 @@ def main(config): # Checking if deepspeed casts dtypes correctly for _, p in model.named_parameters(): - compare_precision(model_dtype, p.dtype) + compare_dtype(model_dtype, p.dtype) break else: - model.to(torch.cuda.current_device()) - model.to(model_dtype) + model.to(device=device, dtype=model_dtype) n_params = sum(p.numel() for p in model.parameters()) print('n_params is: ', n_params) - print('name, latency (s), tokens / s, output token time (ms)') + print( + 'name, latency (s), throughput (tokens/s), latency_per_sequence_output_token (ms)' + ) print('=' * 75) - stats = [] for batch_size in config.batch_sizes: for input_length in config.input_lengths: for output_length in config.output_lengths: - times = [] - - batch = torch.randint( - 0, - config.model.vocab_size - 1, - size=( - batch_size, - input_length)).to(f'cuda:{torch.cuda.current_device()}') + batch = torch.randint(0, + config.model.vocab_size - 1, + size=(batch_size, + input_length)).to(device) # We're just going to have generate eos, padding tokens be # ignored by HF generate batch = batch.to(torch.long) attention_mask = torch.ones_like(batch) - torch.cuda.synchronize() - - for i in range(config.num_runs + 1): - start_time = time.time() + start_time = 0 + for i in range(config.num_batches + config.num_warmup_batches): + if i == config.num_warmup_batches: + torch.cuda.synchronize() + start_time = time.time() with torch.no_grad(): - precision_context = contextlib.nullcontext() - if autocast_precision is not None and autocast_precision in [ - 'fp16', 'bf16' - ]: - precision_context = torch.cuda.amp.autocast( - True, dtype=autocast_precision) - - with precision_context: + with autocast_context: model.generate(batch, max_new_tokens=output_length, - use_cache=True, + use_cache=config.use_cache, attention_mask=attention_mask, eos_token_id=None, pad_token_id=None) - torch.cuda.synchronize() - - # We noticed there sometimes might be a small bit of startup time - # so we only start to benchmark after some number of batches - if i >= config.num_warmup_batches: - times.append(time.time() - start_time) + torch.cuda.synchronize() + mean_time = (time.time() - start_time) / config.num_batches num_output_tokens = output_length * batch_size - mean_time = np.mean(times) - tokens_per_second = num_output_tokens / float(mean_time) - ms_per_seq_output_token = float( - mean_time) * 1000 / num_output_tokens - - result = ( - f'{config.benchmark_name}_{batch_size}_{input_length}_{output_length}', - f'{mean_time:.3f}', f'{tokens_per_second:.3f}', - f'{ms_per_seq_output_token:.3f}') - - run_name, latency, tokens_per_second, ms_per_seq_output_token = result + tokens_per_second = num_output_tokens / mean_time + ms_per_seq_output_token = mean_time * 1000 / output_length + run_name = f'{config.benchmark_name}_{batch_size}_{input_length}_{output_length}' print( - f'{run_name}, {latency}, {tokens_per_second}, {ms_per_seq_output_token}' + f'{run_name}, {mean_time:.3f}, {tokens_per_second:.3f}, {ms_per_seq_output_token:.3f}' ) - stats.append(result) - - print('=' * 75) - print('name, latency (s), tokens / s, output token time (ms)') - for val in stats: - run_name, latency, tokens_per_second, ms_per_seq_output_token = val - print( - f'{run_name}, latency (s) {latency}, tokens per second {tokens_per_second}, output token time (ms) {ms_per_seq_output_token}' - ) - if __name__ == '__main__': yaml_path, args_list = sys.argv[1], sys.argv[2:] diff --git a/scripts/inference/benchmarking/yamls/1b.yaml b/scripts/inference/benchmarking/yamls/1b.yaml index c63552871c..f94aa3d806 100644 --- a/scripts/inference/benchmarking/yamls/1b.yaml +++ b/scripts/inference/benchmarking/yamls/1b.yaml @@ -8,7 +8,6 @@ tokenizer: name: ${tokenizer_name} kwargs: model_max_length: ${max_seq_len} - non_eos_token_id: 17 model: name: mpt_causal_lm @@ -27,14 +26,15 @@ model: attn_config: attn_impl: triton -autocast_precision: bf16 +device: null model_dtype: bf16 +autocast_dtype: null +use_deepspeed: false batch_sizes: [1, 2, 4, 8, 16, 32, 64] input_lengths: [128] output_lengths: [8] -num_runs: 5 +use_cache: true +num_batches: 5 num_warmup_batches: 3 - -use_deepspeed: false diff --git a/scripts/inference/benchmarking/yamls/7b.yaml b/scripts/inference/benchmarking/yamls/7b.yaml index ee6ff34f15..55e9ae8413 100644 --- a/scripts/inference/benchmarking/yamls/7b.yaml +++ b/scripts/inference/benchmarking/yamls/7b.yaml @@ -8,7 +8,6 @@ tokenizer: name: ${tokenizer_name} kwargs: model_max_length: ${max_seq_len} - non_eos_token_id: 17 model: name: mpt_causal_lm @@ -27,14 +26,15 @@ model: attn_config: attn_impl: triton -autocast_precision: bf16 -model_dtype: fp32 +device: null +model_dtype: bf16 +autocast_dtype: null +use_deepspeed: false batch_sizes: [1, 2, 4, 8, 16, 32, 64] input_lengths: [128] output_lengths: [8] -num_runs: 5 +use_cache: true +num_batches: 5 num_warmup_batches: 3 - -use_deepspeed: false