Skip to content

Commit

Permalink
[TaT 2] Extract rich from non-cli lib parts. Use DI instead. (#1247)
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten authored Nov 18, 2024
1 parent 8aebf26 commit 79e1841
Show file tree
Hide file tree
Showing 16 changed files with 367 additions and 314 deletions.
390 changes: 195 additions & 195 deletions poetry.lock

Large diffs are not rendered by default.

7 changes: 3 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ packages = [
]

[tool.poetry.scripts]
truss = "truss.cli:truss_cli"
truss = "truss.cli.cli:truss_cli"

[tool.poetry.urls]
"Homepage" = "https://truss.baseten.co"
Expand Down Expand Up @@ -70,7 +70,7 @@ truss = "truss.cli:truss_cli"
# "When using chains, 3.9 will be required at runtime, but other truss functionality works with 3.8.
python = ">=3.8,<3.13"
huggingface_hub = ">=0.25.0"
pydantic = ">=1.10.0" # We cannot upgrade to v2, due to customer needs.
pydantic = ">=1.10.0" # We cannot upgrade to v2, due to customer constraints.
PyYAML = ">=6.0"
single-source = "^0.3.0"
# "non-base" dependencies.
Expand All @@ -84,7 +84,6 @@ aiohttp = { version = "^3.10.10", optional = false }
blake3 = { version = "^0.3.3", optional = false }
boto3 = { version = "^1.34.85", optional = false }
click = { version = "^8.0.3", optional = false }
fastapi = { version = ">=0.109.1", optional = false }
google-cloud-storage = { version = "2.10.0", optional = false }
httpx = { version = ">=0.24.1", optional = false }
inquirerpy = { version = "^0.3.4", optional = false }
Expand All @@ -111,7 +110,6 @@ aiohttp = { components = "other" }
blake3 = { components = "other" }
boto3 = { components = "other" }
click = { components = "other" }
fastapi = { components = "other" }
google-cloud-storage = { components = "other" }
httpx = { components = "other" }
inquirerpy = { components = "other" }
Expand Down Expand Up @@ -154,6 +152,7 @@ types-setuptools = "^69.0.0.0"
# These packages are needed to run local tests of server components. Note that the actual
# server deps for building the docker image are (so far) defined in `requirements.txt`-files.
dockerfile = "^3.2.0"
fastapi =">=0.109.1"
flask = "^2.3.3"
msgpack = ">=1.0.2"
msgpack-numpy = ">=0.4.8"
Expand Down
6 changes: 4 additions & 2 deletions truss-chains/truss_chains/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,9 +631,11 @@ def gen_truss_chainlet(
name=dep.name,
options=dep.options,
)

chainlet_dir = _make_chainlet_dir(chain_name, chainlet_descriptor, gen_root)

logging.info(
f"Code generation for Chainlet `{chainlet_descriptor.name}` "
f"in `{chainlet_dir}`."
)
_make_truss_config(
chainlet_dir,
chainlet_descriptor.chainlet_cls.remote_config,
Expand Down
8 changes: 2 additions & 6 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from truss.base import truss_config
from truss.base.constants import PRODUCTION_ENVIRONMENT_NAME
from truss.remote import baseten as baseten_remote
from truss.remote import remote_cli, remote_factory
from truss.remote import remote_factory

BASETEN_API_SECRET_NAME = "baseten_chain_api_key"
SECRET_DUMMY = "***"
Expand Down Expand Up @@ -635,13 +635,9 @@ def create(
publish: bool,
promote: Optional[bool],
only_generate_trusses: bool,
remote: Optional[str] = None,
remote: str,
environment: Optional[str] = None,
) -> "PushOptionsBaseten":
if not remote:
remote = remote_cli.inquire_remote_name(
remote_factory.RemoteFactory.get_available_config_names()
)
if promote and not environment:
environment = PRODUCTION_ENVIRONMENT_NAME
if environment:
Expand Down
11 changes: 8 additions & 3 deletions truss-chains/truss_chains/public_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import functools
import pathlib
from typing import ContextManager, Mapping, Optional, Type, Union
from typing import TYPE_CHECKING, ContextManager, Mapping, Optional, Type, Union

if TYPE_CHECKING:
from rich import progress

from truss_chains import definitions, framework
from truss_chains import remote as chains_remote
Expand Down Expand Up @@ -122,8 +125,9 @@ def push(
publish: bool = True,
promote: bool = True,
only_generate_trusses: bool = False,
remote: Optional[str] = None,
remote: str = "baseten",
environment: Optional[str] = None,
progress_bar: Optional[Type["progress.Progress"]] = None,
) -> chains_remote.BasetenChainService:
"""
Deploys a chain remotely (with all dependent chainlets).
Expand All @@ -141,6 +145,7 @@ def push(
remote: name of a remote config in `.trussrc`. If not provided, it will be
inquired.
environment: The name of an environment to promote deployment into.
progress_bar: Optional `rich.progress.Progress` if output is desired.
Returns:
A chain service handle to the deployed chain.
Expand All @@ -154,7 +159,7 @@ def push(
remote=remote,
environment=environment,
)
service = chains_remote.push(entrypoint, options)
service = chains_remote.push(entrypoint, options, progress_bar=progress_bar)
assert isinstance(service, chains_remote.BasetenChainService) # Per options above.
return service

Expand Down
46 changes: 21 additions & 25 deletions truss-chains/truss_chains/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@

if TYPE_CHECKING:
from rich import console as rich_console
from rich import progress
from truss.local import local_config_handler
from truss.remote import remote_cli, remote_factory
from truss.remote import remote_factory
from truss.remote.baseten import core as b10_core
from truss.remote.baseten import custom_types as b10_types
from truss.remote.baseten import remote as b10_remote
Expand Down Expand Up @@ -86,14 +87,10 @@ def _push_service_docker(
options: definitions.PushOptionsLocalDocker,
port: int,
) -> None:
logging.info(f"Running in docker container `{chainlet_display_name}` ")

truss_handle = truss_build.load(str(truss_dir))

truss_handle.add_secret(
definitions.BASETEN_API_SECRET_NAME, options.baseten_chain_api_key
)

truss_handle.docker_run(
local_port=port,
detach=True,
Expand Down Expand Up @@ -274,20 +271,22 @@ def _create_baseten_chain(
baseten_options: definitions.PushOptionsBaseten,
entrypoint_artifact: b10_types.ChainletArtifact,
dependency_artifacts: list[b10_types.ChainletArtifact],
progress_bar: Optional[Type["progress.Progress"]],
):
logging.info(
f"Pushing Chain '{baseten_options.chain_name}' to Baseten "
f"(publish={baseten_options.publish}, environment={baseten_options.environment})."
)
chain_deployment_handle, entrypoint_service = (
baseten_options.remote_provider.push_chain_atomic(
chain_name=baseten_options.chain_name,
entrypoint_artifact=entrypoint_artifact,
dependency_artifacts=dependency_artifacts,
publish=baseten_options.publish,
environment=baseten_options.environment,
progress_bar=progress_bar,
)
)

logging.info(f"Pushed Chain '{baseten_options.chain_name}'.")
logging.debug(f"Internal model endpoint: '{entrypoint_service.predict_url}'.")

return BasetenChainService(
baseten_options.chain_name,
entrypoint_service,
Expand Down Expand Up @@ -315,7 +314,7 @@ class _ChainSourceGenerator:
def __init__(
self,
options: definitions.PushOptions,
gen_root: Optional[pathlib.Path] = None,
gen_root: pathlib.Path,
) -> None:
self._options = options
self._gen_root = gen_root or pathlib.Path(tempfile.gettempdir())
Expand All @@ -335,11 +334,6 @@ def generate_chainlet_artifacts(
# we add a random suffix.
model_suffix = str(uuid.uuid4()).split("-")[0]
model_name = f"{model_base_name}-{model_suffix}"

logging.info(
f"Generating Truss Chainlet model for '{chainlet_descriptor.name}'."
)

chainlet_dir = code_gen.gen_truss_chainlet(
chain_root,
self._gen_root,
Expand Down Expand Up @@ -373,6 +367,7 @@ def push(
options: definitions.PushOptions,
non_entrypoint_root_dir: Optional[str] = None,
gen_root: pathlib.Path = pathlib.Path(tempfile.gettempdir()),
progress_bar: Optional[Type["progress.Progress"]] = None,
) -> Optional[ChainService]:
entrypoint_artifact, dependency_artifacts = _ChainSourceGenerator(
options, gen_root
Expand All @@ -386,7 +381,9 @@ def push(

if isinstance(options, definitions.PushOptionsBaseten):
_create_chains_secret_if_missing(options.remote_provider)
return _create_baseten_chain(options, entrypoint_artifact, dependency_artifacts)
return _create_baseten_chain(
options, entrypoint_artifact, dependency_artifacts, progress_bar
)
elif isinstance(options, definitions.PushOptionsLocalDocker):
chainlet_artifacts = [entrypoint_artifact, *dependency_artifacts]
chainlet_to_predict_url: Dict[str, Dict[str, str]] = {}
Expand Down Expand Up @@ -418,15 +415,18 @@ def push(
# paths for each container under the `/tmp` dir.
for chainlet_artifact in chainlet_artifacts:
truss_dir = chainlet_artifact.truss_dir

logging.info(
f"Building Chainlet `{chainlet_artifact.display_name}` docker image."
)
_push_service_docker(
truss_dir,
chainlet_artifact.display_name,
options,
chainlet_to_service[chainlet_artifact.name].port,
)

logging.info(f"Pushed `{chainlet_artifact.display_name}`")
logging.info(
f"Pushed Chainlet `{chainlet_artifact.display_name}` as docker container."
)
logging.debug(
f"Internal model endpoint: `{chainlet_to_predict_url[chainlet_artifact.name]}`"
)
Expand Down Expand Up @@ -457,7 +457,7 @@ def __init__(
source: pathlib.Path,
entrypoint: Optional[str],
name: Optional[str],
remote: Optional[str],
remote: str,
console: "rich_console.Console",
error_console: "rich_console.Console",
show_stack_trace: bool,
Expand All @@ -467,10 +467,6 @@ def __init__(
self._console = console
self._error_console = error_console
self._show_stack_trace = show_stack_trace
if not remote:
remote = remote_cli.inquire_remote_name(
remote_factory.RemoteFactory.get_available_config_names()
)
self._remote_provider = cast(
b10_remote.BasetenRemote,
remote_factory.RemoteFactory.create(remote=remote),
Expand Down Expand Up @@ -682,7 +678,7 @@ def watch(
source: pathlib.Path,
entrypoint: Optional[str],
name: Optional[str],
remote: Optional[str],
remote: str,
console: "rich_console.Console",
error_console: "rich_console.Console",
show_stack_trace: bool,
Expand Down
9 changes: 8 additions & 1 deletion truss/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Optional, cast
from typing import TYPE_CHECKING, Optional, Type, cast

if TYPE_CHECKING:
from rich import progress

from truss.api import definitions
from truss.remote.baseten.service import BasetenService
Expand Down Expand Up @@ -57,6 +60,7 @@ def push(
trusted: bool = False,
deployment_name: Optional[str] = None,
environment: Optional[str] = None,
progress_bar: Optional[Type["progress.Progress"]] = None,
) -> definitions.ModelDeployment:
"""
Pushes a Truss to Baseten.
Expand All @@ -76,6 +80,8 @@ def push(
deployment_name: Name of the deployment created by the push. Can only be
used in combination with `publish` or `promote`. Deployment name must
only contain alphanumeric, ’.’, ’-’ or ’_’ characters.
environment: Name of stable environment on baseten.
progress_bar: Optional `rich.progress.Progress` if output is desired.
Returns:
The newly created ModelDeployment.
Expand Down Expand Up @@ -111,6 +117,7 @@ def push(
preserve_previous_prod_deployment=preserve_previous_production_deployment,
deployment_name=deployment_name,
environment=environment,
progress_bar=progress_bar,
) # type: ignore

return definitions.ModelDeployment(cast(BasetenService, service))
3 changes: 0 additions & 3 deletions truss/cli/__init__.py

This file was deleted.

Loading

0 comments on commit 79e1841

Please sign in to comment.