Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vertex AI] Rename ImageGenerationResponse to ImagenGenerationResponse #14317

Merged
merged 2 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct ImageGenerationRequest<ImageType: ImagenImageRepresentable> {
struct ImagenGenerationRequest<ImageType: ImagenImageRepresentable> {
let model: String
let options: RequestOptions
let instances: [ImageGenerationInstance]
Expand All @@ -31,16 +31,16 @@ struct ImageGenerationRequest<ImageType: ImagenImageRepresentable> {
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImageGenerationRequest: GenerativeAIRequest where ImageType: Decodable {
typealias Response = ImageGenerationResponse<ImageType>
extension ImagenGenerationRequest: GenerativeAIRequest where ImageType: Decodable {
typealias Response = ImagenGenerationResponse<ImageType>

var url: URL {
return URL(string: "\(Constants.baseURL)/\(options.apiVersion)/\(model):predict")!
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImageGenerationRequest: Encodable {
extension ImagenGenerationRequest: Encodable {
enum CodingKeys: CodingKey {
case instances
case parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ImageGenerationResponse<ImageType: ImagenImageRepresentable> {
public let images: [ImageType]
public let raiFilteredReason: String?
public struct ImagenGenerationResponse<T: ImagenImageRepresentable> {
public let images: [T]
public let filteredReason: String?
}

// MARK: - Codable Conformances

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImageGenerationResponse: Decodable where ImageType: Decodable {
extension ImagenGenerationResponse: Decodable where T: Decodable {
enum CodingKeys: CodingKey {
case predictions
}
Expand All @@ -32,19 +32,19 @@ extension ImageGenerationResponse: Decodable where ImageType: Decodable {
let container = try decoder.container(keyedBy: CodingKeys.self)
guard container.contains(.predictions) else {
images = []
raiFilteredReason = nil
filteredReason = nil
// TODO(#14221): Log warning if no predictions.
return
}
var predictionsContainer = try container.nestedUnkeyedContainer(forKey: .predictions)

var images = [ImageType]()
var raiFilteredReasons = [String]()
var images = [T]()
var filteredReasons = [String]()
while !predictionsContainer.isAtEnd {
if let image = try? predictionsContainer.decode(ImageType.self) {
if let image = try? predictionsContainer.decode(T.self) {
images.append(image)
} else if let filterReason = try? predictionsContainer.decode(RAIFilteredReason.self) {
raiFilteredReasons.append(filterReason.raiFilteredReason)
} else if let filteredReason = try? predictionsContainer.decode(RAIFilteredReason.self) {
filteredReasons.append(filteredReason.raiFilteredReason)
} else if let _ = try? predictionsContainer.decode(JSONObject.self) {
// TODO(#14221): Log or throw unsupported prediction type
} else {
Expand All @@ -57,7 +57,7 @@ extension ImageGenerationResponse: Decodable where ImageType: Decodable {
}

self.images = images
raiFilteredReason = raiFilteredReasons.first
filteredReason = filteredReasons.first
// TODO(#14221): Log if more than one RAI Filtered Reason; unexpected behaviour.
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public final class ImagenModel {

public func generateImages(prompt: String,
generationConfig: ImagenGenerationConfig? = nil) async throws
-> ImageGenerationResponse<ImagenInlineDataImage> {
-> ImagenGenerationResponse<ImagenInlineDataImage> {
return try await generateImages(
prompt: prompt,
parameters: ImagenModel.imageGenerationParameters(
Expand All @@ -69,7 +69,7 @@ public final class ImagenModel {

public func generateImages(prompt: String, storageURI: String,
generationConfig: ImagenGenerationConfig? = nil) async throws
-> ImageGenerationResponse<ImagenFileDataImage> {
-> ImagenGenerationResponse<ImagenFileDataImage> {
return try await generateImages(
prompt: prompt,
parameters: ImagenModel.imageGenerationParameters(
Expand All @@ -83,8 +83,8 @@ public final class ImagenModel {

func generateImages<T: Decodable>(prompt: String,
parameters: ImageGenerationParameters) async throws
-> ImageGenerationResponse<T> {
let request = ImageGenerationRequest<T>(
-> ImagenGenerationResponse<T> {
let request = ImagenGenerationRequest<T>(
model: modelResourceName,
options: requestOptions,
instances: [ImageGenerationInstance(prompt: prompt)],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ final class IntegrationTests: XCTestCase {
generationConfig: generationConfig
)

XCTAssertNil(response.raiFilteredReason)
XCTAssertNil(response.filteredReason)
XCTAssertEqual(response.images.count, 1)
let image = try XCTUnwrap(response.images.first)
XCTAssertEqual(image.mimeType, "image/jpeg")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import XCTest
@testable import FirebaseVertexAI

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
final class ImageGenerationRequestTests: XCTestCase {
final class ImagenGenerationRequestTests: XCTestCase {
let encoder = JSONEncoder()
let requestOptions = RequestOptions(timeout: 30.0)
let modelName = "test-model-name"
Expand All @@ -44,7 +44,7 @@ final class ImageGenerationRequestTests: XCTestCase {
}

func testInitializeRequest_inlineDataImage() throws {
let request = ImageGenerationRequest<ImagenInlineDataImage>(
let request = ImagenGenerationRequest<ImagenInlineDataImage>(
model: modelName,
options: requestOptions,
instances: [instance],
Expand All @@ -62,7 +62,7 @@ final class ImageGenerationRequestTests: XCTestCase {
}

func testInitializeRequest_fileDataImage() throws {
let request = ImageGenerationRequest<ImagenFileDataImage>(
let request = ImagenGenerationRequest<ImagenFileDataImage>(
model: modelName,
options: requestOptions,
instances: [instance],
Expand All @@ -82,7 +82,7 @@ final class ImageGenerationRequestTests: XCTestCase {
// MARK: - Encoding Tests

func testEncodeRequest_inlineDataImage() throws {
let request = ImageGenerationRequest<ImagenInlineDataImage>(
let request = ImagenGenerationRequest<ImagenInlineDataImage>(
model: modelName,
options: RequestOptions(),
instances: [instance],
Expand Down Expand Up @@ -110,7 +110,7 @@ final class ImageGenerationRequestTests: XCTestCase {
}

func testEncodeRequest_fileDataImage() throws {
let request = ImageGenerationRequest<ImagenFileDataImage>(
let request = ImagenGenerationRequest<ImagenFileDataImage>(
model: modelName,
options: RequestOptions(),
instances: [instance],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import XCTest
@testable import FirebaseVertexAI

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
final class ImageGenerationResponseTests: XCTestCase {
final class ImagenGenerationResponseTests: XCTestCase {
let decoder = JSONDecoder()

func testDecodeResponse_oneBase64Image_noneFiltered() throws {
Expand All @@ -37,12 +37,12 @@ final class ImageGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImageGenerationResponse<ImagenInlineDataImage>.self,
ImagenGenerationResponse<ImagenInlineDataImage>.self,
from: jsonData
)

XCTAssertEqual(response.images, [image])
XCTAssertNil(response.raiFilteredReason)
XCTAssertNil(response.filteredReason)
}

func testDecodeResponse_multipleBase64Images_noneFiltered() throws {
Expand Down Expand Up @@ -74,12 +74,12 @@ final class ImageGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImageGenerationResponse<ImagenInlineDataImage>.self,
ImagenGenerationResponse<ImagenInlineDataImage>.self,
from: jsonData
)

XCTAssertEqual(response.images, [image1, image2, image3])
XCTAssertNil(response.raiFilteredReason)
XCTAssertNil(response.filteredReason)
}

func testDecodeResponse_multipleBase64Images_someFiltered() throws {
Expand Down Expand Up @@ -112,12 +112,12 @@ final class ImageGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImageGenerationResponse<ImagenInlineDataImage>.self,
ImagenGenerationResponse<ImagenInlineDataImage>.self,
from: jsonData
)

XCTAssertEqual(response.images, [image1, image2])
XCTAssertEqual(response.raiFilteredReason, raiFilteredReason)
XCTAssertEqual(response.filteredReason, raiFilteredReason)
}

func testDecodeResponse_multipleGCSImages_noneFiltered() throws {
Expand All @@ -143,12 +143,12 @@ final class ImageGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImageGenerationResponse<ImagenFileDataImage>.self,
ImagenGenerationResponse<ImagenFileDataImage>.self,
from: jsonData
)

XCTAssertEqual(response.images, [image1, image2])
XCTAssertNil(response.raiFilteredReason)
XCTAssertNil(response.filteredReason)
}

func testDecodeResponse_noImages_allFiltered() throws {
Expand All @@ -169,25 +169,25 @@ final class ImageGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImageGenerationResponse<ImagenInlineDataImage>.self,
ImagenGenerationResponse<ImagenInlineDataImage>.self,
from: jsonData
)

XCTAssertEqual(response.images, [])
XCTAssertEqual(response.raiFilteredReason, raiFilteredReason)
XCTAssertEqual(response.filteredReason, raiFilteredReason)
}

func testDecodeResponse_noImagesAnd_noFilteredReason() throws {
let json = "{}"
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImageGenerationResponse<ImagenInlineDataImage>.self,
ImagenGenerationResponse<ImagenInlineDataImage>.self,
from: jsonData
)

XCTAssertEqual(response.images, [])
XCTAssertNil(response.raiFilteredReason)
XCTAssertNil(response.filteredReason)
}

func testDecodeResponse_multipleFilterReasons_returnsFirst() throws {
Expand All @@ -208,13 +208,13 @@ final class ImageGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImageGenerationResponse<ImagenFileDataImage>.self,
ImagenGenerationResponse<ImagenFileDataImage>.self,
from: jsonData
)

XCTAssertEqual(response.images, [])
XCTAssertEqual(response.raiFilteredReason, raiFilteredReason1)
XCTAssertNotEqual(response.raiFilteredReason, raiFilteredReason2)
XCTAssertEqual(response.filteredReason, raiFilteredReason1)
XCTAssertNotEqual(response.filteredReason, raiFilteredReason2)
}

func testDecodeResponse_unknownPrediction() throws {
Expand All @@ -230,11 +230,11 @@ final class ImageGenerationResponseTests: XCTestCase {
let jsonData = try XCTUnwrap(json.data(using: .utf8))

let response = try decoder.decode(
ImageGenerationResponse<ImagenInlineDataImage>.self,
ImagenGenerationResponse<ImagenInlineDataImage>.self,
from: jsonData
)

XCTAssertEqual(response.images, [])
XCTAssertNil(response.raiFilteredReason)
XCTAssertNil(response.filteredReason)
}
}
Loading