From de4fb30db076f0e6fd45fd3175091aebd7aeb926 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Fri, 15 Mar 2024 17:00:16 -0400 Subject: [PATCH] Add request options to Vertex AI --- .../Sources/GenerativeAIRequest.swift | 2 +- FirebaseVertexAI/Sources/VertexAI.swift | 25 +++++++++++++------ .../Sources/VertexAIComponent.swift | 20 ++++++++++----- 3 files changed, 33 insertions(+), 14 deletions(-) diff --git a/FirebaseVertexAI/Sources/GenerativeAIRequest.swift b/FirebaseVertexAI/Sources/GenerativeAIRequest.swift index 21f35b1b728..8acdac94b79 100644 --- a/FirebaseVertexAI/Sources/GenerativeAIRequest.swift +++ b/FirebaseVertexAI/Sources/GenerativeAIRequest.swift @@ -25,7 +25,7 @@ protocol GenerativeAIRequest: Encodable { /// Configuration parameters for sending requests to the backend. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) -public struct RequestOptions { +@objc public class RequestOptions: NSObject { /// The request’s timeout interval in seconds; if not specified uses the default value for a /// `URLRequest`. let timeout: TimeInterval? diff --git a/FirebaseVertexAI/Sources/VertexAI.swift b/FirebaseVertexAI/Sources/VertexAI.swift index 16fc0d20754..cb1b46cbb54 100644 --- a/FirebaseVertexAI/Sources/VertexAI.swift +++ b/FirebaseVertexAI/Sources/VertexAI.swift @@ -27,7 +27,9 @@ open class VertexAI: NSObject { /// Returns an instance of `GoogleGenerativeAI.GenerativeModel` that uses the Vertex AI API. /// /// This instance is configured with the default `FirebaseApp`. - public static func generativeModel(modelName: String, location: String) -> GenerativeModel { + public static func generativeModel(modelName: String, location: String, + requestOptions: RequestOptions = RequestOptions()) + -> GenerativeModel { guard let app = FirebaseApp.app() else { fatalError("No instance of the default Firebase app was found.") } @@ -36,17 +38,30 @@ open class VertexAI: NSObject { /// Returns an instance of `GoogleGenerativeAI.GenerativeModel` that uses the Vertex AI API. public static func generativeModel(app: FirebaseApp, modelName: String, - location: String) -> GenerativeModel { + location: String, + requestOptions: RequestOptions = RequestOptions()) + -> GenerativeModel { guard let provider = ComponentType.instance(for: VertexAIProvider.self, in: app.container) else { fatalError("No \(VertexAIProvider.self) instance found for Firebase app: \(app.name)") } let modelResourceName = modelResourceName(app: app, modelName: modelName, location: location) - let vertexAI = provider.vertexAI(location: location, modelResourceName: modelResourceName) + let vertexAI = provider.vertexAI( + for: app, + location: location, + modelResourceName: modelResourceName, + requestOptions: requestOptions + ) return vertexAI.model } + // MARK: - Internal + + let location: String + + let modelResouceName: String + // MARK: - Private /// The `FirebaseApp` associated with this `VertexAI` instance. @@ -54,10 +69,6 @@ open class VertexAI: NSObject { private let appCheck: AppCheckInterop? - private let location: String - - private let modelResouceName: String - lazy var model: GenerativeModel = { guard let apiKey = app.options.apiKey else { fatalError("The Firebase app named \"\(app.name)\" has no API key in its configuration.") diff --git a/FirebaseVertexAI/Sources/VertexAIComponent.swift b/FirebaseVertexAI/Sources/VertexAIComponent.swift index a8d6c177c74..b53df128dff 100644 --- a/FirebaseVertexAI/Sources/VertexAIComponent.swift +++ b/FirebaseVertexAI/Sources/VertexAIComponent.swift @@ -22,7 +22,8 @@ import Foundation @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) @objc(FIRVertexAIProvider) protocol VertexAIProvider { - @objc func vertexAI(location: String, modelResourceName: String) -> VertexAI + @objc func vertexAI(for app: FirebaseApp, location: String, modelResourceName: String, + requestOptions: RequestOptions) -> VertexAI } @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) @@ -36,7 +37,7 @@ class VertexAIComponent: NSObject, Library, VertexAIProvider { /// A map of active `VertexAI` instances for `app`, keyed by model resource names /// (e.g., "projects/my-project-id/locations/us-central1/publishers/google/models/gemini-pro"). - private var instances: [String: VertexAI] = [:] + private var instances: [String: [VertexAI]] = [:] /// Lock to manage access to the `instances` array to avoid race conditions. private var instancesLock: os_unfair_lock = .init() @@ -64,17 +65,24 @@ class VertexAIComponent: NSObject, Library, VertexAIProvider { // MARK: - VertexAIProvider conformance - func vertexAI(location: String, modelResourceName: String) -> VertexAI { + func vertexAI(for app: FirebaseApp, location: String, modelResourceName: String, + requestOptions: RequestOptions) -> VertexAI { os_unfair_lock_lock(&instancesLock) // Unlock before the function returns. defer { os_unfair_lock_unlock(&instancesLock) } - if let instance = instances[modelResourceName] { - return instance + if let associatedInstances = instances[app.name] { + for instance in associatedInstances { + if instance.location == location && instance.modelResouceName == modelResourceName { + return instance + } + } } + let newInstance = VertexAI(app: app, location: location, modelResourceName: modelResourceName) - instances[modelResourceName] = newInstance + let existingInstances = instances[app.name, default: []] + instances[app.name] = existingInstances + [newInstance] return newInstance } }