From 35129f7dc84d8e796a43611358abbb67a34a0c01 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 11 Dec 2024 17:07:43 +0000 Subject: [PATCH] Add `loop.pre_setup` to allow fine-grained LitAPI validation based on inference loop (#393) * pre_setup loop * add test * fix tests * apply feedback --- src/litserve/api.py | 64 +--------------- src/litserve/loops.py | 170 +++++++++++++++++++++++++++++------------ src/litserve/server.py | 11 ++- tests/test_batch.py | 2 +- tests/test_litapi.py | 22 +++--- tests/test_loops.py | 22 +++++- 6 files changed, 161 insertions(+), 130 deletions(-) diff --git a/src/litserve/api.py b/src/litserve/api.py index c54a7b80..e2dd9eea 100644 --- a/src/litserve/api.py +++ b/src/litserve/api.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect import json import warnings from abc import ABC, abstractmethod @@ -113,76 +112,15 @@ def device(self): def device(self, value): self._device = value - def _sanitize(self, max_batch_size: int, spec: Optional[LitSpec]): + def pre_setup(self, max_batch_size: int, spec: Optional[LitSpec]): self.max_batch_size = max_batch_size if self.stream: self._default_unbatch = self._unbatch_stream else: self._default_unbatch = self._unbatch_no_stream - # we will sanitize regularly if no spec - # in case, we have spec then: - # case 1: spec implements a streaming API - # Case 2: spec implements a non-streaming API if spec: - # TODO: Implement sanitization self._spec = spec - return - - original = self.unbatch.__code__ is LitAPI.unbatch.__code__ - if ( - self.stream - and max_batch_size > 1 - and not all([ - inspect.isgeneratorfunction(self.predict), - inspect.isgeneratorfunction(self.encode_response), - (original or inspect.isgeneratorfunction(self.unbatch)), - ]) - ): - raise ValueError( - """When `stream=True` with max_batch_size > 1, `lit_api.predict`, `lit_api.encode_response` and - `lit_api.unbatch` must generate values using `yield`. - - Example: - - def predict(self, inputs): - ... - for i in range(max_token_length): - yield prediction - - def encode_response(self, outputs): - for output in outputs: - encoded_output = ... - yield encoded_output - - def unbatch(self, outputs): - for output in outputs: - unbatched_output = ... - yield unbatched_output - """ - ) - - if self.stream and not all([ - inspect.isgeneratorfunction(self.predict), - inspect.isgeneratorfunction(self.encode_response), - ]): - raise ValueError( - """When `stream=True` both `lit_api.predict` and - `lit_api.encode_response` must generate values using `yield`. - - Example: - - def predict(self, inputs): - ... - for i in range(max_token_length): - yield prediction - - def encode_response(self, outputs): - for output in outputs: - encoded_output = ... - yield encoded_output - """ - ) def set_logger_queue(self, queue: Queue): """Set the queue for logging events.""" diff --git a/src/litserve/loops.py b/src/litserve/loops.py index 8e10653f..46a309c8 100644 --- a/src/litserve/loops.py +++ b/src/litserve/loops.py @@ -441,6 +441,9 @@ def run( """ + def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]): + pass + def __call__( self, lit_api: LitAPI, @@ -487,7 +490,109 @@ def run( raise NotImplementedError -class SingleLoop(_BaseLoop): +class LitLoop(_BaseLoop): + def __init__(self): + self._context = {} + + def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float): + if max_batch_size <= 1: + raise ValueError("max_batch_size must be greater than 1") + + batches, timed_out_uids = collate_requests( + lit_api, + request_queue, + max_batch_size, + batch_timeout, + ) + return batches, timed_out_uids + + def get_request(self, request_queue: Queue, timeout: float = 1.0): + response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=timeout) + return response_queue_id, uid, timestamp, x_enc + + def populate_context(self, lit_spec: LitSpec, request: Any): + if lit_spec and hasattr(lit_spec, "populate_context"): + lit_spec.populate_context(self._context, request) + + def put_response( + self, response_queues: List[Queue], response_queue_id: int, uid: str, response_data: Any, status: LitAPIStatus + ) -> None: + response_queues[response_queue_id].put((uid, (response_data, status))) + + def put_error_response( + self, response_queues: List[Queue], response_queue_id: int, uid: str, error: Exception + ) -> None: + response_queues[response_queue_id].put((uid, (error, LitAPIStatus.ERROR))) + + +class DefaultLoop(LitLoop): + def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]): + # we will sanitize regularly if no spec + # in case, we have spec then: + # case 1: spec implements a streaming API + # Case 2: spec implements a non-streaming API + if spec: + # TODO: Implement sanitization + lit_api._spec = spec + return + + original = lit_api.unbatch.__code__ is LitAPI.unbatch.__code__ + if ( + lit_api.stream + and lit_api.max_batch_size > 1 + and not all([ + inspect.isgeneratorfunction(lit_api.predict), + inspect.isgeneratorfunction(lit_api.encode_response), + (original or inspect.isgeneratorfunction(lit_api.unbatch)), + ]) + ): + raise ValueError( + """When `stream=True` with max_batch_size > 1, `lit_api.predict`, `lit_api.encode_response` and + `lit_api.unbatch` must generate values using `yield`. + + Example: + + def predict(self, inputs): + ... + for i in range(max_token_length): + yield prediction + + def encode_response(self, outputs): + for output in outputs: + encoded_output = ... + yield encoded_output + + def unbatch(self, outputs): + for output in outputs: + unbatched_output = ... + yield unbatched_output + """ + ) + + if lit_api.stream and not all([ + inspect.isgeneratorfunction(lit_api.predict), + inspect.isgeneratorfunction(lit_api.encode_response), + ]): + raise ValueError( + """When `stream=True` both `lit_api.predict` and + `lit_api.encode_response` must generate values using `yield`. + + Example: + + def predict(self, inputs): + ... + for i in range(max_token_length): + yield prediction + + def encode_response(self, outputs): + for output in outputs: + encoded_output = ... + yield encoded_output + """ + ) + + +class SingleLoop(DefaultLoop): def __call__( self, lit_api: LitAPI, @@ -505,7 +610,7 @@ def __call__( run_single_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner) -class BatchedLoop(_BaseLoop): +class BatchedLoop(DefaultLoop): def __call__( self, lit_api: LitAPI, @@ -531,7 +636,7 @@ def __call__( ) -class StreamingLoop(_BaseLoop): +class StreamingLoop(DefaultLoop): def __call__( self, lit_api: LitAPI, @@ -549,7 +654,7 @@ def __call__( run_streaming_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner) -class BatchedStreamingLoop(_BaseLoop): +class BatchedStreamingLoop(DefaultLoop): def __call__( self, lit_api: LitAPI, @@ -593,41 +698,6 @@ class Output: status: LitAPIStatus -class LitLoop(_BaseLoop): - def __init__(self): - self._context = {} - - def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float): - if max_batch_size <= 1: - raise ValueError("max_batch_size must be greater than 1") - - batches, timed_out_uids = collate_requests( - lit_api, - request_queue, - max_batch_size, - batch_timeout, - ) - return batches, timed_out_uids - - def get_request(self, request_queue: Queue, timeout: float = 1.0): - response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=timeout) - return response_queue_id, uid, timestamp, x_enc - - def populate_context(self, lit_spec: LitSpec, request: Any): - if lit_spec and hasattr(lit_spec, "populate_context"): - lit_spec.populate_context(self._context, request) - - def put_response( - self, response_queues: List[Queue], response_queue_id: int, uid: str, response_data: Any, status: LitAPIStatus - ) -> None: - response_queues[response_queue_id].put((uid, (response_data, status))) - - def put_error_response( - self, response_queues: List[Queue], response_queue_id: int, uid: str, error: Exception - ) -> None: - response_queues[response_queue_id].put((uid, (error, LitAPIStatus.ERROR))) - - class ContinuousBatchingLoop(LitLoop): def __init__(self, max_sequence_length: int = 2048): super().__init__() @@ -840,15 +910,7 @@ def inference_worker( logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec") if loop == "auto": - loop = ( - BatchedStreamingLoop() - if stream and max_batch_size > 1 - else StreamingLoop() - if stream - else BatchedLoop() - if max_batch_size > 1 - else SingleLoop() - ) + loop = get_default_loop(stream, max_batch_size) loop( lit_api, @@ -863,3 +925,15 @@ def inference_worker( workers_setup_status, callback_runner, ) + + +def get_default_loop(stream: bool, max_batch_size: int) -> _BaseLoop: + return ( + BatchedStreamingLoop() + if stream and max_batch_size > 1 + else StreamingLoop() + if stream + else BatchedLoop() + if max_batch_size > 1 + else SingleLoop() + ) diff --git a/src/litserve/server.py b/src/litserve/server.py index a5ac5f6d..42e4f097 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -40,7 +40,7 @@ from litserve.callbacks.base import Callback, CallbackRunner, EventTypes from litserve.connector import _Connector from litserve.loggers import Logger, _LoggerConnector -from litserve.loops import _BaseLoop, inference_worker +from litserve.loops import LitLoop, get_default_loop, inference_worker from litserve.middlewares import MaxSizeMiddleware, RequestCountMiddleware from litserve.python_client import client_template from litserve.specs import OpenAISpec @@ -113,7 +113,7 @@ def __init__( spec: Optional[LitSpec] = None, max_payload_size=None, track_requests: bool = False, - loop: Optional[Union[str, _BaseLoop]] = "auto", + loop: Optional[Union[str, LitLoop]] = "auto", callbacks: Optional[Union[List[Callback], Callback]] = None, middlewares: Optional[list[Union[Callable, tuple[Callable, dict]]]] = None, loggers: Optional[Union[Logger, List[Logger]]] = None, @@ -154,6 +154,8 @@ def __init__( if isinstance(loop, str) and loop != "auto": raise ValueError("loop must be an instance of _BaseLoop or 'auto'") + if loop == "auto": + loop = get_default_loop(stream, max_batch_size) if middlewares is None: middlewares = [] @@ -198,7 +200,7 @@ def __init__( "but the max_batch_size parameter was not set." ) - self._loop = loop + self._loop: LitLoop = loop self.api_path = api_path self.healthcheck_path = healthcheck_path self.info_path = info_path @@ -206,7 +208,8 @@ def __init__( self.timeout = timeout lit_api.stream = stream lit_api.request_timeout = self.timeout - lit_api._sanitize(max_batch_size, spec=spec) + lit_api.pre_setup(max_batch_size, spec=spec) + self._loop.pre_setup(lit_api, spec=spec) self.app = FastAPI(lifespan=self.lifespan) self.app.response_queue_id = None self.response_queue_id = None diff --git a/tests/test_batch.py b/tests/test_batch.py index dd0cec30..c8710d81 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -154,7 +154,7 @@ def test_max_batch_size_warning(): def test_batch_predict_string_warning(): api = ls.test_examples.SimpleBatchedAPI() - api._sanitize(2, None) + api.pre_setup(2, None) api.predict = MagicMock(return_value="This is a string") mock_input = torch.tensor([[1.0], [2.0]]) diff --git a/tests/test_litapi.py b/tests/test_litapi.py index e69d31f4..797ecbdf 100644 --- a/tests/test_litapi.py +++ b/tests/test_litapi.py @@ -65,7 +65,7 @@ def encode_response(self, output_stream): def test_default_batch_unbatch(): api = TestDefaultBatchedAPI() - api._sanitize(max_batch_size=4, spec=None) + api.pre_setup(max_batch_size=4, spec=None) inputs = [1, 2, 3, 4] output = api.batch(inputs) assert output == inputs, "Default batch should not change input" @@ -81,7 +81,7 @@ def predict(self, x): def test_default_batch_unbatch_stream(): api = TestStreamAPIBatched() api.stream = True - api._sanitize(max_batch_size=4, spec=None) + api.pre_setup(max_batch_size=4, spec=None) inputs = [1, 2, 3, 4] expected_output = [[0, 0, 0, 0], [1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12]] output = api.batch(inputs) @@ -93,7 +93,7 @@ def test_default_batch_unbatch_stream(): def test_custom_batch_unbatch(): api = TestCustomBatchedAPI() - api._sanitize(max_batch_size=4, spec=None) + api.pre_setup(max_batch_size=4, spec=None) inputs = [1, 2, 3, 4] output = api.batch(inputs) assert np.all(output == np.array(inputs)), "Custom batch stacks input as numpy array" @@ -102,7 +102,7 @@ def test_custom_batch_unbatch(): def test_batch_unbatch_stream(): api = TestStreamAPI() - api._sanitize(max_batch_size=4, spec=None) + api.pre_setup(max_batch_size=4, spec=None) inputs = [1, 2, 3, 4] output = api.batch(inputs) output = api.predict(output) @@ -128,7 +128,7 @@ def test_decode_request(): def test_decode_request_with_openai_spec(): api = ls.test_examples.TestAPI() - api._sanitize(max_batch_size=1, spec=ls.OpenAISpec()) + api.pre_setup(max_batch_size=1, spec=ls.OpenAISpec()) request = ChatCompletionRequest(messages=[{"role": "system", "content": "Hello"}]) decoded_request = api.decode_request(request) assert decoded_request[0]["content"] == "Hello", "Decode request should return the input message" @@ -136,7 +136,7 @@ def test_decode_request_with_openai_spec(): def test_decode_request_with_openai_spec_wrong_request(): api = ls.test_examples.TestAPI() - api._sanitize(max_batch_size=1, spec=ls.OpenAISpec()) + api.pre_setup(max_batch_size=1, spec=ls.OpenAISpec()) with pytest.raises(AttributeError, match="object has no attribute 'messages'"): api.decode_request({"input": "Hello"}) @@ -149,7 +149,7 @@ def test_encode_response(): def test_encode_response_with_openai_spec(): api = ls.test_examples.TestAPI() - api._sanitize(max_batch_size=1, spec=ls.OpenAISpec()) + api.pre_setup(max_batch_size=1, spec=ls.OpenAISpec()) response = "This is a LLM generated text".split() generated_tokens = [] for output in api.encode_response(response): @@ -166,7 +166,7 @@ def predict(): generated_tokens = [] api = ls.test_examples.TestAPI() - api._sanitize(max_batch_size=1, spec=ls.OpenAISpec()) + api.pre_setup(max_batch_size=1, spec=ls.OpenAISpec()) for output in api.encode_response(predict()): assert output["role"] == "assistant", "Role should be assistant" @@ -181,7 +181,7 @@ def encode_response(self, output_stream): yield {"content": output} api = ls.test_examples.TestAPI() - api._sanitize(max_batch_size=1, spec=CustomSpecAPI()) + api.pre_setup(max_batch_size=1, spec=CustomSpecAPI()) response = "This is a LLM generated text".split() generated_tokens = [] for output in api.encode_response(response): @@ -191,7 +191,7 @@ def encode_response(self, output_stream): def test_encode_response_with_openai_spec_invalid_input(): api = ls.test_examples.TestAPI() - api._sanitize(max_batch_size=1, spec=ls.OpenAISpec()) + api.pre_setup(max_batch_size=1, spec=ls.OpenAISpec()) response = 10 with pytest.raises(TypeError, match="object is not iterable"): next(api.encode_response(response)) @@ -202,7 +202,7 @@ def predict(): yield {"hello": "world"} api = ls.test_examples.TestAPI() - api._sanitize(max_batch_size=1, spec=ls.OpenAISpec()) + api.pre_setup(max_batch_size=1, spec=ls.OpenAISpec()) with pytest.raises(HTTPException, match=r"Malformed output from LitAPI.predict"): next(api.encode_response(predict())) diff --git a/tests/test_loops.py b/tests/test_loops.py index 4cf4806c..96581ef3 100644 --- a/tests/test_loops.py +++ b/tests/test_loops.py @@ -258,7 +258,7 @@ def test_run_single_loop_timeout(): def test_run_batched_loop(): lit_api = ls.test_examples.SimpleBatchedAPI() lit_api.setup(None) - lit_api._sanitize(2, None) + lit_api.pre_setup(2, None) assert lit_api.model is not None, "Setup must initialize the model" lit_api.request_timeout = 1 @@ -292,7 +292,7 @@ def test_run_batched_loop_timeout(): ls.configure_logging(stream=stream) lit_api = ls.test_examples.SimpleBatchedAPI() lit_api.setup(None) - lit_api._sanitize(2, None) + lit_api.pre_setup(2, None) assert lit_api.model is not None, "Setup must initialize the model" lit_api.request_timeout = 0.1 @@ -388,7 +388,7 @@ def off_test_run_batched_streaming_loop(openai_request_data): lit_api.request_timeout = 1 lit_api.stream = True spec = ls.OpenAISpec() - lit_api._sanitize(2, spec) + lit_api.pre_setup(2, spec) request_queue = Queue() # response_queue_id, uid, timestamp, x_enc @@ -479,3 +479,19 @@ def test_loop_with_server(): with wrap_litserve_start(server) as server, TestClient(server.app) as client: response = client.post("/predict", json={"input": 4.0}) assert response.json() == {"output": 1600.0} # use LitAPI.load_cache to multiply the input by 10 + + +def test_get_default_loop(): + loop = ls.loops.get_default_loop(stream=False, max_batch_size=1) + assert isinstance(loop, ls.loops.SingleLoop), "SingleLoop must be returned when stream=False" + + loop = ls.loops.get_default_loop(stream=False, max_batch_size=4) + assert isinstance(loop, ls.loops.BatchedLoop), "BatchedLoop must be returned when stream=False and max_batch_size>1" + + loop = ls.loops.get_default_loop(stream=True, max_batch_size=1) + assert isinstance(loop, ls.loops.StreamingLoop), "StreamingLoop must be returned when stream=True" + + loop = ls.loops.get_default_loop(stream=True, max_batch_size=4) + assert isinstance(loop, ls.loops.BatchedStreamingLoop), ( + "BatchedStreamingLoop must be returned when stream=True and max_batch_size>1" + )