Skip to content

Commit

Permalink
Add request options to Vertex AI
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Mar 15, 2024
1 parent ee75091 commit de4fb30
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 14 deletions.
2 changes: 1 addition & 1 deletion FirebaseVertexAI/Sources/GenerativeAIRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ protocol GenerativeAIRequest: Encodable {

/// Configuration parameters for sending requests to the backend.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
public struct RequestOptions {
@objc public class RequestOptions: NSObject {
/// The request’s timeout interval in seconds; if not specified uses the default value for a
/// `URLRequest`.
let timeout: TimeInterval?
Expand Down
25 changes: 18 additions & 7 deletions FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ open class VertexAI: NSObject {
/// Returns an instance of `GoogleGenerativeAI.GenerativeModel` that uses the Vertex AI API.
///
/// This instance is configured with the default `FirebaseApp`.
public static func generativeModel(modelName: String, location: String) -> GenerativeModel {
public static func generativeModel(modelName: String, location: String,
requestOptions: RequestOptions = RequestOptions())
-> GenerativeModel {
guard let app = FirebaseApp.app() else {
fatalError("No instance of the default Firebase app was found.")
}
Expand All @@ -36,28 +38,37 @@ open class VertexAI: NSObject {

/// Returns an instance of `GoogleGenerativeAI.GenerativeModel` that uses the Vertex AI API.
public static func generativeModel(app: FirebaseApp, modelName: String,
location: String) -> GenerativeModel {
location: String,
requestOptions: RequestOptions = RequestOptions())
-> GenerativeModel {
guard let provider = ComponentType<VertexAIProvider>.instance(for: VertexAIProvider.self,
in: app.container) else {
fatalError("No \(VertexAIProvider.self) instance found for Firebase app: \(app.name)")
}
let modelResourceName = modelResourceName(app: app, modelName: modelName, location: location)
let vertexAI = provider.vertexAI(location: location, modelResourceName: modelResourceName)
let vertexAI = provider.vertexAI(
for: app,
location: location,
modelResourceName: modelResourceName,
requestOptions: requestOptions
)

return vertexAI.model
}

// MARK: - Internal

let location: String

let modelResouceName: String

// MARK: - Private

/// The `FirebaseApp` associated with this `VertexAI` instance.
private let app: FirebaseApp

private let appCheck: AppCheckInterop?

private let location: String

private let modelResouceName: String

lazy var model: GenerativeModel = {
guard let apiKey = app.options.apiKey else {
fatalError("The Firebase app named \"\(app.name)\" has no API key in its configuration.")
Expand Down
20 changes: 14 additions & 6 deletions FirebaseVertexAI/Sources/VertexAIComponent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import Foundation
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
@objc(FIRVertexAIProvider)
protocol VertexAIProvider {
@objc func vertexAI(location: String, modelResourceName: String) -> VertexAI
@objc func vertexAI(for app: FirebaseApp, location: String, modelResourceName: String,
requestOptions: RequestOptions) -> VertexAI
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
Expand All @@ -36,7 +37,7 @@ class VertexAIComponent: NSObject, Library, VertexAIProvider {

/// A map of active `VertexAI` instances for `app`, keyed by model resource names
/// (e.g., "projects/my-project-id/locations/us-central1/publishers/google/models/gemini-pro").
private var instances: [String: VertexAI] = [:]
private var instances: [String: [VertexAI]] = [:]

/// Lock to manage access to the `instances` array to avoid race conditions.
private var instancesLock: os_unfair_lock = .init()
Expand Down Expand Up @@ -64,17 +65,24 @@ class VertexAIComponent: NSObject, Library, VertexAIProvider {

// MARK: - VertexAIProvider conformance

func vertexAI(location: String, modelResourceName: String) -> VertexAI {
func vertexAI(for app: FirebaseApp, location: String, modelResourceName: String,
requestOptions: RequestOptions) -> VertexAI {
os_unfair_lock_lock(&instancesLock)

// Unlock before the function returns.
defer { os_unfair_lock_unlock(&instancesLock) }

if let instance = instances[modelResourceName] {
return instance
if let associatedInstances = instances[app.name] {
for instance in associatedInstances {
if instance.location == location && instance.modelResouceName == modelResourceName {
return instance
}
}
}

let newInstance = VertexAI(app: app, location: location, modelResourceName: modelResourceName)
instances[modelResourceName] = newInstance
let existingInstances = instances[app.name, default: []]
instances[app.name] = existingInstances + [newInstance]
return newInstance
}
}

0 comments on commit de4fb30

Please sign in to comment.