Skip to content

Commit

Permalink
addt test
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya committed Dec 16, 2024
1 parent 1756c64 commit 0826d0f
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/litserve/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,8 @@ def has_finished(self, uid: str, token: str, max_sequence_length: int) -> bool:

def add_request(self, uid: str, request: Any, lit_api: LitAPI, lit_spec: Optional[LitSpec]) -> None:
"""Add a new sequence to active sequences and perform any action before prediction such as filling the cache."""
if hasattr(lit_api, "add_request"):
lit_api.add_request(uid, request)
decoded_request = lit_api.decode_request(request)
self.active_sequences[uid] = {"input": decoded_request, "current_length": 0, "generated_sequence": []}

Expand Down
65 changes: 65 additions & 0 deletions tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
from litserve import LitAPI
from litserve.callbacks import CallbackRunner
from litserve.loops import (
ContinuousBatchingLoop,
LitLoop,
Output,
_BaseLoop,
inference_worker,
notify_timed_out_requests,
Expand Down Expand Up @@ -556,3 +558,66 @@ def test_notify_timed_out_requests():
assert response_2[0] == "UUID-002"
assert isinstance(response_2[1][0], HTTPException)
assert response_2[1][1] == LitAPIStatus.ERROR


class ContinuousBatchingAPI(ls.LitAPI):
def setup(self, spec: Optional[LitSpec]):
self.model = {}

def add_request(self, uid: str, request):
self.model[uid] = {"outputs": list(range(5))}

def decode_request(self, input: str):
return input

def encode_response(self, output: str):
return {"output": output}

def step(self, prev_outputs: Optional[List[Output]]) -> List[Output]:
outputs = []
for k in self.model:
v = self.model[k]
if v["outputs"]:
o = v["outputs"].pop(0)
outputs.append(Output(k, o, LitAPIStatus.OK))
keys = list(self.model.keys())
for k in keys:
if k not in [o.uid for o in outputs]:
outputs.append(Output(k, "", LitAPIStatus.FINISH_STREAMING))
del self.model[k]
return outputs


@pytest.fixture
def continuous_batching_setup():
lit_api = ContinuousBatchingAPI()
lit_api.stream = True
lit_api.request_timeout = 0.1
lit_api.pre_setup(2, None)
lit_api.setup(None)
request_queue = Queue()
response_queues = [Queue()]
loop = ContinuousBatchingLoop()
return lit_api, loop, request_queue, response_queues


def test_continuous_batching_pre_setup(continuous_batching_setup):
lit_api, loop, request_queue, response_queues = continuous_batching_setup
request_queue.put((0, "UUID-001", time.monotonic(), {"input": "Hello"}))
loop.run(lit_api, None, "cpu", 0, request_queue, response_queues, 2, 0.1, True, {}, NOOP_CB_RUNNER)

results = []
for i in range(5):
response = response_queues[0].get()
uid, (response_data, status) = response
o = json.loads(response_data)["output"]
assert o == i
assert status == LitAPIStatus.OK
assert uid == "UUID-001"
results.append(o)
assert results == list(range(5)), "API must return a sequence of numbers from 0 to 4"
response = response_queues[0].get()
uid, (response_data, status) = response
o = json.loads(response_data)["output"]
assert o == ""
assert status == LitAPIStatus.FINISH_STREAMING

0 comments on commit 0826d0f

Please sign in to comment.