Skip to content

Commit

Permalink
[Vertex AI] Add ImagenModelConfig for model-level config params (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Jan 7, 2025
1 parent 53d43d8 commit 527d4af
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,11 @@ public struct ImagenGenerationConfig {
public var numberOfImages: Int?
public var negativePrompt: String?
public var aspectRatio: ImagenAspectRatio?
public var imageFormat: ImagenImageFormat?
public var addWatermark: Bool?

public init(numberOfImages: Int? = nil,
negativePrompt: String? = nil,
aspectRatio: ImagenAspectRatio? = nil,
imageFormat: ImagenImageFormat? = nil,
addWatermark: Bool? = nil) {
public init(numberOfImages: Int? = nil, negativePrompt: String? = nil,
aspectRatio: ImagenAspectRatio? = nil) {
self.numberOfImages = numberOfImages
self.negativePrompt = negativePrompt
self.aspectRatio = aspectRatio
self.imageFormat = imageFormat
self.addWatermark = addWatermark
}
}
11 changes: 9 additions & 2 deletions FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public final class ImagenModel {
/// The backing service responsible for sending and receiving model requests to the backend.
let generativeAIService: GenerativeAIService

let modelConfig: ImagenModelConfig?

let safetySettings: ImagenSafetySettings?

/// Configuration parameters for sending requests to the backend.
Expand All @@ -32,6 +34,7 @@ public final class ImagenModel {
init(name: String,
projectID: String,
apiKey: String,
modelConfig: ImagenModelConfig?,
safetySettings: ImagenSafetySettings?,
requestOptions: RequestOptions,
appCheck: AppCheckInterop?,
Expand All @@ -45,6 +48,7 @@ public final class ImagenModel {
auth: auth,
urlSession: urlSession
)
self.modelConfig = modelConfig
self.safetySettings = safetySettings
self.requestOptions = requestOptions
}
Expand All @@ -57,6 +61,7 @@ public final class ImagenModel {
parameters: ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: generationConfig,
modelConfig: modelConfig,
safetySettings: safetySettings
)
)
Expand All @@ -70,6 +75,7 @@ public final class ImagenModel {
parameters: ImagenModel.imageGenerationParameters(
storageURI: storageURI,
generationConfig: generationConfig,
modelConfig: modelConfig,
safetySettings: safetySettings
)
)
Expand All @@ -90,6 +96,7 @@ public final class ImagenModel {

static func imageGenerationParameters(storageURI: String?,
generationConfig: ImagenGenerationConfig?,
modelConfig: ImagenModelConfig?,
safetySettings: ImagenSafetySettings?)
-> ImageGenerationParameters {
return ImageGenerationParameters(
Expand All @@ -99,13 +106,13 @@ public final class ImagenModel {
aspectRatio: generationConfig?.aspectRatio?.rawValue,
safetyFilterLevel: safetySettings?.safetyFilterLevel?.rawValue,
personGeneration: safetySettings?.personFilterLevel?.rawValue,
outputOptions: generationConfig?.imageFormat.map {
outputOptions: modelConfig?.imageFormat.map {
ImageGenerationOutputOptions(
mimeType: $0.mimeType,
compressionQuality: $0.compressionQuality
)
},
addWatermark: generationConfig?.addWatermark,
addWatermark: modelConfig?.addWatermark,
includeResponsibleAIFilterReason: true
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ImagenModelConfig {
let imageFormat: ImagenImageFormat?
let addWatermark: Bool?

public init(imageFormat: ImagenImageFormat? = nil, addWatermark: Bool? = nil) {
self.imageFormat = imageFormat
self.addWatermark = addWatermark
}
}
4 changes: 3 additions & 1 deletion FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,14 @@ public class VertexAI {
)
}

public func imagenModel(modelName: String, safetySettings: ImagenSafetySettings? = nil,
public func imagenModel(modelName: String, modelConfig: ImagenModelConfig? = nil,
safetySettings: ImagenSafetySettings? = nil,
requestOptions: RequestOptions = RequestOptions()) -> ImagenModel {
return ImagenModel(
name: modelResourceName(modelName: modelName),
projectID: projectID,
apiKey: apiKey,
modelConfig: modelConfig,
safetySettings: safetySettings,
requestOptions: requestOptions,
appCheck: appCheck,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ final class IntegrationTests: XCTestCase {
)
imagenModel = vertex.imagenModel(
modelName: "imagen-3.0-fast-generate-001",
modelConfig: ImagenModelConfig(imageFormat: .jpeg(compressionQuality: 70)),
safetySettings: ImagenSafetySettings(
safetyFilterLevel: .blockLowAndAbove,
personFilterLevel: .blockAll
Expand Down Expand Up @@ -253,9 +254,7 @@ final class IntegrationTests: XCTestCase {
overlooking a vast African savanna at sunset. Golden hour light, long shadows, sharp focus on
the lion, shallow depth of field, detailed fur texture, DSLR, 85mm lens.
"""
var generationConfig = ImagenGenerationConfig()
generationConfig.aspectRatio = .landscape16x9
generationConfig.imageFormat = .jpeg(compressionQuality: 70)
let generationConfig = ImagenGenerationConfig(aspectRatio: .landscape16x9)

let response = try await imagenModel.generateImages(
prompt: imagePrompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ final class ImageGenerationParametersTests: XCTestCase {
let parameters = ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: nil,
modelConfig: nil,
safetySettings: nil
)

Expand All @@ -63,6 +64,37 @@ final class ImageGenerationParametersTests: XCTestCase {
let parameters = ImagenModel.imageGenerationParameters(
storageURI: storageURI,
generationConfig: nil,
modelConfig: nil,
safetySettings: nil
)

XCTAssertEqual(parameters, expectedParameters)
}

func testParameters_includeModelConfig() throws {
let compressionQuality = 80
let imageFormat = ImagenImageFormat.jpeg(compressionQuality: compressionQuality)
let addWatermark = true
let modelConfig = ImagenModelConfig(imageFormat: imageFormat, addWatermark: addWatermark)
let expectedParameters = ImageGenerationParameters(
sampleCount: 1,
storageURI: nil,
negativePrompt: nil,
aspectRatio: nil,
safetyFilterLevel: nil,
personGeneration: nil,
outputOptions: ImageGenerationOutputOptions(
mimeType: imageFormat.mimeType,
compressionQuality: imageFormat.compressionQuality
),
addWatermark: addWatermark,
includeResponsibleAIFilterReason: true
)

let parameters = ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: nil,
modelConfig: modelConfig,
safetySettings: nil
)

Expand All @@ -73,15 +105,10 @@ final class ImageGenerationParametersTests: XCTestCase {
let sampleCount = 2
let negativePrompt = "test-negative-prompt"
let aspectRatio = ImagenAspectRatio.landscape16x9
let compressionQuality = 80
let imageFormat = ImagenImageFormat.jpeg(compressionQuality: compressionQuality)
let addWatermark = true
let generationConfig = ImagenGenerationConfig(
numberOfImages: sampleCount,
negativePrompt: negativePrompt,
aspectRatio: aspectRatio,
imageFormat: imageFormat,
addWatermark: addWatermark
aspectRatio: aspectRatio
)
let expectedParameters = ImageGenerationParameters(
sampleCount: sampleCount,
Expand All @@ -90,24 +117,20 @@ final class ImageGenerationParametersTests: XCTestCase {
aspectRatio: aspectRatio.rawValue,
safetyFilterLevel: nil,
personGeneration: nil,
outputOptions: ImageGenerationOutputOptions(
mimeType: imageFormat.mimeType,
compressionQuality: imageFormat.compressionQuality
),
addWatermark: addWatermark,
outputOptions: nil,
addWatermark: nil,
includeResponsibleAIFilterReason: true
)

let parameters = ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: generationConfig,
modelConfig: nil,
safetySettings: nil
)

XCTAssertEqual(parameters, expectedParameters)
XCTAssertEqual(parameters.aspectRatio, "16:9")
XCTAssertEqual(parameters.outputOptions?.mimeType, "image/jpeg")
XCTAssertEqual(parameters.outputOptions?.compressionQuality, compressionQuality)
}

func testDefaultParameters_includeSafetySettings() throws {
Expand All @@ -132,6 +155,7 @@ final class ImageGenerationParametersTests: XCTestCase {
let parameters = ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: nil,
modelConfig: nil,
safetySettings: safetySettings
)

Expand All @@ -145,15 +169,14 @@ final class ImageGenerationParametersTests: XCTestCase {
let sampleCount = 4
let negativePrompt = "test-negative-prompt"
let aspectRatio = ImagenAspectRatio.portrait3x4
let imageFormat = ImagenImageFormat.png()
let addWatermark = false
let generationConfig = ImagenGenerationConfig(
numberOfImages: sampleCount,
negativePrompt: negativePrompt,
aspectRatio: aspectRatio,
imageFormat: imageFormat,
addWatermark: addWatermark
aspectRatio: aspectRatio
)
let imageFormat = ImagenImageFormat.png()
let addWatermark = false
let modelConfig = ImagenModelConfig(imageFormat: imageFormat, addWatermark: addWatermark)
let safetyFilterLevel = ImagenSafetyFilterLevel.blockNone
let personFilterLevel = ImagenPersonFilterLevel.blockAll
let safetySettings = ImagenSafetySettings(
Expand All @@ -178,6 +201,7 @@ final class ImageGenerationParametersTests: XCTestCase {
let parameters = ImagenModel.imageGenerationParameters(
storageURI: storageURI,
generationConfig: generationConfig,
modelConfig: modelConfig,
safetySettings: safetySettings
)

Expand Down

0 comments on commit 527d4af

Please sign in to comment.