diff --git a/pyproject.toml b/pyproject.toml index f5131fbf8..552bd66ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.1" +version = "0.9.2" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss/templates/server/common/schema.py b/truss/templates/server/common/schema.py index 20fce4a09..1e7d3533e 100644 --- a/truss/templates/server/common/schema.py +++ b/truss/templates/server/common/schema.py @@ -1,4 +1,3 @@ -from inspect import Signature from types import MappingProxyType from typing import ( Any, @@ -68,14 +67,20 @@ def _parse_input_type(input_parameters: MappingProxyType) -> Optional[type]: input_type = parameter_types[0].annotation - if ( - input_type == Signature.empty - or not isinstance(input_type, type) - or not issubclass(input_type, BaseModel) - ): - return None + if _annotation_is_pydantic_model(input_type): + return input_type + + return None + - return input_type +def _annotation_is_pydantic_model(annotation: Any) -> bool: + # This try/except clause a workaround for the fact that issubclass() + # does not work with generic types (ie: list, dict), + # and raises a TypeError + try: + return issubclass(annotation, BaseModel) + except TypeError: + return False def _parse_output_type(output_annotation: Any) -> Optional[OutputType]: @@ -90,7 +95,7 @@ def _parse_output_type(output_annotation: Any) -> Optional[OutputType]: If the output_annotation does not match one of these cases, returns None """ - if isinstance(output_annotation, type) and issubclass(output_annotation, BaseModel): + if _annotation_is_pydantic_model(output_annotation): return OutputType(type=output_annotation, supports_streaming=False) if _is_generator_type(output_annotation): diff --git a/truss/templates/server/model_wrapper.py b/truss/templates/server/model_wrapper.py index 740b9702a..422e03f6c 100644 --- a/truss/templates/server/model_wrapper.py +++ b/truss/templates/server/model_wrapper.py @@ -275,6 +275,11 @@ async def __call__( Generator: In case of streaming response """ + # The streaming read timeout is the amount of time in between streamed chunks before a timeout is triggered + streaming_read_timeout = self._config.get("runtime", {}).get( + "streaming_read_timeout", STREAMING_RESPONSE_QUEUE_READ_TIMEOUT_SECS + ) + if self.truss_schema is not None: try: body = self.truss_schema.input_type(**body) @@ -328,11 +333,12 @@ async def __call__( task.add_done_callback(lambda _: semaphore_release_function()) task.add_done_callback(self._background_tasks.discard) + # The gap between responses in a stream must be < streaming_read_timeout async def _response_generator(): while True: chunk = await asyncio.wait_for( response_queue.get(), - timeout=STREAMING_RESPONSE_QUEUE_READ_TIMEOUT_SECS, + timeout=streaming_read_timeout, ) if chunk is None: return diff --git a/truss/test_data/test_streaming_read_timeout/config.yaml b/truss/test_data/test_streaming_read_timeout/config.yaml new file mode 100644 index 000000000..67af376a4 --- /dev/null +++ b/truss/test_data/test_streaming_read_timeout/config.yaml @@ -0,0 +1,12 @@ +model_metadata: {} +model_name: null +model_type: custom +python_version: py39 +requirements: [] +resources: + accelerator: null + cpu: 500m + memory: 512Mi + use_gpu: false +runtime: + streaming_read_timeout: 1 diff --git a/truss/test_data/test_streaming_read_timeout/model/model.py b/truss/test_data/test_streaming_read_timeout/model/model.py new file mode 100644 index 000000000..df8e46f57 --- /dev/null +++ b/truss/test_data/test_streaming_read_timeout/model/model.py @@ -0,0 +1,24 @@ +import time +from typing import Any, Dict, List + + +class Model: + def __init__(self, **kwargs) -> None: + self._data_dir = kwargs["data_dir"] + self._config = kwargs["config"] + self._secrets = kwargs["secrets"] + self._model = None + + def load(self): + # Load model here and assign to self._model. + pass + + def predict(self, model_input: Any) -> Dict[str, List]: + # Invoke model on model_input and calculate predictions here. + def inner(): + time.sleep(2) + for i in range(5): + time.sleep(3) + yield str(i) + + return inner() diff --git a/truss/tests/templates/server/test_model_wrapper.py b/truss/tests/templates/server/test_model_wrapper.py index 79d851ee2..d0a4692b1 100644 --- a/truss/tests/templates/server/test_model_wrapper.py +++ b/truss/tests/templates/server/test_model_wrapper.py @@ -68,3 +68,20 @@ def test_model_wrapper_load_error_more_than_allowed(app_path, helpers): # Allow load thread to execute time.sleep(1) assert model_wrapper.load_failed() + + +@pytest.mark.integration +async def test_model_wrapper_streaming_timeout(app_path): + if "model_wrapper" in sys.modules: + model_wrapper_module = sys.modules["model_wrapper"] + importlib.reload(model_wrapper_module) + else: + model_wrapper_module = importlib.import_module("model_wrapper") + model_wraper_class = getattr(model_wrapper_module, "ModelWrapper") + + # Create an instance of ModelWrapper with streaming_read_timeout set to 5 seconds + config = yaml.safe_load((app_path / "config.yaml").read_text()) + config["runtime"]["streaming_read_timeout"] = 5 + model_wrapper = model_wraper_class(config) + model_wrapper.load() + assert model_wrapper._config.get("runtime").get("streaming_read_timeout") == 5 diff --git a/truss/tests/templates/server/test_schema.py b/truss/tests/templates/server/test_schema.py index 6335772aa..643d89125 100644 --- a/truss/tests/templates/server/test_schema.py +++ b/truss/tests/templates/server/test_schema.py @@ -14,6 +14,21 @@ class ModelOutput(BaseModel): output: str +def test_truss_schema_pydantic_empty_annotations(): + class Model: + def predict(self, request): + return "hello" + + model = Model() + + input_signature = inspect.signature(model.predict).parameters + output_signature = inspect.signature(model.predict).return_annotation + + schema = TrussSchema.from_signature(input_signature, output_signature) + + assert schema is None + + def test_truss_schema_pydantic_input_and_output(): class Model: def predict(self, request: ModelInput) -> ModelOutput: @@ -61,6 +76,36 @@ def predict(self, request: ModelInput) -> str: assert schema is None +def test_truss_schema_list_types(): + class Model: + def predict(self, request: list[str]) -> list[str]: + return ["foo", "bar"] + + model = Model() + + input_signature = inspect.signature(model.predict).parameters + output_signature = inspect.signature(model.predict).return_annotation + + schema = TrussSchema.from_signature(input_signature, output_signature) + + assert schema is None + + +def test_truss_schema_dict_types(): + class Model: + def predict(self, request: dict[str, str]) -> dict[str, str]: + return {"foo": "bar"} + + model = Model() + + input_signature = inspect.signature(model.predict).parameters + output_signature = inspect.signature(model.predict).return_annotation + + schema = TrussSchema.from_signature(input_signature, output_signature) + + assert schema is None + + def test_truss_schema_async(): class Model: async def predict(self, request: ModelInput) -> Awaitable[ModelOutput]: diff --git a/truss/tests/test_model_inference.py b/truss/tests/test_model_inference.py index ce719fafb..a06f31c67 100644 --- a/truss/tests/test_model_inference.py +++ b/truss/tests/test_model_inference.py @@ -20,18 +20,22 @@ logger = logging.getLogger(__name__) +DEFAULT_LOG_ERROR = "Internal Server Error" -def _log_contains_error(line: dict, error: str): + +def _log_contains_error(line: dict, error: str, message: str): return ( line["levelname"] == "ERROR" - and line["message"] == "Internal Server Error" + and line["message"] == message and error in line["exc_info"] ) -def assert_logs_contain_error(logs: str, error: str): +def assert_logs_contain_error(logs: str, error: str, message=DEFAULT_LOG_ERROR): loglines = logs.splitlines() - assert any(_log_contains_error(json.loads(line), error) for line in loglines) + assert any( + _log_contains_error(json.loads(line), error, message) for line in loglines + ) class PropagatingThread(Thread): @@ -232,6 +236,36 @@ def test_async_streaming(): assert predict_non_stream_response.json() == "01234" +@pytest.mark.integration +def test_async_streaming_timeout(): + with ensure_kill_all(): + truss_root = Path(__file__).parent.parent.parent.resolve() / "truss" + + truss_dir = truss_root / "test_data" / "test_streaming_read_timeout" + + tr = TrussHandle(truss_dir) + + container = tr.docker_run( + local_port=8090, detach=True, wait_for_server_ready=True + ) + truss_server_addr = "http://localhost:8090" + predict_url = f"{truss_server_addr}/v1/models/model:predict" + + # ChunkedEncodingError is raised when the chunk does not get processed due to streaming read timeout + with pytest.raises(requests.exceptions.ChunkedEncodingError): + response = requests.post(predict_url, json={}, stream=True) + + for chunk in response.iter_content(): + pass + + # Check to ensure the Timeout error is in the container logs + assert_logs_contain_error( + container.logs(), + error="raise exceptions.TimeoutError()", + message="Exception in ASGI application\n", + ) + + @pytest.mark.integration def test_streaming_with_error(): with ensure_kill_all(): diff --git a/truss/tests/test_model_schema.py b/truss/tests/test_model_schema.py index bcc5a6bc4..697fb980e 100644 --- a/truss/tests/test_model_schema.py +++ b/truss/tests/test_model_schema.py @@ -33,7 +33,32 @@ def test_truss_with_no_annotations(): schema_response = requests.get(SCHEMA_URL) assert schema_response.status_code == 404 - schema_response.json()["error"] == "No schema found" + assert schema_response.json()["error"] == "No schema found" + + +@pytest.mark.integration +def test_truss_with_non_pydantic_annotations(): + truss_non_pydantic_annotations = """ +class Model: + def predict(self, request: str) -> list[str]: + return ["hello"] +""" + + with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir: + truss_dir = Path(tmp_work_dir, "truss") + + create_truss(truss_dir, DEFAULT_CONFIG, truss_non_pydantic_annotations) + + tr = TrussHandle(truss_dir) + _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True) + + response = requests.post(INFERENCE_URL, json={"prompt": "value"}) + assert response.json() == ["hello"] + + schema_response = requests.get(SCHEMA_URL) + assert schema_response.status_code == 404 + + assert schema_response.json()["error"] == "No schema found" @pytest.mark.integration diff --git a/truss/truss_config.py b/truss/truss_config.py index d316a3bb4..1748b3bfd 100644 --- a/truss/truss_config.py +++ b/truss/truss_config.py @@ -31,6 +31,7 @@ DEFAULT_SPEC_VERSION = "2.0" DEFAULT_PREDICT_CONCURRENCY = 1 DEFAULT_NUM_WORKERS = 1 +DEFAULT_STREAMING_RESPONSE_READ_TIMEOUT = 60 DEFAULT_CPU = "1" DEFAULT_MEMORY = "2Gi" @@ -140,21 +141,27 @@ def to_list(self, verbose=False) -> List[Dict[str, str]]: class Runtime: predict_concurrency: int = DEFAULT_PREDICT_CONCURRENCY num_workers: int = DEFAULT_NUM_WORKERS + streaming_read_timeout: int = DEFAULT_STREAMING_RESPONSE_READ_TIMEOUT @staticmethod def from_dict(d): predict_concurrency = d.get("predict_concurrency", DEFAULT_PREDICT_CONCURRENCY) num_workers = d.get("num_workers", DEFAULT_NUM_WORKERS) + streaming_read_timeout = d.get( + "streaming_read_timeout", DEFAULT_STREAMING_RESPONSE_READ_TIMEOUT + ) return Runtime( predict_concurrency=predict_concurrency, num_workers=num_workers, + streaming_read_timeout=streaming_read_timeout, ) def to_dict(self): return { "predict_concurrency": self.predict_concurrency, "num_workers": self.num_workers, + "streaming_read_timeout": self.streaming_read_timeout, }