diff --git a/truss-chains/tests/chains_e2e_test.py b/truss-chains/tests/chains_e2e_test.py index f10d90cbe..a64adc6f1 100644 --- a/truss-chains/tests/chains_e2e_test.py +++ b/truss-chains/tests/chains_e2e_test.py @@ -16,7 +16,9 @@ def test_chain(): root = Path(__file__).parent.resolve() chain_root = root / "itest_chain" / "itest_chain.py" with framework.import_target(chain_root, "ItestChain") as entrypoint: - options = definitions.PushOptionsLocalDocker(chain_name="integration-test") + options = definitions.PushOptionsLocalDocker( + chain_name="integration-test", use_local_chains_src=True + ) service = remote.push(entrypoint, options) url = service.run_remote_url.replace("host.docker.internal", "localhost") diff --git a/truss-chains/tests/itest_chain/itest_chain.py b/truss-chains/tests/itest_chain/itest_chain.py index 5881ad4f3..cdff415f0 100644 --- a/truss-chains/tests/itest_chain/itest_chain.py +++ b/truss-chains/tests/itest_chain/itest_chain.py @@ -34,7 +34,13 @@ def run_remote(self, length: int) -> str: class TextReplicator(chains.ChainletBase): remote_config = chains.RemoteConfig(docker_image=IMAGE_CUSTOM) - def __init__(self, context=chains.depends_context()): + def __init__(self): + try: + import pytzdata + + print(f"Could import {pytzdata} is present") + except ModuleNotFoundError: + print("Could not import pytzdata is present") self.multiplier = 2 def run_remote(self, data: str) -> str: diff --git a/truss-chains/tests/itest_chain/requirements.txt b/truss-chains/tests/itest_chain/requirements.txt index 4ba9dc6e1..8f9daaa9d 100644 --- a/truss-chains/tests/itest_chain/requirements.txt +++ b/truss-chains/tests/itest_chain/requirements.txt @@ -1 +1 @@ -git+https://github.com/basetenlabs/truss.git +pytzdata diff --git a/truss-chains/truss_chains/code_gen.py b/truss-chains/truss_chains/code_gen.py index 526d3deea..62f820425 100644 --- a/truss-chains/truss_chains/code_gen.py +++ b/truss-chains/truss_chains/code_gen.py @@ -563,6 +563,7 @@ def _make_truss_config( chains_config: definitions.RemoteConfig, chainlet_to_service: Mapping[str, definitions.ServiceDescriptor], model_name: str, + use_local_chains_src: bool, ) -> truss_config.TrussConfig: """Generate a truss config for a Chainlet.""" config = truss_config.TrussConfig() @@ -592,6 +593,7 @@ def _make_truss_config( if chains_config.docker_image.external_package_dirs: for ext_dir in chains_config.docker_image.external_package_dirs: config.external_package_dirs.append(ext_dir.abs_path) + config.use_local_chains_src = use_local_chains_src # Assets. assets = chains_config.get_asset_spec() config.secrets = assets.secrets @@ -623,6 +625,7 @@ def gen_truss_chainlet( chain_name: str, chainlet_descriptor: definitions.ChainletAPIDescriptor, model_name: str, + use_local_chains_src: bool, ) -> pathlib.Path: # Filter needed services and customize options. dep_services = {} @@ -641,6 +644,7 @@ def gen_truss_chainlet( chainlet_descriptor.chainlet_cls.remote_config, dep_services, model_name, + use_local_chains_src, ) # TODO This assumes all imports are absolute w.r.t chain root (or site-packages). truss_path.copy_tree_path( diff --git a/truss-chains/truss_chains/definitions.py b/truss-chains/truss_chains/definitions.py index 3acf81e5e..239147057 100644 --- a/truss-chains/truss_chains/definitions.py +++ b/truss-chains/truss_chains/definitions.py @@ -661,3 +661,7 @@ class PushOptionsLocalDocker(PushOptions): # is unset. Additionally, if local docker containers make calls to models deployed # on baseten, a real API key must be provided (i.e. the default must be overridden). baseten_chain_api_key: str = "docker_dummy_key" + # If enabled, chains code is copied from the local package into `/app/truss_chains` + # in the docker image (which takes precedence over potential pip/site-packages). + # This should be used for integration tests or quick local dev loops. + use_local_chains_src: bool = False diff --git a/truss-chains/truss_chains/remote.py b/truss-chains/truss_chains/remote.py index 0ab513aa9..5a99bb6b9 100644 --- a/truss-chains/truss_chains/remote.py +++ b/truss-chains/truss_chains/remote.py @@ -33,7 +33,7 @@ from truss.remote.baseten import custom_types as b10_types from truss.remote.baseten import remote as b10_remote from truss.remote.baseten import service as b10_service -from truss.truss_handle import build as truss_build +from truss.truss_handle import truss_handle from truss.util import log_utils from truss.util import path as truss_path @@ -87,11 +87,9 @@ def _push_service_docker( options: definitions.PushOptionsLocalDocker, port: int, ) -> None: - truss_handle = truss_build.load(str(truss_dir)) - truss_handle.add_secret( - definitions.BASETEN_API_SECRET_NAME, options.baseten_chain_api_key - ) - truss_handle.docker_run( + th = truss_handle.TrussHandle(truss_dir) + th.add_secret(definitions.BASETEN_API_SECRET_NAME, options.baseten_chain_api_key) + th.docker_run( local_port=port, detach=True, wait_for_server_ready=True, @@ -319,6 +317,12 @@ def __init__( self._options = options self._gen_root = gen_root or pathlib.Path(tempfile.gettempdir()) + @property + def _use_local_chains_src(self) -> bool: + if isinstance(self._options, definitions.PushOptionsLocalDocker): + return self._options.use_local_chains_src + return False + def generate_chainlet_artifacts( self, entrypoint: Type[definitions.ABCChainlet], @@ -340,6 +344,7 @@ def generate_chainlet_artifacts( self._options.chain_name, chainlet_descriptor, model_name, + self._use_local_chains_src, ) artifact = b10_types.ChainletArtifact( truss_dir=chainlet_dir, @@ -547,6 +552,7 @@ def _code_gen_and_patch_thread( self._deployed_chain_name, descr, self._chainlet_data[descr.display_name].oracle_name, + use_local_chains_src=False, ) patch_result = self._remote_provider.patch_for_chainlet( chainlet_dir, self._ignore_patterns diff --git a/truss/base/constants.py b/truss/base/constants.py index 94dce03a8..2ffc69518 100644 --- a/truss/base/constants.py +++ b/truss/base/constants.py @@ -23,6 +23,7 @@ TEMPLATES_DIR / SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME ) CONTROL_SERVER_CODE_DIR: pathlib.Path = TEMPLATES_DIR / "control" +CHAINS_CODE_DIR: pathlib.Path = _TRUSS_ROOT.parent / "truss-chains" / "truss_chains" SUPPORTED_PYTHON_VERSIONS = {"3.8", "3.9", "3.10", "3.11"} MAX_SUPPORTED_PYTHON_VERSION_IN_CUSTOM_BASE_IMAGE = "3.12" diff --git a/truss/base/truss_config.py b/truss/base/truss_config.py index da0af2645..21eac56f9 100644 --- a/truss/base/truss_config.py +++ b/truss/base/truss_config.py @@ -560,6 +560,7 @@ class TrussConfig: model_cache: ModelCache = field(default_factory=ModelCache) trt_llm: Optional[TRTLLMConfiguration] = None build_commands: List[str] = field(default_factory=list) + use_local_chains_src: bool = False @property def canonical_python_version(self) -> str: @@ -619,6 +620,7 @@ def from_dict(d): d.get("trt_llm"), lambda x: TRTLLMConfiguration(**x) ), build_commands=d.get("build_commands", []), + use_local_chains_src=d.get("use_local_chains_src", False), ) config.validate() return config diff --git a/truss/contexts/image_builder/serving_image_builder.py b/truss/contexts/image_builder/serving_image_builder.py index 698180337..19c209c38 100644 --- a/truss/contexts/image_builder/serving_image_builder.py +++ b/truss/contexts/image_builder/serving_image_builder.py @@ -19,6 +19,7 @@ AUDIO_MODEL_TRTLLM_TRUSS_DIR, BASE_SERVER_REQUIREMENTS_TXT_FILENAME, BASE_TRTLLM_REQUIREMENTS, + CHAINS_CODE_DIR, CONTROL_SERVER_CODE_DIR, DOCKER_SERVER_TEMPLATES_DIR, FILENAME_CONSTANTS_MAP, @@ -68,6 +69,7 @@ BUILD_SERVER_DIR_NAME = "server" BUILD_CONTROL_SERVER_DIR_NAME = "control" BUILD_SERVER_EXTENSIONS_PATH = "extensions" +BUILD_CHAINS_DIR_NAME = "truss_chains" CONFIG_FILE = "config.yaml" USER_TRUSS_IGNORE_FILE = ".truss_ignore" @@ -356,8 +358,6 @@ def prepare_image_build_dir( # TODO(pankaj) We probably don't need model framework specific directory. build_dir = build_truss_target_directory(model_framework_name) - data_dir = build_dir / config.data_dir # type: ignore[operator] - def copy_into_build_dir(from_path: Path, path_in_build_dir: str): copy_tree_or_file(from_path, build_dir / path_in_build_dir) # type: ignore[operator] @@ -464,6 +464,9 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str): + SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME, ) + if config.use_local_chains_src: + copy_into_build_dir(CHAINS_CODE_DIR, BUILD_CHAINS_DIR_NAME) + # Copy base TrussServer requirements if supplied custom base image base_truss_server_reqs_filepath = SERVER_CODE_DIR / REQUIREMENTS_TXT_FILENAME if config.base_image: @@ -604,6 +607,7 @@ def _render_dockerfile( hf_access_token_file_name=HF_ACCESS_TOKEN_FILE_NAME, external_data_files=external_data_files, build_commands=build_commands, + use_local_chains_src=config.use_local_chains_src, **FILENAME_CONSTANTS_MAP, ) docker_file_path = build_dir / MODEL_DOCKERFILE_NAME diff --git a/truss/templates/server.Dockerfile.jinja b/truss/templates/server.Dockerfile.jinja index a7ac1032c..49b7d4a14 100644 --- a/truss/templates/server.Dockerfile.jinja +++ b/truss/templates/server.Dockerfile.jinja @@ -78,6 +78,11 @@ COPY ./{{config.data_dir}} /app/data COPY ./server /app +{%- if use_local_chains_src %} +{# This path takes precedence over site-packages. #} +COPY ./truss_chains /app/truss_chains +{%- endif %} + COPY ./config.yaml /app/config.yaml {%- if config.live_reload and not config.docker_server%} COPY ./control /control diff --git a/truss/truss_handle/truss_handle.py b/truss/truss_handle/truss_handle.py index 3ce1bad90..f4cdb2f14 100644 --- a/truss/truss_handle/truss_handle.py +++ b/truss/truss_handle/truss_handle.py @@ -6,7 +6,7 @@ import uuid from dataclasses import replace from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from urllib.error import HTTPError import requests @@ -46,6 +46,7 @@ ServingImageBuilderContext, ) from truss.contexts.local_loader.load_model_local import LoadModelLocal +from truss.contexts.truss_context import TrussContext from truss.local.local_config_handler import LocalConfigHandler from truss.templates.shared.serialization import ( truss_msgpack_deserialize, @@ -950,7 +951,7 @@ def _get_serving_lookup_labels(self) -> Dict[str, Any]: def _build_image( self, - builder_context, + builder_context: Type[TrussContext], labels: Dict[str, str], build_dir: Optional[Path] = None, tag: Optional[str] = None,