Skip to content

Commit

Permalink
Automatic migration of TRTLLM runtime configuration (#1279)
Browse files Browse the repository at this point in the history
* migration of runtime configs from build to runtime key

* bump to 0.9.56rc2

* update

* update migration logic
  • Loading branch information
joostinyi authored Dec 11, 2024
1 parent 7ea6f17 commit 90169ed
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 118 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.56rc1"
version = "0.9.56rc2"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
37 changes: 35 additions & 2 deletions truss/base/trt_llm_config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from __future__ import annotations

import json
import logging
import warnings
from enum import Enum
from typing import Optional
from typing import Any, Optional

from huggingface_hub.errors import HFValidationError
from huggingface_hub.utils import validate_repo_id
from pydantic import BaseModel, PydanticDeprecatedSince20, validator
from pydantic import BaseModel, PydanticDeprecatedSince20, model_validator, validator

logger = logging.getLogger(__name__)
# Suppress Pydantic V1 warnings, because we have to use it for backwards compat.
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)

Expand Down Expand Up @@ -107,6 +109,9 @@ class TrussTRTLLMBuildConfiguration(BaseModel):
num_builder_gpus: Optional[int] = None
speculator: Optional[TrussSpeculatorConfiguration] = None

class Config:
extra = "forbid"

@validator("max_beam_width")
def check_max_beam_width(cls, v: int):
if isinstance(v, int):
Expand Down Expand Up @@ -198,6 +203,34 @@ class TRTLLMConfiguration(BaseModel):
runtime: TrussTRTLLMRuntimeConfiguration = TrussTRTLLMRuntimeConfiguration()
build: TrussTRTLLMBuildConfiguration

@model_validator(mode="before")
@classmethod
def migrate_runtime_fields(cls, data: Any) -> Any:
extra_runtime_fields = {}
valid_build_fields = {}
for key, value in data.get("build").items():
if key in TrussTRTLLMBuildConfiguration.__annotations__:
valid_build_fields[key] = value
else:
if key in TrussTRTLLMRuntimeConfiguration.__annotations__:
logger.warning(f"Found runtime.{key}: {value} in build config")
extra_runtime_fields[key] = value
if extra_runtime_fields:
logger.warning(
f"Found extra fields {list(extra_runtime_fields.keys())} in build configuration, unspecified runtime fields will be configured using these values."
" This configuration of deprecated fields is scheduled for removal, please upgrade to the latest truss version and update configs according to https://docs.baseten.co/performance/engine-builder-config."
)
data.get("runtime").update(
{
k: v
for k, v in extra_runtime_fields.items()
if k not in data.get("runtime")
}
)

data.update({"build": valid_build_fields})
return data

@property
def requires_build(self):
return self.build is not None
Expand Down
152 changes: 150 additions & 2 deletions truss/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
import contextlib
import copy
import importlib
import os
import shutil
import subprocess
import sys
import time
from pathlib import Path
from typing import Any, Dict

import pytest
import requests
import yaml

from truss.base.custom_types import Example
from truss.base.trt_llm_config import TrussTRTLLMBatchSchedulerPolicy
from truss.base.truss_config import DEFAULT_BUNDLED_PACKAGES_DIR
from truss.base.trt_llm_config import (
TrussSpecDecMode,
TrussTRTLLMBatchSchedulerPolicy,
)
from truss.base.truss_config import DEFAULT_BUNDLED_PACKAGES_DIR, Accelerator
from truss.contexts.image_builder.serving_image_builder import (
ServingImageBuilderContext,
)
Expand Down Expand Up @@ -736,3 +741,146 @@ def _modify_yaml(yaml_path: Path):
yield content
with yaml_path.open("w") as yaml_file:
yaml.dump(content, yaml_file)


