Skip to content

Commit

Permalink
[Vertex AI] Rename ImageGenerationResponse to ImagenGenerationResponse
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Jan 7, 2025
1 parent 527d4af commit 6680858
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct ImageGenerationRequest<ImageType: ImagenImageRepresentable> {
struct ImagenGenerationRequest<ImageType: ImagenImageRepresentable> {
let model: String
let options: RequestOptions
let instances: [ImageGenerationInstance]
Expand All @@ -31,16 +31,16 @@ struct ImageGenerationRequest<ImageType: ImagenImageRepresentable> {
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImageGenerationRequest: GenerativeAIRequest where ImageType: Decodable {
typealias Response = ImageGenerationResponse<ImageType>
extension ImagenGenerationRequest: GenerativeAIRequest where ImageType: Decodable {
typealias Response = ImagenGenerationResponse<ImageType>

var url: URL {
return URL(string: "\(Constants.baseURL)/\(options.apiVersion)/\(model):predict")!
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImageGenerationRequest: Encodable {
extension ImagenGenerationRequest: Encodable {
enum CodingKeys: CodingKey {
case instances
case parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ImageGenerationResponse<ImageType: ImagenImageRepresentable> {
public let images: [ImageType]
public let raiFilteredReason: String?
public struct ImagenGenerationResponse<T: ImagenImageRepresentable> {
public let images: [T]
public let filteredReason: String?
}

// MARK: - Codable Conformances

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImageGenerationResponse: Decodable where ImageType: Decodable {
extension ImagenGenerationResponse: Decodable where T: Decodable {
enum CodingKeys: CodingKey {
case predictions
}
Expand All @@ -32,19 +32,19 @@ extension ImageGenerationResponse: Decodable where ImageType: Decodable {
let container = try decoder.container(keyedBy: CodingKeys.self)
guard container.contains(.predictions) else {
images = []
raiFilteredReason = nil
filteredReason = nil
// TODO(#14221): Log warning if no predictions.
return
}
var predictionsContainer = try container.nestedUnkeyedContainer(forKey: .predictions)

var images = [ImageType]()
var raiFilteredReasons = [String]()
var images = [T]()
var filteredReasons = [String]()
while !predictionsContainer.isAtEnd {
if let image = try? predictionsContainer.decode(ImageType.self) {
if let image = try? predictionsContainer.decode(T.self) {
images.append(image)
} else if let filterReason = try? predictionsContainer.decode(RAIFilteredReason.self) {
raiFilteredReasons.append(filterReason.raiFilteredReason)
} else if let filteredReason = try? predictionsContainer.decode(RAIFilteredReason.self) {
filteredReasons.append(filteredReason.raiFilteredReason)
} else if let _ = try? predictionsContainer.decode(JSONObject.self) {
// TODO(#14221): Log or throw unsupported prediction type
} else {
Expand All @@ -57,7 +57,7 @@ extension ImageGenerationResponse: Decodable where ImageType: Decodable {
}

self.images = images
raiFilteredReason = raiFilteredReasons.first
filteredReason = filteredReasons.first
// TODO(#14221): Log if more than one RAI Filtered Reason; unexpected behaviour.
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public final class ImagenModel {

public func generateImages(prompt: String,
generationConfig: ImagenGenerationConfig? = nil) async throws
-> ImageGenerationResponse<ImagenInlineDataImage> {
-> ImagenGenerationResponse<ImagenInlineDataImage> {
return try await generateImages(
prompt: prompt,
parameters: ImagenModel.imageGenerationParameters(
Expand All @@ -69,7 +69,7 @@ public final class ImagenModel {

public func generateImages(prompt: String, storageURI: String,
generationConfig: ImagenGenerationConfig? = nil) async throws
-> ImageGenerationResponse<ImagenFileDataImage> {
-> ImagenGenerationResponse<ImagenFileDataImage> {
return try await generateImages(
prompt: prompt,
parameters: ImagenModel.imageGenerationParameters(
Expand All @@ -83,8 +83,8 @@ public final class ImagenModel {

func generateImages<T: Decodable>(prompt: String,
parameters: ImageGenerationParameters) async throws
-> ImageGenerationResponse<T> {
let request = ImageGenerationRequest<T>(
-> ImagenGenerationResponse<T> {
let request = ImagenGenerationRequest<T>(
model: modelResourceName,
options: requestOptions,
instances: [ImageGenerationInstance(prompt: prompt)],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import XCTest
@testable import FirebaseVertexAI

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
final class ImageGenerationRequestTests: XCTestCase {
final class ImagenGenerationRequestTests: XCTestCase {
let encoder = JSONEncoder()
let requestOptions = RequestOptions(timeout: 30.0)
let modelName = "test-model-name"
Expand All @@ -44,7 +44,7 @@ final class ImageGenerationRequestTests: XCTestCase {
}

func testInitializeRequest_inlineDataImage() throws {
let request = ImageGenerationRequest<ImagenInlineDataImage>(
let request = ImagenGenerationRequest<ImagenInlineDataImage>(
model: modelName,
options: requestOptions,
instances: [instance],
Expand All @@ -62,7 +62,7 @@ final class ImageGenerationRequestTests: XCTestCase {
}

func testInitializeRequest_fileDataImage() throws {
let request = ImageGenerationRequest<ImagenFileDataImage>(
let request = ImagenGenerationRequest<ImagenFileDataImage>(
model: modelName,
options: requestOptions,
instances: [instance],
Expand All @@ -82,7 +82,7 @@ final class ImageGenerationRequestTests: XCTestCase {
// MARK: - Encoding Tests

func testEncodeRequest_inlineDataImage() throws {
let request = ImageGenerationRequest<ImagenInlineDataImage>(
let request = ImagenGenerationRequest<ImagenInlineDataImage>(
model: modelName,
options: RequestOptions(),
instances: [instance],
Expand Down Expand Up @@ -110,7 +110,7 @@ final class ImageGenerationRequestTests: XCTestCase {
}

func testEncodeRequest_fileDataImage() throws {
let request = ImageGenerationRequest<ImagenFileDataImage>(
let request = ImagenGenerationRequest<ImagenFileDataImage>(
model: modelName,
options: RequestOptions(),
instances: [instance],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import XCTest
@testable import FirebaseVertexAI

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
final class ImageGenerationResponseTests: XCTestCase {
final class ImagenGenerationResponseTests: XCTestCase {
let decoder = JSONDecoder()

func testDecodeResponse_oneBase64Image_noneFiltered() throws {
Expand All @@ -37,12 +37,12 @@ final class ImageGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImageGenerationResponse<ImagenInlineDataImage>.self,
ImagenGenerationResponse<ImagenInlineDataImage>.self,
from: jsonData
)

XCTAssertEqual(response.images, [image])
XCTAssertNil(response.raiFilteredReason)
XCTAssertNil(response.filteredReason)
}

func testDecodeResponse_multipleBase64Images_noneFiltered() throws {
Expand Down Expand Up @@ -74,12 +74,12 @@ final class ImageGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImageGenerationResponse<ImagenInlineDataImage>.self,
ImagenGenerationResponse<ImagenInlineDataImage>.self,
from: jsonData
)

XCTAssertEqual(response.images, [image1, image2, image3])
XCTAssertNil(response.raiFilteredReason)
XCTAssertNil(response.filteredReason)
}

func testDecodeResponse_multipleBase64Images_someFiltered() throws {
Expand Down Expand Up @@ -112,12 +112,12 @@ final class ImageGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImageGenerationResponse<ImagenInlineDataImage>.self,
ImagenGenerationResponse<ImagenInlineDataImage>.self,
from: jsonData
)

XCTAssertEqual(response.images, [image1, image2])
XCTAssertEqual(response.raiFilteredReason, raiFilteredReason)
XCTAssertEqual(response.filteredReason, raiFilteredReason)
}

func testDecodeResponse_multipleGCSImages_noneFiltered() throws {
Expand All @@ -143,12 +143,12 @@ final class ImageGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImageGenerationResponse<ImagenFileDataImage>.self,
ImagenGenerationResponse<ImagenFileDataImage>.self,
from: jsonData
)

XCTAssertEqual(response.images, [image1, image2])
XCTAssertNil(response.raiFilteredReason)
XCTAssertNil(response.filteredReason)
}

func testDecodeResponse_noImages_allFiltered() throws {
Expand All @@ -169,25 +169,25 @@ final class ImageGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImageGenerationResponse<ImagenInlineDataImage>.self,
ImagenGenerationResponse<ImagenInlineDataImage>.self,
from: jsonData
)

XCTAssertEqual(response.images, [])
XCTAssertEqual(response.raiFilteredReason, raiFilteredReason)
XCTAssertEqual(response.filteredReason, raiFilteredReason)
}

func testDecodeResponse_noImagesAnd_noFilteredReason() throws {
let json = "{}"
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImageGenerationResponse<ImagenInlineDataImage>.self,
ImagenGenerationResponse<ImagenInlineDataImage>.self,
from: jsonData
)

XCTAssertEqual(response.images, [])
XCTAssertNil(response.raiFilteredReason)
XCTAssertNil(response.filteredReason)
}

func testDecodeResponse_multipleFilterReasons_returnsFirst() throws {
Expand All @@ -208,13 +208,13 @@ final class ImageGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImageGenerationResponse<ImagenFileDataImage>.self,
ImagenGenerationResponse<ImagenFileDataImage>.self,
from: jsonData
)

XCTAssertEqual(response.images, [])
XCTAssertEqual(response.raiFilteredReason, raiFilteredReason1)
XCTAssertNotEqual(response.raiFilteredReason, raiFilteredReason2)
XCTAssertEqual(response.filteredReason, raiFilteredReason1)
XCTAssertNotEqual(response.filteredReason, raiFilteredReason2)
}

func testDecodeResponse_unknownPrediction() throws {
Expand All @@ -230,11 +230,11 @@ final class ImageGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImageGenerationResponse<ImagenInlineDataImage>.self,
ImagenGenerationResponse<ImagenInlineDataImage>.self,
from: jsonData
)

XCTAssertEqual(response.images, [])
XCTAssertNil(response.raiFilteredReason)
XCTAssertNil(response.filteredReason)
}
}

0 comments on commit 6680858

Please sign in to comment.