Skip to content

Commit

Permalink
Update TensorRT-LLM (#1554)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaiyux authored May 7, 2024
1 parent 06c0e9b commit 89ba1b1
Show file tree
Hide file tree
Showing 270 changed files with 10,605 additions and 3,111 deletions.
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
*.a filter=lfs diff=lfs merge=lfs -text
*.lib filter=lfs diff=lfs merge=lfs -text
*.so filter=lfs diff=lfs merge=lfs -text
*.dll filter=lfs diff=lfs merge=lfs -text
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/bug_report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ body:
- Libraries
- TensorRT-LLM branch or tag (e.g., main, v0.7.1)
- TensorRT-LLM commit (if known)
- Versions of TensorRT, AMMO, CUDA, cuBLAS, etc. used
- Versions of TensorRT, Modelopt, CUDA, cuBLAS, etc. used
- Container used (if running TensorRT-LLM in a container)
- NVIDIA driver version
- OS (Ubuntu 22.04, CentOS 7, Windows 10)
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ __pycache__/
*.nsys-rep
.VSCodeCounter
build*/
*.so
*.egg-info/
.coverage
*.csv
Expand Down Expand Up @@ -34,6 +33,7 @@ tensorrt_llm/bindings.pyi
tensorrt_llm/bindings/*.pyi
*docs/cpp_docs*
*docs/source/_cpp_gen*
*.swp

# Testing
.coverage.*
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ TensorRT-LLM

[![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](https://nvidia.github.io/TensorRT-LLM/)
[![python](https://img.shields.io/badge/python-3.10.12-green)](https://www.python.org/downloads/release/python-31012/)
[![cuda](https://img.shields.io/badge/cuda-12.3-green)](https://developer.nvidia.com/cuda-downloads)
[![trt](https://img.shields.io/badge/TRT-9.3-green)](https://developer.nvidia.com/tensorrt)
[![version](https://img.shields.io/badge/release-0.9.0-green)](./setup.py)
[![cuda](https://img.shields.io/badge/cuda-12.4.0-green)](https://developer.nvidia.com/cuda-downloads)
[![trt](https://img.shields.io/badge/TRT-10.0.1-green)](https://developer.nvidia.com/tensorrt)
[![version](https://img.shields.io/badge/release-0.10.0.dev-green)](./setup.py)
[![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE)

[Architecture](./docs/source/architecture/overview.md)   |   [Results](./docs/source/performance/perf-overview.md)   |   [Examples](./examples/)   |   [Documentation](./docs/source/)
Expand Down
7 changes: 3 additions & 4 deletions benchmarks/cpp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ Given a `static_emulated_batch_size` of `n` the server will wait for `n` request
```
python prepare_dataset.py \
--output tokens-fixed-lengths.json \
--request-rate -1 \
--time-delay-dist constant \
--tokenizer <path/to/tokenizer> \
token-norm-dist \
--num-requests 128 \
Expand All @@ -184,6 +182,7 @@ Take GPT-350M as an example for single GPU with static batching
./benchmarks/gptManagerBenchmark \
--engine_dir ../../examples/gpt/trt_engine/gpt2/fp16/1-gpu/ \
--type IFB \
--request-rate -1 \
--static_emulated_batch_size 32 \
--static_emulated_timeout 100 \
--dataset ../../benchmarks/cpp/tokens-fixed-lengths.json
Expand Down Expand Up @@ -212,6 +211,7 @@ PP=1
MAX_LEN=1024
MAX_BATCH=32
MAX_LORA_RANK=32
NUM_LORA_MODS=7
SOURCE_LORA=chinese-llama-2-lora-13b
CPP_LORA=chinese-llama-2-lora-13b-cpp
Expand Down Expand Up @@ -241,10 +241,9 @@ NUM_LORAS=(8 16 24 32 64 128 256)
NUM_REQUESTS=1024
# Convert LoRA to cpp format
python examples/gpt/nemo_lora_convert.py \
python examples/hf_lora_convert.py \
-i $SOURCE_LORA \
--storage-type $DTYPE \
--write-cpp-runtime-tensors \
-o $CPP_LORA
# Prepare datasets
Expand Down
52 changes: 48 additions & 4 deletions benchmarks/cpp/gptManagerBenchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ struct BenchmarkParams
bool enableExpDelays{false};
std::optional<float> requestRate{std::nullopt};
int randomSeed = 430;
std::optional<int> maxAttentionWindow{std::nullopt};

// lora / peft params
std::optional<std::string> loraDir{std::nullopt};
Expand Down Expand Up @@ -746,8 +747,8 @@ class ExecutorServer

texec::SchedulerConfig schedulerConfig(batch_scheduler::batchManagerToExecSchedPolicy(schedulerPolicy));
texec::KvCacheConfig kvCacheConfig(benchmarkParams.enableBlockReuse, benchmarkParams.maxTokensInPagedKvCache,
std::nullopt, std::nullopt, benchmarkParams.freeGpuMemoryFraction, benchmarkParams.kvHostCacheSize,
benchmarkParams.kvOnboardBlocks);
benchmarkParams.maxAttentionWindow, std::nullopt, benchmarkParams.freeGpuMemoryFraction,
benchmarkParams.kvHostCacheSize, benchmarkParams.kvOnboardBlocks);
texec::PeftCacheConfig peftCacheConfig(0, benchmarkParams.loraDeviceNumModLayers, 8, 64, 4, 4, 4, 24, 8,
std::nullopt, benchmarkParams.loraHostCacheSize);
texec::ExecutorConfig executorConfig(
Expand Down Expand Up @@ -909,6 +910,16 @@ class GptServer
mWorkItemsQueue.clear();
}

std::string getLayerProfileInfo()
{
return mBatchManager->getLayerProfileInfo();
}

void setLayerProfiler()
{
return mBatchManager->setLayerProfiler();
}

void enqueue(std::shared_ptr<InferenceRequest> const& request)
{
TLLM_CHECK(request != nullptr);
Expand Down Expand Up @@ -1267,7 +1278,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
std::chrono::milliseconds waitSleep, bool returnContextLogits, bool returnGenerationLogits,
std::optional<SizeType> const staticEmulatedBatchSize, std::optional<std::chrono::milliseconds> const batchTimeout,
bool logIterationData, bool excludeInputInOutput, std::string const& responsesJsonFile,
std::optional<SizeType> const maxPromptLen)
std::optional<SizeType> const maxPromptLen, bool dumpProfile)
{
TrtGptModelOptionalParams optionalParams;

Expand All @@ -1279,6 +1290,10 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
{
optionalParams.kvCacheConfig.freeGpuMemoryFraction = benchmarkParams.freeGpuMemoryFraction;
}
if (benchmarkParams.maxAttentionWindow)
{
optionalParams.kvCacheConfig.maxAttentionWindow = benchmarkParams.maxAttentionWindow;
}
optionalParams.kvCacheConfig.enableBlockReuse = benchmarkParams.enableBlockReuse;
optionalParams.enableChunkedContext = benchmarkParams.enableChunkedContext;
optionalParams.enableTrtOverlap = benchmarkParams.enableTrtOverlap;
Expand Down Expand Up @@ -1391,6 +1406,23 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
recorder->report();
recorder->writeOpMetricsToCsv();
recorder->dumpResponseSeqs();
if (dumpProfile)
{
// Do per-layer profiling after normal benchmarking to avoid introducing perf overhead.
gptServer->resetBatchDeadline();
gptServer->setLayerProfiler();
for (std::size_t i = 0; i < numSamples; ++i)
{
auto request = makeRequest(i + 1, samples[i], benchmarkParams.streaming, beamWidthTensor, eosIdTensor,
padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor);
gptServer->enqueue(request);
}
gptServer->waitForEmpty();
if (worldConfig.getRank() == 0)
{
printf("[BENCHMARK] Per layer performance profile\n%s\n", gptServer->getLayerProfileInfo().c_str());
}
}
// Send terminateReqId to terminate servers on all ranks
// Server on rank 0 will broadcast the terminate signal to other servers on multi-GPU cases
gptServer->enqueue(std::make_shared<InferenceRequest>(terminateReqId));
Expand Down Expand Up @@ -1554,6 +1586,7 @@ int main(int argc, char* argv[])
"eos_id", "Specify the end-of-sequence token id.", cxxopts::value<TokenIdType>()->default_value("-1"));
options.add_options()("pad_id", "Specify the padding token id.", cxxopts::value<TokenIdType>());
options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>());
options.add_options()("max_attention_window", "Max KV cache length per sequence", cxxopts::value<int>());
options.add_options()(
"random_seed", "integer random seed for exponential time delays.", cxxopts::value<int>()->default_value("420"));
options.add_options()(
Expand Down Expand Up @@ -1614,6 +1647,8 @@ int main(int argc, char* argv[])
options.add_options()(
"max_prompt_len", "Truncate all prompts from dataset to the length specified.", cxxopts::value<SizeType>());

options.add_options()("dump_profile", "Print profile information per layer.", cxxopts::value<bool>());

auto result = options.parse(argc, argv);

if (result.count("help"))
Expand Down Expand Up @@ -1674,6 +1709,12 @@ int main(int argc, char* argv[])
benchmarkParams.maxTokensInPagedKvCache = result["max_tokens_in_paged_kvcache"].as<int>();
}

// Argument: Max KV cache length
if (result.count("max_attention_window"))
{
benchmarkParams.maxAttentionWindow = result["max_attention_window"].as<int>();
}

if (result.count("random_seed"))
{
benchmarkParams.randomSeed = result["random_seed"].as<int>();
Expand Down Expand Up @@ -1811,6 +1852,9 @@ int main(int argc, char* argv[])
return 1;
}

// Argument: dump profile
bool dumpProfile = result["dump_profile"].as<bool>();

initTrtLlmPlugins(logger.get());

if (api == "gptManager")
Expand All @@ -1821,7 +1865,7 @@ int main(int argc, char* argv[])
maxNumSamples, beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams, schedulerPolicy,
waitSleep, returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, batchTimeout,
logIterationData, result["exclude_input_in_output_seq"].as<bool>(),
result["responses_json_file"].as<std::string>(), maxPromptLen);
result["responses_json_file"].as<std::string>(), maxPromptLen, dumpProfile);
}
catch (std::exception const& e)
{
Expand Down
46 changes: 44 additions & 2 deletions benchmarks/cpp/gptSessionBenchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ size_t monitorMemory(std::atomic_bool& done)
void benchmarkGptSession(std::filesystem::path const& dataPath, std::vector<int> const& batchSizes, int beamWidth,
std::vector<std::vector<int>> const& inOutLen, std::shared_ptr<nvinfer1::ILogger> const& logger, int warmUp,
int numRuns, int duration, GptSession::Config& sessionConfig, bool cudaGraphMode, bool printAllLogits,
bool disableForceMaxTokens, bool dumpLayerInfo)
bool disableForceMaxTokens, bool dumpLayerInfo, bool dumpProfile)
{
std::filesystem::path jsonFileName = dataPath / "config.json";
auto const json = GptJsonConfig::parse(jsonFileName);
Expand Down Expand Up @@ -298,6 +298,46 @@ void benchmarkGptSession(std::filesystem::path const& dataPath, std::vector<int>
<< std::endl;
}
}
// Do per-layer profiling after normal benchmarking to avoid introducing perf overhead.
if (dumpProfile)
{
session.setLayerProfiler();
iterIdx = 0;

while (iterIdx < numRuns)
{
auto const start = std::chrono::steady_clock::now();
SizeType numSteps = 0;
generationOutput.onTokenGenerated
= [&numSteps, maxNewTokens](GenerationOutput::TensorPtr const& outputIds, SizeType step,
bool finished) { ++numSteps; };
session.generate(generationOutput, generationInput, samplingConfig, generationProfiler);
bufferManager.getStream().synchronize();
auto const end = std::chrono::steady_clock::now();

iterIdx += 1;
float latency = std::chrono::duration<float, std::milli>(end - start).count();
curDuration += latency;
latencies.emplace_back(latency);
generationTimes.emplace_back(generationProfiler->getElapsedTimeMs());

bool durationLimitReached{curDuration / 1000 >= duration};
if (worldConfig.getSize() > 1)
{
bool result{false};
comm.allreduce(&durationLimitReached, &result, 1, tmpi::MpiType::kBOOL, tmpi::MpiOp::LOR);
durationLimitReached = result;
}
if (durationLimitReached)
{
break;
}
}
if (worldConfig.getRank() == 0)
{
printf("%s\n", session.getLayerProfileInfo().c_str());
}
}
}
catch (std::runtime_error& e)
{
Expand Down Expand Up @@ -377,6 +417,7 @@ int main(int argc, char* argv[])
options.add_options()("print_all_logits", "Print all context and generation logits.");
options.add_options()("disable_force_max_tokens", "Disable force the engine generating new max_tokens.");
options.add_options()("dump_layer_info", "Print layer information of the engine to console.");
options.add_options()("dump_profile", "Print profile information per layer.");

auto result = options.parse(argc, argv);

Expand Down Expand Up @@ -487,14 +528,15 @@ int main(int argc, char* argv[])
auto printAllLogits = result.count("print_all_logits") > 0;
auto disableForceMaxTokens = result.count("disable_force_max_tokens") > 0;
auto dumpLayerInfo = result.count("dump_layer_info") > 0;
auto dumpProfile = result.count("dump_profile") > 0;

initTrtLlmPlugins(logger.get());

try
{
benchmarkGptSession(result["engine_dir"].as<std::string>(), batchSizes, beamWidth, inOutLen, logger,
result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["duration"].as<int>(), sessionConfig,
enableCudaGraph, printAllLogits, disableForceMaxTokens, dumpLayerInfo);
enableCudaGraph, printAllLogits, disableForceMaxTokens, dumpLayerInfo, dumpProfile);
}
catch (std::exception const& e)
{
Expand Down
8 changes: 8 additions & 0 deletions benchmarks/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,11 @@ mpirun -n 8 python benchmark.py \
--batch_size "1;8;64" \
--input_output_len "60,20;128,20"
```

Note: Building multi-GPU engines in parallel could be a heavy workload for the CPU system. Tuning `mpirun --map-by <XXX>` option on your system may achieve significant boost in build time, for example:
```
mpirun --map-by socket -n 8 python build.py \
--model gpt_175b \
--mode ootb \
--quantization fp8
```
2 changes: 2 additions & 0 deletions benchmarks/python/allowed_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class BuildConfig:
layer_types: List[str] = field(default_factory=list)
rnn_hidden_size: int = 0
logits_soft_cap: float = 0.0
opt_batch_size: Optional[int] = None
opt_num_tokens: Optional[int] = None


@dataclass
Expand Down
55 changes: 52 additions & 3 deletions benchmarks/python/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,25 @@ def parse_arguments():
help=
"Print layer information of the engine to console (default = disabled)")

parser.add_argument(
'--opt_batch_size',
type=int,
default=None,
help=
"If opt_batch_size option is specified, it will override the opt batch size."
"This flag only takes effect when `--mode=ootb` is added. For other modes, please use --opt_num_tokens to replace it."
)

parser.add_argument(
'--opt_num_tokens',
type=int,
default=None,
help="It equals to max_batch_size*max_beam_width by default, set this "
"value as close as possible to the actual number of tokens on your workload. "
"Note that this argument might be removed in the future."
"This flag only takes effect when `--mode` is not `ootb`. For ootb mode, please use --opt_batch_size to replace it."
)

return parser.parse_args()


Expand Down Expand Up @@ -334,9 +353,6 @@ def main(args):
if args.build_only:
return

if args.dump_profile and benchmark_profiler is not None:
benchmark_profiler.set_recording_perf_profile(True)

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
benchmarker.print_report_header(args.csv,
Expand Down Expand Up @@ -432,6 +448,39 @@ def main(args):
csv=args.csv,
benchmark_profiler=benchmark_profiler)

# Rerun for dumping profile per layer.
if args.dump_profile and benchmark_profiler is not None:
benchmark_profiler.set_recording_perf_profile(True)
logger.info(f'Dump profile information per layer')
iter_idx = 0
try:
# Warm up
for _ in range(args.warm_up):
benchmarker.run(inputs, config)
if benchmark_profiler is not None:
benchmark_profiler.clean()
benchmark_profiler.start()
cur_duration = 0
start_time = time()
while iter_idx < args.num_runs or cur_duration < args.duration:
start.record()
benchmarker.run(inputs,
config,
benchmark_profiler=benchmark_profiler)
end.record()
torch.cuda.synchronize()
latencies.append(start.elapsed_time(end))
iter_idx += 1
cur_duration = round(time() - start_time, 3)
benchmarker.report_profiler(
benchmark_profiler=benchmark_profiler)
except Exception as e:
logger.error("Found exception during benchmarking",
e.with_traceback())
if not disable_mem_monitor:
memory_monitor.kill()
raise e


if __name__ == '__main__':
mp.set_start_method('spawn')
Expand Down
Loading

0 comments on commit 89ba1b1

Please sign in to comment.