@pytest.fixture
def default_config() -> Dict[str, Any]:
return {
"build_commands": [],
"environment_variables": {},
"external_package_dirs": [],
"model_metadata": {},
"model_name": None,
"python_version": "py39",
"requirements": [],
"resources": {
"accelerator": None,
"cpu": "1",
"memory": "2Gi",
"use_gpu": False,
},
"secrets": {},
"system_packages": [],
}


@pytest.fixture
def trtllm_config(default_config) -> Dict[str, Any]:
trtllm_config = default_config
trtllm_config["resources"] = {
"accelerator": Accelerator.L4.value,
"cpu": "1",
"memory": "24Gi",
"use_gpu": True,
}
trtllm_config["trt_llm"] = {
"build": {
"base_model": "llama",
"max_seq_len": 2048,
"max_batch_size": 512,
"checkpoint_repository": {
"source": "HF",
"repo": "meta/llama4-500B",
},
"gather_all_token_logits": False,
},
"runtime": {},
}
return trtllm_config


@pytest.fixture
def deprecated_trtllm_config(default_config) -> Dict[str, Any]:
trtllm_config = default_config
trtllm_config["resources"] = {
"accelerator": Accelerator.L4.value,
"cpu": "1",
"memory": "24Gi",
"use_gpu": True,
}
trtllm_config["trt_llm"] = {
"build": {
"base_model": "llama",
"max_seq_len": 2048,
"max_batch_size": 512,
# start deprecated fields
"kv_cache_free_gpu_mem_fraction": 0.1,
"enable_chunked_context": True,
"batch_scheduler_policy": TrussTRTLLMBatchSchedulerPolicy.MAX_UTILIZATION.value,
"request_default_max_tokens": 10,
"total_token_limit": 50,
# end deprecated fields
"checkpoint_repository": {
"source": "HF",
"repo": "meta/llama4-500B",
},
"gather_all_token_logits": False,
},
"runtime": {"total_token_limit": 100},
}
return trtllm_config


@pytest.fixture
def trtllm_spec_dec_config_full(trtllm_config) -> Dict[str, Any]:
spec_dec_config = copy.deepcopy(trtllm_config)
spec_dec_config["trt_llm"] = {
"build": {
"base_model": "llama",
"max_seq_len": 2048,
"max_batch_size": 512,
"checkpoint_repository": {
"source": "HF",
"repo": "meta/llama4-500B",
},
"plugin_configuration": {
"paged_kv_cache": True,
"gemm_plugin": "auto",
"use_paged_context_fmha": True,
},
"speculator": {
"speculative_decoding_mode": TrussSpecDecMode.DRAFT_EXTERNAL.value,
"num_draft_tokens": 4,
"build": {
"base_model": "llama",
"max_seq_len": 2048,
"max_batch_size": 512,
"checkpoint_repository": {
"source": "HF",
"repo": "meta/llama4-500B",
},
},
},
},
}
return spec_dec_config


@pytest.fixture
def trtllm_spec_dec_config(trtllm_config) -> Dict[str, Any]:
spec_dec_config = copy.deepcopy(trtllm_config)
spec_dec_config["trt_llm"] = {
"build": {
"base_model": "llama",
"max_seq_len": 2048,
"max_batch_size": 512,
"checkpoint_repository": {
"source": "HF",
"repo": "meta/llama4-500B",
},
"plugin_configuration": {
"paged_kv_cache": True,
"gemm_plugin": "auto",
"use_paged_context_fmha": True,
},
"speculator": {
"speculative_decoding_mode": TrussSpecDecMode.DRAFT_EXTERNAL.value,
"num_draft_tokens": 4,
"checkpoint_repository": {
"source": "HF",
"repo": "meta/llama4-500B",
},
},
},
}
return spec_dec_config
113 changes: 0 additions & 113 deletions truss/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
import tempfile
from contextlib import nullcontext as does_not_raise
from pathlib import Path
from typing import Any, Dict

import pytest
import yaml

