diff --git a/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationRequest.swift b/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImagenGenerationRequest.swift similarity index 86% rename from FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationRequest.swift rename to FirebaseVertexAI/Sources/Types/Internal/Imagen/ImagenGenerationRequest.swift index da972b4403b..e938dc36d83 100644 --- a/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationRequest.swift +++ b/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImagenGenerationRequest.swift @@ -15,7 +15,7 @@ import Foundation @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -struct ImageGenerationRequest { +struct ImagenGenerationRequest { let model: String let options: RequestOptions let instances: [ImageGenerationInstance] @@ -31,8 +31,8 @@ struct ImageGenerationRequest { } @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension ImageGenerationRequest: GenerativeAIRequest where ImageType: Decodable { - typealias Response = ImageGenerationResponse +extension ImagenGenerationRequest: GenerativeAIRequest where ImageType: Decodable { + typealias Response = ImagenGenerationResponse var url: URL { return URL(string: "\(Constants.baseURL)/\(options.apiVersion)/\(model):predict")! @@ -40,7 +40,7 @@ extension ImageGenerationRequest: GenerativeAIRequest where ImageType: Decodable } @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension ImageGenerationRequest: Encodable { +extension ImagenGenerationRequest: Encodable { enum CodingKeys: CodingKey { case instances case parameters diff --git a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImageGenerationResponse.swift b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenGenerationResponse.swift similarity index 74% rename from FirebaseVertexAI/Sources/Types/Public/Imagen/ImageGenerationResponse.swift rename to FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenGenerationResponse.swift index 92bfbb2f551..db5a49a3daa 100644 --- a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImageGenerationResponse.swift +++ b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenGenerationResponse.swift @@ -15,15 +15,15 @@ import Foundation @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -public struct ImageGenerationResponse { - public let images: [ImageType] - public let raiFilteredReason: String? +public struct ImagenGenerationResponse { + public let images: [T] + public let filteredReason: String? } // MARK: - Codable Conformances @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension ImageGenerationResponse: Decodable where ImageType: Decodable { +extension ImagenGenerationResponse: Decodable where T: Decodable { enum CodingKeys: CodingKey { case predictions } @@ -32,19 +32,19 @@ extension ImageGenerationResponse: Decodable where ImageType: Decodable { let container = try decoder.container(keyedBy: CodingKeys.self) guard container.contains(.predictions) else { images = [] - raiFilteredReason = nil + filteredReason = nil // TODO(#14221): Log warning if no predictions. return } var predictionsContainer = try container.nestedUnkeyedContainer(forKey: .predictions) - var images = [ImageType]() - var raiFilteredReasons = [String]() + var images = [T]() + var filteredReasons = [String]() while !predictionsContainer.isAtEnd { - if let image = try? predictionsContainer.decode(ImageType.self) { + if let image = try? predictionsContainer.decode(T.self) { images.append(image) - } else if let filterReason = try? predictionsContainer.decode(RAIFilteredReason.self) { - raiFilteredReasons.append(filterReason.raiFilteredReason) + } else if let filteredReason = try? predictionsContainer.decode(RAIFilteredReason.self) { + filteredReasons.append(filteredReason.raiFilteredReason) } else if let _ = try? predictionsContainer.decode(JSONObject.self) { // TODO(#14221): Log or throw unsupported prediction type } else { @@ -57,7 +57,7 @@ extension ImageGenerationResponse: Decodable where ImageType: Decodable { } self.images = images - raiFilteredReason = raiFilteredReasons.first + filteredReason = filteredReasons.first // TODO(#14221): Log if more than one RAI Filtered Reason; unexpected behaviour. } } diff --git a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift index 5578afe110f..682a47237a6 100644 --- a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift +++ b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift @@ -55,7 +55,7 @@ public final class ImagenModel { public func generateImages(prompt: String, generationConfig: ImagenGenerationConfig? = nil) async throws - -> ImageGenerationResponse { + -> ImagenGenerationResponse { return try await generateImages( prompt: prompt, parameters: ImagenModel.imageGenerationParameters( @@ -69,7 +69,7 @@ public final class ImagenModel { public func generateImages(prompt: String, storageURI: String, generationConfig: ImagenGenerationConfig? = nil) async throws - -> ImageGenerationResponse { + -> ImagenGenerationResponse { return try await generateImages( prompt: prompt, parameters: ImagenModel.imageGenerationParameters( @@ -83,8 +83,8 @@ public final class ImagenModel { func generateImages(prompt: String, parameters: ImageGenerationParameters) async throws - -> ImageGenerationResponse { - let request = ImageGenerationRequest( + -> ImagenGenerationResponse { + let request = ImagenGenerationRequest( model: modelResourceName, options: requestOptions, instances: [ImageGenerationInstance(prompt: prompt)], diff --git a/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationRequestTests.swift b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImagenGenerationRequestTests.swift similarity index 92% rename from FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationRequestTests.swift rename to FirebaseVertexAI/Tests/Unit/Types/Imagen/ImagenGenerationRequestTests.swift index 8adbc577d83..bc51120cf6c 100644 --- a/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationRequestTests.swift +++ b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImagenGenerationRequestTests.swift @@ -17,7 +17,7 @@ import XCTest @testable import FirebaseVertexAI @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -final class ImageGenerationRequestTests: XCTestCase { +final class ImagenGenerationRequestTests: XCTestCase { let encoder = JSONEncoder() let requestOptions = RequestOptions(timeout: 30.0) let modelName = "test-model-name" @@ -44,7 +44,7 @@ final class ImageGenerationRequestTests: XCTestCase { } func testInitializeRequest_inlineDataImage() throws { - let request = ImageGenerationRequest( + let request = ImagenGenerationRequest( model: modelName, options: requestOptions, instances: [instance], @@ -62,7 +62,7 @@ final class ImageGenerationRequestTests: XCTestCase { } func testInitializeRequest_fileDataImage() throws { - let request = ImageGenerationRequest( + let request = ImagenGenerationRequest( model: modelName, options: requestOptions, instances: [instance], @@ -82,7 +82,7 @@ final class ImageGenerationRequestTests: XCTestCase { // MARK: - Encoding Tests func testEncodeRequest_inlineDataImage() throws { - let request = ImageGenerationRequest( + let request = ImagenGenerationRequest( model: modelName, options: RequestOptions(), instances: [instance], @@ -110,7 +110,7 @@ final class ImageGenerationRequestTests: XCTestCase { } func testEncodeRequest_fileDataImage() throws { - let request = ImageGenerationRequest( + let request = ImagenGenerationRequest( model: modelName, options: RequestOptions(), instances: [instance], diff --git a/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationResponseTests.swift b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImagenGenerationResponseTests.swift similarity index 86% rename from FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationResponseTests.swift rename to FirebaseVertexAI/Tests/Unit/Types/Imagen/ImagenGenerationResponseTests.swift index 2bd7dcb88aa..9f703ee9b0f 100644 --- a/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationResponseTests.swift +++ b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImagenGenerationResponseTests.swift @@ -17,7 +17,7 @@ import XCTest @testable import FirebaseVertexAI @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -final class ImageGenerationResponseTests: XCTestCase { +final class ImagenGenerationResponseTests: XCTestCase { let decoder = JSONDecoder() func testDecodeResponse_oneBase64Image_noneFiltered() throws { @@ -37,12 +37,12 @@ final class ImageGenerationResponseTests: XCTestCase { let jsonData = try XCTUnwrap(json.data(using: .utf8)) let response = try decoder.decode( - ImageGenerationResponse.self, + ImagenGenerationResponse.self, from: jsonData ) XCTAssertEqual(response.images, [image]) - XCTAssertNil(response.raiFilteredReason) + XCTAssertNil(response.filteredReason) } func testDecodeResponse_multipleBase64Images_noneFiltered() throws { @@ -74,12 +74,12 @@ final class ImageGenerationResponseTests: XCTestCase { let jsonData = try XCTUnwrap(json.data(using: .utf8)) let response = try decoder.decode( - ImageGenerationResponse.self, + ImagenGenerationResponse.self, from: jsonData ) XCTAssertEqual(response.images, [image1, image2, image3]) - XCTAssertNil(response.raiFilteredReason) + XCTAssertNil(response.filteredReason) } func testDecodeResponse_multipleBase64Images_someFiltered() throws { @@ -112,12 +112,12 @@ final class ImageGenerationResponseTests: XCTestCase { let jsonData = try XCTUnwrap(json.data(using: .utf8)) let response = try decoder.decode( - ImageGenerationResponse.self, + ImagenGenerationResponse.self, from: jsonData ) XCTAssertEqual(response.images, [image1, image2]) - XCTAssertEqual(response.raiFilteredReason, raiFilteredReason) + XCTAssertEqual(response.filteredReason, raiFilteredReason) } func testDecodeResponse_multipleGCSImages_noneFiltered() throws { @@ -143,12 +143,12 @@ final class ImageGenerationResponseTests: XCTestCase { let jsonData = try XCTUnwrap(json.data(using: .utf8)) let response = try decoder.decode( - ImageGenerationResponse.self, + ImagenGenerationResponse.self, from: jsonData ) XCTAssertEqual(response.images, [image1, image2]) - XCTAssertNil(response.raiFilteredReason) + XCTAssertNil(response.filteredReason) } func testDecodeResponse_noImages_allFiltered() throws { @@ -169,12 +169,12 @@ final class ImageGenerationResponseTests: XCTestCase { let jsonData = try XCTUnwrap(json.data(using: .utf8)) let response = try decoder.decode( - ImageGenerationResponse.self, + ImagenGenerationResponse.self, from: jsonData ) XCTAssertEqual(response.images, []) - XCTAssertEqual(response.raiFilteredReason, raiFilteredReason) + XCTAssertEqual(response.filteredReason, raiFilteredReason) } func testDecodeResponse_noImagesAnd_noFilteredReason() throws { @@ -182,12 +182,12 @@ final class ImageGenerationResponseTests: XCTestCase { let jsonData = try XCTUnwrap(json.data(using: .utf8)) let response = try decoder.decode( - ImageGenerationResponse.self, + ImagenGenerationResponse.self, from: jsonData ) XCTAssertEqual(response.images, []) - XCTAssertNil(response.raiFilteredReason) + XCTAssertNil(response.filteredReason) } func testDecodeResponse_multipleFilterReasons_returnsFirst() throws { @@ -208,13 +208,13 @@ final class ImageGenerationResponseTests: XCTestCase { let jsonData = try XCTUnwrap(json.data(using: .utf8)) let response = try decoder.decode( - ImageGenerationResponse.self, + ImagenGenerationResponse.self, from: jsonData ) XCTAssertEqual(response.images, []) - XCTAssertEqual(response.raiFilteredReason, raiFilteredReason1) - XCTAssertNotEqual(response.raiFilteredReason, raiFilteredReason2) + XCTAssertEqual(response.filteredReason, raiFilteredReason1) + XCTAssertNotEqual(response.filteredReason, raiFilteredReason2) } func testDecodeResponse_unknownPrediction() throws { @@ -230,11 +230,11 @@ final class ImageGenerationResponseTests: XCTestCase { let jsonData = try XCTUnwrap(json.data(using: .utf8)) let response = try decoder.decode( - ImageGenerationResponse.self, + ImagenGenerationResponse.self, from: jsonData ) XCTAssertEqual(response.images, []) - XCTAssertNil(response.raiFilteredReason) + XCTAssertNil(response.filteredReason) } }