Skip to content

Commit

Permalink
[Vertex AI] Make ImageGenerationResponse generic and add image types (
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Dec 6, 2024
1 parent e773941 commit 9d0a8d8
Show file tree
Hide file tree
Showing 9 changed files with 329 additions and 225 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct ImageGenerationResponse {
let images: [InternalImagenImage]
let raiFilteredReason: String?
public struct ImageGenerationResponse<ImageType: ImagenImageRepresentable> {
public let images: [ImageType]
public let raiFilteredReason: String?
}

// MARK: - Codable Conformances

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImageGenerationResponse: Decodable {
extension ImageGenerationResponse: Decodable where ImageType: Decodable {
enum CodingKeys: CodingKey {
case predictions
}
Expand All @@ -38,15 +38,15 @@ extension ImageGenerationResponse: Decodable {
}
var predictionsContainer = try container.nestedUnkeyedContainer(forKey: .predictions)

var images = [InternalImagenImage]()
var images = [ImageType]()
var raiFilteredReasons = [String]()
while !predictionsContainer.isAtEnd {
if let image = try? predictionsContainer.decode(InternalImagenImage.self) {
if let image = try? predictionsContainer.decode(ImageType.self) {
images.append(image)
} else if let filterReason = try? predictionsContainer.decode(RAIFilteredReason.self) {
raiFilteredReasons.append(filterReason.raiFilteredReason)
} else if let _ = try? predictionsContainer.decode(JSONObject.self) {
// TODO: Log or throw unsupported prediction type
// TODO(#14221): Log or throw unsupported prediction type
} else {
// This should never be thrown since JSONObject accepts any valid JSON.
throw DecodingError.dataCorruptedError(
Expand All @@ -58,6 +58,6 @@ extension ImageGenerationResponse: Decodable {

self.images = images
raiFilteredReason = raiFilteredReasons.first
// TODO: Log if more than one RAI Filtered Reason; unexpected behaviour.
// TODO(#14221): Log if more than one RAI Filtered Reason; unexpected behaviour.
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,4 @@ struct InternalImagenImage {
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension InternalImagenImage: DecodableImagenImage {}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension InternalImagenImage: Equatable {}
extension InternalImagenImage: ImagenImage {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ImagenFileDataImage {
public let mimeType: String
public let gcsURI: String

init(mimeType: String, gcsURI: String) {
self.mimeType = mimeType
self.gcsURI = gcsURI
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImagenFileDataImage: ImagenImageRepresentable {
public var imagenImage: any ImagenImage {
InternalImagenImage(mimeType: mimeType, bytesBase64Encoded: nil, gcsURI: gcsURI)
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImagenFileDataImage: Equatable {}

// MARK: - Codable Conformances

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImagenFileDataImage: Decodable {
enum CodingKeys: String, CodingKey {
case mimeType
case gcsURI = "gcsUri"
}

public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let mimeType = try container.decode(String.self, forKey: .mimeType)
let gcsURI = try container.decode(String.self, forKey: .gcsURI)
self.init(mimeType: mimeType, gcsURI: gcsURI)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct ImagenInlineDataImage {
public let mimeType: String
public let data: Data

init(mimeType: String, bytesBase64Encoded: String) {
self.mimeType = mimeType
guard let data = Data(base64Encoded: bytesBase64Encoded) else {
// TODO(#14221): Add error handling for invalid base64 bytes.
fatalError("Creating a `Data` from `bytesBase64Encoded` failed.")
}
self.data = data
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImagenInlineDataImage: ImagenImageRepresentable {
public var imagenImage: any ImagenImage {
InternalImagenImage(
mimeType: mimeType,
bytesBase64Encoded: data.base64EncodedString(),
gcsURI: nil
)
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImagenInlineDataImage: Equatable {}

// MARK: - Codable Conformances

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImagenInlineDataImage: Decodable {
enum CodingKeys: CodingKey {
case mimeType
case bytesBase64Encoded
}

public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let mimeType = try container.decode(String.self, forKey: .mimeType)
let bytesBase64Encoded = try container.decode(String.self, forKey: .bytesBase64Encoded)
self.init(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded)
}
}
Loading

0 comments on commit 9d0a8d8

Please sign in to comment.