diff --git a/truss/templates/server/model_wrapper.py b/truss/templates/server/model_wrapper.py index cff58f2ca..579cc94e2 100644 --- a/truss/templates/server/model_wrapper.py +++ b/truss/templates/server/model_wrapper.py @@ -12,6 +12,7 @@ import time import weakref from contextlib import asynccontextmanager +from datetime import datetime, timezone from enum import Enum from functools import cached_property from multiprocessing import Lock @@ -122,6 +123,7 @@ def _is_request_type(obj: Any) -> bool: class ArgConfig(enum.Enum): + NONE = enum.auto() INPUTS_ONLY = enum.auto() REQUEST_ONLY = enum.auto() INPUTS_AND_REQUEST = enum.auto() @@ -134,11 +136,12 @@ def from_signature( ) -> "ArgConfig": parameters = list(signature.parameters.values()) - if len(parameters) == 1: + if len(parameters) == 0: + return cls.NONE + elif len(parameters) == 1: if _is_request_type(parameters[0].annotation): return cls.REQUEST_ONLY return cls.INPUTS_ONLY - elif len(parameters) == 2: # First arg can be whatever, except request. Second arg must be request. param1, param2 = parameters @@ -204,6 +207,7 @@ class ModelDescriptor: postprocess: Optional[MethodDescriptor] truss_schema: Optional[TrussSchema] setup_environment: Optional[MethodDescriptor] + is_ready: Optional[MethodDescriptor] @cached_property def skip_input_parsing(self) -> bool: @@ -263,12 +267,18 @@ def from_model(cls, model) -> "ModelDescriptor": else: setup_environment = None + if hasattr(model, "is_ready"): + is_ready = MethodDescriptor.from_method(model.is_ready, "is_ready") + else: + is_ready = None + return cls( preprocess=preprocess, predict=predict, postprocess=postprocess, truss_schema=TrussSchema.from_signature(parameters, return_annotation), setup_environment=setup_environment, + is_ready=is_ready, ) @@ -282,6 +292,7 @@ class ModelWrapper: _predict_semaphore: Semaphore _poll_for_environment_updates_task: Optional[asyncio.Task] _environment: Optional[dict] + _first_health_check_failure: Optional[datetime] class Status(Enum): NOT_READY = 0 @@ -311,6 +322,7 @@ def __init__(self, config: Dict, tracer: sdk_trace.Tracer): ) self._poll_for_environment_updates_task = None self._environment = None + self._first_health_check_failure = None @property def _model(self) -> Any: @@ -528,6 +540,40 @@ async def poll_for_environment_updates(self) -> None: exc_info=errors.filter_traceback(self._model_file_name), ) + async def is_ready(self) -> Optional[bool]: + descriptor = self.model_descriptor.is_ready + is_ready: Optional[bool] = None + if not descriptor or self.load_failed: + return is_ready + try: + if descriptor.is_async: + is_ready = await self._model.is_ready() + else: + # Offload sync functions to thread, to not block event loop. + is_ready = await to_thread.run_sync(self._model.is_ready) + except Exception as e: + is_ready = False + self._logger.exception( + "Exception while checking if model is ready: " + str(e), + exc_info=errors.filter_traceback(self._model_file_name), + ) + if not is_ready: + if self._first_health_check_failure is None: + self._first_health_check_failure = datetime.now(timezone.utc) + self._logger.warning("Model is not ready. Health checks failing.") + else: + seconds_since_first_failure = round( + ( + datetime.now(timezone.utc) - self._first_health_check_failure + ).total_seconds() + ) + self._logger.warning( + f"Model is not ready. Health checks failing for {seconds_since_first_failure} seconds." + ) + elif is_ready: + self._first_health_check_failure = None + return is_ready + async def preprocess( self, inputs: InputType, diff --git a/truss/templates/server/truss_server.py b/truss/templates/server/truss_server.py index 8e30a884b..9f34b56ea 100644 --- a/truss/templates/server/truss_server.py +++ b/truss/templates/server/truss_server.py @@ -79,7 +79,12 @@ def check_healthy(model: ModelWrapper): raise errors.ModelNotReady(model.name) async def model_ready(self, model_name: str) -> Dict[str, Union[str, bool]]: - self.check_healthy(self._safe_lookup_model(model_name)) + model: ModelWrapper = self._safe_lookup_model(model_name) + is_ready = await model.is_ready() + if is_ready is None: + self.check_healthy(model) + elif not is_ready: + raise errors.ModelNotReady(model.name) return {} @@ -152,7 +157,7 @@ async def predict( model: ModelWrapper = self._safe_lookup_model(model_name) - self.check_healthy(model) + self.check_healthy(model) # Do we still need this check? trace_ctx = otel_propagate.extract(request.headers) or None # This is the top-level span in the truss-server, so we set the context here. # Nested spans "inherit" context automatically. diff --git a/truss/tests/test_model_inference.py b/truss/tests/test_model_inference.py index 1ee0b3015..744c6bdf8 100644 --- a/truss/tests/test_model_inference.py +++ b/truss/tests/test_model_inference.py @@ -969,6 +969,108 @@ def predict(self, model_input): ) +@pytest.mark.integration +def test_is_ready(): + model = """ + import time + class Model: + def load(self) -> bool: + raise Exception("not loaded") + + def is_ready(self) -> bool: + return True + + def predict(self, model_input): + return model_input + """ + with ensure_kill_all(), _temp_truss(model, "") as tr: + container = tr.docker_run( + local_port=8090, detach=True, wait_for_server_ready=False + ) + + truss_server_addr = "http://localhost:8090" + for _ in range(5): + time.sleep(1) + ready = requests.get(f"{truss_server_addr}/v1/models/model") + if ready.status_code == 503: + break + assert ready.status_code == 200 + assert ready.status_code == 503 + diff = container.diff() + assert "/root/inference_server_crashed.txt" in diff + assert diff["/root/inference_server_crashed.txt"] == "A" + + model = """ + class Model: + def is_ready(self) -> bool: + raise Exception("not ready") + + def predict(self, model_input): + return model_input + """ + with ensure_kill_all(), _temp_truss(model, "") as tr: + container = tr.docker_run( + local_port=8090, detach=True, wait_for_server_ready=False + ) + + # Sleep a few seconds to get the server some time to wake up + time.sleep(10) + + truss_server_addr = "http://localhost:8090" + + ready = requests.get(f"{truss_server_addr}/v1/models/model") + assert ready.status_code == 503 + assert ( + "Exception while checking if model is ready: not ready" in container.logs() + ) + assert "Model is not ready. Health checks failing." in container.logs() + + model = """ + class Model: + def is_ready(self) -> bool: + return False + + def predict(self, model_input): + return model_input + """ + with ensure_kill_all(), _temp_truss(model, "") as tr: + container = tr.docker_run( + local_port=8090, detach=True, wait_for_server_ready=False + ) + + # Sleep a few seconds to get the server some time to wake up + time.sleep(10) + + truss_server_addr = "http://localhost:8090" + + ready = requests.get(f"{truss_server_addr}/v1/models/model") + assert ready.status_code == 503 + assert "Model is not ready. Health checks failing." in container.logs() + time.sleep(5) + ready = requests.get(f"{truss_server_addr}/v1/models/model") + assert ready.status_code == 503 + assert ( + "Model is not ready. Health checks failing for 5 seconds." + in container.logs() + ) + + model = """ + class Model: + def is_ready(self) -> bool: + return True + + def predict(self, model_input): + return model_input + """ + with ensure_kill_all(), _temp_truss(model, "") as tr: + _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True) + + truss_server_addr = "http://localhost:8090" + + ready = requests.get(f"{truss_server_addr}/v1/models/model") + assert ready.status_code == 200 + + def _patch_termination_timeout(container: Container, seconds: int, truss_container_fs): app_path = truss_container_fs / "app" sys.path.append(str(app_path))