Skip to content

Commit

Permalink
Remove postprocess from lock (#700)
Browse files Browse the repository at this point in the history
* Add test.

* Move postprocess out of the predict lock.

* Bump version.

* Bump pyproject.
  • Loading branch information
squidarth authored Oct 17, 2023
1 parent 8d33e9f commit 3eddce1
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.7.12"
version = "0.7.13rc1"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion truss/templates/server/common/truss_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ async def predict(

# In the case that the model returns a Generator object, return a
# StreamingResponse instead.
if isinstance(response, AsyncGenerator):
if isinstance(response, (AsyncGenerator, Generator)):
# media_type in StreamingResponse sets the Content-Type header
return StreamingResponse(response, media_type="application/octet-stream")

Expand Down
15 changes: 11 additions & 4 deletions truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ async def postprocess(
if inspect.isasyncgenfunction(
self._model.postprocess
) or inspect.isgeneratorfunction(self._model.postprocess):
return self._model.postprocess(response, headers)
return self._model.postprocess(response)

if inspect.iscoroutinefunction(self._model.postprocess):
return await _intercept_exceptions_async(self._model.postprocess)(response)
Expand Down Expand Up @@ -264,10 +264,16 @@ async def __call__(
async with deferred_semaphore(self._predict_semaphore) as semaphore_manager:
response = await self.predict(payload, headers)

processed_response = await self.postprocess(response)

# Streaming cases
if inspect.isgenerator(response) or inspect.isasyncgen(response):
if hasattr(self._model, "postprocess"):
logging.warning(
"Predict returned a streaming response, while a postprocess is defined."
"Note that in this case, the postprocess will run within the predict lock."
)

response = await self.postprocess(response)

async_generator = _force_async_generator(response)

if headers and headers.get("accept") == "application/json":
Expand Down Expand Up @@ -309,7 +315,8 @@ async def _response_generator():

return _response_generator()

return processed_response
processed_response = await self.postprocess(response)
return processed_response


class ResponseChunk:
Expand Down
158 changes: 158 additions & 0 deletions truss/tests/test_model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,164 @@ def predict(self, request):
assert "Internal Server Error" in response.json()["error"]


@pytest.mark.integration
def test_postprocess_with_streaming_predict():
"""
Test a Truss that has streaming response from both predict and postprocess.
In this case, the postprocess step continues to happen within the predict lock,
so we don't bother testing the lock scenario, just the behavior that the postprocess
function is applied.
"""
model = """
import time
class Model:
def postprocess(self, response):
for item in response:
time.sleep(1)
yield item + " modified"
def predict(self, request):
for i in range(2):
yield str(i)
"""

config = "model_name: error-truss"
with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
truss_dir = Path(tmp_work_dir, "truss")

_create_truss(truss_dir, config, textwrap.dedent(model))

tr = TrussHandle(truss_dir)
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
truss_server_addr = "http://localhost:8090"
full_url = f"{truss_server_addr}/v1/models/model:predict"
response = requests.post(full_url, json={}, stream=True)
# Note that the postprocess function is applied to the
# streamed response.
assert response.content == b"0 modified1 modified"


@pytest.mark.integration
def test_streaming_postprocess():
"""
Tests a Truss where predict returns non-streaming, but postprocess is streamd, and
ensures that the postprocess step does not happen within the predict lock. To do this,
we sleep for two seconds during the postprocess streaming process, and fire off two
requests with a total timeout of 3 seconds, ensuring that if they were serialized
the test would fail.
"""
model = """
import time
class Model:
def postprocess(self, response):
for item in response:
time.sleep(1)
yield item + " modified"
def predict(self, request):
return ["0", "1"]
"""

config = "model_name: error-truss"
with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
truss_dir = Path(tmp_work_dir, "truss")

_create_truss(truss_dir, config, textwrap.dedent(model))

tr = TrussHandle(truss_dir)
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
truss_server_addr = "http://localhost:8090"
full_url = f"{truss_server_addr}/v1/models/model:predict"

def make_request(delay: int):
# For streamed responses, requests does not start receiving content from server until
# `iter_content` is called, so we must call this in order to get an actual timeout.
time.sleep(delay)
response = requests.post(full_url, json={}, stream=True)

assert response.status_code == 200
assert response.content == b"0 modified1 modified"

with ThreadPoolExecutor() as e:
# We use concurrent.futures.wait instead of the timeout property
# on requests, since requests timeout property has a complex interaction
# with streaming.
first_request = e.submit(make_request, 0)
second_request = e.submit(make_request, 0.2)
futures = [first_request, second_request]
done, _ = concurrent.futures.wait(futures, timeout=3)
# Ensure that both requests complete within the 3 second timeout,
# as the predict lock is not held through the postprocess step
assert first_request in done
assert second_request in done

for future in done:
# Ensure that both futures completed without error
future.result()


@pytest.mark.integration
def test_postprocess():
"""
Tests a Truss that has a postprocess step defined, and ensures that the
postprocess does not happen within the predict lock. To do this, we sleep
for two seconds during the postprocess, and fire off two requests with a total
timeout of 3 seconds, ensureing that if they were serialized the test would fail.
"""

model = """
import time
class Model:
def postprocess(self, response):
updated_items = []
for item in response:
time.sleep(1)
updated_items.append(item + " modified")
return updated_items
def predict(self, request):
return ["0", "1"]
"""

config = "model_name: error-truss"
with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
truss_dir = Path(tmp_work_dir, "truss")

_create_truss(truss_dir, config, textwrap.dedent(model))

tr = TrussHandle(truss_dir)
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
truss_server_addr = "http://localhost:8090"
full_url = f"{truss_server_addr}/v1/models/model:predict"

def make_request(delay: int):
time.sleep(delay)
response = requests.post(full_url, json={})
assert response.status_code == 200
assert response.json() == ["0 modified", "1 modified"]

with ThreadPoolExecutor() as e:
# We use concurrent.futures.wait instead of the timeout property
# on requests, since requests timeout property has a complex interaction
# with streaming.
first_request = e.submit(make_request, 0)
second_request = e.submit(make_request, 0.2)
futures = [first_request, second_request]
done, _ = concurrent.futures.wait(futures, timeout=3)
# Ensure that both requests complete within the 3 second timeout,
# as the predict lock is not held through the postprocess step
assert first_request in done
assert second_request in done

for future in done:
# Ensure that both futures completed without error
future.result()


@pytest.mark.integration
def test_truss_with_errors():
model = """
Expand Down

0 comments on commit 3eddce1

Please sign in to comment.