Skip to content

Commit

Permalink
[Vertex AI] Add ImageGenerationParameters for input to predict call
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Dec 3, 2024
1 parent 37e2390 commit 8fddb83
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct ImageGenerationOutputOptions {
let mimeType: String
let compressionQuality: Int?
}

// MARK: - Codable Conformance

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImageGenerationOutputOptions: Encodable {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct ImageGenerationParameters {
let sampleCount: Int?
let storageURI: String?
let seed: Int32?
let negativePrompt: String?
let aspectRatio: String?
let safetyFilterLevel: String?
let personGeneration: String?
let outputOptions: ImageGenerationOutputOptions?
let addWatermark: Bool?
let includeResponsibleAIFilterReason: Bool?
}

// MARK: - Codable Conformance

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImageGenerationParameters: Encodable {
enum CodingKeys: String, CodingKey {
case sampleCount
case storageURI = "storageUri"
case seed
case negativePrompt
case aspectRatio
case safetyFilterLevel = "safetySetting"
case personGeneration
case outputOptions
case addWatermark
case includeResponsibleAIFilterReason = "includeRaiReason"
}

func encode(to encoder: any Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encodeIfPresent(sampleCount, forKey: .sampleCount)
try container.encodeIfPresent(storageURI, forKey: .storageURI)
try container.encodeIfPresent(seed, forKey: .seed)
try container.encodeIfPresent(negativePrompt, forKey: .negativePrompt)
try container.encodeIfPresent(aspectRatio, forKey: .aspectRatio)
try container.encodeIfPresent(safetyFilterLevel, forKey: .safetyFilterLevel)
try container.encodeIfPresent(personGeneration, forKey: .personGeneration)
try container.encodeIfPresent(outputOptions, forKey: .outputOptions)
try container.encodeIfPresent(addWatermark, forKey: .addWatermark)
try container.encodeIfPresent(
includeResponsibleAIFilterReason,
forKey: .includeResponsibleAIFilterReason
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import XCTest

@testable import FirebaseVertexAI

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
final class ImageGenerationOutputOptionsTests: XCTestCase {
let encoder = JSONEncoder()

override func setUp() {
encoder.outputFormatting = [.sortedKeys, .prettyPrinted, .withoutEscapingSlashes]
}

// MARK: - Encoding Tests

func testEncodeOutputOptions_jpeg_defaultCompressionQuality() throws {
let mimeType = "image/jpeg"
let options = ImageGenerationOutputOptions(mimeType: mimeType, compressionQuality: nil)

let jsonData = try encoder.encode(options)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
"mimeType" : "\(mimeType)"
}
""")
}

func testEncodeOutputOptions_jpeg_customCompressionQuality() throws {
let mimeType = "image/jpeg"
let quality = 50
let options = ImageGenerationOutputOptions(mimeType: mimeType, compressionQuality: quality)

let jsonData = try encoder.encode(options)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
"compressionQuality" : \(quality),
"mimeType" : "\(mimeType)"
}
""")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import XCTest

@testable import FirebaseVertexAI

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
final class ImageGenerationParametersTests: XCTestCase {
let encoder = JSONEncoder()

override func setUp() {
encoder.outputFormatting = [.sortedKeys, .prettyPrinted, .withoutEscapingSlashes]
}

// MARK: - Encoding Tests

func testEncodeParameters_allSpecified() throws {
let sampleCount = 4
let storageURI = "gs://bucket/folder"
let seed: Int32 = 1_076_107_968
let negativePrompt = "test-negative-prompt"
let aspectRatio = "16:9"
let safetyFilterLevel = "block_low_and_above"
let personGeneration = "allow_adult"
let mimeType = "image/png"
let outputOptions = ImageGenerationOutputOptions(mimeType: mimeType, compressionQuality: nil)
let addWatermark = false
let includeRAIReason = true
let parameters = ImageGenerationParameters(
sampleCount: sampleCount,
storageURI: storageURI,
seed: seed,
negativePrompt: negativePrompt,
aspectRatio: aspectRatio,
safetyFilterLevel: safetyFilterLevel,
personGeneration: personGeneration,
outputOptions: outputOptions,
addWatermark: addWatermark,
includeResponsibleAIFilterReason: includeRAIReason
)

let jsonData = try encoder.encode(parameters)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
"addWatermark" : \(addWatermark),
"aspectRatio" : "\(aspectRatio)",
"includeRaiReason" : \(includeRAIReason),
"negativePrompt" : "\(negativePrompt)",
"outputOptions" : {
"mimeType" : "\(mimeType)"
},
"personGeneration" : "\(personGeneration)",
"safetySetting" : "\(safetyFilterLevel)",
"sampleCount" : \(sampleCount),
"seed" : \(seed),
"storageUri" : "\(storageURI)"
}
""")
}

func testEncodeParameters_someSpecified() throws {
let sampleCount = 2
let aspectRatio = "3:4"
let safetyFilterLevel = "block_medium_and_above"
let addWatermark = true
let parameters = ImageGenerationParameters(
sampleCount: sampleCount,
storageURI: nil,
seed: nil,
negativePrompt: nil,
aspectRatio: aspectRatio,
safetyFilterLevel: safetyFilterLevel,
personGeneration: nil,
outputOptions: nil,
addWatermark: addWatermark,
includeResponsibleAIFilterReason: nil
)

let jsonData = try encoder.encode(parameters)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
"addWatermark" : \(addWatermark),
"aspectRatio" : "\(aspectRatio)",
"safetySetting" : "\(safetyFilterLevel)",
"sampleCount" : \(sampleCount)
}
""")
}

func testEncodeParameters_noneSpecified() throws {
let parameters = ImageGenerationParameters(
sampleCount: nil,
storageURI: nil,
seed: nil,
negativePrompt: nil,
aspectRatio: nil,
safetyFilterLevel: nil,
personGeneration: nil,
outputOptions: nil,
addWatermark: nil,
includeResponsibleAIFilterReason: nil
)

let jsonData = try encoder.encode(parameters)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
}
""")
}
}

0 comments on commit 8fddb83

Please sign in to comment.