Skip to content

Commit

Permalink
feat: add mergeToChatMessage (#250)
Browse files Browse the repository at this point in the history
  • Loading branch information
aallam authored Nov 5, 2023
1 parent d4785dd commit 570608e
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 37 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.aallam.openai.client.extension

import com.aallam.openai.api.ExperimentalOpenAI
import com.aallam.openai.api.chat.ChatChunk
import com.aallam.openai.api.chat.ChatMessage
import com.aallam.openai.client.extension.internal.ChatMessageAssembler

/**
* Merges a list of [ChatChunk]s into a single consolidated [ChatMessage].
*/
@ExperimentalOpenAI
public fun List<ChatChunk>.mergeToChatMessage(): ChatMessage {
return fold(ChatMessageAssembler()) { assembler, chatChunk -> assembler.merge(chatChunk) }.build()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package com.aallam.openai.client.extension.internal

import com.aallam.openai.api.chat.*

/**
* A class to help assemble chat messages from chat chunks.
*/
internal class ChatMessageAssembler {
private val chatFuncName = StringBuilder()
private val chatFuncArgs = StringBuilder()
private val chatContent = StringBuilder()
private var chatRole: ChatRole? = null

/**
* Merges a chat chunk into the chat message being assembled.
*/
fun merge(chunk: ChatChunk): ChatMessageAssembler {
chunk.delta.run {
role?.let { chatRole = it }
content?.let { chatContent.append(it) }
functionCall?.let { call ->
call.nameOrNull?.let { chatFuncName.append(it) }
call.argumentsOrNull?.let { chatFuncArgs.append(it) }
}
}
return this
}

/**
* Builds and returns the assembled chat message.
*/
fun build(): ChatMessage = chatMessage {
this.role = chatRole
this.content = chatContent.toString()
if (chatFuncName.isNotEmpty() || chatFuncArgs.isNotEmpty()) {
this.functionCall = FunctionCall(chatFuncName.toString(), chatFuncArgs.toString())
this.name = chatFuncName.toString()
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package com.aallam.openai.client

import com.aallam.openai.api.chat.ChatChunk
import com.aallam.openai.api.chat.ChatDelta
import com.aallam.openai.api.chat.ChatMessage
import com.aallam.openai.api.chat.ChatRole
import com.aallam.openai.api.core.FinishReason
import com.aallam.openai.client.extension.mergeToChatMessage
import kotlin.test.Test
import kotlin.test.assertEquals

class TestChatChunk {

@Test
fun testMerge() {
val chunks = listOf(
ChatChunk(
index = 0,
delta = ChatDelta(
role = ChatRole(role = "assistant"),
content = ""
),
finishReason = null
),
ChatChunk(
index = 0,
delta = ChatDelta(
role = null,
content = "The"
),
finishReason = null
),
ChatChunk(
index = 0,
delta = ChatDelta(
role = null,
content = " World"
),
finishReason = null
),
ChatChunk(
index = 0,
delta = ChatDelta(
role = null,
content = " Series"
),
finishReason = null
),
ChatChunk(
index = 0,
delta = ChatDelta(
role = null,
content = " in"
),
finishReason = null
),
ChatChunk(
index = 0,
delta = ChatDelta(
role = null,
content = " "
),
finishReason = null
),
ChatChunk(
index = 0,
delta = ChatDelta(
role = null,
content = "202"
),
finishReason = null
),
ChatChunk(
index = 0,
delta = ChatDelta(
role = null,
content = "0"
),
finishReason = null
),
ChatChunk(
index = 0,
delta = ChatDelta(
role = null,
content = " is"
),
finishReason = null
),
ChatChunk(
index = 0,
delta = ChatDelta(
role = null,
content = " being held"
),
finishReason = null
),
ChatChunk(
index = 0,
delta = ChatDelta(
role = null,
content = " in"
),
finishReason = null
),
ChatChunk(
index = 0,
delta = ChatDelta(
role = null,
content = " Texas"
),
finishReason = null
),
ChatChunk(
index = 0,
delta = ChatDelta(
role = null,
content = "."
),
finishReason = null
),
ChatChunk(
index = 0,
delta = ChatDelta(
role = null,
content = null
),
finishReason = FinishReason(value = "stop")
)
)
val chatMessage = chunks.mergeToChatMessage()
val message = ChatMessage(
role = ChatRole.Assistant,
content = "The World Series in 2020 is being held in Texas.",
name = null,
functionCall = null
)
assertEquals(chatMessage, message)
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package com.aallam.openai.sample.jvm

import com.aallam.openai.api.ExperimentalOpenAI
import com.aallam.openai.api.chat.*
import com.aallam.openai.api.model.ModelId
import com.aallam.openai.client.OpenAI
import com.aallam.openai.client.extension.mergeToChatMessage
import kotlinx.coroutines.flow.*
import kotlinx.serialization.Serializable
import kotlinx.serialization.encodeToString
Expand All @@ -12,6 +14,7 @@ import kotlinx.serialization.json.*
* This code snippet demonstrates the use of OpenAI's chat completion capabilities
* with a focus on integrating function calls into the chat conversation.
*/
@OptIn(ExperimentalOpenAI::class)
suspend fun chatFunctionCall(openAI: OpenAI) {
// *** Chat Completion with Function Call *** //

Expand Down Expand Up @@ -73,8 +76,8 @@ suspend fun chatFunctionCall(openAI: OpenAI) {
println("\n> Create Chat Completion function call (stream)...")
val chatMessage = openAI.chatCompletions(request)
.map { completion -> completion.choices.first() }
.fold(initial = ChatMessageAssembler()) { assembler, chunk -> assembler.merge(chunk) }
.build()
.toList()
.mergeToChatMessage()

chatMessages.append(chatMessage)
chatMessage.functionCall?.let { functionCall ->
Expand Down Expand Up @@ -140,38 +143,3 @@ private fun MutableList<ChatMessage>.append(message: ChatMessage) {
private fun MutableList<ChatMessage>.append(functionCall: FunctionCall, functionResponse: String) {
add(ChatMessage(role = ChatRole.Function, name = functionCall.name, content = functionResponse))
}

/**
* A class to help assemble chat messages from chat chunks.
*/
class ChatMessageAssembler {
private val chatFuncName = StringBuilder()
private val chatFuncArgs = StringBuilder()
private val chatContent = StringBuilder()
private var chatRole: ChatRole? = null

/**
* Merges a chat chunk into the chat message being assembled.
*/
fun merge(chunk: ChatChunk): ChatMessageAssembler {
chatRole = chunk.delta.role ?: chatRole
chunk.delta.content?.let { chatContent.append(it) }
chunk.delta.functionCall?.let { call ->
call.nameOrNull?.let { chatFuncName.append(it) }
call.argumentsOrNull?.let { chatFuncArgs.append(it) }
}
return this
}

/**
* Builds and returns the assembled chat message.
*/
fun build(): ChatMessage = chatMessage {
this.role = chatRole
this.content = chatContent.toString()
if (chatFuncName.isNotEmpty() || chatFuncArgs.isNotEmpty()) {
this.functionCall = FunctionCall(chatFuncName.toString(), chatFuncArgs.toString())
this.name = chatFuncName.toString()
}
}
}

0 comments on commit 570608e

Please sign in to comment.