Skip to content

Commit

Permalink
chore: add tokenizer example [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
aallam committed Oct 27, 2023
1 parent ba2f4c8 commit d4785dd
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 0 deletions.
1 change: 1 addition & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ okio-fakefilesystem = { group = "com.squareup.okio", name = "okio-fakefilesystem
logback-classic = { group = "ch.qos.logback", name = "logback-classic", version.ref = "logback" }
# ulid
ulid = { group = "com.aallam.ulid", name = "ulid-kotlin", version = "1.2.1" }
ktoken = { group = "com.aallam.ktoken", name = "ktoken", version = "0.3.0" }

[plugins]
kotlin-multiplaform = { id = "org.jetbrains.kotlin.multiplatform", version.ref = "kotlin" }
Expand Down
1 change: 1 addition & 0 deletions sample/jvm/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies {
//implementation("com.aallam.openai:openai-client:<version>")
implementation(projects.openaiClient)
implementation(libs.ktor.client.apache)
implementation(libs.ktoken)
}

application {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ fun main() = runBlocking {
println("5 - Chat")
println("6 - Chat (w/ Function)")
println("7 - Whisper")
println("8 - Tokens")
println("0 - Quit")

when (val option = readlnOrNull()?.toIntOrNull()) {
Expand All @@ -29,6 +30,7 @@ fun main() = runBlocking {
5 -> chat(openAI)
6 -> chatFunctionCall(openAI)
7 -> whisper(openAI)
8 -> tokensCount(openAI)
0 -> {
println("Exiting...")
return@runBlocking
Expand Down
89 changes: 89 additions & 0 deletions sample/jvm/src/main/kotlin/com/aallam/openai/sample/jvm/Tokens.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package com.aallam.openai.sample.jvm

import com.aallam.ktoken.Tokenizer
import com.aallam.openai.api.chat.ChatCompletionRequest
import com.aallam.openai.api.chat.ChatMessage
import com.aallam.openai.api.chat.ChatRole
import com.aallam.openai.api.model.ModelId
import com.aallam.openai.client.OpenAI

suspend fun tokensCount(openAI: OpenAI) {
val messages = listOf(
ChatMessage(
role = ChatRole.System,
content = "You are a helpful, pattern-following assistant that translates corporate jargon into plain English.",
),
ChatMessage(
role = ChatRole.System,
name = "example_user",
content = "New synergies will help drive top-line growth.",
),
ChatMessage(
role = ChatRole.System,
name = "example_assistant",
content = "Things working well together will increase revenue.",
),
ChatMessage(
role = ChatRole.System,
name = "example_user",
content = "Let's circle back when we have more bandwidth to touch base on opportunities for increased leverage.",
),
ChatMessage(
role = ChatRole.System,
name = "example_assistant",
content = "Let's talk later when we're less busy about how to do better.",
),
ChatMessage(
role = ChatRole.User,
content = "This late pivot means we don't have time to boil the ocean for the client deliverable.",
),
)

val models = listOf(
"gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo", "gpt-4-0314", "gpt-4-0613", "gpt-4",
)
for (model in models) {
println(model)
val tokens = encoding(messages, model)
println("$tokens prompt tokens counted by Ktoken.")
val request = ChatCompletionRequest(
model = ModelId(model),
messages = messages,
temperature = 0.0,
maxTokens = 1,
)
val completion = openAI.chatCompletion(request)
println("${completion.usage?.promptTokens} prompt tokens counted by the OpenAI API\n")
}
}

suspend fun encoding(messages: List<ChatMessage>, model: String): Int {
val (tokensPerMessage, tokensPerName) = when (model) {
"gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-4-0314", "gpt-4-32k-0314", "gpt-4-0613", "gpt-4-32k-0613" -> 3 to 1
"gpt-3.5-turbo-0301" -> 4 to -1 // every message follows <|start|>{role/name}\n{content}<|end|>\n
"gpt-3.5-turbo" -> {
println("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return encoding(messages, "gpt-3.5-turbo-0613")
}

"gpt-4" -> {
println("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return encoding(messages, "gpt-4-0613")
}

else -> error("unsupported model")
}

val tokenizer = Tokenizer.of(model)
var numTokens = 0
for (message in messages) {
numTokens += tokensPerMessage
message.run {
numTokens += tokenizer.encode(role.role).size
name?.let { numTokens += tokensPerName + tokenizer.encode(it).size }
content?.let { numTokens += tokenizer.encode(it).size }
}
}
numTokens += 3 // every reply is primed with <|start|>assistant<|message|>
return numTokens
}

0 comments on commit d4785dd

Please sign in to comment.