Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement OpenAI token usage #150

Merged
merged 13 commits into from
Jun 24, 2024
64 changes: 63 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,68 @@ if __name__ == "__main__":
```
 

The OpenAI response includes usage information, which contains the number of tokens for the prompt, generated text, and
total tokens for the given request.

With LitServe this can be achieved using the `encode_response` method to yield usage information as follows:

```python
import litserve as ls

class OpenAIUsageAPI(ls.LitAPI):
def setup(self, device):
self.model = None

def predict(self, x):
yield "10 + 6 is equal to 16."

def encode_response(self, output):
for out in output:
yield {"role": "assistant", "content": out}
# Get the usage info and yield as last output
yield {"role": "assistant", "content": "", "prompt_tokens": 25, "completion_tokens": 10, "total_tokens": 35}
```

An equivalent but simpler approach without the need to override the `encode_response` method is as follows:

```python
import litserve as ls

class OpenAIUsageAPI(ls.LitAPI):
def setup(self, device):
self.model = None

def predict(self, x):
yield {"role": "assistant", "content": "10 + 6 is equal to 16.", "prompt_tokens": 25, "completion_tokens": 10, "total_tokens": 35}
```

The server response using either of the above approaches would include token usage data as shown below:

```json
{
"id": "chatcmpl-9dEtoQgtrtr3451SZ2s98S",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "10 + 6 is equal to 16.",
"role": "assistant",
"function_call": null,
"tool_calls": null
}
}
],
"created": 1719139092,
"model": "gpt-3.5-turbo-0125",
"object": "chat.completion",
"system_fingerprint": null,
"usage": {"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35}
}
```


LitServe's `OpenAISpec` can also handle images in the input. Here is an example:

```python
Expand Down Expand Up @@ -901,7 +963,7 @@ if __name__=="__main__":
</details>

<details>
<summary>Customize the endpoint path</summary>
<summary>Customize the endpoint path</summary>

&nbsp;

Expand Down
50 changes: 50 additions & 0 deletions src/litserve/examples/openai_spec_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,56 @@ def encode_response(self, output_stream, context):
assert ctx["temperature"] == 1.0, f"context {ctx} is not 1.0"


class OpenAIWithUsage(ls.LitAPI):
def setup(self, device):
self.model = None

def predict(self, x):
yield {
"role": "assistant",
"content": "10 + 6 is equal to 16.",
"prompt_tokens": 25,
"completion_tokens": 10,
"total_tokens": 35,
}


class OpenAIWithUsageEncodeResponse(ls.LitAPI):
def setup(self, device):
self.model = None

def predict(self, x):
# streaming tokens
yield from ["10", " +", " ", "6", " is", " equal", " to", " ", "16", "."]

def encode_response(self, output):
for out in output:
yield {"role": "assistant", "content": out}

yield {"role": "assistant", "content": "", "prompt_tokens": 25, "completion_tokens": 10, "total_tokens": 35}


class OpenAIBatchingWithUsage(OpenAIWithUsage):
def batch(self, inputs):
return inputs

def predict(self, x):
n = len(x)
yield ["10 + 6 is equal to 16."] * n

def encode_response(self, output_stream_batch, context):
n = len(context)
for output_batch in output_stream_batch:
yield [{"role": "assistant", "content": out} for out in output_batch]

yield [
{"role": "assistant", "content": "", "prompt_tokens": 25, "completion_tokens": 10, "total_tokens": 35}
] * n

def unbatch(self, output):
return output


if __name__ == "__main__":
server = ls.LitServer(TestAPIWithCustomEncode(), spec=OpenAISpec())
server.run(port=8000)
72 changes: 54 additions & 18 deletions src/litserve/specs/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import typing
import uuid
from enum import Enum
from typing import AsyncGenerator, Dict, List, Literal, Optional, Union
from typing import AsyncGenerator, Dict, List, Literal, Optional, Union, Generator

