Skip to content

Commit

Permalink
Simplify function calling (#176)
Browse files Browse the repository at this point in the history
Per [b/344929335](https://b.corp.google.com/issues/344929335),

This simplifies our function calling to not ascertain types. As a
byproduct, we also no longer provide a means to execute functions.
Additional documentation has been provided to showcase new usage.
  • Loading branch information
daymxn committed Jun 17, 2024
1 parent d8139ed commit 12274f4
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 219 deletions.
1 change: 1 addition & 0 deletions .changes/generativeai/beef-collar-burn-aftermath.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MAJOR","changes":["Simplify function calling and remove provided function execution."]}
3 changes: 2 additions & 1 deletion generativeai/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ android {
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
consumerProguardFiles("consumer-rules.pro")

buildConfigField("String", "VERSION_NAME", "\"${project.version.toString()}\"")
buildConfigField("String", "VERSION_NAME", "\"${project.version}\"")
}

publishing {
Expand Down Expand Up @@ -85,6 +85,7 @@ dependencies {
implementation("com.google.guava:listenablefuture:1.0")
implementation("androidx.concurrent:concurrent-futures:1.2.0-alpha03")
implementation("androidx.concurrent:concurrent-futures-ktx:1.2.0-alpha03")
testImplementation("org.json:json:20210307") // Required for JSONObject to function in tests
testImplementation("junit:junit:4.13.2")
testImplementation("io.kotest:kotest-assertions-core:5.5.5")
testImplementation("io.kotest:kotest-assertions-core-jvm:5.5.5")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,21 @@ import com.google.ai.client.generativeai.internal.util.toPublic
import com.google.ai.client.generativeai.type.Content
import com.google.ai.client.generativeai.type.CountTokensResponse
import com.google.ai.client.generativeai.type.FinishReason
import com.google.ai.client.generativeai.type.FourParameterFunction
import com.google.ai.client.generativeai.type.FunctionCallPart
import com.google.ai.client.generativeai.type.GenerateContentResponse
import com.google.ai.client.generativeai.type.GenerationConfig
import com.google.ai.client.generativeai.type.GoogleGenerativeAIException
import com.google.ai.client.generativeai.type.InvalidStateException
import com.google.ai.client.generativeai.type.NoParameterFunction
import com.google.ai.client.generativeai.type.OneParameterFunction
import com.google.ai.client.generativeai.type.PromptBlockedException
import com.google.ai.client.generativeai.type.RequestOptions
import com.google.ai.client.generativeai.type.ResponseStoppedException
import com.google.ai.client.generativeai.type.SafetySetting
import com.google.ai.client.generativeai.type.SerializationException
import com.google.ai.client.generativeai.type.ThreeParameterFunction
import com.google.ai.client.generativeai.type.Tool
import com.google.ai.client.generativeai.type.ToolConfig
import com.google.ai.client.generativeai.type.TwoParameterFunction
import com.google.ai.client.generativeai.type.content
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.map
import kotlinx.serialization.ExperimentalSerializationApi
import org.json.JSONObject

/**
* A facilitator for a given multimodal model (eg; Gemini).
Expand Down Expand Up @@ -199,36 +191,6 @@ internal constructor(
return countTokens(content { image(prompt) })
}

/**
* Executes a function requested by the model.
*
* @param functionCallPart A [FunctionCallPart] from the model, containing a function call and
* parameters
* @return The output of the requested function call
*/
suspend fun executeFunction(functionCallPart: FunctionCallPart): JSONObject {
if (tools == null) {
throw InvalidStateException("No registered tools")
}
val callable =
tools.flatMap { it.functionDeclarations }.firstOrNull { it.name == functionCallPart.name }
?: throw InvalidStateException("No registered function named ${functionCallPart.name}")
return when (callable) {
is NoParameterFunction -> callable.execute()
is OneParameterFunction<*> ->
(callable as OneParameterFunction<Any?>).execute(functionCallPart)
is TwoParameterFunction<*, *> ->
(callable as TwoParameterFunction<Any?, Any?>).execute(functionCallPart)
is ThreeParameterFunction<*, *, *> ->
(callable as ThreeParameterFunction<Any?, Any?, Any?>).execute(functionCallPart)
is FourParameterFunction<*, *, *, *> ->
(callable as FourParameterFunction<Any?, Any?, Any?, Any?>).execute(functionCallPart)
else -> {
throw RuntimeException("UNREACHABLE")
}
}
}

private fun constructRequest(vararg prompt: Content) =
GenerateContentRequest(
modelName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ internal fun com.google.ai.client.generativeai.type.Part.toInternal(): Part {
is com.google.ai.client.generativeai.type.BlobPart ->
BlobPart(Blob(mimeType, Base64.encodeToString(blob, BASE_64_FLAGS)))
is com.google.ai.client.generativeai.type.FunctionCallPart ->
FunctionCallPart(FunctionCall(name, args.orEmpty()))
FunctionCallPart(FunctionCall(name, args))
is com.google.ai.client.generativeai.type.FunctionResponsePart ->
FunctionResponsePart(FunctionResponse(name, response.toInternal()))
is com.google.ai.client.generativeai.type.FileDataPart ->
Expand Down Expand Up @@ -147,8 +147,8 @@ internal fun FunctionDeclaration.toInternal() =
name,
description,
Schema(
properties = getParameters().associate { it.name to it.toInternal() },
required = getParameters().map { it.name },
properties = parameters.associate { it.name to it.toInternal() },
required = requiredParameters,
type = "OBJECT",
nullable = false,
),
Expand Down Expand Up @@ -196,10 +196,7 @@ internal fun Part.toPublic(): com.google.ai.client.generativeai.type.Part {
}
}
is FunctionCallPart ->
com.google.ai.client.generativeai.type.FunctionCallPart(
functionCall.name,
functionCall.args.orEmpty(),
)
com.google.ai.client.generativeai.type.FunctionCallPart(functionCall.name, functionCall.args)
is FunctionResponsePart ->
com.google.ai.client.generativeai.type.FunctionResponsePart(
functionResponse.name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,144 +19,24 @@ package com.google.ai.client.generativeai.type
import org.json.JSONObject

/**
* A declared function, including implementation, that a model can be given access to in order to
* gain info or complete tasks.
*
* @property name The name of the function call, this should be clear and descriptive for the model
* @property description A description of what the function does and its output.
* @property function the function implementation
*/
class NoParameterFunction(
name: String,
description: String,
val function: suspend () -> JSONObject,
) : FunctionDeclaration(name, description) {
override fun getParameters() = listOf<Schema<Any>>()

suspend fun execute() = function()

override suspend fun execute(part: FunctionCallPart) = function()
}

/**
* A declared function, including implementation, that a model can be given access to in order to
* gain info or complete tasks.
*
* @property name The name of the function call, this should be clear and descriptive for the model
* @property description A description of what the function does and its output.
* @property param A description of the first function parameter
* @property function the function implementation
*/
class OneParameterFunction<T>(
name: String,
description: String,
val param: Schema<T>,
val function: suspend (T) -> JSONObject,
) : FunctionDeclaration(name, description) {
override fun getParameters() = listOf(param)

override suspend fun execute(part: FunctionCallPart): JSONObject {
val arg1 = part.getArgOrThrow(param)
return function(arg1)
}
}

/**
* A declared function, including implementation, that a model can be given access to in order to
* gain info or complete tasks.
*
* @property name The name of the function call, this should be clear and descriptive for the model
* @property description A description of what the function does and its output.
* @property param1 A description of the first function parameter
* @property param2 A description of the second function parameter
* @property function the function implementation
*/
class TwoParameterFunction<T, U>(
name: String,
description: String,
val param1: Schema<T>,
val param2: Schema<U>,
val function: suspend (T, U) -> JSONObject,
) : FunctionDeclaration(name, description) {
override fun getParameters() = listOf(param1, param2)

override suspend fun execute(part: FunctionCallPart): JSONObject {
val arg1 = part.getArgOrThrow(param1)
val arg2 = part.getArgOrThrow(param2)
return function(arg1, arg2)
}
}

/**
* A declared function, including implementation, that a model can be given access to in order to
* gain info or complete tasks.
*
* @property name The name of the function call, this should be clear and descriptive for the model
* @property description A description of what the function does and its output.
* @property param1 A description of the first function parameter
* @property param2 A description of the second function parameter
* @property param3 A description of the third function parameter
* @property function the function implementation
*/
class ThreeParameterFunction<T, U, V>(
name: String,
description: String,
val param1: Schema<T>,
val param2: Schema<U>,
val param3: Schema<V>,
val function: suspend (T, U, V) -> JSONObject,
) : FunctionDeclaration(name, description) {
override fun getParameters() = listOf(param1, param2, param3)

override suspend fun execute(part: FunctionCallPart): JSONObject {
val arg1 = part.getArgOrThrow(param1)
val arg2 = part.getArgOrThrow(param2)
val arg3 = part.getArgOrThrow(param3)
return function(arg1, arg2, arg3)
}
}

/**
* A declared function, including implementation, that a model can be given access to in order to
* gain info or complete tasks.
* Representation of a function that a model can invoke.
*
* @property name The name of the function call, this should be clear and descriptive for the model
* @property description A description of what the function does and its output.
* @property param1 A description of the first function parameter
* @property param2 A description of the second function parameter
* @property param3 A description of the third function parameter
* @property param4 A description of the fourth function parameter
* @property function the function implementation
* @see defineFunction
*/
class FourParameterFunction<T, U, V, W>(
name: String,
description: String,
val param1: Schema<T>,
val param2: Schema<U>,
val param3: Schema<V>,
val param4: Schema<W>,
val function: suspend (T, U, V, W) -> JSONObject,
) : FunctionDeclaration(name, description) {
override fun getParameters() = listOf(param1, param2, param3, param4)

override suspend fun execute(part: FunctionCallPart): JSONObject {
val arg1 = part.getArgOrThrow(param1)
val arg2 = part.getArgOrThrow(param2)
val arg3 = part.getArgOrThrow(param3)
val arg4 = part.getArgOrThrow(param4)
return function(arg1, arg2, arg3, arg4)
}
}

abstract class FunctionDeclaration(val name: String, val description: String) {
abstract fun getParameters(): List<Schema<out Any?>>

abstract suspend fun execute(part: FunctionCallPart): JSONObject
}
class FunctionDeclaration(
val name: String,
val description: String,
val parameters: List<Schema<*>>,
val requiredParameters: List<String>,
)

/**
* Represents a parameter for a declared function
*
* ```
* val currencyFrom = Schema.str("currencyFrom", "The currency to convert from.")
* ```
*
* @property name: The name of the parameter
* @property description: The description of what the parameter should contain or represent
* @property format: format information for the parameter, this can include bitlength in the case of
Expand All @@ -180,6 +60,21 @@ class Schema<T>(
val items: Schema<out Any>? = null,
val type: FunctionType<T>,
) {

/**
* Attempts to parse a string to the type [T] assigned to this schema.
*
* Will return null if the provided string is null. May also return null if the provided string is
* not a valid string of the expected type; but this should not be relied upon, as it may throw in
* certain scenarios (eg; the type is an object or array, and the string is not valid json).
*
* ```
* val currenciesSchema = Schema.arr("currencies", "The currencies available to use.")
* val currencies: List<String> = currenciesSchema.fromString("""
* ["USD", "EUR", "CAD", "GBP", "JPY"]
* """)
* ```
*/
fun fromString(value: String?) = type.parse(value)

companion object {
Expand Down Expand Up @@ -259,46 +154,31 @@ class Schema<T>(
}
}

fun defineFunction(name: String, description: String, function: suspend () -> JSONObject) =
NoParameterFunction(name, description, function)

fun <T> defineFunction(
name: String,
description: String,
arg1: Schema<T>,
function: suspend (T) -> JSONObject,
) = OneParameterFunction(name, description, arg1, function)

fun <T, U> defineFunction(
name: String,
description: String,
arg1: Schema<T>,
arg2: Schema<U>,
function: suspend (T, U) -> JSONObject,
) = TwoParameterFunction(name, description, arg1, arg2, function)

fun <T, U, W> defineFunction(
name: String,
description: String,
arg1: Schema<T>,
arg2: Schema<U>,
arg3: Schema<W>,
function: suspend (T, U, W) -> JSONObject,
) = ThreeParameterFunction(name, description, arg1, arg2, arg3, function)

fun <T, U, W, Z> defineFunction(
/**
* A declared function, including implementation, that a model can be given access to in order to
* gain info or complete tasks.
*
* ```
* val getExchangeRate = defineFunction(
* name = "getExchangeRate",
* description = "Get the exchange rate for currencies between countries.",
* parameters = listOf(
* Schema.str("currencyFrom", "The currency to convert from."),
* Schema.str("currencyTo", "The currency to convert to.")
* ),
* requiredParameters = listOf("currencyFrom", "currencyTo")
* )
* ```
*
* @param name The name of the function call, this should be clear and descriptive for the model.
* @param description A description of what the function does and its output.
* @param parameters A list of parameters that the function accepts.
* @param requiredParameters A list of parameters that the function requires to run.
* @see Schema
*/
fun defineFunction(
name: String,
description: String,
arg1: Schema<T>,
arg2: Schema<U>,
arg3: Schema<W>,
arg4: Schema<Z>,
function: suspend (T, U, W, Z) -> JSONObject,
) = FourParameterFunction(name, description, arg1, arg2, arg3, arg4, function)

private fun <T> FunctionCallPart.getArgOrThrow(param: Schema<T>): T {
return param.fromString(args[param.name])
?: throw RuntimeException(
"Missing argument for parameter \"${param.name}\" for function \"$name\""
)
}
parameters: List<Schema<*>> = emptyList(),
requiredParameters: List<String> = emptyList(),
) = FunctionDeclaration(name, description, parameters, requiredParameters)
Loading

0 comments on commit 12274f4

Please sign in to comment.