Skip to content

Commit

Permalink
[Vertex AI] Add ImagenSafetySettings type and param (#14237)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Dec 10, 2024
1 parent c5472fc commit c986629
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
struct ImageGenerationParameters {
let sampleCount: Int?
let storageURI: String?
let seed: Int32?
let negativePrompt: String?
let aspectRatio: String?
let safetyFilterLevel: String?
Expand All @@ -36,7 +35,6 @@ extension ImageGenerationParameters: Encodable {
enum CodingKeys: String, CodingKey {
case sampleCount
case storageURI = "storageUri"
case seed
case negativePrompt
case aspectRatio
case safetyFilterLevel = "safetySetting"
Expand All @@ -50,7 +48,6 @@ extension ImageGenerationParameters: Encodable {
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encodeIfPresent(sampleCount, forKey: .sampleCount)
try container.encodeIfPresent(storageURI, forKey: .storageURI)
try container.encodeIfPresent(seed, forKey: .seed)
try container.encodeIfPresent(negativePrompt, forKey: .negativePrompt)
try container.encodeIfPresent(aspectRatio, forKey: .aspectRatio)
try container.encodeIfPresent(safetyFilterLevel, forKey: .safetyFilterLevel)
Expand Down
28 changes: 18 additions & 10 deletions FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ public final class ImagenModel {
/// The backing service responsible for sending and receiving model requests to the backend.
let generativeAIService: GenerativeAIService

let safetySettings: ImagenSafetySettings?

/// Configuration parameters for sending requests to the backend.
let requestOptions: RequestOptions

init(name: String,
projectID: String,
apiKey: String,
safetySettings: ImagenSafetySettings?,
requestOptions: RequestOptions,
appCheck: AppCheckInterop?,
auth: AuthInterop?,
Expand All @@ -42,6 +45,7 @@ public final class ImagenModel {
auth: auth,
urlSession: urlSession
)
self.safetySettings = safetySettings
self.requestOptions = requestOptions
}

Expand All @@ -50,7 +54,11 @@ public final class ImagenModel {
-> ImageGenerationResponse<ImagenInlineDataImage> {
return try await generateImages(
prompt: prompt,
parameters: imageGenerationParameters(storageURI: nil, generationConfig: generationConfig)
parameters: ImagenModel.imageGenerationParameters(
storageURI: nil,
generationConfig: generationConfig,
safetySettings: safetySettings
)
)
}

Expand All @@ -59,9 +67,10 @@ public final class ImagenModel {
-> ImageGenerationResponse<ImagenFileDataImage> {
return try await generateImages(
prompt: prompt,
parameters: imageGenerationParameters(
parameters: ImagenModel.imageGenerationParameters(
storageURI: storageURI,
generationConfig: generationConfig
generationConfig: generationConfig,
safetySettings: safetySettings
)
)
}
Expand All @@ -79,26 +88,25 @@ public final class ImagenModel {
return try await generativeAIService.loadRequest(request: request)
}

func imageGenerationParameters(storageURI: String?,
generationConfig: ImagenGenerationConfig? = nil)
static func imageGenerationParameters(storageURI: String?,
generationConfig: ImagenGenerationConfig?,
safetySettings: ImagenSafetySettings?)
-> ImageGenerationParameters {
// TODO(#14221): Add support for configuring remaining parameters.
return ImageGenerationParameters(
sampleCount: generationConfig?.numberOfImages ?? 1,
storageURI: storageURI,
seed: nil,
negativePrompt: generationConfig?.negativePrompt,
aspectRatio: generationConfig?.aspectRatio?.rawValue,
safetyFilterLevel: nil,
personGeneration: nil,
safetyFilterLevel: safetySettings?.safetyFilterLevel?.rawValue,
personGeneration: safetySettings?.personGeneration?.rawValue,
outputOptions: generationConfig?.imageFormat.map {
ImageGenerationOutputOptions(
mimeType: $0.mimeType,
compressionQuality: $0.compressionQuality
)
},
addWatermark: generationConfig?.addWatermark,
includeResponsibleAIFilterReason: true
includeResponsibleAIFilterReason: safetySettings?.includeFilterReason ?? true
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ImagenSafetySettings {
let safetyFilterLevel: SafetyFilterLevel?
let includeFilterReason: Bool?
let personGeneration: PersonGeneration?

public init(safetyFilterLevel: SafetyFilterLevel? = nil, includeFilterReason: Bool? = nil,
personGeneration: PersonGeneration? = nil) {
self.safetyFilterLevel = safetyFilterLevel
self.includeFilterReason = includeFilterReason
self.personGeneration = personGeneration
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public extension ImagenSafetySettings {
struct SafetyFilterLevel: ProtoEnum {
enum Kind: String {
case blockLowAndAbove = "block_low_and_above"
case blockMediumAndAbove = "block_medium_and_above"
case blockOnlyHigh = "block_only_high"
case blockNone = "block_none"
}

public static let blockLowAndAbove = SafetyFilterLevel(kind: .blockLowAndAbove)
public static let blockMediumAndAbove = SafetyFilterLevel(kind: .blockMediumAndAbove)
public static let blockOnlyHigh = SafetyFilterLevel(kind: .blockOnlyHigh)
public static let blockNone = SafetyFilterLevel(kind: .blockNone)

let rawValue: String
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public extension ImagenSafetySettings {
struct PersonGeneration: ProtoEnum {
enum Kind: String {
case blockAll = "dont_allow"
case allowAdult = "allow_adult"
case allowAll = "allow_all"
}

public static let blockAll = PersonGeneration(kind: .blockAll)
public static let allowAdult = PersonGeneration(kind: .allowAdult)
public static let allowAll = PersonGeneration(kind: .allowAll)

let rawValue: String
}
}
5 changes: 3 additions & 2 deletions FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,13 @@ public class VertexAI {
)
}

public func imagenModel(modelName: String, requestOptions: RequestOptions = RequestOptions())
-> ImagenModel {
public func imagenModel(modelName: String, safetySettings: ImagenSafetySettings? = nil,
requestOptions: RequestOptions = RequestOptions()) -> ImagenModel {
return ImagenModel(
name: modelResourceName(modelName: modelName),
projectID: projectID,
apiKey: apiKey,
safetySettings: safetySettings,
requestOptions: requestOptions,
appCheck: appCheck,
auth: auth
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ final class IntegrationTests: XCTestCase {
systemInstruction: systemInstruction
)
imagenModel = vertex.imagenModel(
modelName: "imagen-3.0-fast-generate-001"
modelName: "imagen-3.0-fast-generate-001",
safetySettings: ImagenSafetySettings(
safetyFilterLevel: .blockLowAndAbove,
personGeneration: .blockAll
)
)

storage = Storage.storage()
Expand Down
Loading

0 comments on commit c986629

Please sign in to comment.