Skip to content

Commit

Permalink
Deploy option to use local chains code. (#1246)
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten authored Nov 18, 2024
1 parent 79e1841 commit 9e8c458
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 13 deletions.
4 changes: 3 additions & 1 deletion truss-chains/tests/chains_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def test_chain():
root = Path(__file__).parent.resolve()
chain_root = root / "itest_chain" / "itest_chain.py"
with framework.import_target(chain_root, "ItestChain") as entrypoint:
options = definitions.PushOptionsLocalDocker(chain_name="integration-test")
options = definitions.PushOptionsLocalDocker(
chain_name="integration-test", use_local_chains_src=True
)
service = remote.push(entrypoint, options)

url = service.run_remote_url.replace("host.docker.internal", "localhost")
Expand Down
8 changes: 7 additions & 1 deletion truss-chains/tests/itest_chain/itest_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@ def run_remote(self, length: int) -> str:
class TextReplicator(chains.ChainletBase):
remote_config = chains.RemoteConfig(docker_image=IMAGE_CUSTOM)

def __init__(self, context=chains.depends_context()):
def __init__(self):
try:
import pytzdata

print(f"Could import {pytzdata} is present")
except ModuleNotFoundError:
print("Could not import pytzdata is present")
self.multiplier = 2

def run_remote(self, data: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion truss-chains/tests/itest_chain/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
git+https://github.com/basetenlabs/truss.git
pytzdata
4 changes: 4 additions & 0 deletions truss-chains/truss_chains/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ def _make_truss_config(
chains_config: definitions.RemoteConfig,
chainlet_to_service: Mapping[str, definitions.ServiceDescriptor],
model_name: str,
use_local_chains_src: bool,
) -> truss_config.TrussConfig:
"""Generate a truss config for a Chainlet."""
config = truss_config.TrussConfig()
Expand Down Expand Up @@ -592,6 +593,7 @@ def _make_truss_config(
if chains_config.docker_image.external_package_dirs:
for ext_dir in chains_config.docker_image.external_package_dirs:
config.external_package_dirs.append(ext_dir.abs_path)
config.use_local_chains_src = use_local_chains_src
# Assets.
assets = chains_config.get_asset_spec()
config.secrets = assets.secrets
Expand Down Expand Up @@ -623,6 +625,7 @@ def gen_truss_chainlet(
chain_name: str,
chainlet_descriptor: definitions.ChainletAPIDescriptor,
model_name: str,
use_local_chains_src: bool,
) -> pathlib.Path:
# Filter needed services and customize options.
dep_services = {}
Expand All @@ -641,6 +644,7 @@ def gen_truss_chainlet(
chainlet_descriptor.chainlet_cls.remote_config,
dep_services,
model_name,
use_local_chains_src,
)
# TODO This assumes all imports are absolute w.r.t chain root (or site-packages).
truss_path.copy_tree_path(
Expand Down
4 changes: 4 additions & 0 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,3 +661,7 @@ class PushOptionsLocalDocker(PushOptions):
# is unset. Additionally, if local docker containers make calls to models deployed
# on baseten, a real API key must be provided (i.e. the default must be overridden).
baseten_chain_api_key: str = "docker_dummy_key"
# If enabled, chains code is copied from the local package into `/app/truss_chains`
# in the docker image (which takes precedence over potential pip/site-packages).
# This should be used for integration tests or quick local dev loops.
use_local_chains_src: bool = False
18 changes: 12 additions & 6 deletions truss-chains/truss_chains/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from truss.remote.baseten import custom_types as b10_types
from truss.remote.baseten import remote as b10_remote
from truss.remote.baseten import service as b10_service
from truss.truss_handle import build as truss_build
from truss.truss_handle import truss_handle
from truss.util import log_utils
from truss.util import path as truss_path

Expand Down Expand Up @@ -87,11 +87,9 @@ def _push_service_docker(
options: definitions.PushOptionsLocalDocker,
port: int,
) -> None:
truss_handle = truss_build.load(str(truss_dir))
truss_handle.add_secret(
definitions.BASETEN_API_SECRET_NAME, options.baseten_chain_api_key
)
truss_handle.docker_run(
th = truss_handle.TrussHandle(truss_dir)
th.add_secret(definitions.BASETEN_API_SECRET_NAME, options.baseten_chain_api_key)
th.docker_run(
local_port=port,
detach=True,
wait_for_server_ready=True,
Expand Down Expand Up @@ -319,6 +317,12 @@ def __init__(
self._options = options
self._gen_root = gen_root or pathlib.Path(tempfile.gettempdir())

@property
def _use_local_chains_src(self) -> bool:
if isinstance(self._options, definitions.PushOptionsLocalDocker):
return self._options.use_local_chains_src
return False

def generate_chainlet_artifacts(
self,
entrypoint: Type[definitions.ABCChainlet],
Expand All @@ -340,6 +344,7 @@ def generate_chainlet_artifacts(
self._options.chain_name,
chainlet_descriptor,
model_name,
self._use_local_chains_src,
)
artifact = b10_types.ChainletArtifact(
truss_dir=chainlet_dir,
Expand Down Expand Up @@ -547,6 +552,7 @@ def _code_gen_and_patch_thread(
self._deployed_chain_name,
descr,
self._chainlet_data[descr.display_name].oracle_name,
use_local_chains_src=False,
)
patch_result = self._remote_provider.patch_for_chainlet(
chainlet_dir, self._ignore_patterns
Expand Down
1 change: 1 addition & 0 deletions truss/base/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
TEMPLATES_DIR / SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME
)
CONTROL_SERVER_CODE_DIR: pathlib.Path = TEMPLATES_DIR / "control"
CHAINS_CODE_DIR: pathlib.Path = _TRUSS_ROOT.parent / "truss-chains" / "truss_chains"

SUPPORTED_PYTHON_VERSIONS = {"3.8", "3.9", "3.10", "3.11"}
MAX_SUPPORTED_PYTHON_VERSION_IN_CUSTOM_BASE_IMAGE = "3.12"
Expand Down
2 changes: 2 additions & 0 deletions truss/base/truss_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ class TrussConfig:
model_cache: ModelCache = field(default_factory=ModelCache)
trt_llm: Optional[TRTLLMConfiguration] = None
build_commands: List[str] = field(default_factory=list)
use_local_chains_src: bool = False

@property
def canonical_python_version(self) -> str:
Expand Down Expand Up @@ -619,6 +620,7 @@ def from_dict(d):
d.get("trt_llm"), lambda x: TRTLLMConfiguration(**x)
),
build_commands=d.get("build_commands", []),
use_local_chains_src=d.get("use_local_chains_src", False),
)
config.validate()
return config
Expand Down
8 changes: 6 additions & 2 deletions truss/contexts/image_builder/serving_image_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
AUDIO_MODEL_TRTLLM_TRUSS_DIR,
BASE_SERVER_REQUIREMENTS_TXT_FILENAME,
BASE_TRTLLM_REQUIREMENTS,
CHAINS_CODE_DIR,
CONTROL_SERVER_CODE_DIR,
DOCKER_SERVER_TEMPLATES_DIR,
FILENAME_CONSTANTS_MAP,
Expand Down Expand Up @@ -68,6 +69,7 @@
BUILD_SERVER_DIR_NAME = "server"
BUILD_CONTROL_SERVER_DIR_NAME = "control"
BUILD_SERVER_EXTENSIONS_PATH = "extensions"
BUILD_CHAINS_DIR_NAME = "truss_chains"

CONFIG_FILE = "config.yaml"
USER_TRUSS_IGNORE_FILE = ".truss_ignore"
Expand Down Expand Up @@ -356,8 +358,6 @@ def prepare_image_build_dir(
# TODO(pankaj) We probably don't need model framework specific directory.
build_dir = build_truss_target_directory(model_framework_name)

data_dir = build_dir / config.data_dir # type: ignore[operator]

def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
copy_tree_or_file(from_path, build_dir / path_in_build_dir) # type: ignore[operator]

Expand Down Expand Up @@ -464,6 +464,9 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
+ SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME,
)

if config.use_local_chains_src:
copy_into_build_dir(CHAINS_CODE_DIR, BUILD_CHAINS_DIR_NAME)

# Copy base TrussServer requirements if supplied custom base image
base_truss_server_reqs_filepath = SERVER_CODE_DIR / REQUIREMENTS_TXT_FILENAME
if config.base_image:
Expand Down Expand Up @@ -604,6 +607,7 @@ def _render_dockerfile(
hf_access_token_file_name=HF_ACCESS_TOKEN_FILE_NAME,
external_data_files=external_data_files,
build_commands=build_commands,
use_local_chains_src=config.use_local_chains_src,
**FILENAME_CONSTANTS_MAP,
)
docker_file_path = build_dir / MODEL_DOCKERFILE_NAME
Expand Down
5 changes: 5 additions & 0 deletions truss/templates/server.Dockerfile.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ COPY ./{{config.data_dir}} /app/data

COPY ./server /app

{%- if use_local_chains_src %}
{# This path takes precedence over site-packages. #}
COPY ./truss_chains /app/truss_chains
{%- endif %}

COPY ./config.yaml /app/config.yaml
{%- if config.live_reload and not config.docker_server%}
COPY ./control /control
Expand Down
5 changes: 3 additions & 2 deletions truss/truss_handle/truss_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
from dataclasses import replace
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from urllib.error import HTTPError

import requests
Expand Down Expand Up @@ -46,6 +46,7 @@
ServingImageBuilderContext,
)
from truss.contexts.local_loader.load_model_local import LoadModelLocal
from truss.contexts.truss_context import TrussContext
from truss.local.local_config_handler import LocalConfigHandler
from truss.templates.shared.serialization import (
truss_msgpack_deserialize,
Expand Down Expand Up @@ -950,7 +951,7 @@ def _get_serving_lookup_labels(self) -> Dict[str, Any]:

def _build_image(
self,
builder_context,
builder_context: Type[TrussContext],
labels: Dict[str, str],
build_dir: Optional[Path] = None,
tag: Optional[str] = None,
Expand Down

0 comments on commit 9e8c458

Please sign in to comment.