From 527d4af572d162b52b1d833ed1938e1dcd8cf1f6 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Tue, 7 Jan 2025 16:55:20 -0500 Subject: [PATCH] [Vertex AI] Add `ImagenModelConfig` for model-level config params (#14315) --- .../Imagen/ImagenGenerationConfig.swift | 11 +--- .../Types/Public/Imagen/ImagenModel.swift | 11 +++- .../Public/Imagen/ImagenModelConfig.swift | 24 ++++++++ FirebaseVertexAI/Sources/VertexAI.swift | 4 +- .../Tests/Integration/IntegrationTests.swift | 5 +- .../ImageGenerationParametersTests.swift | 60 +++++++++++++------ 6 files changed, 82 insertions(+), 33 deletions(-) create mode 100644 FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModelConfig.swift diff --git a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenGenerationConfig.swift b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenGenerationConfig.swift index f15838b948e..ad26b9d5538 100644 --- a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenGenerationConfig.swift +++ b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenGenerationConfig.swift @@ -17,18 +17,11 @@ public struct ImagenGenerationConfig { public var numberOfImages: Int? public var negativePrompt: String? public var aspectRatio: ImagenAspectRatio? - public var imageFormat: ImagenImageFormat? - public var addWatermark: Bool? - public init(numberOfImages: Int? = nil, - negativePrompt: String? = nil, - aspectRatio: ImagenAspectRatio? = nil, - imageFormat: ImagenImageFormat? = nil, - addWatermark: Bool? = nil) { + public init(numberOfImages: Int? = nil, negativePrompt: String? = nil, + aspectRatio: ImagenAspectRatio? = nil) { self.numberOfImages = numberOfImages self.negativePrompt = negativePrompt self.aspectRatio = aspectRatio - self.imageFormat = imageFormat - self.addWatermark = addWatermark } } diff --git a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift index afd7d370448..5578afe110f 100644 --- a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift +++ b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift @@ -24,6 +24,8 @@ public final class ImagenModel { /// The backing service responsible for sending and receiving model requests to the backend. let generativeAIService: GenerativeAIService + let modelConfig: ImagenModelConfig? + let safetySettings: ImagenSafetySettings? /// Configuration parameters for sending requests to the backend. @@ -32,6 +34,7 @@ public final class ImagenModel { init(name: String, projectID: String, apiKey: String, + modelConfig: ImagenModelConfig?, safetySettings: ImagenSafetySettings?, requestOptions: RequestOptions, appCheck: AppCheckInterop?, @@ -45,6 +48,7 @@ public final class ImagenModel { auth: auth, urlSession: urlSession ) + self.modelConfig = modelConfig self.safetySettings = safetySettings self.requestOptions = requestOptions } @@ -57,6 +61,7 @@ public final class ImagenModel { parameters: ImagenModel.imageGenerationParameters( storageURI: nil, generationConfig: generationConfig, + modelConfig: modelConfig, safetySettings: safetySettings ) ) @@ -70,6 +75,7 @@ public final class ImagenModel { parameters: ImagenModel.imageGenerationParameters( storageURI: storageURI, generationConfig: generationConfig, + modelConfig: modelConfig, safetySettings: safetySettings ) ) @@ -90,6 +96,7 @@ public final class ImagenModel { static func imageGenerationParameters(storageURI: String?, generationConfig: ImagenGenerationConfig?, + modelConfig: ImagenModelConfig?, safetySettings: ImagenSafetySettings?) -> ImageGenerationParameters { return ImageGenerationParameters( @@ -99,13 +106,13 @@ public final class ImagenModel { aspectRatio: generationConfig?.aspectRatio?.rawValue, safetyFilterLevel: safetySettings?.safetyFilterLevel?.rawValue, personGeneration: safetySettings?.personFilterLevel?.rawValue, - outputOptions: generationConfig?.imageFormat.map { + outputOptions: modelConfig?.imageFormat.map { ImageGenerationOutputOptions( mimeType: $0.mimeType, compressionQuality: $0.compressionQuality ) }, - addWatermark: generationConfig?.addWatermark, + addWatermark: modelConfig?.addWatermark, includeResponsibleAIFilterReason: true ) } diff --git a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModelConfig.swift b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModelConfig.swift new file mode 100644 index 00000000000..36bcad5940e --- /dev/null +++ b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModelConfig.swift @@ -0,0 +1,24 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public struct ImagenModelConfig { + let imageFormat: ImagenImageFormat? + let addWatermark: Bool? + + public init(imageFormat: ImagenImageFormat? = nil, addWatermark: Bool? = nil) { + self.imageFormat = imageFormat + self.addWatermark = addWatermark + } +} diff --git a/FirebaseVertexAI/Sources/VertexAI.swift b/FirebaseVertexAI/Sources/VertexAI.swift index a2d3c62f529..6d623858c97 100644 --- a/FirebaseVertexAI/Sources/VertexAI.swift +++ b/FirebaseVertexAI/Sources/VertexAI.swift @@ -104,12 +104,14 @@ public class VertexAI { ) } - public func imagenModel(modelName: String, safetySettings: ImagenSafetySettings? = nil, + public func imagenModel(modelName: String, modelConfig: ImagenModelConfig? = nil, + safetySettings: ImagenSafetySettings? = nil, requestOptions: RequestOptions = RequestOptions()) -> ImagenModel { return ImagenModel( name: modelResourceName(modelName: modelName), projectID: projectID, apiKey: apiKey, + modelConfig: modelConfig, safetySettings: safetySettings, requestOptions: requestOptions, appCheck: appCheck, diff --git a/FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift b/FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift index 69ede59d802..4f96cae0777 100644 --- a/FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift +++ b/FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift @@ -63,6 +63,7 @@ final class IntegrationTests: XCTestCase { ) imagenModel = vertex.imagenModel( modelName: "imagen-3.0-fast-generate-001", + modelConfig: ImagenModelConfig(imageFormat: .jpeg(compressionQuality: 70)), safetySettings: ImagenSafetySettings( safetyFilterLevel: .blockLowAndAbove, personFilterLevel: .blockAll @@ -253,9 +254,7 @@ final class IntegrationTests: XCTestCase { overlooking a vast African savanna at sunset. Golden hour light, long shadows, sharp focus on the lion, shallow depth of field, detailed fur texture, DSLR, 85mm lens. """ - var generationConfig = ImagenGenerationConfig() - generationConfig.aspectRatio = .landscape16x9 - generationConfig.imageFormat = .jpeg(compressionQuality: 70) + let generationConfig = ImagenGenerationConfig(aspectRatio: .landscape16x9) let response = try await imagenModel.generateImages( prompt: imagePrompt, diff --git a/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationParametersTests.swift b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationParametersTests.swift index 6b313afcd8c..ceb4132d156 100644 --- a/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationParametersTests.swift +++ b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationParametersTests.swift @@ -40,6 +40,7 @@ final class ImageGenerationParametersTests: XCTestCase { let parameters = ImagenModel.imageGenerationParameters( storageURI: nil, generationConfig: nil, + modelConfig: nil, safetySettings: nil ) @@ -63,6 +64,37 @@ final class ImageGenerationParametersTests: XCTestCase { let parameters = ImagenModel.imageGenerationParameters( storageURI: storageURI, generationConfig: nil, + modelConfig: nil, + safetySettings: nil + ) + + XCTAssertEqual(parameters, expectedParameters) + } + + func testParameters_includeModelConfig() throws { + let compressionQuality = 80 + let imageFormat = ImagenImageFormat.jpeg(compressionQuality: compressionQuality) + let addWatermark = true + let modelConfig = ImagenModelConfig(imageFormat: imageFormat, addWatermark: addWatermark) + let expectedParameters = ImageGenerationParameters( + sampleCount: 1, + storageURI: nil, + negativePrompt: nil, + aspectRatio: nil, + safetyFilterLevel: nil, + personGeneration: nil, + outputOptions: ImageGenerationOutputOptions( + mimeType: imageFormat.mimeType, + compressionQuality: imageFormat.compressionQuality + ), + addWatermark: addWatermark, + includeResponsibleAIFilterReason: true + ) + + let parameters = ImagenModel.imageGenerationParameters( + storageURI: nil, + generationConfig: nil, + modelConfig: modelConfig, safetySettings: nil ) @@ -73,15 +105,10 @@ final class ImageGenerationParametersTests: XCTestCase { let sampleCount = 2 let negativePrompt = "test-negative-prompt" let aspectRatio = ImagenAspectRatio.landscape16x9 - let compressionQuality = 80 - let imageFormat = ImagenImageFormat.jpeg(compressionQuality: compressionQuality) - let addWatermark = true let generationConfig = ImagenGenerationConfig( numberOfImages: sampleCount, negativePrompt: negativePrompt, - aspectRatio: aspectRatio, - imageFormat: imageFormat, - addWatermark: addWatermark + aspectRatio: aspectRatio ) let expectedParameters = ImageGenerationParameters( sampleCount: sampleCount, @@ -90,24 +117,20 @@ final class ImageGenerationParametersTests: XCTestCase { aspectRatio: aspectRatio.rawValue, safetyFilterLevel: nil, personGeneration: nil, - outputOptions: ImageGenerationOutputOptions( - mimeType: imageFormat.mimeType, - compressionQuality: imageFormat.compressionQuality - ), - addWatermark: addWatermark, + outputOptions: nil, + addWatermark: nil, includeResponsibleAIFilterReason: true ) let parameters = ImagenModel.imageGenerationParameters( storageURI: nil, generationConfig: generationConfig, + modelConfig: nil, safetySettings: nil ) XCTAssertEqual(parameters, expectedParameters) XCTAssertEqual(parameters.aspectRatio, "16:9") - XCTAssertEqual(parameters.outputOptions?.mimeType, "image/jpeg") - XCTAssertEqual(parameters.outputOptions?.compressionQuality, compressionQuality) } func testDefaultParameters_includeSafetySettings() throws { @@ -132,6 +155,7 @@ final class ImageGenerationParametersTests: XCTestCase { let parameters = ImagenModel.imageGenerationParameters( storageURI: nil, generationConfig: nil, + modelConfig: nil, safetySettings: safetySettings ) @@ -145,15 +169,14 @@ final class ImageGenerationParametersTests: XCTestCase { let sampleCount = 4 let negativePrompt = "test-negative-prompt" let aspectRatio = ImagenAspectRatio.portrait3x4 - let imageFormat = ImagenImageFormat.png() - let addWatermark = false let generationConfig = ImagenGenerationConfig( numberOfImages: sampleCount, negativePrompt: negativePrompt, - aspectRatio: aspectRatio, - imageFormat: imageFormat, - addWatermark: addWatermark + aspectRatio: aspectRatio ) + let imageFormat = ImagenImageFormat.png() + let addWatermark = false + let modelConfig = ImagenModelConfig(imageFormat: imageFormat, addWatermark: addWatermark) let safetyFilterLevel = ImagenSafetyFilterLevel.blockNone let personFilterLevel = ImagenPersonFilterLevel.blockAll let safetySettings = ImagenSafetySettings( @@ -178,6 +201,7 @@ final class ImageGenerationParametersTests: XCTestCase { let parameters = ImagenModel.imageGenerationParameters( storageURI: storageURI, generationConfig: generationConfig, + modelConfig: modelConfig, safetySettings: safetySettings )