Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tests for continuous batching and Default loops #396

Merged
merged 7 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/litserve/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.2.6.dev1"
__version__ = "0.2.6.dev2"
__author__ = "Lightning-AI et al."
__author_email__ = "community@lightning.ai"
__license__ = "Apache-2.0"
Expand Down
11 changes: 6 additions & 5 deletions src/litserve/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,6 @@ def __init__(self):
self._context = {}

def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float):
if max_batch_size <= 1:
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("max_batch_size must be greater than 1")

batches, timed_out_uids = collate_requests(
lit_api,
request_queue,
Expand All @@ -507,8 +504,10 @@ def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_si
return batches, timed_out_uids

def get_request(self, request_queue: Queue, timeout: float = 1.0):
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=timeout)
return response_queue_id, uid, timestamp, x_enc
try:
return request_queue.get(timeout=timeout)
except Empty:
return None

def populate_context(self, lit_spec: LitSpec, request: Any):
if lit_spec and hasattr(lit_spec, "populate_context"):
Expand Down Expand Up @@ -751,6 +750,8 @@ def has_finished(self, uid: str, token: str, max_sequence_length: int) -> bool:

def add_request(self, uid: str, request: Any, lit_api: LitAPI, lit_spec: Optional[LitSpec]) -> None:
"""Add a new sequence to active sequences and perform any action before prediction such as filling the cache."""
if hasattr(lit_api, "add_request"):
lit_api.add_request(uid, request)
decoded_request = lit_api.decode_request(request)
self.active_sequences[uid] = {"input": decoded_request, "current_length": 0, "generated_sequence": []}

