From b41f22bdbe16639211364c894bc6dddd66612b77 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Tue, 18 Jun 2024 11:40:53 +0100 Subject: [PATCH 1/4] send status --- src/litserve/server.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 105af678..ecc15007 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -452,7 +452,7 @@ async def data_reader(self, read): asyncio.get_event_loop().remove_reader(read.fileno()) return read.recv() - async def win_data_streamer(self, read, write): + async def win_data_streamer(self, read, write, send_status=False): # this is a workaround for Windows since asyncio loop.add_reader is not supported. # https://docs.python.org/3/library/asyncio-platforms.html while True: @@ -468,11 +468,14 @@ async def win_data_streamer(self, read, write): "Please check the above traceback." ) return - yield response + if send_status: + yield response, status + else: + yield response await asyncio.sleep(0.0001) - async def data_streamer(self, read: Connection, write: Connection): + async def data_streamer(self, read: Connection, write: Connection, send_status=False): data_available = asyncio.Event() while True: # Calling poll blocks the event loop, so keep the timeout low @@ -491,7 +494,10 @@ async def data_streamer(self, read: Connection, write: Connection): "Please check the above traceback." ) return - yield response + if send_status: + yield response, status + else: + yield response def cleanup_request(self, request_buffer, uid): with contextlib.suppress(KeyError): From facbe12e91510640adb72f7f63a8dc2f5d03d75c Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Tue, 18 Jun 2024 12:23:51 +0100 Subject: [PATCH 2/4] propagate error --- src/litserve/server.py | 2 ++ src/litserve/specs/openai.py | 23 ++++++++++++++--------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index ecc15007..1ec7edc0 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -467,6 +467,7 @@ async def win_data_streamer(self, read, write, send_status=False): "Error occurred while streaming outputs from the inference worker. " "Please check the above traceback." ) + yield response, status return if send_status: yield response, status @@ -493,6 +494,7 @@ async def data_streamer(self, read: Connection, write: Connection, send_status=F "Error occurred while streaming outputs from the inference worker. " "Please check the above traceback." ) + yield response, status return if send_status: yield response, status diff --git a/src/litserve/specs/openai.py b/src/litserve/specs/openai.py index 5ceeab8e..d51aa1f8 100644 --- a/src/litserve/specs/openai.py +++ b/src/litserve/specs/openai.py @@ -26,7 +26,7 @@ from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field -from ..utils import azip +from ..utils import azip, LitAPIStatus, load_and_raise from .base import LitSpec if typing.TYPE_CHECKING: @@ -272,9 +272,9 @@ async def get_from_pipes(self, uids, pipes) -> List[AsyncGenerator]: choice_pipes = [] for uid, (read, write) in zip(uids, pipes): if sys.version_info[0] == 3 and sys.version_info[1] >= 8 and sys.platform.startswith("win"): - data = self._server.win_data_streamer(read, write) + data = self._server.win_data_streamer(read, write, send_status=True) else: - data = self._server.data_streamer(read, write) + data = self._server.data_streamer(read, write, send_status=True) choice_pipes.append(data) return choice_pipes @@ -320,8 +320,10 @@ async def streaming_completion(self, request: ChatCompletionRequest, pipe_respon usage = None async for streaming_response in azip(*pipe_responses): choices = [] - for i, chat_msg in enumerate(streaming_response): - chat_msg = json.loads(chat_msg) + for i, (response, status) in enumerate(streaming_response): + if status == LitAPIStatus.ERROR: + load_and_raise(response) + chat_msg = json.loads(response) logger.debug(chat_msg) chat_msg = ChoiceDelta(**chat_msg) choice = ChatCompletionStreamingChoice( @@ -345,15 +347,17 @@ async def streaming_completion(self, request: ChatCompletionRequest, pipe_respon yield f"data: {last_chunk}\n\n" yield "data: [DONE]\n\n" - async def non_streaming_completion(self, request: ChatCompletionRequest, pipe_responses: List): + async def non_streaming_completion(self, request: ChatCompletionRequest, generator_list: List[AsyncGenerator]): model = request.model usage = UsageInfo() choices = [] - for i, streaming_response in enumerate(pipe_responses): + for i, streaming_response in enumerate(generator_list): msgs = [] tool_calls = None - async for chat_msg in streaming_response: - chat_msg = json.loads(chat_msg) + async for response, status in streaming_response: + if status == LitAPIStatus.ERROR: + load_and_raise(response) + chat_msg = json.loads(response) logger.debug(chat_msg) chat_msg = ChatMessage(**chat_msg) msgs.append(chat_msg.content) @@ -364,5 +368,6 @@ async def non_streaming_completion(self, request: ChatCompletionRequest, pipe_re msg = {"role": "assistant", "content": content, "tool_calls": tool_calls} choice = ChatCompletionResponseChoice(index=i, message=msg, finish_reason="stop") choices.append(choice) + i += 1 return ChatCompletionResponse(model=model, choices=choices, usage=usage) From 95750c327bce92b7004019702b4309adb04918e5 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Tue, 18 Jun 2024 12:26:06 +0100 Subject: [PATCH 3/4] fix --- src/litserve/specs/openai.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/litserve/specs/openai.py b/src/litserve/specs/openai.py index d51aa1f8..f6cf1778 100644 --- a/src/litserve/specs/openai.py +++ b/src/litserve/specs/openai.py @@ -368,6 +368,5 @@ async def non_streaming_completion(self, request: ChatCompletionRequest, generat msg = {"role": "assistant", "content": content, "tool_calls": tool_calls} choice = ChatCompletionResponseChoice(index=i, message=msg, finish_reason="stop") choices.append(choice) - i += 1 return ChatCompletionResponse(model=model, choices=choices, usage=usage) From 56ab9f5b474ca3793c7e48bfdb317bf9a1118883 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Tue, 18 Jun 2024 12:33:47 +0100 Subject: [PATCH 4/4] add test --- tests/test_specs.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_specs.py b/tests/test_specs.py index 3819acf1..1b34b08e 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -129,3 +129,20 @@ async def test_oai_prepopulated_context(openai_request_data): assert ( resp.json()["choices"][0]["message"]["content"] == "This is a" ), "OpenAISpec must return only 3 tokens as specified using `max_tokens` parameter" + + +class WrongLitAPI(ls.LitAPI): + def setup(self, device): + self.model = None + + def predict(self, prompt): + yield "This is a sample generated text" + raise Exception("random error") + + +@pytest.mark.asyncio() +async def test_fail_http(openai_request_data): + server = ls.LitServer(WrongLitAPI(), spec=ls.OpenAISpec()) + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10) + assert resp.status_code == 500, "Server raises an exception so client should fail"