Skip to content

Commit

Permalink
fix: schema
Browse files Browse the repository at this point in the history
  • Loading branch information
bojiang committed Sep 27, 2024
1 parent e77302c commit 35c3e49
Showing 1 changed file with 30 additions and 13 deletions.
43 changes: 30 additions & 13 deletions src/vllm-chat/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import traceback
from argparse import Namespace
from typing import AsyncGenerator, Literal, Optional
from typing import AsyncGenerator, Literal, Optional, Union

import bentoml
import fastapi
Expand All @@ -19,16 +19,17 @@
class URL(pydantic.BaseModel):
url: str

class TextContent(pydantic.BaseModel):
type: Literal["text"] = "text"
text: str

class Content(pydantic.BaseModel):
type: Literal["text", "image_url"] = "text"
text: Optional[str] = None
image_url: Optional[URL] = None

class ImageContent(pydantic.BaseModel):
type: Literal["image_url"] = "image_url"
image_url: URL

class Message(pydantic.BaseModel):
role: Literal["system", "user", "assistant"] = "user"
content: list[Content]
content: list[Union[TextContent, ImageContent]]


PARAMETER_YAML = os.path.join(os.path.dirname(__file__), "openllm_config.yaml")
Expand Down Expand Up @@ -116,8 +117,24 @@ def __init__(self) -> None:
async def generate(
self, prompt: str = "what is this?"
) -> AsyncGenerator[str, None]:
async for text in self.generate_with_image(prompt):
yield text
from openai import AsyncOpenAI

client = AsyncOpenAI(base_url="http://127.0.0.1:3000/v1", api_key="dummy")
content = [TextContent(text=prompt)]
message = Message(role="user", content=content)

try:
completion = await client.chat.completions.create( # type: ignore
model=ENGINE_CONFIG["model"],
messages=[message.model_dump()], # type: ignore
stream=True,
)
async for chunk in completion:
yield chunk.choices[0].delta.content or ""
except Exception:
yield traceback.format_exc()
# async for text in self.generate_with_image(prompt):
# yield text

@bentoml.api
async def generate_with_image(
Expand All @@ -132,12 +149,12 @@ async def generate_with_image(
img_str = base64.b64encode(buffered.getvalue()).decode()
buffered.close()
image_url = f"data:image/png;base64,{img_str}"
content = [
Content(type="image_url", image_url=URL(url=image_url)),
Content(type="text", text=prompt),
content: list[Union[ImageContent, TextContent]] = [
ImageContent(image_url=URL(url=image_url)),
TextContent(text=prompt),
]
else:
content = [Content(type="text", text=prompt)]
content = [TextContent(text=prompt)]
message = Message(role="user", content=content)

try:
Expand Down

0 comments on commit 35c3e49

Please sign in to comment.