From c986629328b577db3613e94879b7940b030475c7 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Mon, 9 Dec 2024 19:03:18 -0500 Subject: [PATCH] [Vertex AI] Add `ImagenSafetySettings` type and param (#14237) --- .../Imagen/ImageGenerationParameters.swift | 3 - .../Types/Public/Imagen/ImagenModel.swift | 28 ++- .../Public/Imagen/ImagenSafetySettings.swift | 65 +++++++ FirebaseVertexAI/Sources/VertexAI.swift | 5 +- .../Tests/Integration/IntegrationTests.swift | 6 +- .../ImageGenerationParametersTests.swift | 174 +++++++++++++++++- .../Imagen/ImageGenerationRequestTests.swift | 1 - 7 files changed, 260 insertions(+), 22 deletions(-) create mode 100644 FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenSafetySettings.swift diff --git a/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationParameters.swift b/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationParameters.swift index 9ab6641862c..4189e5fbac7 100644 --- a/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationParameters.swift +++ b/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationParameters.swift @@ -16,7 +16,6 @@ struct ImageGenerationParameters { let sampleCount: Int? let storageURI: String? - let seed: Int32? let negativePrompt: String? let aspectRatio: String? let safetyFilterLevel: String? @@ -36,7 +35,6 @@ extension ImageGenerationParameters: Encodable { enum CodingKeys: String, CodingKey { case sampleCount case storageURI = "storageUri" - case seed case negativePrompt case aspectRatio case safetyFilterLevel = "safetySetting" @@ -50,7 +48,6 @@ extension ImageGenerationParameters: Encodable { var container = encoder.container(keyedBy: CodingKeys.self) try container.encodeIfPresent(sampleCount, forKey: .sampleCount) try container.encodeIfPresent(storageURI, forKey: .storageURI) - try container.encodeIfPresent(seed, forKey: .seed) try container.encodeIfPresent(negativePrompt, forKey: .negativePrompt) try container.encodeIfPresent(aspectRatio, forKey: .aspectRatio) try container.encodeIfPresent(safetyFilterLevel, forKey: .safetyFilterLevel) diff --git a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift index 24df3e25725..9d5e7887969 100644 --- a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift +++ b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift @@ -24,12 +24,15 @@ public final class ImagenModel { /// The backing service responsible for sending and receiving model requests to the backend. let generativeAIService: GenerativeAIService + let safetySettings: ImagenSafetySettings? + /// Configuration parameters for sending requests to the backend. let requestOptions: RequestOptions init(name: String, projectID: String, apiKey: String, + safetySettings: ImagenSafetySettings?, requestOptions: RequestOptions, appCheck: AppCheckInterop?, auth: AuthInterop?, @@ -42,6 +45,7 @@ public final class ImagenModel { auth: auth, urlSession: urlSession ) + self.safetySettings = safetySettings self.requestOptions = requestOptions } @@ -50,7 +54,11 @@ public final class ImagenModel { -> ImageGenerationResponse { return try await generateImages( prompt: prompt, - parameters: imageGenerationParameters(storageURI: nil, generationConfig: generationConfig) + parameters: ImagenModel.imageGenerationParameters( + storageURI: nil, + generationConfig: generationConfig, + safetySettings: safetySettings + ) ) } @@ -59,9 +67,10 @@ public final class ImagenModel { -> ImageGenerationResponse { return try await generateImages( prompt: prompt, - parameters: imageGenerationParameters( + parameters: ImagenModel.imageGenerationParameters( storageURI: storageURI, - generationConfig: generationConfig + generationConfig: generationConfig, + safetySettings: safetySettings ) ) } @@ -79,18 +88,17 @@ public final class ImagenModel { return try await generativeAIService.loadRequest(request: request) } - func imageGenerationParameters(storageURI: String?, - generationConfig: ImagenGenerationConfig? = nil) + static func imageGenerationParameters(storageURI: String?, + generationConfig: ImagenGenerationConfig?, + safetySettings: ImagenSafetySettings?) -> ImageGenerationParameters { - // TODO(#14221): Add support for configuring remaining parameters. return ImageGenerationParameters( sampleCount: generationConfig?.numberOfImages ?? 1, storageURI: storageURI, - seed: nil, negativePrompt: generationConfig?.negativePrompt, aspectRatio: generationConfig?.aspectRatio?.rawValue, - safetyFilterLevel: nil, - personGeneration: nil, + safetyFilterLevel: safetySettings?.safetyFilterLevel?.rawValue, + personGeneration: safetySettings?.personGeneration?.rawValue, outputOptions: generationConfig?.imageFormat.map { ImageGenerationOutputOptions( mimeType: $0.mimeType, @@ -98,7 +106,7 @@ public final class ImagenModel { ) }, addWatermark: generationConfig?.addWatermark, - includeResponsibleAIFilterReason: true + includeResponsibleAIFilterReason: safetySettings?.includeFilterReason ?? true ) } } diff --git a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenSafetySettings.swift b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenSafetySettings.swift new file mode 100644 index 00000000000..81b72890a0e --- /dev/null +++ b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenSafetySettings.swift @@ -0,0 +1,65 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public struct ImagenSafetySettings { + let safetyFilterLevel: SafetyFilterLevel? + let includeFilterReason: Bool? + let personGeneration: PersonGeneration? + + public init(safetyFilterLevel: SafetyFilterLevel? = nil, includeFilterReason: Bool? = nil, + personGeneration: PersonGeneration? = nil) { + self.safetyFilterLevel = safetyFilterLevel + self.includeFilterReason = includeFilterReason + self.personGeneration = personGeneration + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public extension ImagenSafetySettings { + struct SafetyFilterLevel: ProtoEnum { + enum Kind: String { + case blockLowAndAbove = "block_low_and_above" + case blockMediumAndAbove = "block_medium_and_above" + case blockOnlyHigh = "block_only_high" + case blockNone = "block_none" + } + + public static let blockLowAndAbove = SafetyFilterLevel(kind: .blockLowAndAbove) + public static let blockMediumAndAbove = SafetyFilterLevel(kind: .blockMediumAndAbove) + public static let blockOnlyHigh = SafetyFilterLevel(kind: .blockOnlyHigh) + public static let blockNone = SafetyFilterLevel(kind: .blockNone) + + let rawValue: String + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public extension ImagenSafetySettings { + struct PersonGeneration: ProtoEnum { + enum Kind: String { + case blockAll = "dont_allow" + case allowAdult = "allow_adult" + case allowAll = "allow_all" + } + + public static let blockAll = PersonGeneration(kind: .blockAll) + public static let allowAdult = PersonGeneration(kind: .allowAdult) + public static let allowAll = PersonGeneration(kind: .allowAll) + + let rawValue: String + } +} diff --git a/FirebaseVertexAI/Sources/VertexAI.swift b/FirebaseVertexAI/Sources/VertexAI.swift index 96df5b4abf5..a2d3c62f529 100644 --- a/FirebaseVertexAI/Sources/VertexAI.swift +++ b/FirebaseVertexAI/Sources/VertexAI.swift @@ -104,12 +104,13 @@ public class VertexAI { ) } - public func imagenModel(modelName: String, requestOptions: RequestOptions = RequestOptions()) - -> ImagenModel { + public func imagenModel(modelName: String, safetySettings: ImagenSafetySettings? = nil, + requestOptions: RequestOptions = RequestOptions()) -> ImagenModel { return ImagenModel( name: modelResourceName(modelName: modelName), projectID: projectID, apiKey: apiKey, + safetySettings: safetySettings, requestOptions: requestOptions, appCheck: appCheck, auth: auth diff --git a/FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift b/FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift index 073e66db783..11ffd3de600 100644 --- a/FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift +++ b/FirebaseVertexAI/Tests/TestApp/Tests/Integration/IntegrationTests.swift @@ -62,7 +62,11 @@ final class IntegrationTests: XCTestCase { systemInstruction: systemInstruction ) imagenModel = vertex.imagenModel( - modelName: "imagen-3.0-fast-generate-001" + modelName: "imagen-3.0-fast-generate-001", + safetySettings: ImagenSafetySettings( + safetyFilterLevel: .blockLowAndAbove, + personGeneration: .blockAll + ) ) storage = Storage.storage() diff --git a/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationParametersTests.swift b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationParametersTests.swift index f5ce3a3b81e..f50f1646c8d 100644 --- a/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationParametersTests.swift +++ b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationParametersTests.swift @@ -24,12 +24,180 @@ final class ImageGenerationParametersTests: XCTestCase { encoder.outputFormatting = [.sortedKeys, .prettyPrinted, .withoutEscapingSlashes] } + func testDefaultParameters_noneSpecified() throws { + let expectedParameters = ImageGenerationParameters( + sampleCount: 1, + storageURI: nil, + negativePrompt: nil, + aspectRatio: nil, + safetyFilterLevel: nil, + personGeneration: nil, + outputOptions: nil, + addWatermark: nil, + includeResponsibleAIFilterReason: true + ) + + let parameters = ImagenModel.imageGenerationParameters( + storageURI: nil, + generationConfig: nil, + safetySettings: nil + ) + + XCTAssertEqual(parameters, expectedParameters) + } + + func testDefaultParameters_includeStorageURI() throws { + let storageURI = "gs://test-bucket/path" + let expectedParameters = ImageGenerationParameters( + sampleCount: 1, + storageURI: storageURI, + negativePrompt: nil, + aspectRatio: nil, + safetyFilterLevel: nil, + personGeneration: nil, + outputOptions: nil, + addWatermark: nil, + includeResponsibleAIFilterReason: true + ) + + let parameters = ImagenModel.imageGenerationParameters( + storageURI: storageURI, + generationConfig: nil, + safetySettings: nil + ) + + XCTAssertEqual(parameters, expectedParameters) + } + + func testParameters_includeGenerationConfig() throws { + let sampleCount = 2 + let negativePrompt = "test-negative-prompt" + let aspectRatio = ImagenAspectRatio.landscape16x9 + let compressionQuality = 80 + let imageFormat = ImagenImageFormat.jpeg(compressionQuality: compressionQuality) + let addWatermark = true + let generationConfig = ImagenGenerationConfig( + numberOfImages: sampleCount, + negativePrompt: negativePrompt, + aspectRatio: aspectRatio, + imageFormat: imageFormat, + addWatermark: addWatermark + ) + let expectedParameters = ImageGenerationParameters( + sampleCount: sampleCount, + storageURI: nil, + negativePrompt: negativePrompt, + aspectRatio: aspectRatio.rawValue, + safetyFilterLevel: nil, + personGeneration: nil, + outputOptions: ImageGenerationOutputOptions( + mimeType: imageFormat.mimeType, + compressionQuality: imageFormat.compressionQuality + ), + addWatermark: addWatermark, + includeResponsibleAIFilterReason: true + ) + + let parameters = ImagenModel.imageGenerationParameters( + storageURI: nil, + generationConfig: generationConfig, + safetySettings: nil + ) + + XCTAssertEqual(parameters, expectedParameters) + XCTAssertEqual(parameters.aspectRatio, "16:9") + XCTAssertEqual(parameters.outputOptions?.mimeType, "image/jpeg") + XCTAssertEqual(parameters.outputOptions?.compressionQuality, compressionQuality) + } + + func testDefaultParameters_includeSafetySettings() throws { + let safetyFilterLevel = ImagenSafetySettings.SafetyFilterLevel.blockOnlyHigh + let personGeneration = ImagenSafetySettings.PersonGeneration.allowAll + let includeFilterReason = false + let safetySettings = ImagenSafetySettings( + safetyFilterLevel: safetyFilterLevel, + includeFilterReason: includeFilterReason, + personGeneration: personGeneration + ) + let expectedParameters = ImageGenerationParameters( + sampleCount: 1, + storageURI: nil, + negativePrompt: nil, + aspectRatio: nil, + safetyFilterLevel: safetyFilterLevel.rawValue, + personGeneration: personGeneration.rawValue, + outputOptions: nil, + addWatermark: nil, + includeResponsibleAIFilterReason: includeFilterReason + ) + + let parameters = ImagenModel.imageGenerationParameters( + storageURI: nil, + generationConfig: nil, + safetySettings: safetySettings + ) + + XCTAssertEqual(parameters, expectedParameters) + XCTAssertEqual(parameters.safetyFilterLevel, "block_only_high") + XCTAssertEqual(parameters.personGeneration, "allow_all") + } + + func testParameters_includeAll() throws { + let storageURI = "gs://test-bucket/path" + let sampleCount = 4 + let negativePrompt = "test-negative-prompt" + let aspectRatio = ImagenAspectRatio.portrait3x4 + let imageFormat = ImagenImageFormat.png() + let addWatermark = false + let generationConfig = ImagenGenerationConfig( + numberOfImages: sampleCount, + negativePrompt: negativePrompt, + aspectRatio: aspectRatio, + imageFormat: imageFormat, + addWatermark: addWatermark + ) + let safetyFilterLevel = ImagenSafetySettings.SafetyFilterLevel.blockNone + let personGeneration = ImagenSafetySettings.PersonGeneration.blockAll + let includeFilterReason = false + let safetySettings = ImagenSafetySettings( + safetyFilterLevel: safetyFilterLevel, + includeFilterReason: includeFilterReason, + personGeneration: personGeneration + ) + let expectedParameters = ImageGenerationParameters( + sampleCount: sampleCount, + storageURI: storageURI, + negativePrompt: negativePrompt, + aspectRatio: aspectRatio.rawValue, + safetyFilterLevel: safetyFilterLevel.rawValue, + personGeneration: personGeneration.rawValue, + outputOptions: ImageGenerationOutputOptions( + mimeType: imageFormat.mimeType, + compressionQuality: imageFormat.compressionQuality + ), + addWatermark: addWatermark, + includeResponsibleAIFilterReason: includeFilterReason + ) + + let parameters = ImagenModel.imageGenerationParameters( + storageURI: storageURI, + generationConfig: generationConfig, + safetySettings: safetySettings + ) + + XCTAssertEqual(parameters, expectedParameters) + XCTAssertEqual(parameters.aspectRatio, "3:4") + XCTAssertEqual(parameters.safetyFilterLevel, "block_none") + XCTAssertEqual(parameters.personGeneration, "dont_allow") + XCTAssertEqual(parameters.outputOptions?.mimeType, "image/png") + XCTAssertNil(parameters.outputOptions?.compressionQuality) + } + // MARK: - Encoding Tests func testEncodeParameters_allSpecified() throws { let sampleCount = 4 let storageURI = "gs://bucket/folder" - let seed: Int32 = 1_076_107_968 let negativePrompt = "test-negative-prompt" let aspectRatio = "16:9" let safetyFilterLevel = "block_low_and_above" @@ -41,7 +209,6 @@ final class ImageGenerationParametersTests: XCTestCase { let parameters = ImageGenerationParameters( sampleCount: sampleCount, storageURI: storageURI, - seed: seed, negativePrompt: negativePrompt, aspectRatio: aspectRatio, safetyFilterLevel: safetyFilterLevel, @@ -66,7 +233,6 @@ final class ImageGenerationParametersTests: XCTestCase { "personGeneration" : "\(personGeneration)", "safetySetting" : "\(safetyFilterLevel)", "sampleCount" : \(sampleCount), - "seed" : \(seed), "storageUri" : "\(storageURI)" } """) @@ -80,7 +246,6 @@ final class ImageGenerationParametersTests: XCTestCase { let parameters = ImageGenerationParameters( sampleCount: sampleCount, storageURI: nil, - seed: nil, negativePrompt: nil, aspectRatio: aspectRatio, safetyFilterLevel: safetyFilterLevel, @@ -107,7 +272,6 @@ final class ImageGenerationParametersTests: XCTestCase { let parameters = ImageGenerationParameters( sampleCount: nil, storageURI: nil, - seed: nil, negativePrompt: nil, aspectRatio: nil, safetyFilterLevel: nil, diff --git a/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationRequestTests.swift b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationRequestTests.swift index 90ca9e7c25d..8adbc577d83 100644 --- a/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationRequestTests.swift +++ b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationRequestTests.swift @@ -28,7 +28,6 @@ final class ImageGenerationRequestTests: XCTestCase { lazy var parameters = ImageGenerationParameters( sampleCount: sampleCount, storageURI: nil, - seed: nil, negativePrompt: nil, aspectRatio: aspectRatio, safetyFilterLevel: safetyFilterLevel,