diff --git a/.changes/generativeai/beef-collar-burn-aftermath.json b/.changes/generativeai/beef-collar-burn-aftermath.json new file mode 100644 index 00000000..1c27047c --- /dev/null +++ b/.changes/generativeai/beef-collar-burn-aftermath.json @@ -0,0 +1 @@ +{"type":"MAJOR","changes":["Simplify function calling and remove provided function execution."]} diff --git a/generativeai/build.gradle.kts b/generativeai/build.gradle.kts index ba88cefb..9a0e0259 100644 --- a/generativeai/build.gradle.kts +++ b/generativeai/build.gradle.kts @@ -41,7 +41,7 @@ android { testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" consumerProguardFiles("consumer-rules.pro") - buildConfigField("String", "VERSION_NAME", "\"${project.version.toString()}\"") + buildConfigField("String", "VERSION_NAME", "\"${project.version}\"") } publishing { @@ -85,6 +85,7 @@ dependencies { implementation("com.google.guava:listenablefuture:1.0") implementation("androidx.concurrent:concurrent-futures:1.2.0-alpha03") implementation("androidx.concurrent:concurrent-futures-ktx:1.2.0-alpha03") + testImplementation("org.json:json:20210307") // Required for JSONObject to function in tests testImplementation("junit:junit:4.13.2") testImplementation("io.kotest:kotest-assertions-core:5.5.5") testImplementation("io.kotest:kotest-assertions-core-jvm:5.5.5") diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt index 6ea9df87..c68b49a5 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt @@ -26,29 +26,21 @@ import com.google.ai.client.generativeai.internal.util.toPublic import com.google.ai.client.generativeai.type.Content import com.google.ai.client.generativeai.type.CountTokensResponse import com.google.ai.client.generativeai.type.FinishReason -import com.google.ai.client.generativeai.type.FourParameterFunction -import com.google.ai.client.generativeai.type.FunctionCallPart import com.google.ai.client.generativeai.type.GenerateContentResponse import com.google.ai.client.generativeai.type.GenerationConfig import com.google.ai.client.generativeai.type.GoogleGenerativeAIException -import com.google.ai.client.generativeai.type.InvalidStateException -import com.google.ai.client.generativeai.type.NoParameterFunction -import com.google.ai.client.generativeai.type.OneParameterFunction import com.google.ai.client.generativeai.type.PromptBlockedException import com.google.ai.client.generativeai.type.RequestOptions import com.google.ai.client.generativeai.type.ResponseStoppedException import com.google.ai.client.generativeai.type.SafetySetting import com.google.ai.client.generativeai.type.SerializationException -import com.google.ai.client.generativeai.type.ThreeParameterFunction import com.google.ai.client.generativeai.type.Tool import com.google.ai.client.generativeai.type.ToolConfig -import com.google.ai.client.generativeai.type.TwoParameterFunction import com.google.ai.client.generativeai.type.content import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.map import kotlinx.serialization.ExperimentalSerializationApi -import org.json.JSONObject /** * A facilitator for a given multimodal model (eg; Gemini). @@ -199,36 +191,6 @@ internal constructor( return countTokens(content { image(prompt) }) } - /** - * Executes a function requested by the model. - * - * @param functionCallPart A [FunctionCallPart] from the model, containing a function call and - * parameters - * @return The output of the requested function call - */ - suspend fun executeFunction(functionCallPart: FunctionCallPart): JSONObject { - if (tools == null) { - throw InvalidStateException("No registered tools") - } - val callable = - tools.flatMap { it.functionDeclarations }.firstOrNull { it.name == functionCallPart.name } - ?: throw InvalidStateException("No registered function named ${functionCallPart.name}") - return when (callable) { - is NoParameterFunction -> callable.execute() - is OneParameterFunction<*> -> - (callable as OneParameterFunction).execute(functionCallPart) - is TwoParameterFunction<*, *> -> - (callable as TwoParameterFunction).execute(functionCallPart) - is ThreeParameterFunction<*, *, *> -> - (callable as ThreeParameterFunction).execute(functionCallPart) - is FourParameterFunction<*, *, *, *> -> - (callable as FourParameterFunction).execute(functionCallPart) - else -> { - throw RuntimeException("UNREACHABLE") - } - } - } - private fun constructRequest(vararg prompt: Content) = GenerateContentRequest( modelName, diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index dfa8c3c0..ed093005 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -75,7 +75,7 @@ internal fun com.google.ai.client.generativeai.type.Part.toInternal(): Part { is com.google.ai.client.generativeai.type.BlobPart -> BlobPart(Blob(mimeType, Base64.encodeToString(blob, BASE_64_FLAGS))) is com.google.ai.client.generativeai.type.FunctionCallPart -> - FunctionCallPart(FunctionCall(name, args.orEmpty())) + FunctionCallPart(FunctionCall(name, args)) is com.google.ai.client.generativeai.type.FunctionResponsePart -> FunctionResponsePart(FunctionResponse(name, response.toInternal())) is com.google.ai.client.generativeai.type.FileDataPart -> @@ -147,8 +147,8 @@ internal fun FunctionDeclaration.toInternal() = name, description, Schema( - properties = getParameters().associate { it.name to it.toInternal() }, - required = getParameters().map { it.name }, + properties = parameters.associate { it.name to it.toInternal() }, + required = requiredParameters, type = "OBJECT", nullable = false, ), @@ -196,10 +196,7 @@ internal fun Part.toPublic(): com.google.ai.client.generativeai.type.Part { } } is FunctionCallPart -> - com.google.ai.client.generativeai.type.FunctionCallPart( - functionCall.name, - functionCall.args.orEmpty(), - ) + com.google.ai.client.generativeai.type.FunctionCallPart(functionCall.name, functionCall.args) is FunctionResponsePart -> com.google.ai.client.generativeai.type.FunctionResponsePart( functionResponse.name, diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt index eb8285a9..88247f3d 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -19,144 +19,24 @@ package com.google.ai.client.generativeai.type import org.json.JSONObject /** - * A declared function, including implementation, that a model can be given access to in order to - * gain info or complete tasks. - * - * @property name The name of the function call, this should be clear and descriptive for the model - * @property description A description of what the function does and its output. - * @property function the function implementation - */ -class NoParameterFunction( - name: String, - description: String, - val function: suspend () -> JSONObject, -) : FunctionDeclaration(name, description) { - override fun getParameters() = listOf>() - - suspend fun execute() = function() - - override suspend fun execute(part: FunctionCallPart) = function() -} - -/** - * A declared function, including implementation, that a model can be given access to in order to - * gain info or complete tasks. - * - * @property name The name of the function call, this should be clear and descriptive for the model - * @property description A description of what the function does and its output. - * @property param A description of the first function parameter - * @property function the function implementation - */ -class OneParameterFunction( - name: String, - description: String, - val param: Schema, - val function: suspend (T) -> JSONObject, -) : FunctionDeclaration(name, description) { - override fun getParameters() = listOf(param) - - override suspend fun execute(part: FunctionCallPart): JSONObject { - val arg1 = part.getArgOrThrow(param) - return function(arg1) - } -} - -/** - * A declared function, including implementation, that a model can be given access to in order to - * gain info or complete tasks. - * - * @property name The name of the function call, this should be clear and descriptive for the model - * @property description A description of what the function does and its output. - * @property param1 A description of the first function parameter - * @property param2 A description of the second function parameter - * @property function the function implementation - */ -class TwoParameterFunction( - name: String, - description: String, - val param1: Schema, - val param2: Schema, - val function: suspend (T, U) -> JSONObject, -) : FunctionDeclaration(name, description) { - override fun getParameters() = listOf(param1, param2) - - override suspend fun execute(part: FunctionCallPart): JSONObject { - val arg1 = part.getArgOrThrow(param1) - val arg2 = part.getArgOrThrow(param2) - return function(arg1, arg2) - } -} - -/** - * A declared function, including implementation, that a model can be given access to in order to - * gain info or complete tasks. - * - * @property name The name of the function call, this should be clear and descriptive for the model - * @property description A description of what the function does and its output. - * @property param1 A description of the first function parameter - * @property param2 A description of the second function parameter - * @property param3 A description of the third function parameter - * @property function the function implementation - */ -class ThreeParameterFunction( - name: String, - description: String, - val param1: Schema, - val param2: Schema, - val param3: Schema, - val function: suspend (T, U, V) -> JSONObject, -) : FunctionDeclaration(name, description) { - override fun getParameters() = listOf(param1, param2, param3) - - override suspend fun execute(part: FunctionCallPart): JSONObject { - val arg1 = part.getArgOrThrow(param1) - val arg2 = part.getArgOrThrow(param2) - val arg3 = part.getArgOrThrow(param3) - return function(arg1, arg2, arg3) - } -} - -/** - * A declared function, including implementation, that a model can be given access to in order to - * gain info or complete tasks. + * Representation of a function that a model can invoke. * - * @property name The name of the function call, this should be clear and descriptive for the model - * @property description A description of what the function does and its output. - * @property param1 A description of the first function parameter - * @property param2 A description of the second function parameter - * @property param3 A description of the third function parameter - * @property param4 A description of the fourth function parameter - * @property function the function implementation + * @see defineFunction */ -class FourParameterFunction( - name: String, - description: String, - val param1: Schema, - val param2: Schema, - val param3: Schema, - val param4: Schema, - val function: suspend (T, U, V, W) -> JSONObject, -) : FunctionDeclaration(name, description) { - override fun getParameters() = listOf(param1, param2, param3, param4) - - override suspend fun execute(part: FunctionCallPart): JSONObject { - val arg1 = part.getArgOrThrow(param1) - val arg2 = part.getArgOrThrow(param2) - val arg3 = part.getArgOrThrow(param3) - val arg4 = part.getArgOrThrow(param4) - return function(arg1, arg2, arg3, arg4) - } -} - -abstract class FunctionDeclaration(val name: String, val description: String) { - abstract fun getParameters(): List> - - abstract suspend fun execute(part: FunctionCallPart): JSONObject -} +class FunctionDeclaration( + val name: String, + val description: String, + val parameters: List>, + val requiredParameters: List, +) /** * Represents a parameter for a declared function * + * ``` + * val currencyFrom = Schema.str("currencyFrom", "The currency to convert from.") + * ``` + * * @property name: The name of the parameter * @property description: The description of what the parameter should contain or represent * @property format: format information for the parameter, this can include bitlength in the case of @@ -180,6 +60,21 @@ class Schema( val items: Schema? = null, val type: FunctionType, ) { + + /** + * Attempts to parse a string to the type [T] assigned to this schema. + * + * Will return null if the provided string is null. May also return null if the provided string is + * not a valid string of the expected type; but this should not be relied upon, as it may throw in + * certain scenarios (eg; the type is an object or array, and the string is not valid json). + * + * ``` + * val currenciesSchema = Schema.arr("currencies", "The currencies available to use.") + * val currencies: List = currenciesSchema.fromString(""" + * ["USD", "EUR", "CAD", "GBP", "JPY"] + * """) + * ``` + */ fun fromString(value: String?) = type.parse(value) companion object { @@ -269,46 +164,31 @@ class Schema( } } -fun defineFunction(name: String, description: String, function: suspend () -> JSONObject) = - NoParameterFunction(name, description, function) - -fun defineFunction( - name: String, - description: String, - arg1: Schema, - function: suspend (T) -> JSONObject, -) = OneParameterFunction(name, description, arg1, function) - -fun defineFunction( - name: String, - description: String, - arg1: Schema, - arg2: Schema, - function: suspend (T, U) -> JSONObject, -) = TwoParameterFunction(name, description, arg1, arg2, function) - -fun defineFunction( - name: String, - description: String, - arg1: Schema, - arg2: Schema, - arg3: Schema, - function: suspend (T, U, W) -> JSONObject, -) = ThreeParameterFunction(name, description, arg1, arg2, arg3, function) - -fun defineFunction( +/** + * A declared function, including implementation, that a model can be given access to in order to + * gain info or complete tasks. + * + * ``` + * val getExchangeRate = defineFunction( + * name = "getExchangeRate", + * description = "Get the exchange rate for currencies between countries.", + * parameters = listOf( + * Schema.str("currencyFrom", "The currency to convert from."), + * Schema.str("currencyTo", "The currency to convert to.") + * ), + * requiredParameters = listOf("currencyFrom", "currencyTo") + * ) + * ``` + * + * @param name The name of the function call, this should be clear and descriptive for the model. + * @param description A description of what the function does and its output. + * @param parameters A list of parameters that the function accepts. + * @param requiredParameters A list of parameters that the function requires to run. + * @see Schema + */ +fun defineFunction( name: String, description: String, - arg1: Schema, - arg2: Schema, - arg3: Schema, - arg4: Schema, - function: suspend (T, U, W, Z) -> JSONObject, -) = FourParameterFunction(name, description, arg1, arg2, arg3, arg4, function) - -private fun FunctionCallPart.getArgOrThrow(param: Schema): T { - return param.fromString(args[param.name]) - ?: throw RuntimeException( - "Missing argument for parameter \"${param.name}\" for function \"$name\"" - ) -} + parameters: List> = emptyList(), + requiredParameters: List = emptyList(), +) = FunctionDeclaration(name, description, parameters, requiredParameters) diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt index 51501515..904130c0 100644 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt +++ b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt @@ -26,24 +26,35 @@ import com.google.ai.client.generativeai.common.server.Candidate as Candidate_Co import com.google.ai.client.generativeai.common.server.CitationMetadata as CitationMetadata_Common import com.google.ai.client.generativeai.common.server.CitationSources import com.google.ai.client.generativeai.common.shared.Content as Content_Common +import com.google.ai.client.generativeai.common.shared.FunctionCall +import com.google.ai.client.generativeai.common.shared.FunctionCallPart as FunctionCallPart_Common import com.google.ai.client.generativeai.common.shared.TextPart as TextPart_Common import com.google.ai.client.generativeai.type.Candidate import com.google.ai.client.generativeai.type.CitationMetadata import com.google.ai.client.generativeai.type.Content +import com.google.ai.client.generativeai.type.FunctionResponsePart import com.google.ai.client.generativeai.type.GenerateContentResponse import com.google.ai.client.generativeai.type.InvalidAPIKeyException import com.google.ai.client.generativeai.type.PromptFeedback +import com.google.ai.client.generativeai.type.Schema import com.google.ai.client.generativeai.type.TextPart +import com.google.ai.client.generativeai.type.Tool import com.google.ai.client.generativeai.type.UnsupportedUserLocationException import com.google.ai.client.generativeai.type.UsageMetadata +import com.google.ai.client.generativeai.type.content +import com.google.ai.client.generativeai.type.defineFunction import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.collections.shouldHaveSize import io.kotest.matchers.equality.shouldBeEqualToUsingFields +import io.kotest.matchers.maps.shouldContain +import io.kotest.matchers.nulls.shouldNotBeNull +import io.kotest.matchers.shouldBe import io.mockk.coEvery import io.mockk.mockk import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.flow.flow import kotlinx.coroutines.runBlocking +import org.json.JSONObject import org.junit.Test internal class GenerativeModelTests { @@ -153,6 +164,73 @@ internal class GenerativeModelTests { model.generateContentStream("Why's the sky blue?").collect {} } } + + @Test + fun `generateContent function parts work as expected`() = doBlocking { + val getExchangeRate = + defineFunction( + name = "getExchangeRate", + description = "Get the exchange rate for currencies between countries.", + parameters = + listOf( + Schema.str("currencyFrom", "The currency to convert from."), + Schema.str("currencyTo", "The currency to convert to."), + ), + requiredParameters = listOf("currencyFrom", "currencyTo"), + ) + val tools = listOf(Tool(listOf(getExchangeRate))) + val model = + GenerativeModel("gemini-pro-1.0", apiKey, tools = tools, controller = mockApiController) + val chat = Chat(model) + + coEvery { mockApiController.generateContent(any()) } returns + GenerateContentResponse_Common( + listOf( + Candidate_Common( + Content_Common( + parts = + listOf( + FunctionCallPart_Common( + FunctionCall( + "getExchangeRate", + mapOf("currencyFrom" to "USD", "currencyTo" to "EUR"), + ) + ) + ) + ) + ) + ) + ) + + val request = content { text("How much is $25 USD in EUR?") } + + val response = chat.sendMessage(request) + + response.functionCalls.firstOrNull()?.let { + it.shouldNotBeNull() + it.name shouldBe "getExchangeRate" + it.args shouldContain ("currencyFrom" to "USD") + it.args shouldContain ("currencyTo" to "EUR") + } + + coEvery { mockApiController.generateContent(any()) } returns + GenerateContentResponse_Common( + listOf( + Candidate_Common( + Content_Common(parts = listOf(TextPart_Common("$25 USD is $50 in EUR."))) + ) + ) + ) + + val functionResponse = + content("function") { + part(FunctionResponsePart("getExchangeRate", JSONObject(mapOf("exchangeRate" to "200%")))) + } + + val finalResponse = chat.sendMessage(functionResponse) + + finalResponse.text shouldBe "$25 USD is $50 in EUR." + } } internal fun doBlocking(block: suspend CoroutineScope.() -> Unit) {