from fastapi import BackgroundTasks, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
Expand All @@ -44,6 +44,17 @@ class UsageInfo(BaseModel):
total_tokens: int = 0
completion_tokens: Optional[int] = 0

def __add__(self, other: "UsageInfo") -> "UsageInfo":
other.prompt_tokens += self.prompt_tokens
other.completion_tokens += self.completion_tokens
other.total_tokens += self.total_tokens
return other

def __radd__(self, other):
if other == 0:
return self
return self.__add__(other)


class TextContent(BaseModel):
type: str
Expand Down Expand Up @@ -237,10 +248,20 @@ def batch(self, inputs):
def unbatch(self, output):
yield output

def extract_usage_info(self, output: Dict) -> Dict:
prompt_tokens: int = output.pop("prompt_tokens", 0)
completion_tokens: int = output.pop("completion_tokens", 0)
total_tokens: int = output.pop("total_tokens", 0)
return {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}

def validate_chat_message(self, obj):
return isinstance(obj, dict) and "role" in obj and "content" in obj

def _encode_response(self, output: Union[Dict[str, str], List[Dict[str, str]]]) -> ChatMessage:
def _encode_response(self, output: Union[Dict[str, str], List[Dict[str, str]]]) -> Dict:
logger.debug(output)
if isinstance(output, str):
message = {"role": "assistant", "content": output}
Expand All @@ -258,12 +279,12 @@ def _encode_response(self, output: Union[Dict[str, str], List[Dict[str, str]]])
)
logger.exception(error)
raise HTTPException(500, error)

return ChatMessage(**message)
usage_info = self.extract_usage_info(message)
return {**message, **usage_info}

def encode_response(
self, output_generator: Union[Dict[str, str], List[Dict[str, str]]], context_kwargs: Optional[dict] = None
) -> ChatMessage:
) -> Generator[Dict, None, None]:
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
for output in output_generator:
logger.debug(output)
yield self._encode_response(output)
Expand Down Expand Up @@ -317,49 +338,63 @@ def callback(_=None):

async def streaming_completion(self, request: ChatCompletionRequest, pipe_responses: List):
model = request.model
usage = None
usage_info = None
async for streaming_response in azip(*pipe_responses):
choices = []
usage_infos = []
# iterate over n choices
for i, (response, status) in enumerate(streaming_response):
if status == LitAPIStatus.ERROR:
load_and_raise(response)
chat_msg = json.loads(response)
logger.debug(chat_msg)
chat_msg = ChoiceDelta(**chat_msg)
encoded_response = json.loads(response)
logger.debug(encoded_response)
chat_msg = ChoiceDelta(**encoded_response)
usage_infos.append(UsageInfo(**encoded_response))
choice = ChatCompletionStreamingChoice(
index=i, delta=chat_msg, system_fingerprint="", usage=usage, finish_reason=None
index=i, delta=chat_msg, system_fingerprint="", finish_reason=None
)

choices.append(choice)

chunk = ChatCompletionChunk(model=model, choices=choices, usage=usage).json()
# Only use the last item from encode_response
usage_info = sum(usage_infos)
chunk = ChatCompletionChunk(model=model, choices=choices, usage=None).json()
logger.debug(chunk)
yield f"data: {chunk}\n\n"

choices = [
ChatCompletionStreamingChoice(index=i, delta=ChoiceDelta(), finish_reason="stop") for i in range(request.n)
ChatCompletionStreamingChoice(
index=i,
delta=ChoiceDelta(),
finish_reason="stop",
)
for i in range(request.n)
]
last_chunk = ChatCompletionChunk(
model=model,
choices=choices,
usage=usage,
usage=usage_info,
).json()
yield f"data: {last_chunk}\n\n"
yield "data: [DONE]\n\n"

