Skip to content

Commit 527d4af

Browse files
authored
[Vertex AI] Add ImagenModelConfig for model-level config params (#14315)
1 parent 53d43d8 commit 527d4af

File tree

6 files changed

+82
-33
lines changed

6 files changed

+82
-33
lines changed

FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenGenerationConfig.swift

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,11 @@ public struct ImagenGenerationConfig {
1717
public var numberOfImages: Int?
1818
public var negativePrompt: String?
1919
public var aspectRatio: ImagenAspectRatio?
20-
public var imageFormat: ImagenImageFormat?
21-
public var addWatermark: Bool?
2220

23-
public init(numberOfImages: Int? = nil,
24-
negativePrompt: String? = nil,
25-
aspectRatio: ImagenAspectRatio? = nil,
26-
imageFormat: ImagenImageFormat? = nil,
27-
addWatermark: Bool? = nil) {
21+
public init(numberOfImages: Int? = nil, negativePrompt: String? = nil,
22+
aspectRatio: ImagenAspectRatio? = nil) {
2823
self.numberOfImages = numberOfImages
2924
self.negativePrompt = negativePrompt
3025
self.aspectRatio = aspectRatio
31-
self.imageFormat = imageFormat
32-
self.addWatermark = addWatermark
3326
}
3427
}

FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ public final class ImagenModel {
2424
/// The backing service responsible for sending and receiving model requests to the backend.
2525
let generativeAIService: GenerativeAIService
2626

27+
let modelConfig: ImagenModelConfig?
28+
2729
let safetySettings: ImagenSafetySettings?
2830

2931
/// Configuration parameters for sending requests to the backend.
@@ -32,6 +34,7 @@ public final class ImagenModel {
3234
init(name: String,
3335
projectID: String,
3436
apiKey: String,
37+
modelConfig: ImagenModelConfig?,
3538
safetySettings: ImagenSafetySettings?,
3639
requestOptions: RequestOptions,
3740
appCheck: AppCheckInterop?,
@@ -45,6 +48,7 @@ public final class ImagenModel {
4548
auth: auth,
4649
urlSession: urlSession
4750
)
51+
self.modelConfig = modelConfig
4852
self.safetySettings = safetySettings
4953
self.requestOptions = requestOptions
5054
}
@@ -57,6 +61,7 @@ public final class ImagenModel {
5761
parameters: ImagenModel.imageGenerationParameters(
5862
storageURI: nil,
5963
generationConfig: generationConfig,
64+
modelConfig: modelConfig,
6065
safetySettings: safetySettings
6166
)
6267
)
@@ -70,6 +75,7 @@ public final class ImagenModel {
7075
parameters: ImagenModel.imageGenerationParameters(
7176
storageURI: storageURI,
7277
generationConfig: generationConfig,
78+
modelConfig: modelConfig,
7379
safetySettings: safetySettings
7480
)
7581
)
@@ -90,6 +96,7 @@ public final class ImagenModel {
9096

9197
static func imageGenerationParameters(storageURI: String?,
9298
generationConfig: ImagenGenerationConfig?,
99+
modelConfig: ImagenModelConfig?,
93100
safetySettings: ImagenSafetySettings?)
94101
-> ImageGenerationParameters {
95102
return ImageGenerationParameters(
@@ -99,13 +106,13 @@ public final class ImagenModel {
99106
aspectRatio: generationConfig?.aspectRatio?.rawValue,
100107
safetyFilterLevel: safetySettings?.safetyFilterLevel?.rawValue,
101108
personGeneration: safetySettings?.personFilterLevel?.rawValue,
102-
outputOptions: generationConfig?.imageFormat.map {
109+
outputOptions: modelConfig?.imageFormat.map {
103110
ImageGenerationOutputOptions(
104111
mimeType: $0.mimeType,
105112
compressionQuality: $0.compressionQuality
106113
)
107114
},
108-
addWatermark: generationConfig?.addWatermark,
115+
addWatermark: modelConfig?.addWatermark,
109116
includeResponsibleAIFilterReason: true
110117
)
111118
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
16+
public struct ImagenModelConfig {
17+
let imageFormat: ImagenImageFormat?
18+
let addWatermark: Bool?
19+
20+
public init(imageFormat: ImagenImageFormat? = nil, addWatermark: Bool? = nil) {
21+
self.imageFormat = imageFormat
22+
self.addWatermark = addWatermark
23+
}
24+
}

FirebaseVertexAI/Sources/VertexAI.swift

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,14 @@ public class VertexAI {
104104
)
105105
}
106106

107-
public func imagenModel(modelName: String, safetySettings: ImagenSafetySettings? = nil,
107+
public func imagenModel(modelName: String, modelConfig: ImagenModelConfig? = nil,
108+
safetySettings: ImagenSafetySettings? = nil,
108109
requestOptions: RequestOptions = RequestOptions()) -> ImagenModel {
109110
return ImagenModel(
110111
name: modelResourceName(modelName: modelName),
111112
projectID: projectID,
112113
apiKey: apiKey,
114+
modelConfig: modelConfig,
113115
safetySettings: safetySettings,
114116
requestOptions: requestOptions,
115117
appCheck: appCheck,

FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ final class IntegrationTests: XCTestCase {
6363
)
6464
imagenModel = vertex.imagenModel(
6565
modelName: "imagen-3.0-fast-generate-001",
66+
modelConfig: ImagenModelConfig(imageFormat: .jpeg(compressionQuality: 70)),
6667
safetySettings: ImagenSafetySettings(
6768
safetyFilterLevel: .blockLowAndAbove,
6869
personFilterLevel: .blockAll
@@ -253,9 +254,7 @@ final class IntegrationTests: XCTestCase {
253254
overlooking a vast African savanna at sunset. Golden hour light, long shadows, sharp focus on
254255
the lion, shallow depth of field, detailed fur texture, DSLR, 85mm lens.
255256
"""
256-
var generationConfig = ImagenGenerationConfig()
257-
generationConfig.aspectRatio = .landscape16x9
258-
generationConfig.imageFormat = .jpeg(compressionQuality: 70)
257+
let generationConfig = ImagenGenerationConfig(aspectRatio: .landscape16x9)
259258

260259
let response = try await imagenModel.generateImages(
261260
prompt: imagePrompt,

FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationParametersTests.swift

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ final class ImageGenerationParametersTests: XCTestCase {
4040
let parameters = ImagenModel.imageGenerationParameters(
4141
storageURI: nil,
4242
generationConfig: nil,
43+
modelConfig: nil,
4344
safetySettings: nil
4445
)
4546

@@ -63,6 +64,37 @@ final class ImageGenerationParametersTests: XCTestCase {
6364
let parameters = ImagenModel.imageGenerationParameters(
6465
storageURI: storageURI,
6566
generationConfig: nil,
67+
modelConfig: nil,
68+
safetySettings: nil
69+
)
70+
71+
XCTAssertEqual(parameters, expectedParameters)
72+
}
73+
74+
func testParameters_includeModelConfig() throws {
75+
let compressionQuality = 80
76+
let imageFormat = ImagenImageFormat.jpeg(compressionQuality: compressionQuality)
77+
let addWatermark = true
78+
let modelConfig = ImagenModelConfig(imageFormat: imageFormat, addWatermark: addWatermark)
79+
let expectedParameters = ImageGenerationParameters(
80+
sampleCount: 1,
81+
storageURI: nil,
82+
negativePrompt: nil,
83+
aspectRatio: nil,
84+
safetyFilterLevel: nil,
85+
personGeneration: nil,
86+
outputOptions: ImageGenerationOutputOptions(
87+
mimeType: imageFormat.mimeType,
88+
compressionQuality: imageFormat.compressionQuality
89+
),
90+
addWatermark: addWatermark,
91+
includeResponsibleAIFilterReason: true
92+
)
93+
94+
let parameters = ImagenModel.imageGenerationParameters(
95+
storageURI: nil,
96+
generationConfig: nil,
97+
modelConfig: modelConfig,
6698
safetySettings: nil
6799
)
68100

@@ -73,15 +105,10 @@ final class ImageGenerationParametersTests: XCTestCase {
73105
let sampleCount = 2
74106
let negativePrompt = "test-negative-prompt"
75107
let aspectRatio = ImagenAspectRatio.landscape16x9
76-
let compressionQuality = 80
77-
let imageFormat = ImagenImageFormat.jpeg(compressionQuality: compressionQuality)
78-
let addWatermark = true
79108
let generationConfig = ImagenGenerationConfig(
80109
numberOfImages: sampleCount,
81110
negativePrompt: negativePrompt,
82-
aspectRatio: aspectRatio,
83-
imageFormat: imageFormat,
84-
addWatermark: addWatermark
111+
aspectRatio: aspectRatio
85112
)
86113
let expectedParameters = ImageGenerationParameters(
87114
sampleCount: sampleCount,
@@ -90,24 +117,20 @@ final class ImageGenerationParametersTests: XCTestCase {
90117
aspectRatio: aspectRatio.rawValue,
91118
safetyFilterLevel: nil,
92119
personGeneration: nil,
93-
outputOptions: ImageGenerationOutputOptions(
94-
mimeType: imageFormat.mimeType,
95-
compressionQuality: imageFormat.compressionQuality
96-
),
97-
addWatermark: addWatermark,
120+
outputOptions: nil,
121+
addWatermark: nil,
98122
includeResponsibleAIFilterReason: true
99123
)
100124

101125
let parameters = ImagenModel.imageGenerationParameters(
102126
storageURI: nil,
103127
generationConfig: generationConfig,
128+
modelConfig: nil,
104129
safetySettings: nil
105130
)
106131

107132
XCTAssertEqual(parameters, expectedParameters)
108133
XCTAssertEqual(parameters.aspectRatio, "16:9")
109-
XCTAssertEqual(parameters.outputOptions?.mimeType, "image/jpeg")
110-
XCTAssertEqual(parameters.outputOptions?.compressionQuality, compressionQuality)
111134
}
112135

113136
func testDefaultParameters_includeSafetySettings() throws {
@@ -132,6 +155,7 @@ final class ImageGenerationParametersTests: XCTestCase {
132155
let parameters = ImagenModel.imageGenerationParameters(
133156
storageURI: nil,
134157
generationConfig: nil,
158+
modelConfig: nil,
135159
safetySettings: safetySettings
136160
)
137161

@@ -145,15 +169,14 @@ final class ImageGenerationParametersTests: XCTestCase {
145169
let sampleCount = 4
146170
let negativePrompt = "test-negative-prompt"
147171
let aspectRatio = ImagenAspectRatio.portrait3x4
148-
let imageFormat = ImagenImageFormat.png()
149-
let addWatermark = false
150172
let generationConfig = ImagenGenerationConfig(
151173
numberOfImages: sampleCount,
152174
negativePrompt: negativePrompt,
153-
aspectRatio: aspectRatio,
154-
imageFormat: imageFormat,
155-
addWatermark: addWatermark
175+
aspectRatio: aspectRatio
156176
)
177+
let imageFormat = ImagenImageFormat.png()
178+
let addWatermark = false
179+
let modelConfig = ImagenModelConfig(imageFormat: imageFormat, addWatermark: addWatermark)
157180
let safetyFilterLevel = ImagenSafetyFilterLevel.blockNone
158181
let personFilterLevel = ImagenPersonFilterLevel.blockAll
159182
let safetySettings = ImagenSafetySettings(
@@ -178,6 +201,7 @@ final class ImageGenerationParametersTests: XCTestCase {
178201
let parameters = ImagenModel.imageGenerationParameters(
179202
storageURI: storageURI,
180203
generationConfig: generationConfig,
204+
modelConfig: modelConfig,
181205
safetySettings: safetySettings
182206
)
183207

0 commit comments

Comments
 (0)