Skip to content

Commit

Permalink
Add support for usage stats when streaming with the Chat Completions API
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanBratanov committed May 7, 2024
1 parent 4e93533 commit d28acf0
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
* provided input.
*/
public record ChatCompletionChunk(
String id, List<Choice> choices, long created, String model, String systemFingerprint) {
String id,
List<Choice> choices,
long created,
String model,
String systemFingerprint,
Usage usage) {

public record Choice(Delta delta, int index, Logprobs logprobs, String finishReason) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public record CreateChatCompletionRequest(
Optional<Integer> seed,
Optional<List<String>> stop,
Optional<Boolean> stream,
Optional<StreamOptions> streamOptions,
Optional<Double> temperature,
Optional<Double> topP,
Optional<List<Tool>> tools,
Expand All @@ -29,6 +30,18 @@ public static Builder newBuilder() {
return new Builder();
}

/**
* @param includeUsage If set, an additional chunk will be streamed before the data: [DONE]
* message. The usage field on this chunk shows the token usage statistics for the entire
* request, and the choices field will always be an empty array. All other chunks will also
* include a usage field, but with a null value.
*/
public record StreamOptions(Boolean includeUsage) {
public static StreamOptions withUsageIncluded() {
return new StreamOptions(true);
}
}

public static class Builder {

private static final String DEFAULT_MODEL = OpenAIModel.GPT_3_5_TURBO.getId();
Expand All @@ -48,6 +61,7 @@ public static class Builder {
private Optional<Integer> seed = Optional.empty();
private final List<String> stop = new LinkedList<>();
private Optional<Boolean> stream = Optional.empty();
private Optional<StreamOptions> streamOptions = Optional.empty();
private Optional<Double> temperature = Optional.empty();
private Optional<Double> topP = Optional.empty();
private final List<Tool> tools = new LinkedList<>();
Expand Down Expand Up @@ -195,6 +209,14 @@ public Builder stream(boolean stream) {
return this;
}

/**
* @param streamOptions Options for streaming response. Only set this when you set stream: true.
*/
public Builder streamOptions(StreamOptions streamOptions) {
this.streamOptions = Optional.of(streamOptions);
return this;
}

/**
* @param temperature What sampling temperature to use, between 0 and 2. Higher values like 0.8
* will make the output more random, while lower values like 0.2 will make it more focused
Expand Down Expand Up @@ -287,6 +309,7 @@ public CreateChatCompletionRequest build() {
seed,
stop.isEmpty() ? Optional.empty() : Optional.of(List.copyOf(stop)),
stream,
streamOptions,
temperature,
topP,
tools.isEmpty() ? Optional.empty() : Optional.of(List.copyOf(tools)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ public void setupModule(SetupContext context) {
}

/**
* Remove when the following issue is resolved: <a
* Remove when the following issues are resolved: <a
* href="https://github.com/FasterXML/jackson-databind/issues/2992">Properties naming strategy do
* not work with Record #2992</a>
* not work with Record #2992</a> and <a
* href="https://github.com/FasterXML/jackson-databind/issues/4515">Rewrite Bean Property
* Introspection logic in Jackson 2.x (ideally for 2.18) #4515</a>
*/
private static class ValueInstantiatorsModifier extends ValueInstantiators.Base {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import static org.junit.jupiter.api.Assertions.assertThrows;

import io.github.stefanbratanov.jvm.openai.ChatMessage.UserMessage.UserMessageWithContentParts.ContentPart.TextContentPart;
import io.github.stefanbratanov.jvm.openai.CreateChatCompletionRequest.StreamOptions;
import java.io.UncheckedIOException;
import java.net.http.HttpTimeoutException;
import java.nio.file.Path;
Expand Down Expand Up @@ -107,11 +108,21 @@ void testChatClient() {
// test sending content part
.message(ChatMessage.userMessage(new TextContentPart("Say this is a test")))
.stream(true)
// test usage stats
.streamOptions(StreamOptions.withUsageIncluded())
.build();

String joinedContent =
chatClient
.streamChatCompletion(streamRequest)
.filter(
chunk -> {
if (chunk.choices().isEmpty()) {
assertThat(chunk.usage()).isNotNull();
return false;
}
return true;
})
.map(ChatCompletionChunk::choices)
.map(
choices -> {
Expand All @@ -123,6 +134,12 @@ void testChatClient() {

assertThat(joinedContent).containsPattern("(?i)this is (a|the) test");

streamRequest =
CreateChatCompletionRequest.newBuilder()
.message(ChatMessage.userMessage("Say this is a test"))
.stream(true)
.build();

// test streaming with a subscriber
CompletableFuture<String> joinedContentFuture = new CompletableFuture<>();
chatClient.streamChatCompletion(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.github.stefanbratanov.jvm.openai;

import io.github.stefanbratanov.jvm.openai.ChatMessage.UserMessage.UserMessageWithContentParts.ContentPart;
import io.github.stefanbratanov.jvm.openai.CreateChatCompletionRequest.StreamOptions;
import io.github.stefanbratanov.jvm.openai.FineTuningJobIntegration.Wandb;
import io.github.stefanbratanov.jvm.openai.RunStepsClient.PaginatedThreadRunSteps;
import io.github.stefanbratanov.jvm.openai.ThreadMessage.Content.ImageFileContent;
Expand Down Expand Up @@ -46,6 +47,7 @@ public CreateChatCompletionRequest randomCreateChatCompletionRequest() {
.seed(randomInt())
.stop(arrayOf(randomInt(0, 4), () -> randomString(5), String[]::new))
.stream(randomBoolean())
.streamOptions(StreamOptions.withUsageIncluded())
.temperature(randomDouble(0.0, 2.0))
.topP(randomDouble(0.0, 1.0))
.tools(listOf(randomInt(0, 5), this::randomFunctionTool));
Expand Down

0 comments on commit d28acf0

Please sign in to comment.