diff --git a/poetry.lock b/poetry.lock index e8318ef2f..a774d2c94 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "annotated-types" @@ -3227,4 +3227,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.12" -content-hash = "ae427d2b3d167981a690496ec8f4ee6bfb17d268d872f7cda2bfd93cf1883c82" +content-hash = "54dbe4a5d6765fd101e79912ec8fd6f4241ed6e2621f678661110035a329e1b9" diff --git a/pyproject.toml b/pyproject.toml index 8e3950ae4..c090123fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.10.0rc1" +version = "0.9.3" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" @@ -8,7 +8,6 @@ authors = ["Pankaj Gupta ", "Phil Howes "] include = ["*.txt", "*.Dockerfile", "*.md"] repository = "https://github.com/basetenlabs/truss" keywords = ["MLOps", "AI", "Model Serving", "Model Deployment", "Machine Learning"] -packages = [{include = "truss", format = "wheel"}] [tool.poetry.urls] "Homepage" = "https://truss.baseten.co" @@ -43,7 +42,6 @@ inquirerpy = "^0.3.4" google-cloud-storage = "2.10.0" loguru = ">=0.7.2" uvloop = "^0.19.0" -pathspec = ">=0.9.0" [tool.poetry.group.builder.dependencies] @@ -65,8 +63,6 @@ huggingface_hub = ">=0.19.4" google-cloud-storage = "2.10.0" boto3 = "^1.26.157" loguru = ">=0.7.2" -pathspec = ">=0.9.0" -typer = "^0.9.0" [tool.poetry.dev-dependencies] ipython = "^7.16" diff --git a/truss/build/build.py b/truss/build.py similarity index 100% rename from truss/build/build.py rename to truss/build.py diff --git a/truss/build/__init__.py b/truss/build/__init__.py deleted file mode 100644 index 3ce02a979..000000000 --- a/truss/build/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# flake8: noqa -from .build import * diff --git a/truss/build/configure.py b/truss/build/configure.py deleted file mode 100644 index c4984b4c4..000000000 --- a/truss/build/configure.py +++ /dev/null @@ -1,63 +0,0 @@ -import json -import logging -import os -import sys -from pathlib import Path -from typing import Optional - -import typer -from truss import load -from truss.patch.hash import directory_content_hash -from truss.patch.signature import calc_truss_signature -from truss.patch.truss_dir_patch_applier import TrussDirPatchApplier -from truss.server.control.patch.types import Patch - -logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) - - -app = typer.Typer() - - -@app.command() -def configure_truss_for_build( - truss_dir: str, - build_context_path: str, - use_control_server: bool = False, - patches_path: Optional[str] = None, - hash_file_path: Optional[str] = None, - signature_file_path: Optional[str] = None, -): - tr = load(truss_dir) - - if patches_path is not None: - logging.info("Applying patches") - logger = logging.getLogger("patch_applier") - patch_applier = TrussDirPatchApplier(Path(truss_dir), logger) - patches = json.loads(Path(patches_path).read_text()) - patch_applier([Patch.from_dict(patch) for patch in patches]) - - # Important to do this before making changes to truss, we want - # to capture hash of original truss. - if hash_file_path is not None: - logging.info("Recording truss hash") - Path(hash_file_path).write_text(directory_content_hash(Path(truss_dir))) - - if signature_file_path is not None: - logging.info("Recording truss signature") - signature_str = json.dumps(calc_truss_signature(Path(truss_dir)).to_dict()) - Path(signature_file_path).write_text(signature_str) - - tr.live_reload(enable=use_control_server) - - logging.debug("Setting up docker build context for truss") - - # check if we have a hf_secret - tr.docker_build_setup( - Path(build_context_path), use_hf_secret="HUGGING_FACE_HUB_TOKEN" in os.environ - ) - logging.info("docker build context is set up for the truss") - - -if __name__ == "__main__": - # parse the things - app() diff --git a/truss/cli/cli.py b/truss/cli/cli.py index 74f46b2ee..25fabe1d7 100644 --- a/truss/cli/cli.py +++ b/truss/cli/cli.py @@ -22,7 +22,6 @@ from truss.remote.remote_cli import inquire_model_name, inquire_remote_name from truss.remote.remote_factory import USER_TRUSSRC_PATH, RemoteFactory from truss.truss_config import Build, ModelServer -from truss.truss_handle import TrussHandle logging.basicConfig(level=logging.INFO) @@ -167,14 +166,8 @@ def build(target_directory: str, build_dir: Path, tag) -> None: @click.option( "--attach", is_flag=True, default=False, help="Flag for attaching the process" ) -@click.option( - "--cache/--no-cache", - is_flag=True, - default=True, - help="Flag for caching build or not", -) @error_handling -def run(target_directory: str, build_dir: Path, tag, port, attach, cache) -> None: +def run(target_directory: str, build_dir: Path, tag, port, attach) -> None: """ Runs the docker image for a Truss. @@ -190,9 +183,7 @@ def run(target_directory: str, build_dir: Path, tag, port, attach, cache) -> Non click.confirm( f"Container already exists at {urls}. Are you sure you want to continue?" ) - tr.docker_run( - build_dir=build_dir, tag=tag, local_port=port, detach=not attach, cache=cache - ) + tr.docker_run(build_dir=build_dir, tag=tag, local_port=port, detach=not attach) @truss_cli.command() @@ -455,7 +446,7 @@ def predict( def push( target_directory: str, remote: str, - model_name: Optional[str], + model_name: str, publish: bool = False, trusted: bool = False, promote: bool = False, @@ -593,7 +584,7 @@ def cleanup() -> None: truss.build.cleanup() -def _get_truss_from_directory(target_directory: Optional[str] = None) -> TrussHandle: +def _get_truss_from_directory(target_directory: Optional[str] = None): """Gets Truss from directory. If none, use the current directory""" if target_directory is None: target_directory = os.getcwd() diff --git a/truss/constants.py b/truss/constants.py index 06f5ebcd3..dc0764ff0 100644 --- a/truss/constants.py +++ b/truss/constants.py @@ -1,7 +1,6 @@ import os import pathlib - -TRUSS_PACKAGE_DIR = pathlib.Path(__file__).resolve().parent +from typing import Set SKLEARN = "sklearn" TENSORFLOW = "tensorflow" @@ -16,18 +15,32 @@ CODE_DIR = pathlib.Path(BASE_DIR, "truss") TEMPLATES_DIR = pathlib.Path(CODE_DIR, "templates") +SERVER_CODE_DIR: pathlib.Path = TEMPLATES_DIR / "server" +TRITON_SERVER_CODE_DIR: pathlib.Path = TEMPLATES_DIR / "triton" TRTLLM_TRUSS_DIR: pathlib.Path = TEMPLATES_DIR / "trtllm" +SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME = "shared" +SHARED_SERVING_AND_TRAINING_CODE_DIR: pathlib.Path = ( + TEMPLATES_DIR / SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME +) +CONTROL_SERVER_CODE_DIR: pathlib.Path = TEMPLATES_DIR / "control" SUPPORTED_PYTHON_VERSIONS = {"3.8", "3.9", "3.10", "3.11"} +# Alias for TEMPLATES_DIR +SERVING_DIR: pathlib.Path = TEMPLATES_DIR + REQUIREMENTS_TXT_FILENAME = "requirements.txt" USER_SUPPLIED_REQUIREMENTS_TXT_FILENAME = "user_requirements.txt" +BASE_SERVER_REQUIREMENTS_TXT_FILENAME = "base_server_requirements.txt" +SERVER_REQUIREMENTS_TXT_FILENAME = "server_requirements.txt" SYSTEM_PACKAGES_TXT_FILENAME = "system_packages.txt" FILENAME_CONSTANTS_MAP = { "config_requirements_filename": REQUIREMENTS_TXT_FILENAME, "user_supplied_requirements_filename": USER_SUPPLIED_REQUIREMENTS_TXT_FILENAME, + "base_server_requirements_filename": BASE_SERVER_REQUIREMENTS_TXT_FILENAME, + "server_requirements_filename": SERVER_REQUIREMENTS_TXT_FILENAME, "system_packages_filename": SYSTEM_PACKAGES_TXT_FILENAME, } @@ -47,6 +60,38 @@ TRUSS_DIR = "truss_dir" TRUSS_HASH = "truss_hash" +HUGGINGFACE_TRANSFORMER_MODULE_NAME: Set[str] = set({}) + +# list from https://scikit-learn.org/stable/developers/advanced_installation.html +SKLEARN_REQ_MODULE_NAMES: Set[str] = { + "numpy", + "scipy", + "joblib", + "scikit-learn", + "threadpoolctl", +} + +XGBOOST_REQ_MODULE_NAMES: Set[str] = {"xgboost"} + +# list from https://www.tensorflow.org/install/pip +# if problematic, lets look to https://www.tensorflow.org/install/source +TENSORFLOW_REQ_MODULE_NAMES: Set[str] = { + "tensorflow", +} + +LIGHTGBM_REQ_MODULE_NAMES: Set[str] = { + "lightgbm", +} + +# list from https://pytorch.org/get-started/locally/ +PYTORCH_REQ_MODULE_NAMES: Set[str] = { + "torch", + "torchvision", + "torchaudio", +} + +MLFLOW_REQ_MODULE_NAMES: Set[str] = {"mlflow"} + INFERENCE_SERVER_PORT = 8080 HTTP_PUBLIC_BLOB_BACKEND = "http_public" diff --git a/truss/contexts/image_builder/serving_image_builder.py b/truss/contexts/image_builder/serving_image_builder.py index 5adce5a38..e678911c9 100644 --- a/truss/contexts/image_builder/serving_image_builder.py +++ b/truss/contexts/image_builder/serving_image_builder.py @@ -6,22 +6,28 @@ from typing import Any, Dict, List, Optional, Tuple, Type import boto3 +import yaml from botocore import UNSIGNED from botocore.client import Config from google.cloud import storage from huggingface_hub import get_hf_file_metadata, hf_hub_url, list_repo_files from huggingface_hub.utils import filter_repo_objects from truss.constants import ( + BASE_SERVER_REQUIREMENTS_TXT_FILENAME, BASE_TRTLLM_REQUIREMENTS, + CONTROL_SERVER_CODE_DIR, FILENAME_CONSTANTS_MAP, MODEL_DOCKERFILE_NAME, REQUIREMENTS_TXT_FILENAME, + SERVER_CODE_DIR, SERVER_DOCKERFILE_TEMPLATE_NAME, + SERVER_REQUIREMENTS_TXT_FILENAME, + SHARED_SERVING_AND_TRAINING_CODE_DIR, + SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME, SYSTEM_PACKAGES_TXT_FILENAME, TEMPLATES_DIR, TRTLLM_BASE_IMAGE, TRTLLM_TRUSS_DIR, - TRUSS_PACKAGE_DIR, USER_SUPPLIED_REQUIREMENTS_TXT_FILENAME, ) from truss.contexts.image_builder.cache_warmer import ( @@ -48,6 +54,9 @@ load_trussignore_patterns, ) +BUILD_SERVER_DIR_NAME = "server" +BUILD_CONTROL_SERVER_DIR_NAME = "control" + CONFIG_FILE = "config.yaml" USER_TRUSS_IGNORE_FILE = ".truss_ignore" GCS_CREDENTIALS = "service_account.json" @@ -311,15 +320,7 @@ def prepare_image_build_dir( data_dir = build_dir / config.data_dir # type: ignore[operator] def copy_into_build_dir(from_path: Path, path_in_build_dir: str): - # using default ignore patterns ignores the `build` dir in truss - copy_tree_or_file(from_path, build_dir / path_in_build_dir, ignore_files=False) # type: ignore[operator] - - # Copy truss package from the context builder image to build dir - copy_into_build_dir(TRUSS_PACKAGE_DIR, "truss/") - copy_into_build_dir( - TRUSS_PACKAGE_DIR.parent / "pyproject.toml", "./pyproject.toml" - ) - copy_into_build_dir(TRUSS_PACKAGE_DIR.parent / "README.md", "./README.md") + copy_tree_or_file(from_path, build_dir / path_in_build_dir) # type: ignore[operator] truss_ignore_patterns = [] if (truss_dir / USER_TRUSS_IGNORE_FILE).exists(): @@ -350,6 +351,10 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str): ) config.requirements.extend(BASE_TRTLLM_REQUIREMENTS) + # Override config.yml + with (build_dir / CONFIG_FILE).open("w") as config_file: + yaml.dump(config.to_dict(verbose=True), config_file) + external_data_files: list = [] data_dir = Path("/app/data/") if self._spec.external_data is not None: @@ -366,8 +371,50 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str): config, truss_dir, build_dir ) + # Copy inference server code + copy_into_build_dir(SERVER_CODE_DIR, BUILD_SERVER_DIR_NAME) + copy_into_build_dir( + SHARED_SERVING_AND_TRAINING_CODE_DIR, + BUILD_SERVER_DIR_NAME + "/" + SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME, + ) + + # Copy control server code + if config.live_reload: + copy_into_build_dir(CONTROL_SERVER_CODE_DIR, BUILD_CONTROL_SERVER_DIR_NAME) + copy_into_build_dir( + SHARED_SERVING_AND_TRAINING_CODE_DIR, + BUILD_CONTROL_SERVER_DIR_NAME + + "/control/" + + SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME, + ) + + # Copy base TrussServer requirements if supplied custom base image + base_truss_server_reqs_filepath = SERVER_CODE_DIR / REQUIREMENTS_TXT_FILENAME + if config.base_image: + copy_into_build_dir( + base_truss_server_reqs_filepath, BASE_SERVER_REQUIREMENTS_TXT_FILENAME + ) + + # Copy model framework specific requirements file + server_reqs_filepath = ( + TEMPLATES_DIR / model_framework_name / REQUIREMENTS_TXT_FILENAME + ) + should_install_server_requirements = file_is_not_empty(server_reqs_filepath) + if should_install_server_requirements: + copy_into_build_dir(server_reqs_filepath, SERVER_REQUIREMENTS_TXT_FILENAME) + + with open(base_truss_server_reqs_filepath, "r") as f: + base_server_requirements = f.read() + + # If the user has provided python requirements, + # append the truss server requirements, so that any conflicts + # are detected and cause a build failure. If there are no + # requirements provided, we just pass an empty string, + # as there's no need to install anything. user_provided_python_requirements = ( - spec.requirements_txt if spec.requirements else "" + base_server_requirements + spec.requirements_txt + if spec.requirements + else "" ) if spec.requirements_file is not None: copy_into_build_dir( @@ -381,6 +428,7 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str): self._render_dockerfile( build_dir, + should_install_server_requirements, model_files, use_hf_secret, cached_files, @@ -390,6 +438,7 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str): def _render_dockerfile( self, build_dir: Path, + should_install_server_requirements: bool, model_files: Dict[str, Any], use_hf_secret: bool, cached_files: List[str], @@ -425,6 +474,7 @@ def _render_dockerfile( hf_access_token = config.secrets.get(HF_ACCESS_TOKEN_SECRET_NAME) dockerfile_contents = dockerfile_template.render( + should_install_server_requirements=should_install_server_requirements, base_image_name_and_tag=base_image_name_and_tag, should_install_system_requirements=should_install_system_requirements, should_install_requirements=should_install_python_requirements, diff --git a/truss/contexts/local_loader/docker_build_emulator.py b/truss/contexts/local_loader/docker_build_emulator.py new file mode 100644 index 000000000..3e68bb920 --- /dev/null +++ b/truss/contexts/local_loader/docker_build_emulator.py @@ -0,0 +1,62 @@ +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List + +from truss.util.path import copy_tree_or_file + + +@dataclass +class DockerBuildEmulatorResult: + workdir: Path = field(default_factory=lambda: Path("/")) + env: Dict = field(default_factory=dict) + entrypoint: List = field(default_factory=list) + + +class DockerBuildEmulator: + """Emulates Docker Builds + + As running docker builds is expensive, this class emulates the docker build + by parsing the docker file and applying certain commands to create an + appropriate enviroment in a directory to simulate the root of the file system. + + Support COPY, ENV, ENTRYPOINT, WORKDIR commands. All other commands are ignored. + """ + + def __init__( + self, + dockerfile_path: Path, + context_dir: Path, + ) -> None: + import dockerfile + + self._commands = dockerfile.parse_file(str(dockerfile_path)) + self._context_dir = context_dir + + def run(self, fs_root_dir: Path) -> DockerBuildEmulatorResult: + def _resolve_env(key: str) -> str: + if key.startswith("$"): + key = key.replace("$", "", 1) + v = result.env[key] + return v + return key + + def _resolve_values(keys: List[str]) -> List[str]: + return list(map(_resolve_env, keys)) + + result = DockerBuildEmulatorResult() + for cmd in self._commands: + if cmd.cmd not in ["ENV", "ENTRYPOINT", "COPY", "WORKDIR"]: + continue + values = _resolve_values(cmd.value) + if cmd.cmd == "ENV": + result.env[values[0]] = values[1] + if cmd.cmd == "ENTRYPOINT": + result.entrypoint = list(values) + if cmd.cmd == "COPY": + src, dst = values + src = src.replace("./", "", 1) + dst = dst.replace("/", "", 1) + copy_tree_or_file(self._context_dir / src, fs_root_dir / dst) + if cmd.cmd == "WORKDIR": + result.workdir = result.workdir / values[0] + return result diff --git a/truss/contexts/local_loader/load_model_local.py b/truss/contexts/local_loader/load_model_local.py index db9768416..dd621cf70 100644 --- a/truss/contexts/local_loader/load_model_local.py +++ b/truss/contexts/local_loader/load_model_local.py @@ -7,7 +7,7 @@ signature_accepts_keyword_arg, ) from truss.contexts.truss_context import TrussContext -from truss.server.common.patches import apply_patches +from truss.templates.server.common.patches import apply_patches from truss.truss_spec import TrussSpec diff --git a/truss/patch/calc_patch.py b/truss/patch/calc_patch.py index ddd3dc437..58d082f05 100644 --- a/truss/patch/calc_patch.py +++ b/truss/patch/calc_patch.py @@ -6,9 +6,13 @@ from truss.constants import CONFIG_FILE from truss.patch.hash import file_content_hash_str from truss.patch.types import TrussSignature -from truss.server.control.patch.requirement_name_identifier import reqs_by_name -from truss.server.control.patch.system_packages import system_packages_set -from truss.server.control.patch.types import ( +from truss.templates.control.control.helpers.truss_patch.requirement_name_identifier import ( + reqs_by_name, +) +from truss.templates.control.control.helpers.truss_patch.system_packages import ( + system_packages_set, +) +from truss.templates.control.control.helpers.types import ( Action, ConfigPatch, DataPatch, diff --git a/truss/patch/local_truss_patch_applier.py b/truss/patch/local_truss_patch_applier.py index 26a9d300e..ef08e1c58 100644 --- a/truss/patch/local_truss_patch_applier.py +++ b/truss/patch/local_truss_patch_applier.py @@ -3,9 +3,11 @@ from pathlib import Path from typing import List -from truss.server.control.errors import UnsupportedPatch -from truss.server.control.patch.model_code_patch_applier import apply_code_patch -from truss.server.control.patch.types import ( +from truss.templates.control.control.helpers.errors import UnsupportedPatch +from truss.templates.control.control.helpers.truss_patch.model_code_patch_applier import ( + apply_code_patch, +) +from truss.templates.control.control.helpers.types import ( Action, ModelCodePatch, Patch, diff --git a/truss/patch/truss_dir_patch_applier.py b/truss/patch/truss_dir_patch_applier.py index c238235ca..54876c095 100644 --- a/truss/patch/truss_dir_patch_applier.py +++ b/truss/patch/truss_dir_patch_applier.py @@ -2,14 +2,18 @@ from pathlib import Path from typing import List -from truss.server.control.errors import UnsupportedPatch -from truss.server.control.patch.model_code_patch_applier import apply_code_patch -from truss.server.control.patch.requirement_name_identifier import ( +from truss.templates.control.control.helpers.errors import UnsupportedPatch +from truss.templates.control.control.helpers.truss_patch.model_code_patch_applier import ( + apply_code_patch, +) +from truss.templates.control.control.helpers.truss_patch.requirement_name_identifier import ( identify_requirement_name, reqs_by_name, ) -from truss.server.control.patch.system_packages import system_packages_set -from truss.server.control.patch.types import ( +from truss.templates.control.control.helpers.truss_patch.system_packages import ( + system_packages_set, +) +from truss.templates.control.control.helpers.types import ( Action, ConfigPatch, EnvVarPatch, diff --git a/truss/server/control/patch/model_code_patch_applier.py b/truss/server/control/patch/model_code_patch_applier.py deleted file mode 100644 index 7587da196..000000000 --- a/truss/server/control/patch/model_code_patch_applier.py +++ /dev/null @@ -1,46 +0,0 @@ -import logging -import os -from pathlib import Path -from typing import Union - -from truss.server.control.patch.types import Action, ModelCodePatch, PackagePatch - - -def apply_code_patch( - relative_dir: Path, - patch: Union[ModelCodePatch, PackagePatch], - logger: logging.Logger, -): - logger.debug(f"Applying code patch {patch.to_dict()}") - filepath: Path = relative_dir / patch.path - action = patch.action - - if action in [Action.ADD, Action.UPDATE]: - filepath.parent.mkdir(parents=True, exist_ok=True) - action_log = "Adding" if action == Action.ADD else "Updating" - logger.info(f"{action_log} file {filepath}") - with filepath.open("w") as file: - content = patch.content - if content is None: - raise ValueError( - "Invalid patch: content of a file update patch should not be None." - ) - file.write(content) - - elif action == Action.REMOVE: - if not filepath.exists(): - logger.warning(f"Could not delete file {filepath}: not found.") - elif filepath.is_file(): - logger.info(f"Deleting file {filepath}") - filepath.unlink() - # attempt to recursively remove potentially empty directories, if applicable - # os.removedirs raises OSError with errno 39 when this process encounters a non-empty dir - try: - os.removedirs(filepath.parent) - except OSError as e: - if e.errno == 39: # Directory not empty - pass - else: - raise - else: - raise ValueError(f"Unknown patch action {action}") diff --git a/truss/templates/base.Dockerfile.jinja b/truss/templates/base.Dockerfile.jinja index c19c79961..604fdd975 100644 --- a/truss/templates/base.Dockerfile.jinja +++ b/truss/templates/base.Dockerfile.jinja @@ -2,7 +2,6 @@ ARG PYVERSION={{config.python_version}} FROM {{base_image_name_and_tag}} as truss_server ENV PYTHON_EXECUTABLE {{ config.base_image.python_executable_path or 'python3' }} -ENV JSON_LOG True {% block fail_fast %} RUN grep -w 'ID=debian\|ID_LIKE=debian' /etc/os-release || { echo "ERROR: Supplied base image is not a debian image"; exit 1; } @@ -13,15 +12,16 @@ RUN $PYTHON_EXECUTABLE -c "import sys; sys.exit(0) if sys.version_info.major == RUN pip install --upgrade pip --no-cache-dir \ && rm -rf /root/.cache/pip -# Always install the truss package -COPY ./truss/ /lib/truss_pkg/truss -COPY ./pyproject.toml /lib/truss_pkg/ -COPY ./README.md /lib/truss_pkg/ -RUN pip install /lib/truss_pkg --no-cache-dir && rm -rf /root/.cache/pip - {% block base_image_patch %} {% endblock %} +{% if config.model_framework.value == 'huggingface_transformer' %} + {% if config.resources.use_gpu %} +# HuggingFace pytorch gpu support needs mkl +RUN pip install mkl + {% endif %} +{% endif%} + {% block post_base %} {% endblock %} @@ -60,7 +60,7 @@ WORKDIR $APP_HOME {% block bundled_packages_copy %} {%- if bundled_packages_dir_exists %} -COPY ./{{config.bundled_packages_dir}} /app/packages +COPY ./{{config.bundled_packages_dir}} /packages {%- endif %} {% endblock %} diff --git a/truss/server/control/application.py b/truss/templates/control/control/application.py similarity index 84% rename from truss/server/control/application.py rename to truss/templates/control/control/application.py index 3fdb814fb..4b121538e 100644 --- a/truss/server/control/application.py +++ b/truss/templates/control/control/application.py @@ -5,24 +5,16 @@ from typing import Dict import httpx +from endpoints import control_app from fastapi import FastAPI from fastapi.responses import JSONResponse +from helpers.errors import ModelLoadFailed, PatchApplicatonError +from helpers.inference_server_controller import InferenceServerController +from helpers.inference_server_process_controller import InferenceServerProcessController +from helpers.inference_server_starter import async_inference_server_startup_flow +from helpers.truss_patch.model_container_patch_applier import ModelContainerPatchApplier +from shared.logging import setup_logging from starlette.datastructures import State -from truss.server.control.endpoints import control_app -from truss.server.control.errors import ModelLoadFailed, PatchApplicatonError -from truss.server.control.helpers.inference_server_controller import ( - InferenceServerController, -) -from truss.server.control.helpers.inference_server_process_controller import ( - InferenceServerProcessController, -) -from truss.server.control.helpers.inference_server_starter import ( - async_inference_server_startup_flow, -) -from truss.server.control.patch.model_container_patch_applier import ( - ModelContainerPatchApplier, -) -from truss.server.shared.logging import setup_logging async def handle_patch_error(_, exc): diff --git a/truss/server/control/endpoints.py b/truss/templates/control/control/endpoints.py similarity index 98% rename from truss/server/control/endpoints.py rename to truss/templates/control/control/endpoints.py index a1fc76fc2..6bc21b296 100644 --- a/truss/server/control/endpoints.py +++ b/truss/templates/control/control/endpoints.py @@ -4,11 +4,11 @@ import httpx from fastapi import APIRouter from fastapi.responses import JSONResponse, StreamingResponse +from helpers.errors import ModelLoadFailed, ModelNotReady from httpx import URL, ConnectError, RemoteProtocolError from starlette.requests import ClientDisconnect, Request from starlette.responses import Response from tenacity import Retrying, retry_if_exception_type, stop_after_attempt, wait_fixed -from truss.server.control.errors import ModelLoadFailed, ModelNotReady INFERENCE_SERVER_START_WAIT_SECS = 60 diff --git a/truss/server/control/helpers/context_managers.py b/truss/templates/control/control/helpers/context_managers.py similarity index 100% rename from truss/server/control/helpers/context_managers.py rename to truss/templates/control/control/helpers/context_managers.py diff --git a/truss/server/control/errors.py b/truss/templates/control/control/helpers/errors.py similarity index 100% rename from truss/server/control/errors.py rename to truss/templates/control/control/helpers/errors.py diff --git a/truss/server/control/helpers/inference_server_controller.py b/truss/templates/control/control/helpers/inference_server_controller.py similarity index 94% rename from truss/server/control/helpers/inference_server_controller.py rename to truss/templates/control/control/helpers/inference_server_controller.py index 8a5713ee9..83a6614b7 100644 --- a/truss/server/control/helpers/inference_server_controller.py +++ b/truss/templates/control/control/helpers/inference_server_controller.py @@ -4,19 +4,15 @@ import time from typing import Optional -from truss.server.control.errors import ( +from helpers.errors import ( InadmissiblePatch, PatchFailedRecoverable, PatchFailedUnrecoverable, UnsupportedPatch, ) -from truss.server.control.helpers.inference_server_process_controller import ( - InferenceServerProcessController, -) -from truss.server.control.patch.model_container_patch_applier import ( - ModelContainerPatchApplier, -) -from truss.server.control.patch.types import Patch, PatchType +from helpers.inference_server_process_controller import InferenceServerProcessController +from helpers.truss_patch.model_container_patch_applier import ModelContainerPatchApplier +from helpers.types import Patch, PatchType INFERENCE_SERVER_CHECK_INTERVAL_SECS = 10 diff --git a/truss/server/control/helpers/inference_server_process_controller.py b/truss/templates/control/control/helpers/inference_server_process_controller.py similarity index 98% rename from truss/server/control/helpers/inference_server_process_controller.py rename to truss/templates/control/control/helpers/inference_server_process_controller.py index 994bee381..4ca9a34ba 100644 --- a/truss/server/control/helpers/inference_server_process_controller.py +++ b/truss/templates/control/control/helpers/inference_server_process_controller.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import List, Optional -from truss.server.control.helpers.context_managers import current_directory +from helpers.context_managers import current_directory INFERENCE_SERVER_FAILED_FILE = Path("~/inference_server_crashed.txt").expanduser() TERMINATION_TIMEOUT_SECS = 120.0 diff --git a/truss/server/control/helpers/inference_server_starter.py b/truss/templates/control/control/helpers/inference_server_starter.py similarity index 94% rename from truss/server/control/helpers/inference_server_starter.py rename to truss/templates/control/control/helpers/inference_server_starter.py index c74b15a9c..a9d5157c8 100644 --- a/truss/server/control/helpers/inference_server_starter.py +++ b/truss/templates/control/control/helpers/inference_server_starter.py @@ -3,10 +3,8 @@ import requests from anyio import to_thread +from helpers.inference_server_controller import InferenceServerController from tenacity import Retrying, stop_after_attempt, wait_exponential -from truss.server.control.helpers.inference_server_controller import ( - InferenceServerController, -) def inference_server_startup_flow( @@ -58,7 +56,7 @@ def inference_server_startup_flow( if "is_current" in resp_body and resp_body["is_current"] is True: logger.info("Hash is current, starting inference server") inference_server_controller.start() - except Exception as exc: + except Exception as exc: # noqa logger.warning(f"Patch ping attempt failed with error {exc}") raise exc diff --git a/truss/server/control/patch/__init__.py b/truss/templates/control/control/helpers/truss_patch/__init__.py similarity index 100% rename from truss/server/control/patch/__init__.py rename to truss/templates/control/control/helpers/truss_patch/__init__.py diff --git a/truss/server/control/truss_patch/model_code_patch_applier.py b/truss/templates/control/control/helpers/truss_patch/model_code_patch_applier.py similarity index 82% rename from truss/server/control/truss_patch/model_code_patch_applier.py rename to truss/templates/control/control/helpers/truss_patch/model_code_patch_applier.py index deab70dca..954710532 100644 --- a/truss/server/control/truss_patch/model_code_patch_applier.py +++ b/truss/templates/control/control/helpers/truss_patch/model_code_patch_applier.py @@ -2,12 +2,17 @@ import os from pathlib import Path -from truss.server.control.patch.types import Action, ModelCodePatch +# TODO(pankaj) In desparate need of refactoring into separate library +try: + from helpers.types import Action, Patch +except ModuleNotFoundError as exc: + logging.debug(f"Importing helpers from truss core, caused by: {exc}") + from truss.templates.control.control.helpers.types import Action, Patch def apply_code_patch( relative_dir: Path, - patch: ModelCodePatch, + patch: Patch, logger: logging.Logger, ): logger.debug(f"Applying code patch {patch.to_dict()}") diff --git a/truss/server/control/patch/model_container_patch_applier.py b/truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py similarity index 96% rename from truss/server/control/patch/model_container_patch_applier.py rename to truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py index 92051558b..75e1015fe 100644 --- a/truss/server/control/patch/model_container_patch_applier.py +++ b/truss/templates/control/control/helpers/truss_patch/model_container_patch_applier.py @@ -3,9 +3,9 @@ from pathlib import Path from typing import Optional -from truss.server.control.errors import UnsupportedPatch -from truss.server.control.patch.model_code_patch_applier import apply_code_patch -from truss.server.control.patch.types import ( +from helpers.errors import UnsupportedPatch +from helpers.truss_patch.model_code_patch_applier import apply_code_patch +from helpers.types import ( Action, ConfigPatch, EnvVarPatch, @@ -36,7 +36,7 @@ def __init__( self._inference_server_home / self._truss_config.model_module_dir ) self._bundled_packages_dir = ( - self._inference_server_home / self._truss_config.bundled_packages_dir + self._inference_server_home / ".." / self._truss_config.bundled_packages_dir ).resolve() self._data_dir = self._inference_server_home / self._truss_config.data_dir self._app_logger = app_logger diff --git a/truss/server/control/patch/requirement_name_identifier.py b/truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py similarity index 100% rename from truss/server/control/patch/requirement_name_identifier.py rename to truss/templates/control/control/helpers/truss_patch/requirement_name_identifier.py diff --git a/truss/server/control/patch/system_packages.py b/truss/templates/control/control/helpers/truss_patch/system_packages.py similarity index 100% rename from truss/server/control/patch/system_packages.py rename to truss/templates/control/control/helpers/truss_patch/system_packages.py diff --git a/truss/server/control/patch/types.py b/truss/templates/control/control/helpers/types.py similarity index 100% rename from truss/server/control/patch/types.py rename to truss/templates/control/control/helpers/types.py diff --git a/truss/server/control/server.py b/truss/templates/control/control/server.py similarity index 76% rename from truss/server/control/server.py rename to truss/templates/control/control/server.py index 811bd07cb..97be9a223 100644 --- a/truss/server/control/server.py +++ b/truss/templates/control/control/server.py @@ -3,10 +3,24 @@ from pathlib import Path import uvicorn -from truss.server.control.application import create_app +from application import create_app CONTROL_SERVER_PORT = int(os.environ.get("CONTROL_SERVER_PORT", "8080")) INFERENCE_SERVER_PORT = int(os.environ.get("INFERENCE_SERVER_PORT", "8090")) +PYTHON_EXECUTABLE_LOOKUP_PATHS = [ + "/usr/local/bin/python", + "/usr/local/bin/python3", + "/usr/bin/python", + "/usr/bin/python3", +] + + +def _identify_python_executable_path() -> str: + for path in PYTHON_EXECUTABLE_LOOKUP_PATHS: + if Path(path).exists(): + return path + + raise RuntimeError("Unable to find python, make sure it's installed.") class ControlServer: @@ -29,8 +43,7 @@ def run(self): "inference_server_home": self._inf_serv_home, "inference_server_process_args": [ self._python_executable_path, - "-m", - "truss.server.inference_server", + f"{self._inf_serv_home}/inference_server.py", ], "control_server_host": "0.0.0.0", "control_server_port": self._control_server_port, @@ -59,8 +72,8 @@ def run(self): if __name__ == "__main__": control_server = ControlServer( - python_executable_path=os.environ.get("PYTHON_EXECUTABLE", default="python3"), - inf_serv_home=os.environ.get("APP_HOME", default=str(Path.cwd())), + python_executable_path=_identify_python_executable_path(), + inf_serv_home=os.environ["APP_HOME"], control_server_port=CONTROL_SERVER_PORT, inference_server_port=INFERENCE_SERVER_PORT, ) diff --git a/truss/templates/control/requirements.txt b/truss/templates/control/requirements.txt new file mode 100644 index 000000000..734a09225 --- /dev/null +++ b/truss/templates/control/requirements.txt @@ -0,0 +1,9 @@ +dataclasses-json==0.5.7 +truss==0.9.1rc1 +fastapi==0.109.1 +uvicorn==0.24.0 +uvloop==0.19.0 +tenacity==8.1.0 +httpx==0.24.1 +python-json-logger==2.0.2 +loguru==0.7.2 diff --git a/truss/templates/server.Dockerfile.jinja b/truss/templates/server.Dockerfile.jinja index e050fd13e..690980bcd 100644 --- a/truss/templates/server.Dockerfile.jinja +++ b/truss/templates/server.Dockerfile.jinja @@ -21,6 +21,9 @@ RUN apt update && \ && apt-get clean -y \ && rm -rf /var/lib/apt/lists/* +COPY ./{{base_server_requirements_filename}} {{base_server_requirements_filename}} +RUN pip install -r {{base_server_requirements_filename}} --no-cache-dir && rm -rf /root/.cache/pip + {%- if config.live_reload %} RUN $PYTHON_EXECUTABLE -m venv -h >/dev/null \ || { pythonVersion=$(echo $($PYTHON_EXECUTABLE --version) | cut -d" " -f2 | cut -d"." -f1,2) \ @@ -39,6 +42,10 @@ RUN ln -sf {{config.base_image.python_executable_path}} /usr/local/bin/python {% endblock %} {% block install_requirements %} + {%- if should_install_server_requirements %} +COPY ./{{server_requirements_filename}} {{server_requirements_filename}} +RUN pip install -r {{server_requirements_filename}} --no-cache-dir && rm -rf /root/.cache/pip + {%- endif %} {{ super() }} {% endblock %} @@ -60,8 +67,14 @@ RUN mkdir -p {{ dst.parent }}; curl -L "{{ url }}" -o {{ dst }} COPY ./{{config.data_dir}} /app/data {%- endif %} +COPY ./server /app COPY ./{{ config.model_module_dir }} /app/model COPY ./config.yaml /app/config.yaml + {%- if config.live_reload %} +COPY ./control /control +RUN python3 -m venv /control/.env \ + && /control/.env/bin/pip3 install -r /control/requirements.txt + {%- endif %} {% endblock %} {% block run %} @@ -69,12 +82,11 @@ COPY ./config.yaml /app/config.yaml ENV HASH_TRUSS {{truss_hash}} ENV CONTROL_SERVER_PORT 8080 ENV INFERENCE_SERVER_PORT 8090 -ENV PYTHON_EXECUTABLE {{config.base_image.python_executable_path or "python3"}} -ENV SERVER_START_CMD="{{config.base_image.python_executable_path or "python3"}} -m truss.server.control.server" -ENTRYPOINT ["{{config.base_image.python_executable_path or "python3"}}", "-m", "truss.server.control.server"] +ENV SERVER_START_CMD="/control/.env/bin/python3 /control/control/server.py" +ENTRYPOINT ["/control/.env/bin/python3", "/control/control/server.py"] {%- else %} ENV INFERENCE_SERVER_PORT 8080 -ENV SERVER_START_CMD="{{(config.base_image.python_executable_path or "python3") ~ " -m truss.server.inference_server"}}" -ENTRYPOINT ["{{config.base_image.python_executable_path or "python3"}}", "-m", "truss.server.inference_server"] +ENV SERVER_START_CMD="{{(config.base_image.python_executable_path or "python3") ~ " /app/inference_server.py"}}" +ENTRYPOINT ["{{config.base_image.python_executable_path or "python3"}}", "/app/inference_server.py"] {%- endif %} {% endblock %} diff --git a/truss/server/shared/__init__.py b/truss/templates/server/__init__.py similarity index 100% rename from truss/server/shared/__init__.py rename to truss/templates/server/__init__.py diff --git a/truss/server/common/__init__.py b/truss/templates/server/common/__init__.py similarity index 100% rename from truss/server/common/__init__.py rename to truss/templates/server/common/__init__.py diff --git a/truss/server/common/errors.py b/truss/templates/server/common/errors.py similarity index 100% rename from truss/server/common/errors.py rename to truss/templates/server/common/errors.py diff --git a/truss/server/common/patches/__init__.py b/truss/templates/server/common/patches.py similarity index 96% rename from truss/server/common/patches/__init__.py rename to truss/templates/server/common/patches.py index 73e07a1ad..4f2ca6364 100644 --- a/truss/server/common/patches/__init__.py +++ b/truss/templates/server/common/patches.py @@ -12,7 +12,7 @@ def apply_patches(enabled: bool, requirements: list): Apply patches to certain functions. The patches are contained in each patch module under 'patches' directory. If a patch cannot be applied, it logs the name of the function and the exception details. """ - PATCHES_DIR = Path(__file__).parent + PATCHES_DIR = Path(__file__).parent / "patches" if not enabled: return for requirement in requirements: diff --git a/truss/server/common/patches/whisper/patch.py b/truss/templates/server/common/patches/whisper/patch.py similarity index 100% rename from truss/server/common/patches/whisper/patch.py rename to truss/templates/server/common/patches/whisper/patch.py diff --git a/truss/server/common/retry.py b/truss/templates/server/common/retry.py similarity index 100% rename from truss/server/common/retry.py rename to truss/templates/server/common/retry.py diff --git a/truss/server/common/schema.py b/truss/templates/server/common/schema.py similarity index 100% rename from truss/server/common/schema.py rename to truss/templates/server/common/schema.py diff --git a/truss/server/common/termination_handler_middleware.py b/truss/templates/server/common/termination_handler_middleware.py similarity index 100% rename from truss/server/common/termination_handler_middleware.py rename to truss/templates/server/common/termination_handler_middleware.py diff --git a/truss/server/common/truss_server.py b/truss/templates/server/common/truss_server.py similarity index 97% rename from truss/server/common/truss_server.py rename to truss/templates/server/common/truss_server.py index 4e8aa653b..3863c896c 100644 --- a/truss/server/common/truss_server.py +++ b/truss/templates/server/common/truss_server.py @@ -11,25 +11,23 @@ from pathlib import Path from typing import AsyncGenerator, Dict, List, Optional, Union -import truss.server.common.errors as errors -import truss.server.shared.util as utils +import common.errors as errors +import shared.util as utils import uvicorn +from common.termination_handler_middleware import TerminationHandlerMiddleware from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.responses import ORJSONResponse, StreamingResponse from fastapi.routing import APIRoute as FastAPIRoute -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.requests import ClientDisconnect -from starlette.responses import Response -from truss.server.common.termination_handler_middleware import ( - TerminationHandlerMiddleware, -) -from truss.server.model_wrapper import ModelWrapper -from truss.server.shared.logging import setup_logging -from truss.server.shared.serialization import ( +from model_wrapper import ModelWrapper +from shared.logging import setup_logging +from shared.serialization import ( DeepNumpyEncoder, truss_msgpack_deserialize, truss_msgpack_serialize, ) +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import ClientDisconnect +from starlette.responses import Response # [IMPORTANT] A lot of things depend on this currently. # Please consider the following when increasing this: diff --git a/truss/server/inference_server.py b/truss/templates/server/inference_server.py similarity index 65% rename from truss/server/inference_server.py rename to truss/templates/server/inference_server.py index 922486880..9c9162260 100644 --- a/truss/server/inference_server.py +++ b/truss/templates/server/inference_server.py @@ -1,10 +1,9 @@ import os -from pathlib import Path from typing import Dict -from truss.server.common.truss_server import TrussServer -from truss.server.shared.logging import setup_logging -from truss.truss_config import TrussConfig +import yaml +from common.truss_server import TrussServer # noqa: E402 +from shared.logging import setup_logging CONFIG_FILE = "config.yaml" @@ -17,7 +16,8 @@ class ConfiguredTrussServer: def __init__(self, config_path: str, port: int): self._port = port - self._config = TrussConfig.from_yaml(Path(config_path)).to_dict(verbose=True) + with open(config_path, encoding="utf-8") as config_file: + self._config = yaml.safe_load(config_file) def start(self): server = TrussServer(http_port=self._port, config=self._config) diff --git a/truss/server/model_wrapper.py b/truss/templates/server/model_wrapper.py similarity index 97% rename from truss/server/model_wrapper.py rename to truss/templates/server/model_wrapper.py index dbf5478af..428e0dc61 100644 --- a/truss/server/model_wrapper.py +++ b/truss/templates/server/model_wrapper.py @@ -15,12 +15,12 @@ import pydantic from anyio import Semaphore, to_thread +from common.patches import apply_patches +from common.retry import retry +from common.schema import TrussSchema from fastapi import HTTPException from pydantic import BaseModel -from truss.server.common.patches import apply_patches -from truss.server.common.retry import retry -from truss.server.common.schema import TrussSchema -from truss.server.shared.secrets_resolver import SecretsResolver +from shared.secrets_resolver import SecretsResolver MODEL_BASENAME = "model" @@ -87,8 +87,7 @@ def __init__(self, config: Dict): ) ) self._background_tasks: Set[asyncio.Task] = set() - self.truss_schema: Optional[TrussSchema] = None - self.app_home: Path = Path(os.environ.get("APP_HOME", default=str(Path.cwd()))) + self.truss_schema: TrussSchema = None def load(self) -> bool: if self.ready: @@ -133,12 +132,10 @@ def try_load(self): data_dir = Path("data") data_dir.mkdir(exist_ok=True) - sys.path.append(str(self.app_home)) if "bundled_packages_dir" in self._config: - bundled_packages_path = self.app_home / "packages" + bundled_packages_path = Path("/packages") if bundled_packages_path.exists(): sys.path.append(str(bundled_packages_path)) - model_module_name = str( Path(self._config["model_class_filename"]).with_suffix("") ) diff --git a/truss/templates/server/requirements.txt b/truss/templates/server/requirements.txt new file mode 100644 index 000000000..f166536e6 --- /dev/null +++ b/truss/templates/server/requirements.txt @@ -0,0 +1,16 @@ +-i https://pypi.org/simple + +argparse==1.4.0 +aiocontextvars==0.2.2 +cython==3.0.5 +msgpack-numpy==0.4.8 +msgpack==1.0.2 +python-json-logger==2.0.2 +pyyaml==6.0.0 +fastapi==0.109.1 +uvicorn==0.24.0 +uvloop==0.17.0 +psutil==5.9.4 +joblib==1.2.0 +requests==2.31.0 +loguru==0.7.2 diff --git a/truss/templates/shared/README.md b/truss/templates/shared/README.md new file mode 100644 index 000000000..145566331 --- /dev/null +++ b/truss/templates/shared/README.md @@ -0,0 +1,3 @@ +# Shared code between training and serving images + +Code in this directory is common to both training and serving and is copied into them. diff --git a/truss/templates/shared/__init__.py b/truss/templates/shared/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/truss/server/shared/logging.py b/truss/templates/shared/logging.py similarity index 72% rename from truss/server/shared/logging.py rename to truss/templates/shared/logging.py index 01caa7efc..094d6ec02 100644 --- a/truss/server/shared/logging.py +++ b/truss/templates/shared/logging.py @@ -1,13 +1,10 @@ import logging -import os import sys from pythonjsonlogger import jsonlogger LEVEL: int = logging.INFO -use_json_logs = os.environ.get("JSON_LOG", default=False) - JSON_LOG_HANDLER = logging.StreamHandler(stream=sys.stdout) JSON_LOG_HANDLER.set_name("json_logger_handler") JSON_LOG_HANDLER.setLevel(LEVEL) @@ -35,15 +32,15 @@ def setup_logging() -> None: logger.propagate = False setup = False - if use_json_logs: - # let's not thrash the handlers unnecessarily - for handler in logger.handlers: - if handler.name == JSON_LOG_HANDLER.name: - setup = True - - if not setup: - logger.handlers.clear() - logger.addHandler(JSON_LOG_HANDLER) + + # let's not thrash the handlers unnecessarily + for handler in logger.handlers: + if handler.name == JSON_LOG_HANDLER.name: + setup = True + + if not setup: + logger.handlers.clear() + logger.addHandler(JSON_LOG_HANDLER) # some special handling for request logging if logger.name == "uvicorn.access": diff --git a/truss/server/shared/secrets_resolver.py b/truss/templates/shared/secrets_resolver.py similarity index 100% rename from truss/server/shared/secrets_resolver.py rename to truss/templates/shared/secrets_resolver.py diff --git a/truss/server/shared/serialization.py b/truss/templates/shared/serialization.py similarity index 100% rename from truss/server/shared/serialization.py rename to truss/templates/shared/serialization.py diff --git a/truss/server/shared/util.py b/truss/templates/shared/util.py similarity index 94% rename from truss/server/shared/util.py rename to truss/templates/shared/util.py index 51ef017f5..6fc8f245d 100644 --- a/truss/server/shared/util.py +++ b/truss/templates/shared/util.py @@ -1,7 +1,7 @@ import multiprocessing import os import sys -from typing import Callable, Dict, List, Mapping, TypeVar +from typing import Callable, Dict, List, TypeVar import psutil @@ -83,5 +83,5 @@ def kill_child_processes(parent_pid: int): Z = TypeVar("Z") -def transform_keys(d: Mapping[X, Z], fn: Callable[[X], Y]) -> Dict[Y, Z]: +def transform_keys(d: Dict[X, Z], fn: Callable[[X], Y]) -> Dict[Y, Z]: return {fn(key): value for key, value in d.items()} diff --git a/truss/templates/trtllm/packages/constants.py b/truss/templates/trtllm/packages/constants.py index 5940b4c7e..1f19e8065 100644 --- a/truss/templates/trtllm/packages/constants.py +++ b/truss/templates/trtllm/packages/constants.py @@ -1,9 +1,7 @@ from pathlib import Path # If changing model repo path, please updated inside tensorrt_llm config.pbtxt as well -TENSORRT_LLM_MODEL_REPOSITORY_PATH = Path( - "/app/packages/tensorrt_llm_model_repository/" -) +TENSORRT_LLM_MODEL_REPOSITORY_PATH = Path("/packages/tensorrt_llm_model_repository/") GRPC_SERVICE_PORT = 8001 HTTP_SERVICE_PORT = 8003 HF_AUTH_KEY_CONSTANT = "HUGGING_FACE_HUB_TOKEN" diff --git a/truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt b/truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt index 09fde9df2..75cb6718f 100644 --- a/truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +++ b/truss/templates/trtllm/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt @@ -173,7 +173,7 @@ parameters: { parameters: { key: "gpt_model_path" value: { - string_value: "/app/packages/tensorrt_llm_model_repository/tensorrt_llm/1" + string_value: "/packages/tensorrt_llm_model_repository/tensorrt_llm/1" } } parameters: { diff --git a/truss/test_data/server.Dockerfile b/truss/test_data/server.Dockerfile index c6f42d425..7be57e609 100644 --- a/truss/test_data/server.Dockerfile +++ b/truss/test_data/server.Dockerfile @@ -2,7 +2,6 @@ ARG PYVERSION=py39 FROM baseten/truss-server-base:3.9-v0.4.3 as truss_server ENV PYTHON_EXECUTABLE /usr/local/bin/python3 -ENV JSON_LOG True RUN grep -w 'ID=debian\|ID_LIKE=debian' /etc/os-release || { echo "ERROR: Supplied base image is not a debian image"; exit 1; } RUN $PYTHON_EXECUTABLE -c "import sys; sys.exit(0) if sys.version_info.major == 3 and sys.version_info.minor >=8 and sys.version_info.minor <=11 else sys.exit(1)" \ @@ -11,14 +10,6 @@ RUN $PYTHON_EXECUTABLE -c "import sys; sys.exit(0) if sys.version_info.major == RUN pip install --upgrade pip --no-cache-dir \ && rm -rf /root/.cache/pip - -# Always install the truss package -COPY ./truss/ /lib/truss_pkg/truss -COPY ./pyproject.toml /lib/truss_pkg/ -COPY ./README.md /lib/truss_pkg/ -RUN pip install /lib/truss_pkg --no-cache-dir && rm -rf /root/.cache/pip - - # If user base image is supplied in config, apply build commands from truss base image ENV PYTHONUNBUFFERED True ENV DEBIAN_FRONTEND=noninteractive @@ -34,16 +25,20 @@ RUN apt update && \ && apt-get clean -y \ && rm -rf /var/lib/apt/lists/* +COPY ./base_server_requirements.txt base_server_requirements.txt +RUN pip install -r base_server_requirements.txt --no-cache-dir && rm -rf /root/.cache/pip + ENV APP_HOME /app WORKDIR $APP_HOME # Copy data before code for better caching COPY ./data /app/data +COPY ./server /app COPY ./model /app/model COPY ./config.yaml /app/config.yaml -COPY ./packages /app/packages +COPY ./packages /packages ENV INFERENCE_SERVER_PORT 8080 -ENV SERVER_START_CMD="/usr/local/bin/python3 -m truss.server.inference_server" -ENTRYPOINT ["/usr/local/bin/python3", "-m", "truss.server.inference_server"] +ENV SERVER_START_CMD="/usr/local/bin/python3 /app/inference_server.py" +ENTRYPOINT ["/usr/local/bin/python3", "/app/inference_server.py"] diff --git a/truss/tests/conftest.py b/truss/tests/conftest.py index d59d86282..24c581166 100644 --- a/truss/tests/conftest.py +++ b/truss/tests/conftest.py @@ -6,15 +6,17 @@ import sys import time from pathlib import Path -from typing import Callable, Generator import pytest import requests import yaml from truss.build import init +from truss.contexts.image_builder.serving_image_builder import ( + ServingImageBuilderContext, +) +from truss.contexts.local_loader.docker_build_emulator import DockerBuildEmulator from truss.truss_config import DEFAULT_BUNDLED_PACKAGES_DIR from truss.types import Example -from truss.util.path import copy_tree_path CUSTOM_MODEL_CODE = """ class Model: @@ -204,18 +206,13 @@ def predict(self, model_input): """ -@pytest.fixture -def temp_path(tmpdir): - yield Path(tmpdir) - - @pytest.fixture def pytorch_model_init_args(): return {"arg1": 1, "arg2": 2, "kwarg1": 3, "kwarg2": 4} @pytest.fixture -def custom_model_truss_dir(tmp_path) -> Generator[Path, None, None]: +def custom_model_truss_dir(tmp_path) -> Path: yield _custom_model_from_code( tmp_path, "custom_truss", @@ -510,26 +507,20 @@ def custom_model_truss_dir_for_secrets(tmp_path): @pytest.fixture -def tmp_truss_dir(tmp_path, monkeypatch): +def truss_container_fs(tmp_path): ROOT = Path(__file__).parent.parent.parent.resolve() - tmp_dir = _copy_truss_dir_to_tmp( - ROOT / "truss" / "test_data" / "test_truss", tmp_path - ) - monkeypatch.setenv("APP_HOME", str(tmp_dir)) - return tmp_dir + return _build_truss_fs(ROOT / "truss" / "test_data" / "test_truss", tmp_path) @pytest.fixture -def tmp_truss_control_dir(tmp_path, monkeypatch): +def truss_control_container_fs(tmp_path): ROOT = Path(__file__).parent.parent.parent.resolve() test_truss_dir = ROOT / "truss" / "test_data" / "test_truss" control_truss_dir = tmp_path / "control_truss" shutil.copytree(str(test_truss_dir), str(control_truss_dir)) with _modify_yaml(control_truss_dir / "config.yaml") as content: content["live_reload"] = True - tmp_dir = _copy_truss_dir_to_tmp(control_truss_dir, tmp_path) - monkeypatch.setenv("APP_HOME", str(tmp_dir)) - return tmp_dir + return _build_truss_fs(control_truss_dir, tmp_path) @pytest.fixture @@ -585,7 +576,7 @@ def _custom_model_from_code( where_dir: Path, truss_name: str, model_code: str, - handle_ops: Callable = None, + handle_ops: callable = None, ) -> Path: dir_path = where_dir / truss_name handle = init(str(dir_path)) @@ -648,9 +639,18 @@ def helpers(): return Helpers() -def _copy_truss_dir_to_tmp(truss_dir: Path, tmp_path: Path) -> Path: - copy_tree_path(truss_dir, tmp_path) - return tmp_path +def _build_truss_fs(truss_dir: Path, tmp_path: Path) -> Path: + truss_fs = tmp_path / "truss_fs" + truss_fs.mkdir() + truss_build_dir = tmp_path / "truss_fs_build" + truss_build_dir.mkdir() + image_builder = ServingImageBuilderContext.run(truss_dir) + image_builder.prepare_image_build_dir(truss_build_dir) + dockerfile_path = truss_build_dir / "Dockerfile" + + docker_build_emulator = DockerBuildEmulator(dockerfile_path, truss_build_dir) + docker_build_emulator.run(truss_fs) + return truss_fs @contextlib.contextmanager diff --git a/truss/tests/contexts/image_builder/test_serving_image_builder.py b/truss/tests/contexts/image_builder/test_serving_image_builder.py index 491c8c26c..95ea83b8c 100644 --- a/truss/tests/contexts/image_builder/test_serving_image_builder.py +++ b/truss/tests/contexts/image_builder/test_serving_image_builder.py @@ -19,8 +19,8 @@ def test_serving_image_dockerfile_from_user_base_image(custom_model_truss_dir): th = TrussHandle(custom_model_truss_dir) th.set_base_image("baseten/truss-server-base:3.9-v0.4.3", "/usr/local/bin/python3") - - image_builder = ServingImageBuilderContext.run(th.spec.truss_dir) + builder_context = ServingImageBuilderContext + image_builder = builder_context.run(th.spec.truss_dir) with TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir) image_builder.prepare_image_build_dir(tmp_path) @@ -43,7 +43,8 @@ def filter_empty_lines(lines): def test_requirements_setup_in_build_dir(custom_model_truss_dir): th = TrussHandle(custom_model_truss_dir) th.add_python_requirement("numpy") - image_builder = ServingImageBuilderContext.run(th.spec.truss_dir) + builder_context = ServingImageBuilderContext + image_builder = builder_context.run(th.spec.truss_dir) with TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir) @@ -51,8 +52,10 @@ def test_requirements_setup_in_build_dir(custom_model_truss_dir): with open(tmp_path / "requirements.txt", "r") as f: requirements_content = f.read() - # We are no longer adding the base requirements because the server is installed separately. - assert requirements_content == "numpy\n" + with open(f"{BASE_DIR}/../../../templates/server/requirements.txt", "r") as f: + base_requirements_content = f.read() + + assert requirements_content == base_requirements_content + "numpy\n" def flatten_cached_files(local_cache_files): diff --git a/truss/tests/patch/test_calc_patch.py b/truss/tests/patch/test_calc_patch.py index 566d70d65..dc4a7c75c 100644 --- a/truss/tests/patch/test_calc_patch.py +++ b/truss/tests/patch/test_calc_patch.py @@ -4,7 +4,7 @@ import yaml from truss.patch.calc_patch import calc_truss_patch, calc_unignored_paths from truss.patch.signature import calc_truss_signature -from truss.server.control.patch.types import ( +from truss.templates.control.control.helpers.types import ( Action, ConfigPatch, EnvVarPatch, @@ -631,7 +631,7 @@ def _apply_config_change_and_calc_patches( custom_model_truss_dir: Path, config_op: Callable[[TrussConfig], Any], config_pre_op: Optional[Callable[[TrussConfig], Any]] = None, -) -> Optional[List[Patch]]: +) -> List[Patch]: def modify_config(op): config_path = custom_model_truss_dir / "config.yaml" config = TrussConfig.from_yaml(config_path) diff --git a/truss/tests/patch/test_truss_dir_patch_applier.py b/truss/tests/patch/test_truss_dir_patch_applier.py index 697be2e32..0dc440ccc 100644 --- a/truss/tests/patch/test_truss_dir_patch_applier.py +++ b/truss/tests/patch/test_truss_dir_patch_applier.py @@ -3,7 +3,7 @@ import yaml from truss.patch.truss_dir_patch_applier import TrussDirPatchApplier -from truss.server.control.patch.types import ( +from truss.templates.control.control.helpers.types import ( Action, ConfigPatch, ModelCodePatch, diff --git a/truss/tests/templates/control/control/helpers/test_context_managers.py b/truss/tests/templates/control/control/helpers/test_context_managers.py new file mode 100644 index 000000000..1c222b017 --- /dev/null +++ b/truss/tests/templates/control/control/helpers/test_context_managers.py @@ -0,0 +1,11 @@ +import os + +from truss.templates.control.control.helpers.context_managers import current_directory + + +def test_current_directory(tmp_path): + orig_cwd = os.getcwd() + with current_directory(tmp_path): + assert os.getcwd() == str(tmp_path) + + assert os.getcwd() == orig_cwd diff --git a/truss/tests/server/control/test_model_container_patch_applier.py b/truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py similarity index 70% rename from truss/tests/server/control/test_model_container_patch_applier.py rename to truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py index 3660d1d81..83d151dfd 100644 --- a/truss/tests/server/control/test_model_container_patch_applier.py +++ b/truss/tests/templates/control/control/helpers/test_model_container_patch_applier.py @@ -1,11 +1,26 @@ import os +import sys +from pathlib import Path from unittest import mock import pytest -from truss.server.control.patch.model_container_patch_applier import ( +from truss.truss_config import TrussConfig + +# Needed to simulate the set up on the model docker container +sys.path.append( + str( + Path(__file__).parent.parent.parent.parent.parent.parent + / "templates" + / "control" + / "control" + ) +) + +# Have to use imports in this form, otherwise isinstance checks fail on helper classes +from helpers.truss_patch.model_container_patch_applier import ( # noqa ModelContainerPatchApplier, ) -from truss.server.control.patch.types import ( +from helpers.types import ( # noqa Action, ConfigPatch, EnvVarPatch, @@ -15,16 +30,15 @@ Patch, PatchType, ) -from truss.truss_config import TrussConfig @pytest.fixture -def patch_applier(tmp_truss_dir): - return ModelContainerPatchApplier(tmp_truss_dir, mock.Mock()) +def patch_applier(truss_container_fs): + return ModelContainerPatchApplier(truss_container_fs / "app", mock.Mock()) def test_patch_applier_model_code_patch_add( - patch_applier: ModelContainerPatchApplier, tmp_truss_dir + patch_applier: ModelContainerPatchApplier, truss_container_fs ): patch = Patch( type=PatchType.MODEL_CODE, @@ -35,11 +49,11 @@ def test_patch_applier_model_code_patch_add( ), ) patch_applier(patch, os.environ.copy()) - assert (tmp_truss_dir / "model" / "dummy").exists() + assert (truss_container_fs / "app" / "model" / "dummy").exists() def test_patch_applier_model_code_patch_remove( - patch_applier: ModelContainerPatchApplier, tmp_truss_dir + patch_applier: ModelContainerPatchApplier, truss_container_fs ): patch = Patch( type=PatchType.MODEL_CODE, @@ -48,13 +62,13 @@ def test_patch_applier_model_code_patch_remove( path="model.py", ), ) - assert (tmp_truss_dir / "model" / "model.py").exists() + assert (truss_container_fs / "app" / "model" / "model.py").exists() patch_applier(patch, os.environ.copy()) - assert not (tmp_truss_dir / "model" / "model.py").exists() + assert not (truss_container_fs / "app" / "model" / "model.py").exists() def test_patch_applier_model_code_patch_update( - patch_applier: ModelContainerPatchApplier, tmp_truss_dir + patch_applier: ModelContainerPatchApplier, truss_container_fs ): new_model_file_content = """ class Model: @@ -69,11 +83,13 @@ class Model: ), ) patch_applier(patch, os.environ.copy()) - assert (tmp_truss_dir / "model" / "model.py").read_text() == new_model_file_content + assert ( + truss_container_fs / "app" / "model" / "model.py" + ).read_text() == new_model_file_content def test_patch_applier_package_patch_add( - patch_applier: ModelContainerPatchApplier, tmp_truss_dir + patch_applier: ModelContainerPatchApplier, truss_container_fs ): patch = Patch( type=PatchType.PACKAGE, @@ -84,12 +100,12 @@ def test_patch_applier_package_patch_add( ), ) patch_applier(patch, os.environ.copy()) - assert (tmp_truss_dir / "packages" / "test_package" / "test.py").exists() + assert (truss_container_fs / "packages" / "test_package" / "test.py").exists() def test_patch_applier_package_patch_remove( patch_applier: ModelContainerPatchApplier, - tmp_truss_dir, + truss_container_fs, ): patch = Patch( type=PatchType.PACKAGE, @@ -98,14 +114,14 @@ def test_patch_applier_package_patch_remove( path="test_package/test.py", ), ) - assert (tmp_truss_dir / "packages" / "test_package" / "test.py").exists() + assert (truss_container_fs / "packages" / "test_package" / "test.py").exists() patch_applier(patch, os.environ.copy()) - assert not (tmp_truss_dir / "packages" / "test_package" / "test.py").exists() + assert not (truss_container_fs / "packages" / "test_package" / "test.py").exists() def test_patch_applier_package_patch_update( patch_applier: ModelContainerPatchApplier, - tmp_truss_dir, + truss_container_fs, ): new_package_content = """X = 2""" patch = Patch( @@ -118,12 +134,12 @@ def test_patch_applier_package_patch_update( ) patch_applier(patch, os.environ.copy()) assert ( - tmp_truss_dir / "packages" / "test_package" / "test.py" + truss_container_fs / "packages" / "test_package" / "test.py" ).read_text() == new_package_content def test_patch_applier_config_patch_update( - patch_applier: ModelContainerPatchApplier, tmp_truss_dir + patch_applier: ModelContainerPatchApplier, truss_container_fs ): new_config_dict = {"model_name": "foobar"} patch = Patch( @@ -134,7 +150,7 @@ def test_patch_applier_config_patch_update( ), ) patch_applier(patch, os.environ.copy()) - new_config = TrussConfig.from_yaml(tmp_truss_dir / "config.yaml") + new_config = TrussConfig.from_yaml(truss_container_fs / "app" / "config.yaml") assert new_config.model_name == "foobar" @@ -187,7 +203,7 @@ def test_patch_applier_env_var_patch_remove( def test_patch_applier_external_data_patch_add( patch_applier: ModelContainerPatchApplier, - tmp_truss_dir, + truss_container_fs, ): patch = Patch( type=PatchType.EXTERNAL_DATA, @@ -200,4 +216,4 @@ def test_patch_applier_external_data_patch_add( ), ) patch_applier(patch, os.environ.copy()) - assert (tmp_truss_dir / "data" / "truss_icon").exists() + assert (truss_container_fs / "app" / "data" / "truss_icon").exists() diff --git a/truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py b/truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py new file mode 100644 index 000000000..28ea00c78 --- /dev/null +++ b/truss/tests/templates/control/control/helpers/test_requirement_name_identifier.py @@ -0,0 +1,34 @@ +import pytest +from truss.templates.control.control.helpers.truss_patch.requirement_name_identifier import ( + identify_requirement_name, + reqs_by_name, +) + + +@pytest.mark.parametrize( + "req, expected_name", + [ + ("pytorch", "pytorch"), + ( + "git+https://github.com/huggingface/transformers.git", + "git+https://github.com/huggingface/transformers.git", + ), + ( + " git+https://github.com/huggingface/transformers.git ", + "git+https://github.com/huggingface/transformers.git", + ), + ("pytorch==1.0", "pytorch"), + ("pytorch>=1.0", "pytorch"), + ("pytorch<=1.0", "pytorch"), + ], +) +def test_identify_requirement_name(req, expected_name): + assert expected_name == identify_requirement_name(req) + + +def test_reqs_by_name(): + reqs = [ + "pytorch", + "jinja==1.0", + ] + assert reqs_by_name(reqs) == {"pytorch": "pytorch", "jinja": "jinja==1.0"} diff --git a/truss/tests/server/control/test_server.py b/truss/tests/templates/control/control/test_server.py similarity index 89% rename from truss/tests/server/control/test_server.py rename to truss/tests/templates/control/control/test_server.py index b0e764e4a..5343e3ced 100644 --- a/truss/tests/server/control/test_server.py +++ b/truss/tests/templates/control/control/test_server.py @@ -1,18 +1,36 @@ import os +import sys from contextlib import contextmanager +from pathlib import Path from typing import Dict, List import pytest from httpx import AsyncClient -from truss.server.control.application import create_app -from truss.server.control.patch.types import ( +from truss.types import PatchRequest + +# Needed to simulate the set up on the model docker container +sys.path.append( + str( + Path(__file__).parent.parent.parent.parent.parent + / "templates" + / "control" + / "control" + ) +) + +sys.path.append(str(Path(__file__).parent.parent.parent.parent.parent / "templates")) +sys.path.append( + str(Path(__file__).parent.parent.parent.parent.parent / "templates" / "shared") +) + +from truss.templates.control.control.application import create_app # noqa +from truss.templates.control.control.helpers.types import ( # noqa Action, ModelCodePatch, Patch, PatchType, PythonRequirementPatch, ) -from truss.types import PatchRequest @pytest.fixture @@ -21,17 +39,13 @@ def truss_original_hash(): @pytest.fixture -def app(tmp_truss_dir, truss_original_hash): +def app(truss_container_fs, truss_original_hash): with _env_var({"HASH_TRUSS": truss_original_hash}): - inf_serv_home = tmp_truss_dir + inf_serv_home = truss_container_fs / "app" control_app = create_app( { "inference_server_home": inf_serv_home, - "inference_server_process_args": [ - "python", - "-m", - "truss.server.inference_server", - ], + "inference_server_process_args": ["python", "inference_server.py"], "control_server_host": "*", "control_server_port": 8081, "inference_server_port": 8082, diff --git a/truss/tests/server/control/test_server_integration.py b/truss/tests/templates/control/control/test_server_integration.py similarity index 93% rename from truss/tests/server/control/test_server_integration.py rename to truss/tests/templates/control/control/test_server_integration.py index aaff787d0..2631c2dc5 100644 --- a/truss/tests/server/control/test_server_integration.py +++ b/truss/tests/templates/control/control/test_server_integration.py @@ -15,7 +15,6 @@ import psutil import pytest import requests -from truss.server.control.server import ControlServer PATCH_PING_MAX_DELAY_SECS = 3 @@ -28,8 +27,8 @@ class ControlServerDetails: @pytest.fixture -def control_server(tmp_truss_control_dir): - with _configured_control_server(tmp_truss_control_dir) as server: +def control_server(truss_control_container_fs): + with _configured_control_server(truss_control_container_fs) as server: yield server @@ -96,10 +95,10 @@ def inner(): @pytest.mark.integration -def test_truss_control_server_patch_ping_delays(tmp_truss_control_dir: Path): +def test_truss_control_server_patch_ping_delays(truss_control_container_fs: Path): for _ in range(10): with _configured_control_server( - tmp_truss_control_dir, + truss_control_container_fs, with_patch_ping_flow=True, ) as control_server: # Account for patch ping delays @@ -187,10 +186,16 @@ def start_truss_server(stdout_capture_file_path): "PATCH_PING_URL_TRUSS" ] = f"http://localhost:{patch_ping_server_port}" sys.stdout = open(stdout_capture_file_path, "w") + app_path = truss_control_container_fs / "app" + sys.path.append(str(app_path)) + control_path = truss_control_container_fs / "control" / "control" + sys.path.append(str(control_path)) + + from server import ControlServer control_server = ControlServer( python_executable_path=sys.executable, - inf_serv_home=str(truss_control_container_fs), + inf_serv_home=str(app_path), control_server_port=ctrl_port, inference_server_port=inf_port, ) diff --git a/truss/tests/server/core/server/common/test_truss_server.py b/truss/tests/templates/core/server/common/test_truss_server.py similarity index 82% rename from truss/tests/server/core/server/common/test_truss_server.py rename to truss/tests/templates/core/server/common/test_truss_server.py index e03a90aa5..2f746f405 100644 --- a/truss/tests/server/core/server/common/test_truss_server.py +++ b/truss/tests/templates/core/server/common/test_truss_server.py @@ -9,16 +9,20 @@ import pytest import yaml -from truss.server.common.truss_server import TrussServer @pytest.mark.integration -def test_truss_server_termination(tmp_truss_dir): +def test_truss_server_termination(truss_container_fs): port = 10123 def start_truss_server(stdout_capture_file_path): sys.stdout = open(stdout_capture_file_path, "w") - config = yaml.safe_load((tmp_truss_dir / "config.yaml").read_text()) + app_path = truss_container_fs / "app" + sys.path.append(str(app_path)) + + from common.truss_server import TrussServer + + config = yaml.safe_load((app_path / "config.yaml").read_text()) server = TrussServer(http_port=port, config=config) server.start() diff --git a/truss/tests/server/core/server/common/test_util.py b/truss/tests/templates/core/server/common/test_util.py similarity index 100% rename from truss/tests/server/core/server/common/test_util.py rename to truss/tests/templates/core/server/common/test_util.py diff --git a/truss/tests/server/core/server/test_secrets_resolver.py b/truss/tests/templates/core/server/test_secrets_resolver.py similarity index 95% rename from truss/tests/server/core/server/test_secrets_resolver.py rename to truss/tests/templates/core/server/test_secrets_resolver.py index 7ccb1d273..cad15bd08 100644 --- a/truss/tests/server/core/server/test_secrets_resolver.py +++ b/truss/tests/templates/core/server/test_secrets_resolver.py @@ -2,7 +2,7 @@ from contextlib import contextmanager from pathlib import Path -from truss.server.shared.secrets_resolver import SecretsResolver +from truss.templates.shared.secrets_resolver import SecretsResolver CONFIG = {"secrets": {"secret_key": "default_secret_value"}} diff --git a/truss/tests/server/common/test_retry.py b/truss/tests/templates/server/common/test_retry.py similarity index 93% rename from truss/tests/server/common/test_retry.py rename to truss/tests/templates/server/common/test_retry.py index c477299df..c13fbb4be 100644 --- a/truss/tests/server/common/test_retry.py +++ b/truss/tests/templates/server/common/test_retry.py @@ -1,8 +1,8 @@ -from typing import Any, Callable +from typing import Any from unittest.mock import Mock import pytest -from truss.server.common.retry import retry +from truss.templates.server.common.retry import retry class FailForCallCount: @@ -20,7 +20,7 @@ def call_count(self) -> int: return self._call_count -def fail_for_call_count(count: int) -> Callable: +def fail_for_call_count(count: int) -> callable: call_count = 0 def inner(): diff --git a/truss/tests/server/common/test_termination_handler_middleware.py b/truss/tests/templates/server/common/test_termination_handler_middleware.py similarity index 97% rename from truss/tests/server/common/test_termination_handler_middleware.py rename to truss/tests/templates/server/common/test_termination_handler_middleware.py index 91caa7b57..cd98edead 100644 --- a/truss/tests/server/common/test_termination_handler_middleware.py +++ b/truss/tests/templates/server/common/test_termination_handler_middleware.py @@ -5,7 +5,7 @@ from typing import Awaitable, Callable, List import pytest -from truss.server.common.termination_handler_middleware import ( +from truss.templates.server.common.termination_handler_middleware import ( TerminationHandlerMiddleware, ) diff --git a/truss/tests/server/test_model_wrapper.py b/truss/tests/templates/server/test_model_wrapper.py similarity index 96% rename from truss/tests/server/test_model_wrapper.py rename to truss/tests/templates/server/test_model_wrapper.py index 43e1fde60..d0a4692b1 100644 --- a/truss/tests/server/test_model_wrapper.py +++ b/truss/tests/templates/server/test_model_wrapper.py @@ -9,8 +9,8 @@ @pytest.fixture -def app_path(tmp_truss_dir: Path, helpers: Any): - truss_container_app_path = tmp_truss_dir +def app_path(truss_container_fs: Path, helpers: Any): + truss_container_app_path = truss_container_fs / "app" model_file_content = """ class Model: def __init__(self): diff --git a/truss/tests/server/test_schema.py b/truss/tests/templates/server/test_schema.py similarity index 99% rename from truss/tests/server/test_schema.py rename to truss/tests/templates/server/test_schema.py index 25cf6ae11..643d89125 100644 --- a/truss/tests/server/test_schema.py +++ b/truss/tests/templates/server/test_schema.py @@ -2,7 +2,7 @@ from typing import AsyncGenerator, Awaitable, Generator, Union from pydantic import BaseModel -from truss.server.common.schema import TrussSchema +from truss.templates.server.common.schema import TrussSchema class ModelInput(BaseModel): diff --git a/truss/tests/test_control_truss_patching.py b/truss/tests/test_control_truss_patching.py index 51796ccad..56f4ca5b3 100644 --- a/truss/tests/test_control_truss_patching.py +++ b/truss/tests/test_control_truss_patching.py @@ -1,6 +1,5 @@ from dataclasses import replace from pathlib import Path -from typing import Tuple import pytest from truss.constants import SUPPORTED_PYTHON_VERSIONS @@ -24,7 +23,7 @@ def current_num_docker_images(th: TrussHandle) -> int: @pytest.fixture def control_model_handle_tag_tuple( custom_model_control, -) -> Tuple[Path, TrussHandle, str]: +) -> tuple[Path, TrussHandle, str]: th = TrussHandle(custom_model_control) tag = "test-docker-custom-model-control-tag:0.0.1" return (custom_model_control, th, tag) diff --git a/truss/tests/test_truss_handle.py b/truss/tests/test_truss_handle.py index 4e22599e4..10700bb1e 100644 --- a/truss/tests/test_truss_handle.py +++ b/truss/tests/test_truss_handle.py @@ -10,7 +10,12 @@ from truss.docker import Docker, DockerStates from truss.errors import ContainerIsDownError, ContainerNotFoundError from truss.local.local_config_handler import LocalConfigHandler -from truss.server.control.patch.types import Action, ModelCodePatch, Patch, PatchType +from truss.templates.control.control.helpers.types import ( + Action, + ModelCodePatch, + Patch, + PatchType, +) from truss.tests.test_testing_utilities_for_other_tests import ( ensure_kill_all, kill_all_with_retries, diff --git a/truss/truss_handle.py b/truss/truss_handle.py index 3b4bea8ec..ddd149574 100644 --- a/truss/truss_handle.py +++ b/truss/truss_handle.py @@ -54,7 +54,7 @@ from truss.patch.signature import calc_truss_signature from truss.patch.types import TrussSignature from truss.readme_generator import generate_readme -from truss.server.shared.serialization import ( +from truss.templates.shared.serialization import ( truss_msgpack_deserialize, truss_msgpack_serialize, ) @@ -159,7 +159,6 @@ def docker_run( patch_ping_url: Optional[str] = None, wait_for_server_ready: bool = True, network: Optional[str] = None, - cache: bool = True, ): """ Builds a docker image and runs it as a container. For control trusses, @@ -186,7 +185,7 @@ def docker_run( container = container_if_patched else: image = self.build_serving_docker_image( - build_dir=build_dir, tag=tag, network=network, cache=cache + build_dir=build_dir, tag=tag, network=network ) secrets_mount_dir_path = _prepare_secrets_mount_dir() publish_ports = [[local_port, INFERENCE_SERVER_PORT]] @@ -855,7 +854,7 @@ def _build_image( network: Optional[str] = None, ): image = _docker_image_from_labels(labels=labels) - if cache and image is not None: + if image is not None: return image build_dir_path = Path(build_dir) if build_dir is not None else None diff --git a/truss/types.py b/truss/types.py index 30ee43c78..d51ddd16e 100644 --- a/truss/types.py +++ b/truss/types.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from truss.patch.types import TrussSignature -from truss.server.control.patch.types import Patch +from truss.templates.control.control.helpers.types import Patch class ModelFrameworkType(Enum): diff --git a/truss/util/path.py b/truss/util/path.py index 47d4efa84..7e302cdb4 100644 --- a/truss/util/path.py +++ b/truss/util/path.py @@ -15,15 +15,10 @@ FIXED_TRUSS_IGNORE_PATH = Path(__file__).parent / ".truss_ignore" -def copy_tree_path( - src: Path, dest: Path, ignore_patterns: List[str] = [], ignore_files: bool = True -) -> None: +def copy_tree_path(src: Path, dest: Path, ignore_patterns: List[str] = []) -> None: """Copy a directory tree, ignoring files specified in .truss_ignore.""" - if ignore_files: - patterns = load_trussignore_patterns() - patterns.extend(ignore_patterns) - else: - patterns = [] + patterns = load_trussignore_patterns() + patterns.extend(ignore_patterns) if not dest.exists(): dest.mkdir(parents=True) @@ -45,13 +40,11 @@ def copy_file_path(src: Path, dest: Path) -> Tuple[str, str]: return copy_file(str(src), str(dest), verbose=False) -def copy_tree_or_file( - src: Path, dest: Path, ignore_files: bool = True -) -> Union[List[str], Tuple[str, str]]: +def copy_tree_or_file(src: Path, dest: Path) -> Union[List[str], Tuple[str, str]]: if src.is_file(): return copy_file_path(src, dest) - return copy_tree_path(src, dest, ignore_files=ignore_files) # type: ignore + return copy_tree_path(src, dest) # type: ignore def remove_tree_path(target: Path) -> None: