Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

propagate error with OpenAISpec #143

Merged
merged 6 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ async def data_reader(self, read):
asyncio.get_event_loop().remove_reader(read.fileno())
return read.recv()

async def win_data_streamer(self, read, write):
async def win_data_streamer(self, read, write, send_status=False):
# this is a workaround for Windows since asyncio loop.add_reader is not supported.
# https://docs.python.org/3/library/asyncio-platforms.html
while True:
Expand All @@ -467,12 +467,16 @@ async def win_data_streamer(self, read, write):
"Error occurred while streaming outputs from the inference worker. "
"Please check the above traceback."
)
yield response, status
return
yield response
if send_status:
yield response, status
else:
yield response

await asyncio.sleep(0.0001)

async def data_streamer(self, read: Connection, write: Connection):
async def data_streamer(self, read: Connection, write: Connection, send_status=False):
data_available = asyncio.Event()
while True:
# Calling poll blocks the event loop, so keep the timeout low
Expand All @@ -490,8 +494,12 @@ async def data_streamer(self, read: Connection, write: Connection):
"Error occurred while streaming outputs from the inference worker. "
"Please check the above traceback."
)
yield response, status
return
yield response
if send_status:
yield response, status
else:
yield response

def cleanup_request(self, request_buffer, uid):
with contextlib.suppress(KeyError):
Expand Down
22 changes: 13 additions & 9 deletions src/litserve/specs/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field

from ..utils import azip
from ..utils import azip, LitAPIStatus, load_and_raise
from .base import LitSpec

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -272,9 +272,9 @@ async def get_from_pipes(self, uids, pipes) -> List[AsyncGenerator]:
choice_pipes = []
for uid, (read, write) in zip(uids, pipes):
if sys.version_info[0] == 3 and sys.version_info[1] >= 8 and sys.platform.startswith("win"):
data = self._server.win_data_streamer(read, write)
data = self._server.win_data_streamer(read, write, send_status=True)
else:
data = self._server.data_streamer(read, write)
data = self._server.data_streamer(read, write, send_status=True)

choice_pipes.append(data)
return choice_pipes
Expand Down Expand Up @@ -320,8 +320,10 @@ async def streaming_completion(self, request: ChatCompletionRequest, pipe_respon
usage = None
async for streaming_response in azip(*pipe_responses):
choices = []
for i, chat_msg in enumerate(streaming_response):
chat_msg = json.loads(chat_msg)
for i, (response, status) in enumerate(streaming_response):
if status == LitAPIStatus.ERROR:
load_and_raise(response)
chat_msg = json.loads(response)
logger.debug(chat_msg)
chat_msg = ChoiceDelta(**chat_msg)
choice = ChatCompletionStreamingChoice(
Expand All @@ -345,15 +347,17 @@ async def streaming_completion(self, request: ChatCompletionRequest, pipe_respon
yield f"data: {last_chunk}\n\n"
yield "data: [DONE]\n\n"

async def non_streaming_completion(self, request: ChatCompletionRequest, pipe_responses: List):
async def non_streaming_completion(self, request: ChatCompletionRequest, generator_list: List[AsyncGenerator]):
model = request.model
usage = UsageInfo()
choices = []
for i, streaming_response in enumerate(pipe_responses):
for i, streaming_response in enumerate(generator_list):
msgs = []
tool_calls = None
async for chat_msg in streaming_response:
chat_msg = json.loads(chat_msg)
async for response, status in streaming_response:
if status == LitAPIStatus.ERROR:
load_and_raise(response)
chat_msg = json.loads(response)
logger.debug(chat_msg)
chat_msg = ChatMessage(**chat_msg)
msgs.append(chat_msg.content)
Expand Down
17 changes: 17 additions & 0 deletions tests/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,20 @@ async def test_oai_prepopulated_context(openai_request_data):
assert (
resp.json()["choices"][0]["message"]["content"] == "This is a"
), "OpenAISpec must return only 3 tokens as specified using `max_tokens` parameter"


class WrongLitAPI(ls.LitAPI):
def setup(self, device):
self.model = None

def predict(self, prompt):
yield "This is a sample generated text"
raise Exception("random error")


@pytest.mark.asyncio()
async def test_fail_http(openai_request_data):
server = ls.LitServer(WrongLitAPI(), spec=ls.OpenAISpec())
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10)
assert resp.status_code == 500, "Server raises an exception so client should fail"
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
Loading