Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inject context for batching loops #139

Merged
merged 5 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,56 @@ if __name__ == "__main__":
server.run(port=8000)
```

When using `context` with dynamic batching, the `predict` method will receive a list of `context` items,
one for each request in the batch.
The `decode_request` and `encode_response` methods will process the `context` for each individual request as usual.


```python
import litserve as ls

class SimpleBatchedAPI(ls.examples.SimpleBatchedAPI):
def predict(self, x_batch, context):
# context contains a list of dictionary
for c, x in zip(context, x_batch):
c["input"] = x
return self.model(x_batch)

def encode_response(self, output, context):
input = context["input"]
return {"output": input}

if __name__=="__main__":
api = SimpleBatchedAPI()
server = ls.LitServer(api)
server.run(port=8000)
```

When you enable batching with streaming, your `encode_response` method will receive a list of `context` items,
one for each request in the batch.

```python
import litserve as ls

class BatchedStreamingAPI(ls.examples.SimpleBatchedAPI):
def predict(self, x_batch, context):
# context contains a list of dictionary
for c, x in zip(context, x_batch):
c["input"] = x
yield self.model(x_batch)

def encode_response(self, output_stream, context):
for _ in output_stream:
# context contains a list of dictionary
yield [{"output": ctx["input"]} for ctx in context]

if __name__=="__main__":
api = BatchedStreamingAPI()
server = ls.LitServer(api, max_batch_size=2, batch_timeout=0.01, stream=True)
server.run(port=8000)
```


</details>

<details>
Expand Down
44 changes: 35 additions & 9 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import time
import os
import shutil
from typing import Sequence, Optional, Union
from typing import Sequence, Optional, Union, List
import uuid

from fastapi import FastAPI, Depends, HTTPException, BackgroundTasks, Request, Response
Expand All @@ -50,7 +50,7 @@
LONG_TIMEOUT = 100


def _inject_context(context: dict, func, *args, **kwargs):
def _inject_context(context: Union[List[dict], dict], func, *args, **kwargs):
sig = inspect.signature(func)
if "context" in sig.parameters:
return func(*args, **kwargs, context=context)
Expand Down Expand Up @@ -96,12 +96,24 @@ def run_batched_loop(lit_api, lit_spec, request_queue: Queue, request_buffer, ma
inputs, pipes = zip(*batches)

try:
x = [lit_api.decode_request(input) for input in inputs]
contexts = [{}] * len(inputs)
if hasattr(lit_spec, "populate_context"):
for input, context in (inputs, contexts):
lit_spec.populate_context(context, input)

x = [
_inject_context(
context,
lit_api.decode_request,
input,
)
for input, context in zip(inputs, contexts)
]
x = lit_api.batch(x)
y = lit_api.predict(x)
y = _inject_context(contexts, lit_api.predict, x)
lantiga marked this conversation as resolved.
Show resolved Hide resolved
outputs = lit_api.unbatch(y)
for y, pipe_s in zip(outputs, pipes):
y_enc = lit_api.encode_response(y)
for y, pipe_s, context in zip(outputs, pipes, contexts):
y_enc = _inject_context(context, lit_api.encode_response, y)

with contextlib.suppress(BrokenPipeError):
pipe_s.send((y_enc, LitAPIStatus.OK))
Expand All @@ -128,6 +140,8 @@ def run_single_loop(lit_api, lit_spec, request_queue: Queue, request_buffer):
continue
try:
context = {}
if hasattr(lit_spec, "populate_context"):
lit_spec.populate_context(context, x_enc)
x = _inject_context(
context,
lit_api.decode_request,
Expand Down Expand Up @@ -218,11 +232,23 @@ def run_batched_streaming_loop(lit_api, lit_spec, request_queue: Queue, request_
inputs, pipes = zip(*batches)

try:
x = [lit_api.decode_request(input) for input in inputs]
contexts = [{}] * len(inputs)
if hasattr(lit_spec, "populate_context"):
for input, context in (inputs, contexts):
lit_spec.populate_context(context, input)

x = [
_inject_context(
context,
lit_api.decode_request,
input,
)
for input, context in zip(inputs, contexts)
]
x = lit_api.batch(x)
y_iter = lit_api.predict(x)
y_iter = _inject_context(contexts, lit_api.predict, x)
unbatched_iter = lit_api.unbatch(y_iter)
y_enc_iter = lit_api.encode_response(unbatched_iter)
y_enc_iter = _inject_context(contexts, lit_api.encode_response, unbatched_iter)

# y_enc_iter -> [[response-1, response-2], [response-1, response-2]]
for y_batch in y_enc_iter:
Expand Down
40 changes: 37 additions & 3 deletions tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,29 @@ def encode_response(self, output, context):
return {"output": input}


class PredictError(ls.examples.SimpleLitAPI):
class IndentityBatchedAPI(ls.examples.SimpleBatchedAPI):
def predict(self, x_batch, context):
for c, x in zip(context, x_batch):
c["input"] = x
return self.model(x_batch)

def encode_response(self, output, context):
input = context["input"]
return {"output": input}


class IndentityBatchedStreamingAPI(ls.examples.SimpleBatchedAPI):
def predict(self, x_batch, context):
for c, x in zip(context, x_batch):
c["input"] = x
yield self.model(x_batch)

def encode_response(self, output_stream, context):
for _ in output_stream:
yield [{"output": ctx["input"]} for ctx in context]


class PredictErrorAPI(ls.examples.SimpleLitAPI):
def predict(self, x, y, context):
context["input"] = x
return self.model(x)
Expand All @@ -360,14 +382,26 @@ def dummy_load_and_raise(resp):

mocked_load_and_raise.side_effect = dummy_load_and_raise

# Test context injection with single loop
api = IndentityAPI()
server = LitServer(api)
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
resp = await ac.post("/predict", json={"input": 5.0}, timeout=10)
assert resp.json()["output"] == 5.0, "output from Identity server must be same as input"

api = PredictError()
server = LitServer(api)
# Test context injection with batched loop
server = LitServer(IndentityBatchedAPI(), max_batch_size=2, batch_timeout=0.01)
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
resp = await ac.post("/predict", json={"input": 5.0}, timeout=10)
assert resp.json()["output"] == 5.0, "output from Identity server must be same as input"

# Test context injection with batched streaming loop
server = LitServer(IndentityBatchedStreamingAPI(), max_batch_size=2, batch_timeout=0.01, stream=True)
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
resp = await ac.post("/predict", json={"input": 5.0}, timeout=10)
assert resp.json()["output"] == 5.0, "output from Identity server must be same as input"

server = LitServer(PredictErrorAPI())
with pytest.raises(TypeError, match=re.escape("predict() missing 1 required positional argument: 'y'")):
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
resp = await ac.post("/predict", json={"input": 5.0}, timeout=10)
Expand Down
Loading