Skip to content

Commit

Permalink
[Vertex AI] Add ImageGenerationResponse for decoding PredictResponse (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Dec 6, 2024
1 parent 78fe33c commit e773941
Show file tree
Hide file tree
Showing 9 changed files with 640 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
protocol DecodableImagenImage: ImagenImage, Decodable {
init(mimeType: String, bytesBase64Encoded: String?, gcsURI: String?)
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
enum ImagenImageCodingKeys: String, CodingKey {
case mimeType
case bytesBase64Encoded
case gcsURI = "gcsUri"
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension DecodableImagenImage {
init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: ImagenImageCodingKeys.self)
let mimeType = try container.decode(String.self, forKey: .mimeType)
let bytesBase64Encoded = try container.decodeIfPresent(
String.self,
forKey: .bytesBase64Encoded
)
let gcsURI = try container.decodeIfPresent(String.self, forKey: .gcsURI)
guard bytesBase64Encoded != nil || gcsURI != nil else {
throw DecodingError.dataCorrupted(
DecodingError.Context(
codingPath: [ImagenImageCodingKeys.bytesBase64Encoded, ImagenImageCodingKeys.gcsURI],
debugDescription: """
Expected one of \(ImagenImageCodingKeys.bytesBase64Encoded.rawValue) or \
\(ImagenImageCodingKeys.gcsURI.rawValue); both are nil.
"""
)
)
}
guard bytesBase64Encoded == nil || gcsURI == nil else {
throw DecodingError.dataCorrupted(
DecodingError.Context(
codingPath: [ImagenImageCodingKeys.bytesBase64Encoded, ImagenImageCodingKeys.gcsURI],
debugDescription: """
Expected one of \(ImagenImageCodingKeys.bytesBase64Encoded.rawValue) or \
\(ImagenImageCodingKeys.gcsURI.rawValue); both are specified.
"""
)
)
}

self.init(mimeType: mimeType, bytesBase64Encoded: bytesBase64Encoded, gcsURI: gcsURI)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct ImageGenerationResponse {
let images: [InternalImagenImage]
let raiFilteredReason: String?
}

// MARK: - Codable Conformances

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension ImageGenerationResponse: Decodable {
enum CodingKeys: CodingKey {
case predictions
}

public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
guard container.contains(.predictions) else {
images = []
raiFilteredReason = nil
// TODO: Log warning if no predictions.
return
}
var predictionsContainer = try container.nestedUnkeyedContainer(forKey: .predictions)

var images = [InternalImagenImage]()
var raiFilteredReasons = [String]()
while !predictionsContainer.isAtEnd {
if let image = try? predictionsContainer.decode(InternalImagenImage.self) {
images.append(image)
} else if let filterReason = try? predictionsContainer.decode(RAIFilteredReason.self) {
raiFilteredReasons.append(filterReason.raiFilteredReason)
} else if let _ = try? predictionsContainer.decode(JSONObject.self) {
// TODO: Log or throw unsupported prediction type
} else {
// This should never be thrown since JSONObject accepts any valid JSON.
throw DecodingError.dataCorruptedError(
in: predictionsContainer,
debugDescription: "Failed to decode Prediction."
)
}
}

self.images = images
raiFilteredReason = raiFilteredReasons.first
// TODO: Log if more than one RAI Filtered Reason; unexpected behaviour.
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct InternalImagenImage {
let mimeType: String
let bytesBase64Encoded: String?
let gcsURI: String?

init(mimeType: String, bytesBase64Encoded: String?, gcsURI: String?) {
self.mimeType = mimeType
self.bytesBase64Encoded = bytesBase64Encoded
self.gcsURI = gcsURI
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension InternalImagenImage: DecodableImagenImage {}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension InternalImagenImage: Equatable {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct RAIFilteredReason {
let raiFilteredReason: String
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension RAIFilteredReason: Decodable {
enum CodingKeys: CodingKey {
case raiFilteredReason
}
}
22 changes: 22 additions & 0 deletions FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenImage.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public protocol ImagenImage: ImagenImageRepresentable {
var mimeType: String { get }
var bytesBase64Encoded: String? { get }
var gcsURI: String? { get }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public protocol ImagenImageRepresentable {
var imagenImage: any ImagenImage { get }
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public extension ImagenImage {
var imagenImage: any ImagenImage {
return self
}
}
Loading

0 comments on commit e773941

Please sign in to comment.