Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forced function calling #124

Merged
merged 6 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 121 additions & 14 deletions Examples/GenerativeAICLI/Sources/GenerateContent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ struct GenerateContent: AsyncParsableCommand {
help: "Enable additional debug logging."
) var debugLogEnabled = false

// Function calls pending processing
var functionCalls = [FunctionCall]()

// Input to the model
var input = [ModelContent]()

mutating func validate() throws {
if textPrompt == nil && imageURL == nil {
throw ValidationError(
Expand All @@ -70,7 +76,33 @@ struct GenerateContent: AsyncParsableCommand {
name: modelNameOrDefault(),
apiKey: apiKey,
generationConfig: config,
safetySettings: safetySettings
safetySettings: safetySettings,
tools: [Tool(functionDeclarations: [
FunctionDeclaration(
name: "get_exchange_rate",
description: "Get the exchange rate for currencies between countries",
parameters: [
"currency_from": Schema(
type: .string,
format: "enum",
description: "The currency to convert from in ISO 4217 format",
enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"]
),
"currency_to": Schema(
type: .string,
format: "enum",
description: "The currency to convert to in ISO 4217 format",
enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"]
),
],
requiredParameters: ["currency_from", "currency_to"]
),
])],
toolConfig: .init(functionCallingConfig: .init(
mode: .any,
allowedFunctionNames: ["get_exchange_rate"]
)),
requestOptions: RequestOptions(apiVersion: "v1beta")
)

var parts = [ModelContent.Part]()
Expand All @@ -93,27 +125,71 @@ struct GenerateContent: AsyncParsableCommand {
parts.append(.data(mimetype: mimeType, imageData))
}

let input = [ModelContent(parts: parts)]
input = [ModelContent(parts: parts)]

repeat {
try await processFunctionCalls()

if isStreaming {
let contentStream = model.generateContentStream(input)
print("Generated Content <streaming>:")
for try await content in contentStream {
if let text = content.text {
print(text)
if isStreaming {
let contentStream = model.generateContentStream(input)
print("Generated Content <streaming>:")
for try await content in contentStream {
processResponseContent(content: content)
}
} else {
// Unary generate content
let content = try await model.generateContent(input)
print("Generated Content:")
processResponseContent(content: content)
}
} else {
let content = try await model.generateContent(input)
if let text = content.text {
print("Generated Content:\n\(text)")
}
}
} while !functionCalls.isEmpty
} catch {
print("Generate Content Error: \(error)")
}
}

mutating func processResponseContent(content: GenerateContentResponse) {
guard let candidate = content.candidates.first else {
fatalError("No candidate.")
}

for part in candidate.content.parts {
switch part {
case let .text(text):
print(text)
case .data:
fatalError("Inline data not supported.")
case let .functionCall(functionCall):
functionCalls.append(functionCall)
case let .functionResponse(functionResponse):
print("FunctionResponse: \(functionResponse)")
}
}
}

mutating func processFunctionCalls() async throws {
for functionCall in functionCalls {
input.append(ModelContent(
role: "model",
parts: [ModelContent.Part.functionCall(functionCall)]
))
switch functionCall.name {
case "get_exchange_rate":
let exchangeRates = getExchangeRate(args: functionCall.args)
input.append(ModelContent(
role: "function",
parts: [ModelContent.Part.functionResponse(FunctionResponse(
name: "get_exchange_rate",
response: exchangeRates
))]
))
default:
fatalError("Unknown function named \"\(functionCall.name)\".")
}
}
functionCalls = []
}

func modelNameOrDefault() -> String {
if let modelName = modelName {
return modelName
Expand All @@ -123,6 +199,37 @@ struct GenerateContent: AsyncParsableCommand {
return "gemini-1.0-pro"
}
}

// MARK: - Callable Functions

func getExchangeRate(args: JSONObject) -> JSONObject {
// 1. Validate and extract the parameters provided by the model (from a `FunctionCall`)
guard case let .string(from) = args["currency_from"] else {
fatalError("Missing `currency_from` parameter.")
}
guard case let .string(to) = args["currency_to"] else {
fatalError("Missing `currency_to` parameter.")
}

// 2. Get the exchange rate
let allRates: [String: [String: Double]] = [
"AUD": ["CAD": 0.89265, "EUR": 0.6072, "GBP": 0.51714, "JPY": 97.75, "USD": 0.66379],
"CAD": ["AUD": 1.1203, "EUR": 0.68023, "GBP": 0.57933, "JPY": 109.51, "USD": 0.74362],
"EUR": ["AUD": 1.6469, "CAD": 1.4701, "GBP": 0.85168, "JPY": 160.99, "USD": 1.0932],
"GBP": ["AUD": 1.9337, "CAD": 1.7261, "EUR": 1.1741, "JPY": 189.03, "USD": 1.2836],
"JPY": ["AUD": 0.01023, "CAD": 0.00913, "EUR": 0.00621, "GBP": 0.00529, "USD": 0.00679],
"USD": ["AUD": 1.5065, "CAD": 1.3448, "EUR": 0.91475, "GBP": 0.77907, "JPY": 147.26],
]
guard let fromRates = allRates[from] else {
return ["error": .string("No data for currency \(from).")]
}
guard let toRate = fromRates[to] else {
return ["error": .string("No data for currency \(to).")]
}

// 3. Return the exchange rates as a JSON object (returned to the model in a `FunctionResponse`)
return ["rates": .number(toRate)]
}
}

enum CLIError: Error {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

import SwiftUI

public struct InputField<Label>: View where Label: View {
@Binding
private var text: String
Expand Down
2 changes: 1 addition & 1 deletion Mintfile
Original file line number Diff line number Diff line change
@@ -1 +1 @@
nicklockwood/SwiftFormat@0.52.10
nicklockwood/SwiftFormat@0.53.5
45 changes: 45 additions & 0 deletions Sources/GoogleAI/FunctionCalling.swift
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,51 @@ public struct Tool: Encodable {
}
}

/// Defines the execution behavior for function calling by defining the
/// execution mode.
public enum FunctionCallingMode: String, Encodable {
/// The default behavior for function calling. The model calls functions to answer queries at its
/// discretion.
case auto = "AUTO"

/// The model always predicts a provided function call to answer every query.
case any = "ANY"

/// The model will never predict a function call to answer a query. This can also be achieved by
/// not passing any tools to the model.
case none = "NONE"
}

/// Configuration for specifying function calling behavior.
public struct FunctionCallingConfig: Encodable {
/// Specifies the mode in which function calling should execute. If
/// unspecified, the default value will be set to AUTO.
let mode: FunctionCallingMode?

/// A set of function names that, when provided, limits the functions the model
/// will call.
///
/// This should only be set when the Mode is ANY. Function names
/// should match [FunctionDeclaration.name]. With mode set to ANY, model will
/// predict a function call from the set of function names provided.
let allowedFunctionNames: [String]?

public init(mode: FunctionCallingMode? = nil, allowedFunctionNames: [String]? = nil) {
self.mode = mode
self.allowedFunctionNames = allowedFunctionNames
}
}

/// Tool configuration for any `Tool` specified in the request.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
public struct ToolConfig: Encodable {
let functionCallingConfig: FunctionCallingConfig?

public init(functionCallingConfig: FunctionCallingConfig? = nil) {
self.functionCallingConfig = functionCallingConfig
}
}

/// Result output from a ``FunctionCall``.
///
/// Contains a string representing the `FunctionDeclaration.name` and a structured JSON object
Expand Down
2 changes: 2 additions & 0 deletions Sources/GoogleAI/GenerateContentRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ struct GenerateContentRequest {
let generationConfig: GenerationConfig?
let safetySettings: [SafetySetting]?
let tools: [Tool]?
let toolConfig: ToolConfig?
let isStreaming: Bool
let options: RequestOptions
}
Expand All @@ -33,6 +34,7 @@ extension GenerateContentRequest: Encodable {
case generationConfig
case safetySettings
case tools
case toolConfig
}
}

Expand Down
10 changes: 10 additions & 0 deletions Sources/GoogleAI/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ public final class GenerativeModel {
/// A list of tools the model may use to generate the next response.
let tools: [Tool]?

// Tool configuration for any `Tool` specified in the request.
let toolConfig: ToolConfig?

/// Configuration parameters for sending requests to the backend.
let requestOptions: RequestOptions

Expand All @@ -48,19 +51,22 @@ public final class GenerativeModel {
/// - generationConfig: The content generation parameters your model should use.
/// - safetySettings: A value describing what types of harmful content your model should allow.
/// - tools: A list of ``Tool`` objects that the model may use to generate the next response.
/// - toolConfig: Tool configuration for any `Tool` specified in the request.
/// - requestOptions Configuration parameters for sending requests to the backend.
public convenience init(name: String,
apiKey: String,
generationConfig: GenerationConfig? = nil,
safetySettings: [SafetySetting]? = nil,
tools: [Tool]? = nil,
toolConfig: ToolConfig? = nil,
requestOptions: RequestOptions = RequestOptions()) {
self.init(
name: name,
apiKey: apiKey,
generationConfig: generationConfig,
safetySettings: safetySettings,
tools: tools,
toolConfig: toolConfig,
requestOptions: requestOptions,
urlSession: .shared
)
Expand All @@ -72,13 +78,15 @@ public final class GenerativeModel {
generationConfig: GenerationConfig? = nil,
safetySettings: [SafetySetting]? = nil,
tools: [Tool]? = nil,
toolConfig: ToolConfig? = nil,
requestOptions: RequestOptions = RequestOptions(),
urlSession: URLSession) {
modelResourceName = GenerativeModel.modelResourceName(name: name)
generativeAIService = GenerativeAIService(apiKey: apiKey, urlSession: urlSession)
self.generationConfig = generationConfig
self.safetySettings = safetySettings
self.tools = tools
self.toolConfig = toolConfig
self.requestOptions = requestOptions

Logging.default.info("""
Expand Down Expand Up @@ -125,6 +133,7 @@ public final class GenerativeModel {
generationConfig: generationConfig,
safetySettings: safetySettings,
tools: tools,
toolConfig: toolConfig,
isStreaming: false,
options: requestOptions)
response = try await generativeAIService.loadRequest(request: generateContentRequest)
Expand Down Expand Up @@ -197,6 +206,7 @@ public final class GenerativeModel {
generationConfig: generationConfig,
safetySettings: safetySettings,
tools: tools,
toolConfig: toolConfig,
isStreaming: true,
options: requestOptions)

Expand Down
Loading