async def non_streaming_completion(self, request: ChatCompletionRequest, generator_list: List[AsyncGenerator]):
model = request.model
usage = UsageInfo()
usage_infos = []
choices = []
# iterate over n choices
for i, streaming_response in enumerate(generator_list):
msgs = []
tool_calls = None
usage = None
async for response, status in streaming_response:
if status == LitAPIStatus.ERROR:
load_and_raise(response)
chat_msg = json.loads(response)
logger.debug(chat_msg)
chat_msg = ChatMessage(**chat_msg)
# data from LitAPI.encode_response
encoded_response = json.loads(response)
logger.debug(encoded_response)
chat_msg = ChatMessage(**encoded_response)
usage = UsageInfo(**encoded_response)
msgs.append(chat_msg.content)
if chat_msg.tool_calls:
tool_calls = chat_msg.tool_calls
Expand All @@ -368,5 +403,6 @@ async def non_streaming_completion(self, request: ChatCompletionRequest, generat
msg = {"role": "assistant", "content": content, "tool_calls": tool_calls}
choice = ChatCompletionResponseChoice(index=i, message=msg, finish_reason="stop")
choices.append(choice)
usage_infos.append(usage) # Only use the last item from encode_response

return ChatCompletionResponse(model=model, choices=choices, usage=usage)
return ChatCompletionResponse(model=model, choices=choices, usage=sum(usage_infos))
lantiga marked this conversation as resolved.
Show resolved Hide resolved
25 changes: 25 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,31 @@ def openai_request_data():
}


@pytest.fixture()
def openai_response_data():
return {
"id": "chatcmpl-9dEtoQu4g45g3431SZ2s98S",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": None,
"message": {
"content": "10 + 6 is equal to 16.",
"role": "assistant",
"function_call": None,
"tool_calls": None,
},
}
],
"created": 1719139092,
"model": "gpt-3.5-turbo-0125",
"object": "chat.completion",
"system_fingerprint": None,
"usage": {"completion_tokens": 10, "prompt_tokens": 25, "total_tokens": 35},
}


@pytest.fixture()
def openai_request_data_with_image():
return {
Expand Down
36 changes: 35 additions & 1 deletion tests/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@
from asgi_lifespan import LifespanManager
from fastapi import HTTPException
from httpx import AsyncClient
from litserve.examples.openai_spec_example import TestAPI, TestAPIWithCustomEncode, TestAPIWithToolCalls
from litserve.examples.openai_spec_example import (
TestAPI,
TestAPIWithCustomEncode,
TestAPIWithToolCalls,
OpenAIWithUsage,
OpenAIBatchingWithUsage,
OpenAIWithUsageEncodeResponse,
)
from litserve.specs.openai import OpenAISpec, ChatMessage
import litserve as ls

Expand All @@ -34,6 +41,33 @@ async def test_openai_spec(openai_request_data):
), "LitAPI predict response should match with the generated output"


# OpenAIWithUsage
@pytest.mark.asyncio()
@pytest.mark.parametrize(
("api", "batch_size"),
[
(OpenAIWithUsage(), 1),
(OpenAIWithUsageEncodeResponse(), 1),
(OpenAIBatchingWithUsage(), 2),
],
)
async def test_openai_token_usage(api, batch_size, openai_request_data, openai_response_data):
server = ls.LitServer(api, spec=ls.OpenAISpec(), max_batch_size=batch_size, batch_timeout=0.01)
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10)
assert resp.status_code == 200, "Status code should be 200"
result = resp.json()
content = result["choices"][0]["message"]["content"]
assert content == "10 + 6 is equal to 16.", "LitAPI predict response should match with the generated output"
assert result["usage"] == openai_response_data["usage"]

# with streaming
openai_request_data["stream"] = True
resp = await ac.post("/v1/chat/completions", json=openai_request_data, timeout=10)
assert resp.status_code == 200, "Status code should be 200"
assert result["usage"] == openai_response_data["usage"]


@pytest.mark.asyncio()
async def test_openai_spec_with_image(openai_request_data_with_image):
spec = OpenAISpec()
Expand Down
Loading