diff --git a/truss/base/truss_config.py b/truss/base/truss_config.py index b670669b3..5c712974a 100644 --- a/truss/base/truss_config.py +++ b/truss/base/truss_config.py @@ -9,6 +9,7 @@ from truss.base.constants import ( HTTP_PUBLIC_BLOB_BACKEND, + TRTLLM_SPEC_DEC_TARGET_MODEL_NAME, ) from truss.base.custom_types import ModelFrameworkType from truss.base.errors import ValidationError @@ -633,7 +634,7 @@ def from_dict(d): trt_llm=transform_optional( d.get("trt_llm"), lambda x: (TRTLLMConfiguration(**x)) - if "target" not in d.get("trt_llm") + if TRTLLM_SPEC_DEC_TARGET_MODEL_NAME not in d.get("trt_llm") else (TRTLLMSpeculativeDecodingConfiguration(**x)), ), build_commands=d.get("build_commands", []),