Expand Down
156 changes: 156 additions & 0 deletions tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import inspect
import io
import json
import re
import threading
import time
from queue import Queue
Expand All @@ -28,8 +29,13 @@
from litserve import LitAPI
from litserve.callbacks import CallbackRunner
from litserve.loops import (
ContinuousBatchingLoop,
DefaultLoop,
LitLoop,
Output,
_BaseLoop,
inference_worker,
notify_timed_out_requests,
run_batched_loop,
run_batched_streaming_loop,
run_single_loop,
Expand Down Expand Up @@ -495,3 +501,153 @@ def test_get_default_loop():
assert isinstance(loop, ls.loops.BatchedStreamingLoop), (
"BatchedStreamingLoop must be returned when stream=True and max_batch_size>1"
)


@pytest.fixture
def lit_loop_setup():
lit_loop = LitLoop()
lit_api = MagicMock(request_timeout=0.1)
request_queue = Queue()
return lit_loop, lit_api, request_queue


def test_lit_loop_get_batch_requests(lit_loop_setup):
lit_loop, lit_api, request_queue = lit_loop_setup
request_queue.put((0, "UUID-001", time.monotonic(), {"input": 4.0}))
request_queue.put((0, "UUID-002", time.monotonic(), {"input": 5.0}))
batches, timed_out_uids = lit_loop.get_batch_requests(lit_api, request_queue, 2, 0.001)
assert len(batches) == 2
assert batches == [(0, "UUID-001", {"input": 4.0}), (0, "UUID-002", {"input": 5.0})]
assert timed_out_uids == []


def test_lit_loop_get_request(lit_loop_setup):
lit_loop, _, request_queue = lit_loop_setup
t = time.monotonic()
request_queue.put((0, "UUID-001", t, {"input": 4.0}))
response_queue_id, uid, timestamp, x_enc = lit_loop.get_request(request_queue, timeout=1)
assert uid == "UUID-001"
assert response_queue_id == 0
assert timestamp == t
assert x_enc == {"input": 4.0}
assert lit_loop.get_request(request_queue, timeout=0.001) is None


def test_lit_loop_put_response(lit_loop_setup):
lit_loop, _, request_queue = lit_loop_setup
response_queues = [Queue()]
lit_loop.put_response(response_queues, 0, "UUID-001", {"output": 16.0}, LitAPIStatus.OK)
response = response_queues[0].get()
assert response == ("UUID-001", ({"output": 16.0}, LitAPIStatus.OK))


def test_notify_timed_out_requests():
response_queues = [Queue()]

# Simulate timed out requests
timed_out_uids = [(0, "UUID-001"), (0, "UUID-002")]

# Call the function to notify timed out requests
notify_timed_out_requests(response_queues, timed_out_uids)

# Check the responses in the response queue
response_1 = response_queues[0].get()
response_2 = response_queues[0].get()

assert response_1[0] == "UUID-001"
assert response_1[1][1] == LitAPIStatus.ERROR
assert isinstance(response_1[1][0], HTTPException)
assert response_2[0] == "UUID-002"
assert isinstance(response_2[1][0], HTTPException)
assert response_2[1][1] == LitAPIStatus.ERROR


class ContinuousBatchingAPI(ls.LitAPI):
def setup(self, spec: Optional[LitSpec]):
self.model = {}

def add_request(self, uid: str, request):
self.model[uid] = {"outputs": list(range(5))}

def decode_request(self, input: str):
return input

def encode_response(self, output: str):
return {"output": output}

def step(self, prev_outputs: Optional[List[Output]]) -> List[Output]:
outputs = []
for k in self.model:
v = self.model[k]
if v["outputs"]:
o = v["outputs"].pop(0)
outputs.append(Output(k, o, LitAPIStatus.OK))
keys = list(self.model.keys())
for k in keys:
if k not in [o.uid for o in outputs]:
outputs.append(Output(k, "", LitAPIStatus.FINISH_STREAMING))
del self.model[k]
return outputs


@pytest.mark.parametrize(
("stream", "max_batch_size", "error_msg"),
[
(True, 4, "`lit_api.unbatch` must generate values using `yield`."),
(True, 1, "`lit_api.encode_response` must generate values using `yield`."),
],
)
def test_default_loop_pre_setup_error(stream, max_batch_size, error_msg):
lit_api = ls.test_examples.SimpleLitAPI()
lit_api.stream = stream
lit_api.max_batch_size = max_batch_size
loop = DefaultLoop()
with pytest.raises(ValueError, match=error_msg):
loop.pre_setup(lit_api, None)


@pytest.fixture
def continuous_batching_setup():
lit_api = ContinuousBatchingAPI()
lit_api.stream = True
lit_api.request_timeout = 0.1
lit_api.pre_setup(2, None)
lit_api.setup(None)
request_queue = Queue()
response_queues = [Queue()]
loop = ContinuousBatchingLoop()
return lit_api, loop, request_queue, response_queues


def test_continuous_batching_pre_setup(continuous_batching_setup):
lit_api, loop, request_queue, response_queues = continuous_batching_setup
lit_api.stream = False
with pytest.raises(
ValueError,
match=re.escape(
"Continuous batching loop requires streaming to be enabled. Please set LitServe(..., stream=True)"
),
):
loop.pre_setup(lit_api, None)


def test_continuous_batching_run(continuous_batching_setup):
lit_api, loop, request_queue, response_queues = continuous_batching_setup
request_queue.put((0, "UUID-001", time.monotonic(), {"input": "Hello"}))
loop.run(lit_api, None, "cpu", 0, request_queue, response_queues, 2, 0.1, True, {}, NOOP_CB_RUNNER)

results = []
for i in range(5):
response = response_queues[0].get()
uid, (response_data, status) = response
o = json.loads(response_data)["output"]
assert o == i
assert status == LitAPIStatus.OK
assert uid == "UUID-001"
results.append(o)
assert results == list(range(5)), "API must return a sequence of numbers from 0 to 4"
response = response_queues[0].get()
uid, (response_data, status) = response
o = json.loads(response_data)["output"]
assert o == ""
assert status == LitAPIStatus.FINISH_STREAMING
Loading