from truss.base.custom_types import ModelFrameworkType
from truss.base.trt_llm_config import (
TrussSpecDecMode,
TrussTRTLLMQuantizationType,
)
from truss.base.truss_config import (
Expand All @@ -29,117 +27,6 @@
from truss.truss_handle.truss_handle import TrussHandle


@pytest.fixture
def default_config() -> Dict[str, Any]:
return {
"build_commands": [],
"environment_variables": {},
"external_package_dirs": [],
"model_metadata": {},
"model_name": None,
"python_version": "py39",
"requirements": [],
"resources": {
"accelerator": None,
"cpu": "1",
"memory": "2Gi",
"use_gpu": False,
},
"secrets": {},
"system_packages": [],
}


@pytest.fixture
def trtllm_config(default_config) -> Dict[str, Any]:
trtllm_config = default_config
trtllm_config["resources"] = {
"accelerator": Accelerator.L4.value,
"cpu": "1",
"memory": "24Gi",
"use_gpu": True,
}
trtllm_config["trt_llm"] = {
"build": {
"base_model": "llama",
"max_seq_len": 2048,
"max_batch_size": 512,
"checkpoint_repository": {
"source": "HF",
"repo": "meta/llama4-500B",
},
"gather_all_token_logits": False,
},
"runtime": {},
}
return trtllm_config


@pytest.fixture
def trtllm_spec_dec_config_full(trtllm_config) -> Dict[str, Any]:
spec_dec_config = copy.deepcopy(trtllm_config)
spec_dec_config["trt_llm"] = {
"build": {
"base_model": "llama",
"max_seq_len": 2048,
"max_batch_size": 512,
"checkpoint_repository": {
"source": "HF",
"repo": "meta/llama4-500B",
},
"plugin_configuration": {
"paged_kv_cache": True,
"gemm_plugin": "auto",
"use_paged_context_fmha": True,
},
"speculator": {
"speculative_decoding_mode": TrussSpecDecMode.DRAFT_EXTERNAL,
"num_draft_tokens": 4,
"build": {
"base_model": "llama",
"max_seq_len": 2048,
"max_batch_size": 512,
"checkpoint_repository": {
"source": "HF",
"repo": "meta/llama4-500B",
},
},
},
},
}
return spec_dec_config


@pytest.fixture
def trtllm_spec_dec_config(trtllm_config) -> Dict[str, Any]:
spec_dec_config = copy.deepcopy(trtllm_config)
spec_dec_config["trt_llm"] = {
"build": {
"base_model": "llama",
"max_seq_len": 2048,
"max_batch_size": 512,
"checkpoint_repository": {
"source": "HF",
"repo": "meta/llama4-500B",
},
"plugin_configuration": {
"paged_kv_cache": True,
"gemm_plugin": "auto",
"use_paged_context_fmha": True,
},
"speculator": {
"speculative_decoding_mode": TrussSpecDecMode.DRAFT_EXTERNAL,
"num_draft_tokens": 4,
"checkpoint_repository": {
"source": "HF",
"repo": "meta/llama4-500B",
},
},
},
}
return spec_dec_config


@pytest.mark.parametrize(
"input_dict, expect_resources, output_dict",
[
Expand Down
17 changes: 17 additions & 0 deletions truss/tests/trt_llm/test_trt_llm_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from truss.base.trt_llm_config import (
TRTLLMConfiguration,
TrussTRTLLMBatchSchedulerPolicy,
)


def test_trt_llm_configuration_init_and_migrate_deprecated_runtime_fields(
deprecated_trtllm_config,
):
trt_llm_config = TRTLLMConfiguration(**deprecated_trtllm_config["trt_llm"])
assert trt_llm_config.runtime.model_dump() == {
"kv_cache_free_gpu_mem_fraction": 0.1,
"enable_chunked_context": True,
"batch_scheduler_policy": TrussTRTLLMBatchSchedulerPolicy.MAX_UTILIZATION.value,
"request_default_max_tokens": 10,
"total_token_limit": 100,
}

0 comments on commit 90169ed

Please sign in to comment.