Skip to content

Commit

Permalink
bump rc6 for briton 0.3.12.dev3
Browse files Browse the repository at this point in the history
  • Loading branch information
joostinyi committed Dec 3, 2024
1 parent 567a644 commit c995633
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.54rc5"
version = "0.9.54rc6"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion truss/base/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
TRTLLM_SPEC_DEC_DRAFT_MODEL_NAME = "draft"
TRTLLM_BASE_IMAGE = "baseten/briton-server:v0.13.0-4fd8a10-5e5c3d7"
TRTLLM_PYTHON_EXECUTABLE = "/usr/bin/python3"
BASE_TRTLLM_REQUIREMENTS = ["briton==0.3.10"]
BASE_TRTLLM_REQUIREMENTS = ["briton==0.3.12.dev3"]
AUDIO_MODEL_TRTLLM_REQUIREMENTS = [
"--extra-index-url https://pypi.nvidia.com",
"tensorrt_cu12_bindings==10.2.0.post1",
Expand Down
14 changes: 8 additions & 6 deletions truss/base/trt_llm_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import warnings
from enum import Enum
from typing import Optional
Expand Down Expand Up @@ -94,11 +93,14 @@ def check_max_beam_width(cls, v: int):
class TrussTRTLLMRuntimeConfiguration(BaseModel):
kv_cache_free_gpu_mem_fraction: float = 0.9
enable_chunked_context: bool = False
num_draft_tokens: Optional[int] = None
batch_scheduler_policy: TrussTRTLLMBatchSchedulerPolicy = (
TrussTRTLLMBatchSchedulerPolicy.GUARANTEED_NO_EVICT
)
request_default_max_tokens: Optional[int] = None
# Speculative Decoding runtime configuration, ignored for non spec dec configurations
num_draft_tokens: Optional[int] = (
None # number of draft tokens to be sampled from draft model in speculative decoding scheme
)


class TRTLLMConfiguration(BaseModel):
Expand Down Expand Up @@ -144,8 +146,8 @@ def requires_build(self):

# TODO(Abu): Replace this with model_dump(json=True)
# when pydantic v2 is used here
def to_json_dict(self, verbose=True):
return json.loads(self.json(exclude_unset=not verbose))
def to_dict(self, verbose=True):
return self.dict(exclude_unset=not verbose)


class TRTLLMSpeculativeDecodingConfiguration(BaseModel):
Expand Down Expand Up @@ -182,5 +184,5 @@ def _validate_spec_dec(self):
"Speculative decoding requires the same tensor parallelism for target and draft models."
)

def to_json_dict(self, verbose=True):
return json.loads(self.json(exclude_unset=not verbose))
def to_dict(self, verbose=True):
return self.dict(exclude_unset=not verbose)
4 changes: 2 additions & 2 deletions truss/base/truss_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,11 +814,11 @@ def obj_to_dict(obj, verbose: bool = False):
)
elif isinstance(field_curr_value, TRTLLMConfiguration):
d["trt_llm"] = transform_optional(
field_curr_value, lambda data: data.to_json_dict(verbose=verbose)
field_curr_value, lambda data: data.to_dict(verbose=verbose)
)
elif isinstance(field_curr_value, TRTLLMSpeculativeDecodingConfiguration):
d["trt_llm"] = transform_optional(
field_curr_value, lambda data: data.to_json_dict(verbose=verbose)
field_curr_value, lambda data: data.to_dict(verbose=verbose)
)
elif isinstance(field_curr_value, BaseImage):
d["base_image"] = transform_optional(
Expand Down
14 changes: 11 additions & 3 deletions truss/templates/trtllm-briton/src/extension.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from briton.spec_dec_truss_model import Model as SpecDecModel
from briton.trtllm_config import (
TRTLLMConfiguration,
TRTLLMSpeculativeDecodingConfiguration,
)
from briton.truss_model import Model
from pydantic import ValidationError

TRTLLM_SPEC_DEC_TARGET_MODEL_NAME = "target"

Expand Down Expand Up @@ -37,10 +42,13 @@ class Extension:

def __init__(self, *args, **kwargs):
self._config = kwargs["config"]
if TRTLLM_SPEC_DEC_TARGET_MODEL_NAME in self._config.get("trt_llm"):
self._model = SpecDecModel(*args, **kwargs)
else:
trt_llm_config = self._config.get("trt_llm")
try:
TRTLLMConfiguration(**trt_llm_config)
self._model = Model(*args, **kwargs)
except ValidationError as _:
TRTLLMSpeculativeDecodingConfiguration(**trt_llm_config)
self._model = SpecDecModel(*args, **kwargs)

def model_override(self):
"""Return a model object.
Expand Down

0 comments on commit c995633

Please sign in to comment.