From 9d0a8d870145e3d760ee592eda9d92cbfe776aab Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Fri, 6 Dec 2024 16:33:20 -0500 Subject: [PATCH] [Vertex AI] Make `ImageGenerationResponse` generic and add image types (#14224) --- .../Imagen/DecodableImagenImage.swift | 62 ----------- .../Imagen/ImageGenerationResponse.swift | 16 +-- .../Internal/Imagen/InternalImagenImage.swift | 5 +- .../Public/Imagen/ImagenFileDataImage.swift | 53 +++++++++ .../Public/Imagen/ImagenInlineDataImage.swift | 61 +++++++++++ .../Imagen/ImageGenerationResponseTests.swift | 96 ++++++++--------- .../Imagen/ImagenFileDataImageTests.swift | 80 ++++++++++++++ .../Imagen/ImagenInlineDataImageTests.swift | 80 ++++++++++++++ .../Imagen/InternalImagenImageTests.swift | 101 ------------------ 9 files changed, 329 insertions(+), 225 deletions(-) delete mode 100644 FirebaseVertexAI/Sources/Types/Internal/Imagen/DecodableImagenImage.swift create mode 100644 FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenFileDataImage.swift create mode 100644 FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenInlineDataImage.swift create mode 100644 FirebaseVertexAI/Tests/Unit/Types/Imagen/ImagenFileDataImageTests.swift create mode 100644 FirebaseVertexAI/Tests/Unit/Types/Imagen/ImagenInlineDataImageTests.swift delete mode 100644 FirebaseVertexAI/Tests/Unit/Types/Imagen/InternalImagenImageTests.swift diff --git a/FirebaseVertexAI/Sources/Types/Internal/Imagen/DecodableImagenImage.swift b/FirebaseVertexAI/Sources/Types/Internal/Imagen/DecodableImagenImage.swift deleted file mode 100644 index 2b4cc888830..00000000000 --- a/FirebaseVertexAI/Sources/Types/Internal/Imagen/DecodableImagenImage.swift +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -protocol DecodableImagenImage: ImagenImage, Decodable { - init(mimeType: String, bytesBase64Encoded: String?, gcsURI: String?) -} - -@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -enum ImagenImageCodingKeys: String, CodingKey { - case mimeType - case bytesBase64Encoded - case gcsURI = "gcsUri" -} - -@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension DecodableImagenImage { - init(from decoder: any Decoder) throws { - let container = try decoder.container(keyedBy: ImagenImageCodingKeys.self) - let mimeType = try container.decode(String.self, forKey: .mimeType) - let bytesBase64Encoded = try container.decodeIfPresent( - String.self, - forKey: .bytesBase64Encoded - ) - let gcsURI = try container.decodeIfPresent(String.self, forKey: .gcsURI) - guard bytesBase64Encoded != nil || gcsURI != nil else { - throw DecodingError.dataCorrupted( - DecodingError.Context( - codingPath: [ImagenImageCodingKeys.bytesBase64Encoded, ImagenImageCodingKeys.gcsURI], - debugDescription: """ - Expected one of \(ImagenImageCodingKeys.bytesBase64Encoded.rawValue) or \ - \(ImagenImageCodingKeys.gcsURI.rawValue); both are nil. - """ - ) - ) - } - guard bytesBase64Encoded == nil || gcsURI == nil else { - throw DecodingError.dataCorrupted( - DecodingError.Context( - codingPath: [ImagenImageCodingKeys.bytesBase64Encoded, ImagenImageCodingKeys.gcsURI], - debugDescription: """ - Expected one of \(ImagenImageCodingKeys.bytesBase64Encoded.rawValue) or \ - \(ImagenImageCodingKeys.gcsURI.rawValue); both are specified. - """ - ) - ) - } - - self.init(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded, gcsURI: gcsURI) - } -} diff --git a/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationResponse.swift b/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationResponse.swift index bb1b2809bd1..6cf0cce9111 100644 --- a/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationResponse.swift +++ b/FirebaseVertexAI/Sources/Types/Internal/Imagen/ImageGenerationResponse.swift @@ -15,15 +15,15 @@ import Foundation @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -struct ImageGenerationResponse { - let images: [InternalImagenImage] - let raiFilteredReason: String? +public struct ImageGenerationResponse { + public let images: [ImageType] + public let raiFilteredReason: String? } // MARK: - Codable Conformances @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension ImageGenerationResponse: Decodable { +extension ImageGenerationResponse: Decodable where ImageType: Decodable { enum CodingKeys: CodingKey { case predictions } @@ -38,15 +38,15 @@ extension ImageGenerationResponse: Decodable { } var predictionsContainer = try container.nestedUnkeyedContainer(forKey: .predictions) - var images = [InternalImagenImage]() + var images = [ImageType]() var raiFilteredReasons = [String]() while !predictionsContainer.isAtEnd { - if let image = try? predictionsContainer.decode(InternalImagenImage.self) { + if let image = try? predictionsContainer.decode(ImageType.self) { images.append(image) } else if let filterReason = try? predictionsContainer.decode(RAIFilteredReason.self) { raiFilteredReasons.append(filterReason.raiFilteredReason) } else if let _ = try? predictionsContainer.decode(JSONObject.self) { - // TODO: Log or throw unsupported prediction type + // TODO(#14221): Log or throw unsupported prediction type } else { // This should never be thrown since JSONObject accepts any valid JSON. throw DecodingError.dataCorruptedError( @@ -58,6 +58,6 @@ extension ImageGenerationResponse: Decodable { self.images = images raiFilteredReason = raiFilteredReasons.first - // TODO: Log if more than one RAI Filtered Reason; unexpected behaviour. + // TODO(#14221): Log if more than one RAI Filtered Reason; unexpected behaviour. } } diff --git a/FirebaseVertexAI/Sources/Types/Internal/Imagen/InternalImagenImage.swift b/FirebaseVertexAI/Sources/Types/Internal/Imagen/InternalImagenImage.swift index a9f175b9241..16296245f95 100644 --- a/FirebaseVertexAI/Sources/Types/Internal/Imagen/InternalImagenImage.swift +++ b/FirebaseVertexAI/Sources/Types/Internal/Imagen/InternalImagenImage.swift @@ -26,7 +26,4 @@ struct InternalImagenImage { } @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension InternalImagenImage: DecodableImagenImage {} - -@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension InternalImagenImage: Equatable {} +extension InternalImagenImage: ImagenImage {} diff --git a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenFileDataImage.swift b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenFileDataImage.swift new file mode 100644 index 00000000000..48b94504293 --- /dev/null +++ b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenFileDataImage.swift @@ -0,0 +1,53 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public struct ImagenFileDataImage { + public let mimeType: String + public let gcsURI: String + + init(mimeType: String, gcsURI: String) { + self.mimeType = mimeType + self.gcsURI = gcsURI + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension ImagenFileDataImage: ImagenImageRepresentable { + public var imagenImage: any ImagenImage { + InternalImagenImage(mimeType: mimeType, bytesBase64Encoded: nil, gcsURI: gcsURI) + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension ImagenFileDataImage: Equatable {} + +// MARK: - Codable Conformances + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension ImagenFileDataImage: Decodable { + enum CodingKeys: String, CodingKey { + case mimeType + case gcsURI = "gcsUri" + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let mimeType = try container.decode(String.self, forKey: .mimeType) + let gcsURI = try container.decode(String.self, forKey: .gcsURI) + self.init(mimeType: mimeType, gcsURI: gcsURI) + } +} diff --git a/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenInlineDataImage.swift b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenInlineDataImage.swift new file mode 100644 index 00000000000..f3c1343721b --- /dev/null +++ b/FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenInlineDataImage.swift @@ -0,0 +1,61 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public struct ImagenInlineDataImage { + public let mimeType: String + public let data: Data + + init(mimeType: String, bytesBase64Encoded: String) { + self.mimeType = mimeType + guard let data = Data(base64Encoded: bytesBase64Encoded) else { + // TODO(#14221): Add error handling for invalid base64 bytes. + fatalError("Creating a `Data` from `bytesBase64Encoded` failed.") + } + self.data = data + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension ImagenInlineDataImage: ImagenImageRepresentable { + public var imagenImage: any ImagenImage { + InternalImagenImage( + mimeType: mimeType, + bytesBase64Encoded: data.base64EncodedString(), + gcsURI: nil + ) + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension ImagenInlineDataImage: Equatable {} + +// MARK: - Codable Conformances + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension ImagenInlineDataImage: Decodable { + enum CodingKeys: CodingKey { + case mimeType + case bytesBase64Encoded + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let mimeType = try container.decode(String.self, forKey: .mimeType) + let bytesBase64Encoded = try container.decode(String.self, forKey: .bytesBase64Encoded) + self.init(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded) + } +} diff --git a/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationResponseTests.swift b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationResponseTests.swift index ab66bffa6e8..2bd7dcb88aa 100644 --- a/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationResponseTests.swift +++ b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImageGenerationResponseTests.swift @@ -22,12 +22,8 @@ final class ImageGenerationResponseTests: XCTestCase { func testDecodeResponse_oneBase64Image_noneFiltered() throws { let mimeType = "image/png" - let bytesBase64Encoded = "test-base64-bytes" - let image = InternalImagenImage( - mimeType: mimeType, - bytesBase64Encoded: bytesBase64Encoded, - gcsURI: nil - ) + let bytesBase64Encoded = "dGVzdC1iYXNlNjQtZGF0YQ==" + let image = ImagenInlineDataImage(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded) let json = """ { "predictions": [ @@ -40,7 +36,10 @@ final class ImageGenerationResponseTests: XCTestCase { """ let jsonData = try XCTUnwrap(json.data(using: .utf8)) - let response = try decoder.decode(ImageGenerationResponse.self, from: jsonData) + let response = try decoder.decode( + ImageGenerationResponse.self, + from: jsonData + ) XCTAssertEqual(response.images, [image]) XCTAssertNil(response.raiFilteredReason) @@ -48,24 +47,12 @@ final class ImageGenerationResponseTests: XCTestCase { func testDecodeResponse_multipleBase64Images_noneFiltered() throws { let mimeType = "image/png" - let bytesBase64Encoded1 = "test-base64-bytes-1" - let bytesBase64Encoded2 = "test-base64-bytes-2" - let bytesBase64Encoded3 = "test-base64-bytes-3" - let image1 = InternalImagenImage( - mimeType: mimeType, - bytesBase64Encoded: bytesBase64Encoded1, - gcsURI: nil - ) - let image2 = InternalImagenImage( - mimeType: mimeType, - bytesBase64Encoded: bytesBase64Encoded2, - gcsURI: nil - ) - let image3 = InternalImagenImage( - mimeType: mimeType, - bytesBase64Encoded: bytesBase64Encoded3, - gcsURI: nil - ) + let bytesBase64Encoded1 = "dGVzdC1iYXNlNjQtYnl0ZXMtMQ==" + let bytesBase64Encoded2 = "dGVzdC1iYXNlNjQtYnl0ZXMtMg==" + let bytesBase64Encoded3 = "dGVzdC1iYXNlNjQtYnl0ZXMtMw==" + let image1 = ImagenInlineDataImage(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded1) + let image2 = ImagenInlineDataImage(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded2) + let image3 = ImagenInlineDataImage(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded3) let json = """ { "predictions": [ @@ -86,7 +73,10 @@ final class ImageGenerationResponseTests: XCTestCase { """ let jsonData = try XCTUnwrap(json.data(using: .utf8)) - let response = try decoder.decode(ImageGenerationResponse.self, from: jsonData) + let response = try decoder.decode( + ImageGenerationResponse.self, + from: jsonData + ) XCTAssertEqual(response.images, [image1, image2, image3]) XCTAssertNil(response.raiFilteredReason) @@ -94,18 +84,10 @@ final class ImageGenerationResponseTests: XCTestCase { func testDecodeResponse_multipleBase64Images_someFiltered() throws { let mimeType = "image/png" - let bytesBase64Encoded1 = "test-base64-bytes-1" - let bytesBase64Encoded2 = "test-base64-bytes-2" - let image1 = InternalImagenImage( - mimeType: mimeType, - bytesBase64Encoded: bytesBase64Encoded1, - gcsURI: nil - ) - let image2 = InternalImagenImage( - mimeType: mimeType, - bytesBase64Encoded: bytesBase64Encoded2, - gcsURI: nil - ) + let bytesBase64Encoded1 = "dGVzdC1iYXNlNjQtYnl0ZXMtMQ==" + let bytesBase64Encoded2 = "dGVzdC1iYXNlNjQtYnl0ZXMtMg==" + let image1 = ImagenInlineDataImage(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded1) + let image2 = ImagenInlineDataImage(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded2) let raiFilteredReason = """ Your current safety filter threshold filtered out 2 generated images. You will not be charged \ for blocked images. Try rephrasing the prompt. If you think this was an error, send feedback. @@ -129,7 +111,10 @@ final class ImageGenerationResponseTests: XCTestCase { """ let jsonData = try XCTUnwrap(json.data(using: .utf8)) - let response = try decoder.decode(ImageGenerationResponse.self, from: jsonData) + let response = try decoder.decode( + ImageGenerationResponse.self, + from: jsonData + ) XCTAssertEqual(response.images, [image1, image2]) XCTAssertEqual(response.raiFilteredReason, raiFilteredReason) @@ -139,12 +124,8 @@ final class ImageGenerationResponseTests: XCTestCase { let mimeType = "image/png" let gcsURI1 = "gs://test-bucket/images/123456789/sample_0.png" let gcsURI2 = "gs://test-bucket/images/123456789/sample_1.png" - let image1 = InternalImagenImage( - mimeType: mimeType, - bytesBase64Encoded: nil, - gcsURI: gcsURI1 - ) - let image2 = InternalImagenImage(mimeType: mimeType, bytesBase64Encoded: nil, gcsURI: gcsURI2) + let image1 = ImagenFileDataImage(mimeType: mimeType, gcsURI: gcsURI1) + let image2 = ImagenFileDataImage(mimeType: mimeType, gcsURI: gcsURI2) let json = """ { "predictions": [ @@ -161,7 +142,10 @@ final class ImageGenerationResponseTests: XCTestCase { """ let jsonData = try XCTUnwrap(json.data(using: .utf8)) - let response = try decoder.decode(ImageGenerationResponse.self, from: jsonData) + let response = try decoder.decode( + ImageGenerationResponse.self, + from: jsonData + ) XCTAssertEqual(response.images, [image1, image2]) XCTAssertNil(response.raiFilteredReason) @@ -184,7 +168,10 @@ final class ImageGenerationResponseTests: XCTestCase { """ let jsonData = try XCTUnwrap(json.data(using: .utf8)) - let response = try decoder.decode(ImageGenerationResponse.self, from: jsonData) + let response = try decoder.decode( + ImageGenerationResponse.self, + from: jsonData + ) XCTAssertEqual(response.images, []) XCTAssertEqual(response.raiFilteredReason, raiFilteredReason) @@ -194,7 +181,10 @@ final class ImageGenerationResponseTests: XCTestCase { let json = "{}" let jsonData = try XCTUnwrap(json.data(using: .utf8)) - let response = try decoder.decode(ImageGenerationResponse.self, from: jsonData) + let response = try decoder.decode( + ImageGenerationResponse.self, + from: jsonData + ) XCTAssertEqual(response.images, []) XCTAssertNil(response.raiFilteredReason) @@ -217,7 +207,10 @@ final class ImageGenerationResponseTests: XCTestCase { """ let jsonData = try XCTUnwrap(json.data(using: .utf8)) - let response = try decoder.decode(ImageGenerationResponse.self, from: jsonData) + let response = try decoder.decode( + ImageGenerationResponse.self, + from: jsonData + ) XCTAssertEqual(response.images, []) XCTAssertEqual(response.raiFilteredReason, raiFilteredReason1) @@ -236,7 +229,10 @@ final class ImageGenerationResponseTests: XCTestCase { """ let jsonData = try XCTUnwrap(json.data(using: .utf8)) - let response = try decoder.decode(ImageGenerationResponse.self, from: jsonData) + let response = try decoder.decode( + ImageGenerationResponse.self, + from: jsonData + ) XCTAssertEqual(response.images, []) XCTAssertNil(response.raiFilteredReason) diff --git a/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImagenFileDataImageTests.swift b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImagenFileDataImageTests.swift new file mode 100644 index 00000000000..21327d0dd7f --- /dev/null +++ b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImagenFileDataImageTests.swift @@ -0,0 +1,80 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest + +@testable import FirebaseVertexAI + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +final class ImagenFileDataImageTests: XCTestCase { + let decoder = JSONDecoder() + + func testDecodeImage_gcsURI() throws { + let gcsURI = "gs://test-bucket/images/123456789/sample_0.png" + let mimeType = "image/jpeg" + let json = """ + { + "mimeType": "\(mimeType)", + "gcsUri": "\(gcsURI)" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let image = try decoder.decode(ImagenFileDataImage.self, from: jsonData) + + XCTAssertEqual(image.mimeType, mimeType) + XCTAssertEqual(image.gcsURI, gcsURI) + XCTAssertEqual(image.imagenImage.mimeType, mimeType) + XCTAssertEqual(image.imagenImage.gcsURI, gcsURI) + XCTAssertNil(image.imagenImage.bytesBase64Encoded) + } + + func testDecodeImage_missingGCSURI_throws() throws { + let json = """ + { + "mimeType": "image/jpeg" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + do { + _ = try decoder.decode(ImagenFileDataImage.self, from: jsonData) + XCTFail("Expected an error; none thrown.") + } catch let DecodingError.keyNotFound(codingKey, _) { + let codingKey = try XCTUnwrap(codingKey as? ImagenFileDataImage.CodingKeys) + XCTAssertEqual(codingKey, .gcsURI) + } catch { + XCTFail("Expected a DecodingError.keyNotFound error; got \(error).") + } + } + + func testDecodeImage_missingMimeType_throws() throws { + let json = """ + { + "gcsUri": "gs://test-bucket/images/123456789/sample_0.png" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + do { + _ = try decoder.decode(ImagenFileDataImage.self, from: jsonData) + XCTFail("Expected an error; none thrown.") + } catch let DecodingError.keyNotFound(codingKey, _) { + let codingKey = try XCTUnwrap(codingKey as? ImagenFileDataImage.CodingKeys) + XCTAssertEqual(codingKey, .mimeType) + } catch { + XCTFail("Expected a DecodingError.keyNotFound error; got \(error).") + } + } +} diff --git a/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImagenInlineDataImageTests.swift b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImagenInlineDataImageTests.swift new file mode 100644 index 00000000000..8479f3e0079 --- /dev/null +++ b/FirebaseVertexAI/Tests/Unit/Types/Imagen/ImagenInlineDataImageTests.swift @@ -0,0 +1,80 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest + +@testable import FirebaseVertexAI + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +final class ImagenInlineDataImageTests: XCTestCase { + let decoder = JSONDecoder() + + func testDecodeImage_bytesBase64Encoded() throws { + let mimeType = "image/png" + let bytesBase64Encoded = "dGVzdC1iYXNlNjQtZGF0YQ==" + let json = """ + { + "bytesBase64Encoded": "\(bytesBase64Encoded)", + "mimeType": "\(mimeType)" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let image = try decoder.decode(ImagenInlineDataImage.self, from: jsonData) + + XCTAssertEqual(image.mimeType, mimeType) + XCTAssertEqual(image.data.base64EncodedString(), bytesBase64Encoded) + XCTAssertEqual(image.imagenImage.mimeType, mimeType) + XCTAssertEqual(image.imagenImage.bytesBase64Encoded, bytesBase64Encoded) + XCTAssertNil(image.imagenImage.gcsURI) + } + + func testDecodeImage_missingBytesBase64Encoded_throws() throws { + let json = """ + { + "mimeType": "image/jpeg" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + do { + _ = try decoder.decode(ImagenInlineDataImage.self, from: jsonData) + XCTFail("Expected an error; none thrown.") + } catch let DecodingError.keyNotFound(codingKey, _) { + let codingKey = try XCTUnwrap(codingKey as? ImagenInlineDataImage.CodingKeys) + XCTAssertEqual(codingKey, .bytesBase64Encoded) + } catch { + XCTFail("Expected a DecodingError.keyNotFound error; got \(error).") + } + } + + func testDecodeImage_missingMimeType_throws() throws { + let json = """ + { + "bytesBase64Encoded": "dGVzdC1iYXNlNjQtZGF0YQ==" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + do { + _ = try decoder.decode(ImagenInlineDataImage.self, from: jsonData) + XCTFail("Expected an error; none thrown.") + } catch let DecodingError.keyNotFound(codingKey, _) { + let codingKey = try XCTUnwrap(codingKey as? ImagenInlineDataImage.CodingKeys) + XCTAssertEqual(codingKey, .mimeType) + } catch { + XCTFail("Expected a DecodingError.keyNotFound error; got \(error).") + } + } +} diff --git a/FirebaseVertexAI/Tests/Unit/Types/Imagen/InternalImagenImageTests.swift b/FirebaseVertexAI/Tests/Unit/Types/Imagen/InternalImagenImageTests.swift deleted file mode 100644 index c66cc052785..00000000000 --- a/FirebaseVertexAI/Tests/Unit/Types/Imagen/InternalImagenImageTests.swift +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -import XCTest - -@testable import FirebaseVertexAI - -@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -final class InternalImagenImageTests: XCTestCase { - let decoder = JSONDecoder() - - func testDecodeImage_bytesBase64Encoded() throws { - let mimeType = "image/png" - let bytesBase64Encoded = "test-base64-bytes" - let json = """ - { - "bytesBase64Encoded": "\(bytesBase64Encoded)", - "mimeType": "\(mimeType)" - } - """ - let jsonData = try XCTUnwrap(json.data(using: .utf8)) - - let image = try decoder.decode(InternalImagenImage.self, from: jsonData) - - XCTAssertEqual(image.mimeType, mimeType) - XCTAssertEqual(image.bytesBase64Encoded, bytesBase64Encoded) - XCTAssertNil(image.gcsURI) - } - - func testDecodeImage_gcsURI() throws { - let gcsURI = "gs://test-bucket/images/123456789/sample_0.png" - let mimeType = "image/jpeg" - let json = """ - { - "mimeType": "\(mimeType)", - "gcsUri": "\(gcsURI)" - } - """ - let jsonData = try XCTUnwrap(json.data(using: .utf8)) - - let image = try decoder.decode(InternalImagenImage.self, from: jsonData) - - XCTAssertEqual(image.mimeType, mimeType) - XCTAssertEqual(image.gcsURI, gcsURI) - XCTAssertNil(image.bytesBase64Encoded) - } - - func testDecodeImage_missingBytesBase64EncodedAndGCSURI_throws() throws { - let json = """ - { - "mimeType": "image/jpeg" - } - """ - let jsonData = try XCTUnwrap(json.data(using: .utf8)) - - do { - _ = try decoder.decode(InternalImagenImage.self, from: jsonData) - XCTFail("Expected an error; none thrown.") - } catch let DecodingError.dataCorrupted(context) { - let codingPath = try XCTUnwrap(context - .codingPath as? [ImagenImageCodingKeys]) - XCTAssertEqual(codingPath, [.bytesBase64Encoded, .gcsURI]) - XCTAssertTrue(context.debugDescription.contains("both are nil")) - } catch { - XCTFail("Expected a DecodingError.dataCorrupted error; got \(error).") - } - } - - func testDecodeImage_bothBytesBase64EncodedAndGCSURI_throws() throws { - let json = """ - { - "bytesBase64Encoded": "test-base64-bytes", - "mimeType": "image/png", - "gcsUri": "gs://test-bucket/images/123456789/sample_0.png" - } - """ - let jsonData = try XCTUnwrap(json.data(using: .utf8)) - - do { - _ = try decoder.decode(InternalImagenImage.self, from: jsonData) - XCTFail("Expected an error; none thrown.") - } catch let DecodingError.dataCorrupted(context) { - let codingPath = try XCTUnwrap(context.codingPath as? [ImagenImageCodingKeys]) - XCTAssertEqual(codingPath, [.bytesBase64Encoded, .gcsURI]) - XCTAssertTrue(context.debugDescription.contains("both are specified")) - } catch { - XCTFail("Expected a DecodingError.dataCorrupted error; got \(error).") - } - } -}