Skip to content

Commit

Permalink
Davidmotson.code execution (#186)
Browse files Browse the repository at this point in the history
Co-authored-by: David Motsonashvili <davidmotson@google.com>
  • Loading branch information
davidmotson and David Motsonashvili committed Jun 28, 2024
1 parent 6c5550f commit 69329f7
Show file tree
Hide file tree
Showing 13 changed files with 274 additions and 6 deletions.
1 change: 1 addition & 0 deletions .changes/common/angle-carpenter-beam-clock.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MAJOR","changes":["add code execution tool"]}
1 change: 1 addition & 0 deletions .changes/generativeai/direction-bee-brass-aftermath.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MAJOR","changes":["add code execution tool"]}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.google.ai.client.generativeai.common.client

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonObject

@Serializable
data class GenerationConfig(
Expand All @@ -33,7 +34,12 @@ data class GenerationConfig(
@SerialName("response_schema") val responseSchema: Schema? = null,
)

@Serializable data class Tool(val functionDeclarations: List<FunctionDeclaration>)
@Serializable
data class Tool(
val functionDeclarations: List<FunctionDeclaration>? = null,
// This is a json object because it is not possible to make a data class with no parameters.
val codeExecution: JsonObject? = null,
)

@Serializable
data class ToolConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ data class Content(@EncodeDefault val role: String? = "user", val parts: List<Pa

@Serializable data class FunctionResponsePart(val functionResponse: FunctionResponse) : Part

@Serializable data class ExecutableCodePart(val executableCode: ExecutableCode) : Part

@Serializable
data class CodeExecutionResultPart(val codeExecutionResult: CodeExecutionResult) : Part

@Serializable data class FunctionResponse(val name: String, val response: JsonObject)

@Serializable data class FunctionCall(val name: String, val args: Map<String, String?>)
Expand All @@ -71,6 +76,18 @@ data class FileData(

@Serializable data class Blob(@SerialName("mime_type") val mimeType: String, val data: Base64)

@Serializable data class ExecutableCode(val language: String, val code: String)

@Serializable data class CodeExecutionResult(val outcome: Outcome, val output: String)

@Serializable
enum class Outcome {
@SerialName("OUTCOME_UNSPECIFIED") UNSPECIFIED,
OUTCOME_OK,
OUTCOME_FAILED,
OUTCOME_DEADLINE_EXCEEDED,
}

@Serializable
data class SafetySetting(
val category: HarmCategory,
Expand Down Expand Up @@ -101,8 +118,10 @@ object PartSerializer : JsonContentPolymorphicSerializer<Part>(Part::class) {
"text" in jsonObject -> TextPart.serializer()
"functionCall" in jsonObject -> FunctionCallPart.serializer()
"functionResponse" in jsonObject -> FunctionResponsePart.serializer()
"inline_data" in jsonObject -> BlobPart.serializer()
"file_data" in jsonObject -> FileDataPart.serializer()
"inlineData" in jsonObject -> BlobPart.serializer()
"fileData" in jsonObject -> FileDataPart.serializer()
"executableCode" in jsonObject -> ExecutableCodePart.serializer()
"codeExecutionResult" in jsonObject -> CodeExecutionResultPart.serializer()
else -> throw SerializationException("Unknown Part type")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.ai.client.generativeai.common

import com.google.ai.client.generativeai.common.client.FunctionCallingConfig
import com.google.ai.client.generativeai.common.client.Tool
import com.google.ai.client.generativeai.common.client.ToolConfig
import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.TextPart
Expand All @@ -43,6 +44,7 @@ import kotlin.time.Duration.Companion.seconds
import kotlinx.coroutines.delay
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.JsonObject
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
Expand Down Expand Up @@ -259,6 +261,41 @@ internal class RequestFormatTests {

mockEngine.requestHistory.first().headers.contains("header1") shouldBe false
}

@Test
fun `code execution tool serialization contains correct keys`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
val mockEngine = MockEngine {
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }

val controller =
APIController(
"super_cool_test_key",
"gemini-pro-1.0",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
null,
)

withTimeout(5.seconds) {
controller
.generateContentStream(
GenerateContentRequest(
model = "unused",
contents = listOf(Content(parts = listOf(TextPart("Arbitrary")))),
tools = listOf(Tool(codeExecution = JsonObject(emptyMap()))),
)
)
.collect { channel.close() }
}

val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text

requestBodyAsText shouldContainJsonKey "tools[0].codeExecution"
}
}

@RunWith(Parameterized::class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@ import com.google.ai.client.generativeai.common.server.BlockReason
import com.google.ai.client.generativeai.common.server.FinishReason
import com.google.ai.client.generativeai.common.server.HarmProbability
import com.google.ai.client.generativeai.common.server.HarmSeverity
import com.google.ai.client.generativeai.common.shared.CodeExecutionResult
import com.google.ai.client.generativeai.common.shared.CodeExecutionResultPart
import com.google.ai.client.generativeai.common.shared.ExecutableCode
import com.google.ai.client.generativeai.common.shared.ExecutableCodePart
import com.google.ai.client.generativeai.common.shared.FunctionCallPart
import com.google.ai.client.generativeai.common.shared.HarmCategory
import com.google.ai.client.generativeai.common.shared.Outcome
import com.google.ai.client.generativeai.common.shared.TextPart
import com.google.ai.client.generativeai.common.util.goldenUnaryFile
import com.google.ai.client.generativeai.common.util.shouldNotBeNullOrEmpty
Expand Down Expand Up @@ -331,4 +336,23 @@ internal class UnarySnapshotTests {
callPart.functionCall.args["current"] shouldBe "true"
}
}

@Test
fun `code execution parses correctly`() =
goldenUnaryFile("success-code-execution.json") {
withTimeout(testTimeout) {
val response = apiController.generateContent(textGenerateContentRequest("prompt"))
val content = response.candidates.shouldNotBeNullOrEmpty().first().content
content.shouldNotBeNull()
val executableCodePart = content.parts[0]
val codeExecutionResult = content.parts[1]

executableCodePart.shouldBe(
ExecutableCodePart(ExecutableCode("PYTHON", "print(\"Hello World\")"))
)
codeExecutionResult.shouldBe(
CodeExecutionResultPart(CodeExecutionResult(Outcome.OUTCOME_OK, "Hello World"))
)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
{
"candidates": [
{
"content": {
"parts": [
{
"executableCode": {
"language": "PYTHON",
"code": "print(\"Hello World\")"
}
},
{
"codeExecutionResult": {
"outcome": "OUTCOME_OK",
"output": "Hello World"
}
}
],
"role": "model"
},
"finishReason": "STOP",
"index": 0,
"safetyRatings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE"
}
]
}
],
"usageMetadata": {
"promptTokenCount": 774,
"candidatesTokenCount": 4176,
"totalTokenCount": 4950
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ import com.google.ai.client.generativeai.common.server.PromptFeedback
import com.google.ai.client.generativeai.common.server.SafetyRating
import com.google.ai.client.generativeai.common.shared.Blob
import com.google.ai.client.generativeai.common.shared.BlobPart
import com.google.ai.client.generativeai.common.shared.CodeExecutionResult
import com.google.ai.client.generativeai.common.shared.CodeExecutionResultPart
import com.google.ai.client.generativeai.common.shared.Content
import com.google.ai.client.generativeai.common.shared.ExecutableCode
import com.google.ai.client.generativeai.common.shared.ExecutableCodePart
import com.google.ai.client.generativeai.common.shared.FileData
import com.google.ai.client.generativeai.common.shared.FileDataPart
import com.google.ai.client.generativeai.common.shared.FunctionCall
Expand All @@ -42,11 +46,13 @@ import com.google.ai.client.generativeai.common.shared.FunctionResponse
import com.google.ai.client.generativeai.common.shared.FunctionResponsePart
import com.google.ai.client.generativeai.common.shared.HarmBlockThreshold
import com.google.ai.client.generativeai.common.shared.HarmCategory
import com.google.ai.client.generativeai.common.shared.Outcome
import com.google.ai.client.generativeai.common.shared.Part
import com.google.ai.client.generativeai.common.shared.SafetySetting
import com.google.ai.client.generativeai.common.shared.TextPart
import com.google.ai.client.generativeai.type.BlockThreshold
import com.google.ai.client.generativeai.type.CitationMetadata
import com.google.ai.client.generativeai.type.ExecutionOutcome
import com.google.ai.client.generativeai.type.FunctionCallingConfig
import com.google.ai.client.generativeai.type.FunctionDeclaration
import com.google.ai.client.generativeai.type.ImagePart
Expand Down Expand Up @@ -80,6 +86,10 @@ internal fun com.google.ai.client.generativeai.type.Part.toInternal(): Part {
FunctionResponsePart(FunctionResponse(name, response.toInternal()))
is com.google.ai.client.generativeai.type.FileDataPart ->
FileDataPart(FileData(fileUri = uri, mimeType = mimeType))
is com.google.ai.client.generativeai.type.ExecutableCodePart ->
ExecutableCodePart(ExecutableCode(language, code))
is com.google.ai.client.generativeai.type.CodeExecutionResultPart ->
CodeExecutionResultPart(CodeExecutionResult(outcome.toInternal(), output))
else ->
throw SerializationException(
"The given subclass of Part (${javaClass.simpleName}) is not supported in the serialization yet."
Expand Down Expand Up @@ -122,8 +132,19 @@ internal fun BlockThreshold.toInternal() =
BlockThreshold.UNSPECIFIED -> HarmBlockThreshold.UNSPECIFIED
}

internal fun ExecutionOutcome.toInternal() =
when (this) {
ExecutionOutcome.UNSPECIFIED -> Outcome.UNSPECIFIED
ExecutionOutcome.OK -> Outcome.OUTCOME_OK
ExecutionOutcome.FAILED -> Outcome.OUTCOME_FAILED
ExecutionOutcome.DEADLINE_EXCEEDED -> Outcome.OUTCOME_DEADLINE_EXCEEDED
}

internal fun Tool.toInternal() =
com.google.ai.client.generativeai.common.client.Tool(functionDeclarations.map { it.toInternal() })
com.google.ai.client.generativeai.common.client.Tool(
functionDeclarations?.map { it.toInternal() },
codeExecution = codeExecution?.toInternal(),
)

internal fun ToolConfig.toInternal() =
com.google.ai.client.generativeai.common.client.ToolConfig(
Expand Down Expand Up @@ -204,6 +225,16 @@ internal fun Part.toPublic(): com.google.ai.client.generativeai.type.Part {
)
is FileDataPart ->
com.google.ai.client.generativeai.type.FileDataPart(fileData.fileUri, fileData.mimeType)
is ExecutableCodePart ->
com.google.ai.client.generativeai.type.ExecutableCodePart(
executableCode.language,
executableCode.code,
)
is CodeExecutionResultPart ->
com.google.ai.client.generativeai.type.CodeExecutionResultPart(
codeExecutionResult.outcome.toPublic(),
codeExecutionResult.output,
)
else ->
throw SerializationException(
"Unsupported part type \"${javaClass.simpleName}\" provided. This model may not be supported by this SDK."
Expand Down Expand Up @@ -267,6 +298,14 @@ internal fun BlockReason.toPublic() =
BlockReason.UNKNOWN -> com.google.ai.client.generativeai.type.BlockReason.UNKNOWN
}

internal fun Outcome.toPublic() =
when (this) {
Outcome.UNSPECIFIED -> ExecutionOutcome.UNSPECIFIED
Outcome.OUTCOME_OK -> ExecutionOutcome.OK
Outcome.OUTCOME_FAILED -> ExecutionOutcome.FAILED
Outcome.OUTCOME_DEADLINE_EXCEEDED -> ExecutionOutcome.DEADLINE_EXCEEDED
}

internal fun GenerateContentResponse.toPublic() =
com.google.ai.client.generativeai.type.GenerateContentResponse(
candidates?.map { it.toPublic() }.orEmpty(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.ai.client.generativeai.type

enum class ExecutionOutcome {
UNSPECIFIED,
OK,
FAILED,
DEADLINE_EXCEEDED,
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,19 @@ class GenerateContentResponse(
) {
/** Convenience field representing all the text parts in the response, if they exists. */
val text: String? by lazy {
candidates.first().content.parts.filterIsInstance<TextPart>().joinToString(" ") { it.text }
candidates
.first()
.content
.parts
.filter { it is TextPart || it is ExecutableCodePart || it is CodeExecutionResultPart }
.joinToString(" ") {
when (it) {
is TextPart -> it.text
is ExecutableCodePart -> "\n```${it.language.lowercase()}\n${it.code}\n```"
is CodeExecutionResultPart -> "\n```\n${it.output}\n```"
else -> throw RuntimeException("unreachable")
}
}
}

/** Convenience field representing the first function call part in the request, if it exists */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ import org.json.JSONObject
* * [ImagePart] representing image data.
* * [BlobPart] representing MIME typed binary data.
* * [FileDataPart] representing MIME typed binary data.
* * [FunctionCallPart] representing a requested clientside function call by the model
* * [FunctionResponsePart] representing the result of a clientside function call
* * [ExecutableCodePart] representing code generated and executed by the model
* * [CodeExecutionResultPart] representing the result of running code generated by the model.
*/
interface Part

Expand Down Expand Up @@ -54,6 +58,12 @@ class FunctionCallPart(val name: String, val args: Map<String, String?>) : Part
/** Represents function call output to be returned to the model when it requests a function call */
class FunctionResponsePart(val name: String, val response: JSONObject) : Part

/** Represents an internal function call written by the model */
class ExecutableCodePart(val language: String, val code: String) : Part

/** Represents the results of an internal function call written by the model */
class CodeExecutionResultPart(val outcome: ExecutionOutcome, val output: String) : Part

/** @return The part as a [String] if it represents text, and null otherwise */
fun Part.asTextOrNull(): String? = (this as? TextPart)?.text

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,20 @@

package com.google.ai.client.generativeai.type

import org.json.JSONObject

/**
* Contains a set of function declarations that the model has access to. These can be used to gather
* information, or complete tasks
*
* @param functionDeclarations The set of functions that this tool allows the model access to
* @param codeExecution This is a flag value to enable Code Execution. Use [CODE_EXECUTION].
*/
class Tool(val functionDeclarations: List<FunctionDeclaration>)
class Tool(
val functionDeclarations: List<FunctionDeclaration>? = null,
val codeExecution: JSONObject? = null,
) {
companion object {
val CODE_EXECUTION = Tool(codeExecution = JSONObject())
}
}
Loading

0 comments on commit 69329f7

Please sign in to comment.