Skip to content

Commit

Permalink
Add instance_id param to chat completions
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanBratanov committed Feb 13, 2024
1 parent 900a8e6 commit 97562b8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ public record CreateChatCompletionRequest(
String model,
List<ChatMessage> messages,
Optional<Double> frequencyPenalty,
Optional<String> instanceId,
Optional<Map<Integer, Integer>> logitBias,
Optional<Boolean> logprobs,
Optional<Integer> topLogprobs,
Expand Down Expand Up @@ -48,6 +49,7 @@ public static class Builder {
private final List<ChatMessage> messages = new LinkedList<>();

private Optional<Double> frequencyPenalty = Optional.empty();
private Optional<String> instanceId = Optional.empty();
private Optional<Map<Integer, Integer>> logitBias = Optional.empty();
private Optional<Boolean> logprobs = Optional.empty();
private Optional<Integer> topLogprobs = Optional.empty();
Expand Down Expand Up @@ -102,6 +104,15 @@ public Builder frequencyPenalty(double frequencyPenalty) {
return this;
}

/**
* @param instanceId A unique identifier to a custom instance to execute the request. The
* requesting organization is required to have access to the instance.
*/
public Builder instanceId(String instanceId) {
this.instanceId = Optional.of(instanceId);
return this;
}

/**
* @param logitBias A map that maps tokens (specified by their token ID in the tokenizer) to an
* associated bias value from -100 to 100. Mathematically, the bias is added to the logits
Expand Down Expand Up @@ -305,6 +316,7 @@ public CreateChatCompletionRequest build() {
model,
List.copyOf(messages),
frequencyPenalty,
instanceId,
logitBias,
logprobs,
topLogprobs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public CreateChatCompletionRequest randomCreateChatCompletionRequest() {
.model(randomModel())
.messages(listOf(randomInt(1, 3), this::randomChatMessage))
.frequencyPenalty(randomDouble(-2.0, 2.0))
.instanceId(randomString(7))
.logitBias(randomLogitBias(randomInt(0, 6)))
.logprobs(randomBoolean())
.topLogprobs(randomInt(0, 5))
Expand Down

0 comments on commit 97562b8

Please sign in to comment.