diff --git a/.changes/common/angle-carpenter-beam-clock.json b/.changes/common/angle-carpenter-beam-clock.json new file mode 100644 index 00000000..a6746a8a --- /dev/null +++ b/.changes/common/angle-carpenter-beam-clock.json @@ -0,0 +1 @@ +{"type":"MAJOR","changes":["add code execution tool"]} diff --git a/.changes/generativeai/direction-bee-brass-aftermath.json b/.changes/generativeai/direction-bee-brass-aftermath.json new file mode 100644 index 00000000..a6746a8a --- /dev/null +++ b/.changes/generativeai/direction-bee-brass-aftermath.json @@ -0,0 +1 @@ +{"type":"MAJOR","changes":["add code execution tool"]} diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt index c8fc3fb6..e32b9196 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt @@ -18,6 +18,7 @@ package com.google.ai.client.generativeai.common.client import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject @Serializable data class GenerationConfig( @@ -33,7 +34,12 @@ data class GenerationConfig( @SerialName("response_schema") val responseSchema: Schema? = null, ) -@Serializable data class Tool(val functionDeclarations: List) +@Serializable +data class Tool( + val functionDeclarations: List? = null, + // This is a json object because it is not possible to make a data class with no parameters. + val codeExecution: JsonObject? = null, +) @Serializable data class ToolConfig( diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt index 64e816e2..7dc31858 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt @@ -57,6 +57,11 @@ data class Content(@EncodeDefault val role: String? = "user", val parts: List) @@ -71,6 +76,18 @@ data class FileData( @Serializable data class Blob(@SerialName("mime_type") val mimeType: String, val data: Base64) +@Serializable data class ExecutableCode(val language: String, val code: String) + +@Serializable data class CodeExecutionResult(val outcome: Outcome, val output: String) + +@Serializable +enum class Outcome { + @SerialName("OUTCOME_UNSPECIFIED") UNSPECIFIED, + OUTCOME_OK, + OUTCOME_FAILED, + OUTCOME_DEADLINE_EXCEEDED, +} + @Serializable data class SafetySetting( val category: HarmCategory, @@ -101,8 +118,10 @@ object PartSerializer : JsonContentPolymorphicSerializer(Part::class) { "text" in jsonObject -> TextPart.serializer() "functionCall" in jsonObject -> FunctionCallPart.serializer() "functionResponse" in jsonObject -> FunctionResponsePart.serializer() - "inline_data" in jsonObject -> BlobPart.serializer() - "file_data" in jsonObject -> FileDataPart.serializer() + "inlineData" in jsonObject -> BlobPart.serializer() + "fileData" in jsonObject -> FileDataPart.serializer() + "executableCode" in jsonObject -> ExecutableCodePart.serializer() + "codeExecutionResult" in jsonObject -> CodeExecutionResultPart.serializer() else -> throw SerializationException("Unknown Part type") } } diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt index 776e0d0c..aa3c402b 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/APIControllerTests.kt @@ -17,6 +17,7 @@ package com.google.ai.client.generativeai.common import com.google.ai.client.generativeai.common.client.FunctionCallingConfig +import com.google.ai.client.generativeai.common.client.Tool import com.google.ai.client.generativeai.common.client.ToolConfig import com.google.ai.client.generativeai.common.shared.Content import com.google.ai.client.generativeai.common.shared.TextPart @@ -43,6 +44,7 @@ import kotlin.time.Duration.Companion.seconds import kotlinx.coroutines.delay import kotlinx.coroutines.withTimeout import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.JsonObject import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.Parameterized @@ -259,6 +261,41 @@ internal class RequestFormatTests { mockEngine.requestHistory.first().headers.contains("header1") shouldBe false } + + @Test + fun `code execution tool serialization contains correct keys`() = doBlocking { + val channel = ByteChannel(autoFlush = true) + val mockEngine = MockEngine { + respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) + } + prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) } + + val controller = + APIController( + "super_cool_test_key", + "gemini-pro-1.0", + RequestOptions(), + mockEngine, + TEST_CLIENT_ID, + null, + ) + + withTimeout(5.seconds) { + controller + .generateContentStream( + GenerateContentRequest( + model = "unused", + contents = listOf(Content(parts = listOf(TextPart("Arbitrary")))), + tools = listOf(Tool(codeExecution = JsonObject(emptyMap()))), + ) + ) + .collect { channel.close() } + } + + val requestBodyAsText = (mockEngine.requestHistory.first().body as TextContent).text + + requestBodyAsText shouldContainJsonKey "tools[0].codeExecution" + } } @RunWith(Parameterized::class) diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt index 8801a955..ad5c634a 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt @@ -20,8 +20,13 @@ import com.google.ai.client.generativeai.common.server.BlockReason import com.google.ai.client.generativeai.common.server.FinishReason import com.google.ai.client.generativeai.common.server.HarmProbability import com.google.ai.client.generativeai.common.server.HarmSeverity +import com.google.ai.client.generativeai.common.shared.CodeExecutionResult +import com.google.ai.client.generativeai.common.shared.CodeExecutionResultPart +import com.google.ai.client.generativeai.common.shared.ExecutableCode +import com.google.ai.client.generativeai.common.shared.ExecutableCodePart import com.google.ai.client.generativeai.common.shared.FunctionCallPart import com.google.ai.client.generativeai.common.shared.HarmCategory +import com.google.ai.client.generativeai.common.shared.Outcome import com.google.ai.client.generativeai.common.shared.TextPart import com.google.ai.client.generativeai.common.util.goldenUnaryFile import com.google.ai.client.generativeai.common.util.shouldNotBeNullOrEmpty @@ -331,4 +336,23 @@ internal class UnarySnapshotTests { callPart.functionCall.args["current"] shouldBe "true" } } + + @Test + fun `code execution parses correctly`() = + goldenUnaryFile("success-code-execution.json") { + withTimeout(testTimeout) { + val response = apiController.generateContent(textGenerateContentRequest("prompt")) + val content = response.candidates.shouldNotBeNullOrEmpty().first().content + content.shouldNotBeNull() + val executableCodePart = content.parts[0] + val codeExecutionResult = content.parts[1] + + executableCodePart.shouldBe( + ExecutableCodePart(ExecutableCode("PYTHON", "print(\"Hello World\")")) + ) + codeExecutionResult.shouldBe( + CodeExecutionResultPart(CodeExecutionResult(Outcome.OUTCOME_OK, "Hello World")) + ) + } + } } diff --git a/common/src/test/resources/golden-files/unary/success-code-execution.json b/common/src/test/resources/golden-files/unary/success-code-execution.json new file mode 100644 index 00000000..3b8f4c25 --- /dev/null +++ b/common/src/test/resources/golden-files/unary/success-code-execution.json @@ -0,0 +1,48 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "executableCode": { + "language": "PYTHON", + "code": "print(\"Hello World\")" + } + }, + { + "codeExecutionResult": { + "outcome": "OUTCOME_OK", + "output": "Hello World" + } + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 774, + "candidatesTokenCount": 4176, + "totalTokenCount": 4950 + } +} \ No newline at end of file diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index ed093005..c886a6dc 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -33,7 +33,11 @@ import com.google.ai.client.generativeai.common.server.PromptFeedback import com.google.ai.client.generativeai.common.server.SafetyRating import com.google.ai.client.generativeai.common.shared.Blob import com.google.ai.client.generativeai.common.shared.BlobPart +import com.google.ai.client.generativeai.common.shared.CodeExecutionResult +import com.google.ai.client.generativeai.common.shared.CodeExecutionResultPart import com.google.ai.client.generativeai.common.shared.Content +import com.google.ai.client.generativeai.common.shared.ExecutableCode +import com.google.ai.client.generativeai.common.shared.ExecutableCodePart import com.google.ai.client.generativeai.common.shared.FileData import com.google.ai.client.generativeai.common.shared.FileDataPart import com.google.ai.client.generativeai.common.shared.FunctionCall @@ -42,11 +46,13 @@ import com.google.ai.client.generativeai.common.shared.FunctionResponse import com.google.ai.client.generativeai.common.shared.FunctionResponsePart import com.google.ai.client.generativeai.common.shared.HarmBlockThreshold import com.google.ai.client.generativeai.common.shared.HarmCategory +import com.google.ai.client.generativeai.common.shared.Outcome import com.google.ai.client.generativeai.common.shared.Part import com.google.ai.client.generativeai.common.shared.SafetySetting import com.google.ai.client.generativeai.common.shared.TextPart import com.google.ai.client.generativeai.type.BlockThreshold import com.google.ai.client.generativeai.type.CitationMetadata +import com.google.ai.client.generativeai.type.ExecutionOutcome import com.google.ai.client.generativeai.type.FunctionCallingConfig import com.google.ai.client.generativeai.type.FunctionDeclaration import com.google.ai.client.generativeai.type.ImagePart @@ -80,6 +86,10 @@ internal fun com.google.ai.client.generativeai.type.Part.toInternal(): Part { FunctionResponsePart(FunctionResponse(name, response.toInternal())) is com.google.ai.client.generativeai.type.FileDataPart -> FileDataPart(FileData(fileUri = uri, mimeType = mimeType)) + is com.google.ai.client.generativeai.type.ExecutableCodePart -> + ExecutableCodePart(ExecutableCode(language, code)) + is com.google.ai.client.generativeai.type.CodeExecutionResultPart -> + CodeExecutionResultPart(CodeExecutionResult(outcome.toInternal(), output)) else -> throw SerializationException( "The given subclass of Part (${javaClass.simpleName}) is not supported in the serialization yet." @@ -122,8 +132,19 @@ internal fun BlockThreshold.toInternal() = BlockThreshold.UNSPECIFIED -> HarmBlockThreshold.UNSPECIFIED } +internal fun ExecutionOutcome.toInternal() = + when (this) { + ExecutionOutcome.UNSPECIFIED -> Outcome.UNSPECIFIED + ExecutionOutcome.OK -> Outcome.OUTCOME_OK + ExecutionOutcome.FAILED -> Outcome.OUTCOME_FAILED + ExecutionOutcome.DEADLINE_EXCEEDED -> Outcome.OUTCOME_DEADLINE_EXCEEDED + } + internal fun Tool.toInternal() = - com.google.ai.client.generativeai.common.client.Tool(functionDeclarations.map { it.toInternal() }) + com.google.ai.client.generativeai.common.client.Tool( + functionDeclarations?.map { it.toInternal() }, + codeExecution = codeExecution?.toInternal(), + ) internal fun ToolConfig.toInternal() = com.google.ai.client.generativeai.common.client.ToolConfig( @@ -204,6 +225,16 @@ internal fun Part.toPublic(): com.google.ai.client.generativeai.type.Part { ) is FileDataPart -> com.google.ai.client.generativeai.type.FileDataPart(fileData.fileUri, fileData.mimeType) + is ExecutableCodePart -> + com.google.ai.client.generativeai.type.ExecutableCodePart( + executableCode.language, + executableCode.code, + ) + is CodeExecutionResultPart -> + com.google.ai.client.generativeai.type.CodeExecutionResultPart( + codeExecutionResult.outcome.toPublic(), + codeExecutionResult.output, + ) else -> throw SerializationException( "Unsupported part type \"${javaClass.simpleName}\" provided. This model may not be supported by this SDK." @@ -267,6 +298,14 @@ internal fun BlockReason.toPublic() = BlockReason.UNKNOWN -> com.google.ai.client.generativeai.type.BlockReason.UNKNOWN } +internal fun Outcome.toPublic() = + when (this) { + Outcome.UNSPECIFIED -> ExecutionOutcome.UNSPECIFIED + Outcome.OUTCOME_OK -> ExecutionOutcome.OK + Outcome.OUTCOME_FAILED -> ExecutionOutcome.FAILED + Outcome.OUTCOME_DEADLINE_EXCEEDED -> ExecutionOutcome.DEADLINE_EXCEEDED + } + internal fun GenerateContentResponse.toPublic() = com.google.ai.client.generativeai.type.GenerateContentResponse( candidates?.map { it.toPublic() }.orEmpty(), diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/ExecutionOutcome.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/ExecutionOutcome.kt new file mode 100644 index 00000000..a7decd43 --- /dev/null +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/ExecutionOutcome.kt @@ -0,0 +1,24 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai.type + +enum class ExecutionOutcome { + UNSPECIFIED, + OK, + FAILED, + DEADLINE_EXCEEDED, +} diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerateContentResponse.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerateContentResponse.kt index fad24fd6..bbaa3f23 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerateContentResponse.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/GenerateContentResponse.kt @@ -32,7 +32,19 @@ class GenerateContentResponse( ) { /** Convenience field representing all the text parts in the response, if they exists. */ val text: String? by lazy { - candidates.first().content.parts.filterIsInstance().joinToString(" ") { it.text } + candidates + .first() + .content + .parts + .filter { it is TextPart || it is ExecutableCodePart || it is CodeExecutionResultPart } + .joinToString(" ") { + when (it) { + is TextPart -> it.text + is ExecutableCodePart -> "\n```${it.language.lowercase()}\n${it.code}\n```" + is CodeExecutionResultPart -> "\n```\n${it.output}\n```" + else -> throw RuntimeException("unreachable") + } + } } /** Convenience field representing the first function call part in the request, if it exists */ diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt index b72d99a0..9087a535 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt @@ -27,6 +27,10 @@ import org.json.JSONObject * * [ImagePart] representing image data. * * [BlobPart] representing MIME typed binary data. * * [FileDataPart] representing MIME typed binary data. + * * [FunctionCallPart] representing a requested clientside function call by the model + * * [FunctionResponsePart] representing the result of a clientside function call + * * [ExecutableCodePart] representing code generated and executed by the model + * * [CodeExecutionResultPart] representing the result of running code generated by the model. */ interface Part @@ -54,6 +58,12 @@ class FunctionCallPart(val name: String, val args: Map) : Part /** Represents function call output to be returned to the model when it requests a function call */ class FunctionResponsePart(val name: String, val response: JSONObject) : Part +/** Represents an internal function call written by the model */ +class ExecutableCodePart(val language: String, val code: String) : Part + +/** Represents the results of an internal function call written by the model */ +class CodeExecutionResultPart(val outcome: ExecutionOutcome, val output: String) : Part + /** @return The part as a [String] if it represents text, and null otherwise */ fun Part.asTextOrNull(): String? = (this as? TextPart)?.text diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt index 42b8c287..602ebdf5 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Tool.kt @@ -16,10 +16,20 @@ package com.google.ai.client.generativeai.type +import org.json.JSONObject + /** * Contains a set of function declarations that the model has access to. These can be used to gather * information, or complete tasks * * @param functionDeclarations The set of functions that this tool allows the model access to + * @param codeExecution This is a flag value to enable Code Execution. Use [CODE_EXECUTION]. */ -class Tool(val functionDeclarations: List) +class Tool( + val functionDeclarations: List? = null, + val codeExecution: JSONObject? = null, +) { + companion object { + val CODE_EXECUTION = Tool(codeExecution = JSONObject()) + } +} diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerateContentResponseTest.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerateContentResponseTest.kt index 335e81d5..31234299 100644 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/GenerateContentResponseTest.kt +++ b/generativeai/src/test/java/com/google/ai/client/generativeai/GenerateContentResponseTest.kt @@ -17,6 +17,9 @@ package com.google.ai.client.generativeai import com.google.ai.client.generativeai.type.Candidate +import com.google.ai.client.generativeai.type.CodeExecutionResultPart +import com.google.ai.client.generativeai.type.ExecutableCodePart +import com.google.ai.client.generativeai.type.ExecutionOutcome import com.google.ai.client.generativeai.type.FunctionCallPart import com.google.ai.client.generativeai.type.GenerateContentResponse import com.google.ai.client.generativeai.type.content @@ -74,6 +77,40 @@ internal class GenerateContentResponseTest { response.text shouldBe "This is a textPart" } + @Test + fun `generate response should add generated code to the response`() { + val response = + GenerateContentResponse( + candidates = + listOf( + Candidate( + content { + text("I can calculate that for you!") + part(ExecutableCodePart("python", "print(\"hello world\")")) + part(CodeExecutionResultPart(ExecutionOutcome.OK, "hello world")) + }, + listOf(), + listOf(), + null, + ) + ), + null, + null, + ) + + response.text shouldBe + """ + I can calculate that for you! + ```python + print("hello world") + ``` + ``` + hello world + ``` + """ + .trimIndent() + } + @Test fun `generate response should get strings and concatenate them together`() { val response =