Skip to content

Commit

Permalink
error loudly for development pushes for TRT-LLM built trusses (#1050)
Browse files Browse the repository at this point in the history
* add live_reload config override for TRT-LLM built trusses

* add unary default for max_beam_width

* fail loudly instead

* add field validator

* fix tests

* add for missing case

* remove redundant helper

* bump
  • Loading branch information
joostinyi authored Jul 30, 2024
1 parent 5314598 commit 4d251c3
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.25rc9"
version = "0.9.25rc10"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
5 changes: 5 additions & 0 deletions truss/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from truss.util.config_checks import (
check_and_update_memory_for_trt_llm_builder,
check_secrets_for_trt_llm_builder,
uses_trt_llm_builder,
)
from truss.util.errors import RemoteNetworkError

Expand Down Expand Up @@ -860,6 +861,10 @@ def push(
console.print(
f"Automatically increasing memory for trt-llm builder to {TRTLLM_MIN_MEMORY_REQUEST_GI}Gi."
)
if uses_trt_llm_builder(tr) and not publish:
live_reload_disabled_text = "Development mode is currently not supported for trusses using TRT-LLM build flow, push as a published model using --publish"
console.print(live_reload_disabled_text, style="red")
sys.exit(1)

# TODO(Abu): This needs to be refactored to be more generic
service = remote_provider.push(
Expand Down
10 changes: 8 additions & 2 deletions truss/config/trt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from enum import Enum
from typing import Optional

from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from rich.console import Console

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -53,7 +53,7 @@ class TrussTRTLLMBuildConfiguration(BaseModel):
max_input_len: int
max_output_len: int
max_batch_size: int
max_beam_width: int
max_beam_width: Optional[int] = 1
max_prompt_embedding_table_size: int = 0
checkpoint_repository: CheckpointRepository
gather_all_token_logits: bool = False
Expand All @@ -70,6 +70,12 @@ class TrussTRTLLMBuildConfiguration(BaseModel):
kv_cache_free_gpu_mem_fraction: float = 0.9
num_builder_gpus: Optional[int] = None

@field_validator("max_beam_width", mode="after")
@classmethod
def ensure_unary_max_beam_width(cls, value):
if value and value != 1:
raise ValueError("Non-unary max_beam_width not supported")


class TrussTRTLLMServingConfiguration(BaseModel):
engine_repository: str
Expand Down
3 changes: 0 additions & 3 deletions truss/contexts/image_builder/serving_image_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,6 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
# Copy over truss
copy_tree_path(truss_dir, build_dir, ignore_patterns=truss_ignore_patterns)

# Copy over template truss for TRT-LLM (we overwrite the model and packages dir)
# Most of the code is pulled from upstream triton-inference-server tensorrtllm_backend
# https://github.com/triton-inference-server/tensorrtllm_backend/tree/v0.9.0/all_models/inflight_batcher_llm
if config.trt_llm is not None:
is_audio_model = (
config.trt_llm.build.base_model == TrussTRTLLMModel.WHISPER
Expand Down
2 changes: 1 addition & 1 deletion truss/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def trtllm_config(default_config) -> Dict[str, Any]:
"max_input_len": 1024,
"max_output_len": 1024,
"max_batch_size": 512,
"max_beam_width": 1,
"max_beam_width": None,
"checkpoint_repository": {
"source": "HF",
"repo": "meta/llama4-500B",
Expand Down
6 changes: 6 additions & 0 deletions truss/util/config_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
from truss.truss_handle import TrussHandle


def uses_trt_llm_builder(tr: TrussHandle) -> bool:
return (
tr.spec.config.trt_llm is not None and tr.spec.config.trt_llm.build is not None
)


def check_secrets_for_trt_llm_builder(tr: TrussHandle) -> bool:
if tr.spec.config.trt_llm and tr.spec.config.trt_llm.build:
source = tr.spec.config.trt_llm.build.checkpoint_repository.source
Expand Down

0 comments on commit 4d251c3

Please sign in to comment.