diff --git a/pyproject.toml b/pyproject.toml index 2515bee1d..9901ebef2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.7.12" +version = "0.7.13rc1" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss/templates/server/common/truss_server.py b/truss/templates/server/common/truss_server.py index 62584f428..7bae4b55f 100644 --- a/truss/templates/server/common/truss_server.py +++ b/truss/templates/server/common/truss_server.py @@ -142,7 +142,7 @@ async def predict( # In the case that the model returns a Generator object, return a # StreamingResponse instead. - if isinstance(response, AsyncGenerator): + if isinstance(response, (AsyncGenerator, Generator)): # media_type in StreamingResponse sets the Content-Type header return StreamingResponse(response, media_type="application/octet-stream") diff --git a/truss/templates/server/model_wrapper.py b/truss/templates/server/model_wrapper.py index 89de7e8b8..f6132d90f 100644 --- a/truss/templates/server/model_wrapper.py +++ b/truss/templates/server/model_wrapper.py @@ -225,7 +225,7 @@ async def postprocess( if inspect.isasyncgenfunction( self._model.postprocess ) or inspect.isgeneratorfunction(self._model.postprocess): - return self._model.postprocess(response, headers) + return self._model.postprocess(response) if inspect.iscoroutinefunction(self._model.postprocess): return await _intercept_exceptions_async(self._model.postprocess)(response) @@ -264,10 +264,16 @@ async def __call__( async with deferred_semaphore(self._predict_semaphore) as semaphore_manager: response = await self.predict(payload, headers) - processed_response = await self.postprocess(response) - # Streaming cases if inspect.isgenerator(response) or inspect.isasyncgen(response): + if hasattr(self._model, "postprocess"): + logging.warning( + "Predict returned a streaming response, while a postprocess is defined." + "Note that in this case, the postprocess will run within the predict lock." + ) + + response = await self.postprocess(response) + async_generator = _force_async_generator(response) if headers and headers.get("accept") == "application/json": @@ -309,7 +315,8 @@ async def _response_generator(): return _response_generator() - return processed_response + processed_response = await self.postprocess(response) + return processed_response class ResponseChunk: diff --git a/truss/tests/test_model_inference.py b/truss/tests/test_model_inference.py index 1375a4e2f..454472ec0 100644 --- a/truss/tests/test_model_inference.py +++ b/truss/tests/test_model_inference.py @@ -397,6 +397,164 @@ def predict(self, request): assert "Internal Server Error" in response.json()["error"] +@pytest.mark.integration +def test_postprocess_with_streaming_predict(): + """ + Test a Truss that has streaming response from both predict and postprocess. + In this case, the postprocess step continues to happen within the predict lock, + so we don't bother testing the lock scenario, just the behavior that the postprocess + function is applied. + """ + model = """ + import time + + class Model: + def postprocess(self, response): + for item in response: + time.sleep(1) + yield item + " modified" + + def predict(self, request): + for i in range(2): + yield str(i) + """ + + config = "model_name: error-truss" + with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir: + truss_dir = Path(tmp_work_dir, "truss") + + _create_truss(truss_dir, config, textwrap.dedent(model)) + + tr = TrussHandle(truss_dir) + _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True) + truss_server_addr = "http://localhost:8090" + full_url = f"{truss_server_addr}/v1/models/model:predict" + response = requests.post(full_url, json={}, stream=True) + # Note that the postprocess function is applied to the + # streamed response. + assert response.content == b"0 modified1 modified" + + +@pytest.mark.integration +def test_streaming_postprocess(): + """ + Tests a Truss where predict returns non-streaming, but postprocess is streamd, and + ensures that the postprocess step does not happen within the predict lock. To do this, + we sleep for two seconds during the postprocess streaming process, and fire off two + requests with a total timeout of 3 seconds, ensuring that if they were serialized + the test would fail. + """ + model = """ + import time + + class Model: + def postprocess(self, response): + for item in response: + time.sleep(1) + yield item + " modified" + + def predict(self, request): + return ["0", "1"] + """ + + config = "model_name: error-truss" + with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir: + truss_dir = Path(tmp_work_dir, "truss") + + _create_truss(truss_dir, config, textwrap.dedent(model)) + + tr = TrussHandle(truss_dir) + _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True) + truss_server_addr = "http://localhost:8090" + full_url = f"{truss_server_addr}/v1/models/model:predict" + + def make_request(delay: int): + # For streamed responses, requests does not start receiving content from server until + # `iter_content` is called, so we must call this in order to get an actual timeout. + time.sleep(delay) + response = requests.post(full_url, json={}, stream=True) + + assert response.status_code == 200 + assert response.content == b"0 modified1 modified" + + with ThreadPoolExecutor() as e: + # We use concurrent.futures.wait instead of the timeout property + # on requests, since requests timeout property has a complex interaction + # with streaming. + first_request = e.submit(make_request, 0) + second_request = e.submit(make_request, 0.2) + futures = [first_request, second_request] + done, _ = concurrent.futures.wait(futures, timeout=3) + # Ensure that both requests complete within the 3 second timeout, + # as the predict lock is not held through the postprocess step + assert first_request in done + assert second_request in done + + for future in done: + # Ensure that both futures completed without error + future.result() + + +@pytest.mark.integration +def test_postprocess(): + """ + Tests a Truss that has a postprocess step defined, and ensures that the + postprocess does not happen within the predict lock. To do this, we sleep + for two seconds during the postprocess, and fire off two requests with a total + timeout of 3 seconds, ensureing that if they were serialized the test would fail. + """ + + model = """ + import time + + class Model: + def postprocess(self, response): + updated_items = [] + for item in response: + time.sleep(1) + updated_items.append(item + " modified") + return updated_items + + def predict(self, request): + return ["0", "1"] + + """ + + config = "model_name: error-truss" + with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir: + truss_dir = Path(tmp_work_dir, "truss") + + _create_truss(truss_dir, config, textwrap.dedent(model)) + + tr = TrussHandle(truss_dir) + _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True) + truss_server_addr = "http://localhost:8090" + full_url = f"{truss_server_addr}/v1/models/model:predict" + + def make_request(delay: int): + time.sleep(delay) + response = requests.post(full_url, json={}) + assert response.status_code == 200 + assert response.json() == ["0 modified", "1 modified"] + + with ThreadPoolExecutor() as e: + # We use concurrent.futures.wait instead of the timeout property + # on requests, since requests timeout property has a complex interaction + # with streaming. + first_request = e.submit(make_request, 0) + second_request = e.submit(make_request, 0.2) + futures = [first_request, second_request] + done, _ = concurrent.futures.wait(futures, timeout=3) + # Ensure that both requests complete within the 3 second timeout, + # as the predict lock is not held through the postprocess step + assert first_request in done + assert second_request in done + + for future in done: + # Ensure that both futures completed without error + future.result() + + @pytest.mark.integration def test_truss_with_errors(): model = """