From 5d8b31cc64439c8ae10e70eb0b65b5c1d828b1b6 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Tue, 27 Feb 2024 10:21:44 -0500 Subject: [PATCH 1/4] Add types to represent JSON values (#112) --- Sources/GoogleAI/JSONValue.swift | 71 ++++++++++++++++++++ Tests/GoogleAITests/JSONValueTests.swift | 85 ++++++++++++++++++++++++ 2 files changed, 156 insertions(+) create mode 100644 Sources/GoogleAI/JSONValue.swift create mode 100644 Tests/GoogleAITests/JSONValueTests.swift diff --git a/Sources/GoogleAI/JSONValue.swift b/Sources/GoogleAI/JSONValue.swift new file mode 100644 index 0000000..b6166bb --- /dev/null +++ b/Sources/GoogleAI/JSONValue.swift @@ -0,0 +1,71 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// A collection of name-value pairs representing a JSON object. +/// +/// This may be decoded from, or encoded to, a +/// [`google.protobuf.Struct`](https://protobuf.dev/reference/protobuf/google.protobuf/#struct). +public typealias JSONObject = [String: JSONValue] + +/// Represents a value in one of JSON's data types. +/// +/// This may be decoded from, or encoded to, a +/// [`google.protobuf.Value`](https://protobuf.dev/reference/protobuf/google.protobuf/#value). +public enum JSONValue { + /// A `null` value. + case null + + /// A numeric value. + case number(Double) + + /// A string value. + case string(String) + + /// A boolean value. + case bool(Bool) + + /// A JSON object. + case object(JSONObject) + + /// An array of `JSONValue`s. + case array([JSONValue]) +} + +extension JSONValue: Decodable { + public init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + if container.decodeNil() { + self = .null + } else if let numberValue = try? container.decode(Double.self) { + self = .number(numberValue) + } else if let stringValue = try? container.decode(String.self) { + self = .string(stringValue) + } else if let boolValue = try? container.decode(Bool.self) { + self = .bool(boolValue) + } else if let objectValue = try? container.decode(JSONObject.self) { + self = .object(objectValue) + } else if let arrayValue = try? container.decode([JSONValue].self) { + self = .array(arrayValue) + } else { + throw DecodingError.dataCorruptedError( + in: container, + debugDescription: "Failed to decode JSON value." + ) + } + } +} + +extension JSONValue: Equatable {} diff --git a/Tests/GoogleAITests/JSONValueTests.swift b/Tests/GoogleAITests/JSONValueTests.swift new file mode 100644 index 0000000..14f9d96 --- /dev/null +++ b/Tests/GoogleAITests/JSONValueTests.swift @@ -0,0 +1,85 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +import XCTest + +@testable import GoogleGenerativeAI + +final class JSONValueTests: XCTestCase { + func testDecodeNull() throws { + let jsonData = try XCTUnwrap("null".data(using: .utf8)) + + let jsonObject = try XCTUnwrap(JSONDecoder().decode(JSONValue.self, from: jsonData)) + + XCTAssertEqual(jsonObject, .null) + } + + func testDecodeNumber() throws { + let expectedNumber = 3.14159 + let jsonData = try XCTUnwrap("\(expectedNumber)".data(using: .utf8)) + + let jsonObject = try XCTUnwrap(JSONDecoder().decode(JSONValue.self, from: jsonData)) + + XCTAssertEqual(jsonObject, .number(expectedNumber)) + } + + func testDecodeString() throws { + let expectedString = "hello-world" + let jsonData = try XCTUnwrap("\"\(expectedString)\"".data(using: .utf8)) + + let jsonObject = try XCTUnwrap(JSONDecoder().decode(JSONValue.self, from: jsonData)) + + XCTAssertEqual(jsonObject, .string(expectedString)) + } + + func testDecodeBool() throws { + let expectedBool = true + let jsonData = try XCTUnwrap("\(expectedBool)".data(using: .utf8)) + + let jsonObject = try XCTUnwrap(JSONDecoder().decode(JSONValue.self, from: jsonData)) + + XCTAssertEqual(jsonObject, .bool(expectedBool)) + } + + func testDecodeObject() throws { + let numberKey = "pi" + let numberValue = 3.14159 + let stringKey = "hello" + let stringValue = "world" + let expectedObject: JSONObject = [ + numberKey: .number(numberValue), + stringKey: .string(stringValue), + ] + let json = """ + { + "\(numberKey)": \(numberValue), + "\(stringKey)": "\(stringValue)" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let jsonObject = try XCTUnwrap(JSONDecoder().decode(JSONValue.self, from: jsonData)) + + XCTAssertEqual(jsonObject, .object(expectedObject)) + } + + func testDecodeArray() throws { + let numberValue = 3.14159 + let expectedArray: [JSONValue] = [.null, .number(numberValue)] + let jsonData = try XCTUnwrap("[ null, \(numberValue) ]".data(using: .utf8)) + + let jsonObject = try XCTUnwrap(JSONDecoder().decode(JSONValue.self, from: jsonData)) + + XCTAssertEqual(jsonObject, .array(expectedArray)) + } +} From 45bc200e2afe9861067d7dcee0959f61ed911ee4 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 28 Feb 2024 13:40:31 -0500 Subject: [PATCH 2/4] Add `Encodable` conformance to `JSONValue` (#113) --- Sources/GoogleAI/JSONValue.swift | 25 +++++++ Tests/GoogleAITests/JSONValueTests.swift | 94 +++++++++++++++++++----- 2 files changed, 102 insertions(+), 17 deletions(-) diff --git a/Sources/GoogleAI/JSONValue.swift b/Sources/GoogleAI/JSONValue.swift index b6166bb..5ce52cd 100644 --- a/Sources/GoogleAI/JSONValue.swift +++ b/Sources/GoogleAI/JSONValue.swift @@ -68,4 +68,29 @@ extension JSONValue: Decodable { } } +extension JSONValue: Encodable { + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case .null: + try container.encodeNil() + case let .number(numberValue): + // Convert to `Decimal` before encoding for consistent floating-point serialization across + // platforms. E.g., `Double` serializes 3.14159 as 3.1415899999999999 in some cases and + // 3.14159 in others. See + // https://forums.swift.org/t/jsonencoder-encodable-floating-point-rounding-error/41390/4 for + // more details. + try container.encode(Decimal(numberValue)) + case let .string(stringValue): + try container.encode(stringValue) + case let .bool(boolValue): + try container.encode(boolValue) + case let .object(objectValue): + try container.encode(objectValue) + case let .array(arrayValue): + try container.encode(arrayValue) + } + } +} + extension JSONValue: Equatable {} diff --git a/Tests/GoogleAITests/JSONValueTests.swift b/Tests/GoogleAITests/JSONValueTests.swift index 14f9d96..19c871e 100644 --- a/Tests/GoogleAITests/JSONValueTests.swift +++ b/Tests/GoogleAITests/JSONValueTests.swift @@ -16,46 +16,53 @@ import XCTest @testable import GoogleGenerativeAI final class JSONValueTests: XCTestCase { + let decoder = JSONDecoder() + let encoder = JSONEncoder() + + let numberKey = "pi" + let numberValue = 3.14159 + let numberValueEncoded = "3.14159" + let stringKey = "hello" + let stringValue = "Hello, world!" + + override func setUp() { + encoder.outputFormatting = .sortedKeys + } + func testDecodeNull() throws { let jsonData = try XCTUnwrap("null".data(using: .utf8)) - let jsonObject = try XCTUnwrap(JSONDecoder().decode(JSONValue.self, from: jsonData)) + let jsonObject = try XCTUnwrap(decoder.decode(JSONValue.self, from: jsonData)) XCTAssertEqual(jsonObject, .null) } func testDecodeNumber() throws { - let expectedNumber = 3.14159 - let jsonData = try XCTUnwrap("\(expectedNumber)".data(using: .utf8)) + let jsonData = try XCTUnwrap("\(numberValue)".data(using: .utf8)) - let jsonObject = try XCTUnwrap(JSONDecoder().decode(JSONValue.self, from: jsonData)) + let jsonObject = try XCTUnwrap(decoder.decode(JSONValue.self, from: jsonData)) - XCTAssertEqual(jsonObject, .number(expectedNumber)) + XCTAssertEqual(jsonObject, .number(numberValue)) } func testDecodeString() throws { - let expectedString = "hello-world" - let jsonData = try XCTUnwrap("\"\(expectedString)\"".data(using: .utf8)) + let jsonData = try XCTUnwrap("\"\(stringValue)\"".data(using: .utf8)) - let jsonObject = try XCTUnwrap(JSONDecoder().decode(JSONValue.self, from: jsonData)) + let jsonObject = try XCTUnwrap(decoder.decode(JSONValue.self, from: jsonData)) - XCTAssertEqual(jsonObject, .string(expectedString)) + XCTAssertEqual(jsonObject, .string(stringValue)) } func testDecodeBool() throws { let expectedBool = true let jsonData = try XCTUnwrap("\(expectedBool)".data(using: .utf8)) - let jsonObject = try XCTUnwrap(JSONDecoder().decode(JSONValue.self, from: jsonData)) + let jsonObject = try XCTUnwrap(decoder.decode(JSONValue.self, from: jsonData)) XCTAssertEqual(jsonObject, .bool(expectedBool)) } func testDecodeObject() throws { - let numberKey = "pi" - let numberValue = 3.14159 - let stringKey = "hello" - let stringValue = "world" let expectedObject: JSONObject = [ numberKey: .number(numberValue), stringKey: .string(stringValue), @@ -68,18 +75,71 @@ final class JSONValueTests: XCTestCase { """ let jsonData = try XCTUnwrap(json.data(using: .utf8)) - let jsonObject = try XCTUnwrap(JSONDecoder().decode(JSONValue.self, from: jsonData)) + let jsonObject = try XCTUnwrap(decoder.decode(JSONValue.self, from: jsonData)) XCTAssertEqual(jsonObject, .object(expectedObject)) } func testDecodeArray() throws { - let numberValue = 3.14159 let expectedArray: [JSONValue] = [.null, .number(numberValue)] let jsonData = try XCTUnwrap("[ null, \(numberValue) ]".data(using: .utf8)) - let jsonObject = try XCTUnwrap(JSONDecoder().decode(JSONValue.self, from: jsonData)) + let jsonObject = try XCTUnwrap(decoder.decode(JSONValue.self, from: jsonData)) XCTAssertEqual(jsonObject, .array(expectedArray)) } + + func testEncodeNull() throws { + let jsonData = try encoder.encode(JSONValue.null) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, "null") + } + + func testEncodeNumber() throws { + let jsonData = try encoder.encode(JSONValue.number(numberValue)) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, "\(numberValue)") + } + + func testEncodeString() throws { + let jsonData = try encoder.encode(JSONValue.string(stringValue)) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, "\"\(stringValue)\"") + } + + func testEncodeBool() throws { + let boolValue = true + + let jsonData = try encoder.encode(JSONValue.bool(boolValue)) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, "\(boolValue)") + } + + func testEncodeObject() throws { + let objectValue: JSONObject = [ + numberKey: .number(numberValue), + stringKey: .string(stringValue), + ] + + let jsonData = try encoder.encode(JSONValue.object(objectValue)) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual( + json, + "{\"\(stringKey)\":\"\(stringValue)\",\"\(numberKey)\":\(numberValueEncoded)}" + ) + } + + func testEncodeArray() throws { + let arrayValue: [JSONValue] = [.null, .number(numberValue)] + + let jsonData = try encoder.encode(JSONValue.array(arrayValue)) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, "[null,\(numberValueEncoded)]") + } } From 667eccf5661d5bbe22219ca6b178c3f098998f0e Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Tue, 5 Mar 2024 19:19:10 +0000 Subject: [PATCH 3/4] Add `FunctionCall` decoding (#114) --- Sources/GoogleAI/Chat.swift | 4 ++ Sources/GoogleAI/FunctionCalling.swift | 41 ++++++++++++ Sources/GoogleAI/ModelContent.swift | 11 ++- ...success-function-call-empty-arguments.json | 19 ++++++ ...ry-success-function-call-no-arguments.json | 19 ++++++ ...-success-function-call-with-arguments.json | 22 ++++++ .../GoogleAITests/GenerativeModelTests.swift | 67 +++++++++++++++++++ 7 files changed, 182 insertions(+), 1 deletion(-) create mode 100644 Sources/GoogleAI/FunctionCalling.swift create mode 100644 Tests/GoogleAITests/GenerateContentResponses/unary-success-function-call-empty-arguments.json create mode 100644 Tests/GoogleAITests/GenerateContentResponses/unary-success-function-call-no-arguments.json create mode 100644 Tests/GoogleAITests/GenerateContentResponses/unary-success-function-call-with-arguments.json diff --git a/Sources/GoogleAI/Chat.swift b/Sources/GoogleAI/Chat.swift index c7cfb85..6e7e885 100644 --- a/Sources/GoogleAI/Chat.swift +++ b/Sources/GoogleAI/Chat.swift @@ -162,6 +162,10 @@ public class Chat { } parts.append(part) + + case .functionCall: + // TODO(andrewheard): Add function call to the chat history when encoding is implemented. + fatalError("Function calling not yet implemented in chat.") } } } diff --git a/Sources/GoogleAI/FunctionCalling.swift b/Sources/GoogleAI/FunctionCalling.swift new file mode 100644 index 0000000..5d8ded5 --- /dev/null +++ b/Sources/GoogleAI/FunctionCalling.swift @@ -0,0 +1,41 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// A predicted function call returned from the model. +public struct FunctionCall: Equatable { + /// The name of the function to call. + let name: String + + /// The function parameters and values. + let args: JSONObject +} + +extension FunctionCall: Decodable { + enum CodingKeys: CodingKey { + case name + case args + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + name = try container.decode(String.self, forKey: .name) + if let args = try container.decodeIfPresent(JSONObject.self, forKey: .args) { + self.args = args + } else { + args = JSONObject() + } + } +} diff --git a/Sources/GoogleAI/ModelContent.swift b/Sources/GoogleAI/ModelContent.swift index 44648c5..2ce8876 100644 --- a/Sources/GoogleAI/ModelContent.swift +++ b/Sources/GoogleAI/ModelContent.swift @@ -25,6 +25,7 @@ public struct ModelContent: Codable, Equatable { enum CodingKeys: String, CodingKey { case text case inlineData + case functionCall } enum InlineDataKeys: String, CodingKey { @@ -38,6 +39,9 @@ public struct ModelContent: Codable, Equatable { /// Data with a specified media type. Not all media types may be supported by the AI model. case data(mimetype: String, Data) + /// A predicted function call returned from the model. + case functionCall(FunctionCall) + // MARK: Convenience Initializers /// Convenience function for populating a Part with JPEG data. @@ -64,6 +68,9 @@ public struct ModelContent: Codable, Equatable { ) try inlineDataContainer.encode(mimetype, forKey: .mimeType) try inlineDataContainer.encode(bytes, forKey: .bytes) + case .functionCall: + // TODO(andrewheard): Encode FunctionCalls when when encoding is implemented. + fatalError("FunctionCall encoding not implemented.") } } @@ -79,10 +86,12 @@ public struct ModelContent: Codable, Equatable { let mimetype = try dataContainer.decode(String.self, forKey: .mimeType) let bytes = try dataContainer.decode(Data.self, forKey: .bytes) self = .data(mimetype: mimetype, bytes) + } else if values.contains(.functionCall) { + self = try .functionCall(values.decode(FunctionCall.self, forKey: .functionCall)) } else { throw DecodingError.dataCorrupted(.init( codingPath: [CodingKeys.text, CodingKeys.inlineData], - debugDescription: "Neither text or inline data was found." + debugDescription: "No text, inline data or function call was found." )) } } diff --git a/Tests/GoogleAITests/GenerateContentResponses/unary-success-function-call-empty-arguments.json b/Tests/GoogleAITests/GenerateContentResponses/unary-success-function-call-empty-arguments.json new file mode 100644 index 0000000..703bdf8 --- /dev/null +++ b/Tests/GoogleAITests/GenerateContentResponses/unary-success-function-call-empty-arguments.json @@ -0,0 +1,19 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "name": "current_time" + } + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0 + } + ] +} + diff --git a/Tests/GoogleAITests/GenerateContentResponses/unary-success-function-call-no-arguments.json b/Tests/GoogleAITests/GenerateContentResponses/unary-success-function-call-no-arguments.json new file mode 100644 index 0000000..05f4f4d --- /dev/null +++ b/Tests/GoogleAITests/GenerateContentResponses/unary-success-function-call-no-arguments.json @@ -0,0 +1,19 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "name": "current_time", + "args": {} + } + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0 + } + ] +} diff --git a/Tests/GoogleAITests/GenerateContentResponses/unary-success-function-call-with-arguments.json b/Tests/GoogleAITests/GenerateContentResponses/unary-success-function-call-with-arguments.json new file mode 100644 index 0000000..025735a --- /dev/null +++ b/Tests/GoogleAITests/GenerateContentResponses/unary-success-function-call-with-arguments.json @@ -0,0 +1,22 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "name": "sum", + "args": { + "y": 5, + "x": 4 + } + } + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0 + } + ] +} diff --git a/Tests/GoogleAITests/GenerativeModelTests.swift b/Tests/GoogleAITests/GenerativeModelTests.swift index 2b90ec1..41f6844 100644 --- a/Tests/GoogleAITests/GenerativeModelTests.swift +++ b/Tests/GoogleAITests/GenerativeModelTests.swift @@ -169,6 +169,73 @@ final class GenerativeModelTests: XCTestCase { _ = try await model.generateContent(testPrompt) } + func testGenerateContent_success_functionCall_emptyArguments() async throws { + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "unary-success-function-call-empty-arguments", + withExtension: "json" + ) + + let response = try await model.generateContent(testPrompt) + + XCTAssertEqual(response.candidates.count, 1) + let candidate = try XCTUnwrap(response.candidates.first) + XCTAssertEqual(candidate.content.parts.count, 1) + let part = try XCTUnwrap(candidate.content.parts.first) + guard case let .functionCall(functionCall) = part else { + XCTFail("Part is not a FunctionCall.") + return + } + XCTAssertEqual(functionCall.name, "current_time") + XCTAssertTrue(functionCall.args.isEmpty) + } + + func testGenerateContent_success_functionCall_noArguments() async throws { + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "unary-success-function-call-no-arguments", + withExtension: "json" + ) + + let response = try await model.generateContent(testPrompt) + + XCTAssertEqual(response.candidates.count, 1) + let candidate = try XCTUnwrap(response.candidates.first) + XCTAssertEqual(candidate.content.parts.count, 1) + let part = try XCTUnwrap(candidate.content.parts.first) + guard case let .functionCall(functionCall) = part else { + XCTFail("Part is not a FunctionCall.") + return + } + XCTAssertEqual(functionCall.name, "current_time") + XCTAssertTrue(functionCall.args.isEmpty) + } + + func testGenerateContent_success_functionCall_withArguments() async throws { + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "unary-success-function-call-with-arguments", + withExtension: "json" + ) + + let response = try await model.generateContent(testPrompt) + + XCTAssertEqual(response.candidates.count, 1) + let candidate = try XCTUnwrap(response.candidates.first) + XCTAssertEqual(candidate.content.parts.count, 1) + let part = try XCTUnwrap(candidate.content.parts.first) + guard case let .functionCall(functionCall) = part else { + XCTFail("Part is not a FunctionCall.") + return + } + XCTAssertEqual(functionCall.name, "sum") + XCTAssertEqual(functionCall.args.count, 2) + let argX = try XCTUnwrap(functionCall.args["x"]) + XCTAssertEqual(argX, .number(4)) + let argY = try XCTUnwrap(functionCall.args["y"]) + XCTAssertEqual(argY, .number(5)) + } + func testGenerateContent_failure_invalidAPIKey() async throws { let expectedStatusCode = 400 MockURLProtocol From 1ed073fcd343cfe3efda88f51f29ea13ee6e924f Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Mon, 11 Mar 2024 12:09:11 -0400 Subject: [PATCH 4/4] Automatic function calling prototype --- .../Sources/GenerateContent.swift | 88 ++++++++++++++- Sources/GoogleAI/Chat.swift | 28 ++++- Sources/GoogleAI/FunctionCalling.swift | 104 +++++++++++++++++- Sources/GoogleAI/GenerateContentRequest.swift | 2 + Sources/GoogleAI/GenerativeModel.swift | 34 ++++++ Sources/GoogleAI/ModelContent.swift | 10 +- 6 files changed, 256 insertions(+), 10 deletions(-) diff --git a/Examples/GenerativeAICLI/Sources/GenerateContent.swift b/Examples/GenerativeAICLI/Sources/GenerateContent.swift index ab71c43..488b090 100644 --- a/Examples/GenerativeAICLI/Sources/GenerateContent.swift +++ b/Examples/GenerativeAICLI/Sources/GenerateContent.swift @@ -70,9 +70,20 @@ struct GenerateContent: AsyncParsableCommand { name: modelNameOrDefault(), apiKey: apiKey, generationConfig: config, - safetySettings: safetySettings + safetySettings: safetySettings, + tools: [Tool(functionDeclarations: [ + FunctionDeclaration( + name: "get_exchange_rate", + description: "Get the exchange rate for currencies between countries", + parameters: getExchangeRateSchema(), + function: getExchangeRateWrapper + ), + ])], + requestOptions: RequestOptions(apiVersion: "v1beta") ) + let chat = model.startChat() + var parts = [ModelContent.Part]() if let textPrompt = textPrompt { @@ -96,7 +107,7 @@ struct GenerateContent: AsyncParsableCommand { let input = [ModelContent(parts: parts)] if isStreaming { - let contentStream = model.generateContentStream(input) + let contentStream = chat.sendMessageStream(input) print("Generated Content :") for try await content in contentStream { if let text = content.text { @@ -104,7 +115,8 @@ struct GenerateContent: AsyncParsableCommand { } } } else { - let content = try await model.generateContent(input) + // Unary generate content + let content = try await chat.sendMessage(input) if let text = content.text { print("Generated Content:\n\(text)") } @@ -123,6 +135,76 @@ struct GenerateContent: AsyncParsableCommand { return "gemini-1.0-pro" } } + + // MARK: - Callable Functions + + // Returns exchange rates from the Frankfurter API + // This is an example function that a developer might provide. + func getExchangeRate(amount: Double, date: String, from: String, + to: String) async throws -> String { + var urlComponents = URLComponents(string: "https://api.frankfurter.app")! + urlComponents.path = "/\(date)" + urlComponents.queryItems = [ + .init(name: "amount", value: String(amount)), + .init(name: "from", value: from), + .init(name: "to", value: to), + ] + + let (data, _) = try await URLSession.shared.data(from: urlComponents.url!) + return String(data: data, encoding: .utf8)! + } + + // This is a wrapper for the `getExchangeRate` function. + func getExchangeRateWrapper(args: JSONObject) async throws -> JSONObject { + // 1. Validate and extract the parameters provided by the model (from a `FunctionCall`) + guard case let .string(date) = args["currency_date"] else { + fatalError() + } + guard case let .string(from) = args["currency_from"] else { + fatalError() + } + guard case let .string(to) = args["currency_to"] else { + fatalError() + } + guard case let .number(amount) = args["amount"] else { + fatalError() + } + + // 2. Call the wrapped function + let response = try await getExchangeRate(amount: amount, date: date, from: from, to: to) + + // 3. Return the exchange rates as a JSON object (returned to the model in a `FunctionResponse`) + return ["content": .string(response)] + } + + // Returns the schema of the `getExchangeRate` function + func getExchangeRateSchema() -> Schema { + return Schema( + type: .object, + properties: [ + "currency_date": Schema( + type: .string, + description: """ + A date that must always be in YYYY-MM-DD format or the value 'latest' if a time period + is not specified + """ + ), + "currency_from": Schema( + type: .string, + description: "The currency to convert from in ISO 4217 format" + ), + "currency_to": Schema( + type: .string, + description: "The currency to convert to in ISO 4217 format" + ), + "amount": Schema( + type: .number, + description: "The amount of currency to convert as a double value" + ), + ], + required: ["currency_date", "currency_from", "currency_to", "amount"] + ) + } } enum CLIError: Error { diff --git a/Sources/GoogleAI/Chat.swift b/Sources/GoogleAI/Chat.swift index 6e7e885..74fdebd 100644 --- a/Sources/GoogleAI/Chat.swift +++ b/Sources/GoogleAI/Chat.swift @@ -70,10 +70,32 @@ public class Chat { // Make sure we inject the role into the content received. let toAdd = ModelContent(role: "model", parts: reply.parts) + var functionResponses = [FunctionResponse]() + for part in reply.parts { + if case let .functionCall(functionCall) = part { + try functionResponses.append(await model.executeFunction(functionCall: functionCall)) + } + } + + // Call the functions requested by the model, if any. + let functionResponseContent = try ModelContent( + role: "function", + functionResponses.map { functionResponse in + ModelContent.Part.functionResponse(functionResponse) + } + ) + // Append the request and successful result to history, then return the value. history.append(contentsOf: newContent) history.append(toAdd) - return result + + // If no function calls requested, return the results. + if functionResponses.isEmpty { + return result + } + + // Re-send the message with the function responses. + return try await sendMessage([functionResponseContent]) } /// See ``sendMessageStream(_:)-4abs3``. @@ -166,6 +188,10 @@ public class Chat { case .functionCall: // TODO(andrewheard): Add function call to the chat history when encoding is implemented. fatalError("Function calling not yet implemented in chat.") + + case .functionResponse: + // TODO(andrewheard): Add function response to chat history when encoding is implemented. + fatalError("Function calling not yet implemented in chat.") } } } diff --git a/Sources/GoogleAI/FunctionCalling.swift b/Sources/GoogleAI/FunctionCalling.swift index 5d8ded5..e76e434 100644 --- a/Sources/GoogleAI/FunctionCalling.swift +++ b/Sources/GoogleAI/FunctionCalling.swift @@ -15,14 +15,97 @@ import Foundation /// A predicted function call returned from the model. -public struct FunctionCall: Equatable { +/// +/// REST Docs: https://ai.google.dev/api/rest/v1beta/Content#functioncall +public struct FunctionCall: Equatable, Encodable { /// The name of the function to call. - let name: String + public let name: String /// The function parameters and values. - let args: JSONObject + public let args: JSONObject } +// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool#schema +public class Schema: Encodable { + let type: DataType + + let format: String? + + let description: String? + + let nullable: Bool? + + let enumValues: [String]? + + let items: Schema? + + let properties: [String: Schema]? + + let required: [String]? + + public init(type: DataType, format: String? = nil, description: String? = nil, + nullable: Bool? = nil, + enumValues: [String]? = nil, items: Schema? = nil, + properties: [String: Schema]? = nil, + required: [String]? = nil) { + self.type = type + self.format = format + self.description = description + self.nullable = nullable + self.enumValues = enumValues + self.items = items + self.properties = properties + self.required = required + } +} + +// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool#Type +public enum DataType: String, Encodable { + case string = "STRING" + case number = "NUMBER" + case integer = "INTEGER" + case boolean = "BOOLEAN" + case array = "ARRAY" + case object = "OBJECT" +} + +// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool#FunctionDeclaration +public struct FunctionDeclaration { + let name: String + + let description: String + + let parameters: Schema + + let function: ((JSONObject) async throws -> JSONObject)? + + public init(name: String, description: String, parameters: Schema, + function: ((JSONObject) async throws -> JSONObject)?) { + self.name = name + self.description = description + self.parameters = parameters + self.function = function + } +} + +// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool +public struct Tool: Encodable { + let functionDeclarations: [FunctionDeclaration]? + + public init(functionDeclarations: [FunctionDeclaration]?) { + self.functionDeclarations = functionDeclarations + } +} + +// REST Docs: https://ai.google.dev/api/rest/v1beta/Content#functionresponse +public struct FunctionResponse: Equatable, Encodable { + let name: String + + let response: JSONObject +} + +// MARK: - Codable Conformance + extension FunctionCall: Decodable { enum CodingKeys: CodingKey { case name @@ -39,3 +122,18 @@ extension FunctionCall: Decodable { } } } + +extension FunctionDeclaration: Encodable { + enum CodingKeys: String, CodingKey { + case name + case description + case parameters + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(name, forKey: .name) + try container.encode(description, forKey: .description) + try container.encode(parameters, forKey: .parameters) + } +} diff --git a/Sources/GoogleAI/GenerateContentRequest.swift b/Sources/GoogleAI/GenerateContentRequest.swift index 417260b..535ac4d 100644 --- a/Sources/GoogleAI/GenerateContentRequest.swift +++ b/Sources/GoogleAI/GenerateContentRequest.swift @@ -21,6 +21,7 @@ struct GenerateContentRequest { let contents: [ModelContent] let generationConfig: GenerationConfig? let safetySettings: [SafetySetting]? + let tools: [Tool]? let isStreaming: Bool let options: RequestOptions } @@ -31,6 +32,7 @@ extension GenerateContentRequest: Encodable { case contents case generationConfig case safetySettings + case tools } } diff --git a/Sources/GoogleAI/GenerativeModel.swift b/Sources/GoogleAI/GenerativeModel.swift index 03e0191..39fdbb8 100644 --- a/Sources/GoogleAI/GenerativeModel.swift +++ b/Sources/GoogleAI/GenerativeModel.swift @@ -36,6 +36,8 @@ public final class GenerativeModel { /// The safety settings to be used for prompts. let safetySettings: [SafetySetting]? + let tools: [Tool]? + /// Configuration parameters for sending requests to the backend. let requestOptions: RequestOptions @@ -52,12 +54,14 @@ public final class GenerativeModel { apiKey: String, generationConfig: GenerationConfig? = nil, safetySettings: [SafetySetting]? = nil, + tools: [Tool]? = nil, requestOptions: RequestOptions = RequestOptions()) { self.init( name: name, apiKey: apiKey, generationConfig: generationConfig, safetySettings: safetySettings, + tools: tools, requestOptions: requestOptions, urlSession: .shared ) @@ -68,12 +72,14 @@ public final class GenerativeModel { apiKey: String, generationConfig: GenerationConfig? = nil, safetySettings: [SafetySetting]? = nil, + tools: [Tool]? = nil, requestOptions: RequestOptions = RequestOptions(), urlSession: URLSession) { modelResourceName = GenerativeModel.modelResourceName(name: name) generativeAIService = GenerativeAIService(apiKey: apiKey, urlSession: urlSession) self.generationConfig = generationConfig self.safetySettings = safetySettings + self.tools = tools self.requestOptions = requestOptions Logging.default.info(""" @@ -119,6 +125,7 @@ public final class GenerativeModel { contents: content(), generationConfig: generationConfig, safetySettings: safetySettings, + tools: tools, isStreaming: false, options: requestOptions) response = try await generativeAIService.loadRequest(request: generateContentRequest) @@ -190,6 +197,7 @@ public final class GenerativeModel { contents: evaluatedContent, generationConfig: generationConfig, safetySettings: safetySettings, + tools: tools, isStreaming: true, options: requestOptions) @@ -270,6 +278,30 @@ public final class GenerativeModel { } } + func executeFunction(functionCall: FunctionCall) async throws -> FunctionResponse { + guard let tools = tools else { + throw GenerateContentError.internalError(underlying: FunctionCallError()) + } + guard let tool = tools.first(where: { tool in + tool.functionDeclarations != nil + }) else { + throw GenerateContentError.internalError(underlying: FunctionCallError()) + } + guard let functionDeclaration = tool.functionDeclarations?.first(where: { functionDeclaration in + functionDeclaration.name == functionCall.name + }) else { + throw GenerateContentError.internalError(underlying: FunctionCallError()) + } + guard let function = functionDeclaration.function else { + throw GenerateContentError.internalError(underlying: FunctionCallError()) + } + + return try FunctionResponse( + name: functionCall.name, + response: await function(functionCall.args) + ) + } + /// Returns a model resource name of the form "models/model-name" based on `name`. private static func modelResourceName(name: String) -> String { if name.contains("/") { @@ -299,3 +331,5 @@ public final class GenerativeModel { public enum CountTokensError: Error { case internalError(underlying: Error) } + +struct FunctionCallError: Error {} diff --git a/Sources/GoogleAI/ModelContent.swift b/Sources/GoogleAI/ModelContent.swift index 2ce8876..4aefe7b 100644 --- a/Sources/GoogleAI/ModelContent.swift +++ b/Sources/GoogleAI/ModelContent.swift @@ -26,6 +26,7 @@ public struct ModelContent: Codable, Equatable { case text case inlineData case functionCall + case functionResponse } enum InlineDataKeys: String, CodingKey { @@ -42,6 +43,8 @@ public struct ModelContent: Codable, Equatable { /// A predicted function call returned from the model. case functionCall(FunctionCall) + case functionResponse(FunctionResponse) + // MARK: Convenience Initializers /// Convenience function for populating a Part with JPEG data. @@ -68,9 +71,10 @@ public struct ModelContent: Codable, Equatable { ) try inlineDataContainer.encode(mimetype, forKey: .mimeType) try inlineDataContainer.encode(bytes, forKey: .bytes) - case .functionCall: - // TODO(andrewheard): Encode FunctionCalls when when encoding is implemented. - fatalError("FunctionCall encoding not implemented.") + case let .functionCall(functionCall): + try container.encode(functionCall, forKey: .functionCall) + case let .functionResponse(functionResponse): + try container.encode(functionResponse, forKey: .functionResponse) } }