Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow PartsRepresentable to throw errors #88

Merged
merged 16 commits into from
Feb 21, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class PhotoReasoningViewModel: ObservableObject {

let prompt = "Look at the image(s), and then answer the following question: \(userInput)"

var images = [PartsRepresentable]()
var images = [any ThrowingPartsRepresentable]()
for item in selectedItems {
if let data = try? await item.loadTransferable(type: Data.self) {
guard let image = UIImage(data: data) else {
Expand Down
40 changes: 33 additions & 7 deletions Sources/GoogleAI/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ public class Chat {
public var history: [ModelContent]

/// See ``sendMessage(_:)-3ify5``.
public func sendMessage(_ parts: PartsRepresentable...) async throws -> GenerateContentResponse {
public func sendMessage(_ parts: any ThrowingPartsRepresentable...) async throws
-> GenerateContentResponse {
return try await sendMessage([ModelContent(parts: parts)])
}

Expand All @@ -40,9 +41,19 @@ public class Chat {
/// - Parameter content: The new content to send as a single chat message.
/// - Returns: The model's response if no error occurred.
/// - Throws: A ``GenerateContentError`` if an error occurred.
public func sendMessage(_ content: [ModelContent]) async throws -> GenerateContentResponse {
public func sendMessage(_ content: @autoclosure () throws -> [ModelContent]) async throws
-> GenerateContentResponse {
// Ensure that the new content has the role set.
let newContent: [ModelContent] = content.map(populateContentRole(_:))
let newContent: [ModelContent]
do {
newContent = try content().map(populateContentRole(_:))
} catch let underlying {
if let contentError = underlying as? ImageConversionError {
throw GenerateContentError.promptImageContentError(underlying: contentError)
} else {
throw GenerateContentError.internalError(underlying: underlying)
}
}

// Send the history alongside the new message as context.
let request = history + newContent
Expand All @@ -67,24 +78,39 @@ public class Chat {

/// See ``sendMessageStream(_:)-4abs3``.
@available(macOS 12.0, *)
public func sendMessageStream(_ parts: PartsRepresentable...)
public func sendMessageStream(_ parts: any ThrowingPartsRepresentable...)
-> AsyncThrowingStream<GenerateContentResponse, Error> {
return sendMessageStream([ModelContent(parts: parts)])
return try sendMessageStream([ModelContent(parts: parts)])
}

/// Sends a message using the existing history of this chat as context. If successful, the message
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
/// - Parameter content: The new content to send as a single chat message.
/// - Returns: A stream containing the model's response or an error if an error occurred.
@available(macOS 12.0, *)
public func sendMessageStream(_ content: [ModelContent])
public func sendMessageStream(_ content: @autoclosure () throws -> [ModelContent])
-> AsyncThrowingStream<GenerateContentResponse, Error> {
let resolvedContent: [ModelContent]
do {
resolvedContent = try content()
} catch let underlying {
return AsyncThrowingStream { continuation in
let error: Error
if let contentError = underlying as? ImageConversionError {
error = GenerateContentError.promptImageContentError(underlying: contentError)
} else {
error = GenerateContentError.internalError(underlying: underlying)
}
continuation.finish(throwing: error)
}
}

return AsyncThrowingStream { continuation in
Task {
var aggregatedContent: [ModelContent] = []

// Ensure that the new content has the role set.
let newContent: [ModelContent] = content.map(populateContentRole(_:))
let newContent: [ModelContent] = resolvedContent.map(populateContentRole(_:))

// Send the history alongside the new message as context.
let request = history + newContent
Expand Down
3 changes: 3 additions & 0 deletions Sources/GoogleAI/GenerateContentError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ import Foundation
/// Errors that occur when generating content from a model.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
public enum GenerateContentError: Error {
/// An error occurred when constructing the prompt. Examine the related error for details.
case promptImageContentError(underlying: ImageConversionError)

/// An internal error occurred. See the underlying error for more context.
case internalError(underlying: Error)

Expand Down
70 changes: 46 additions & 24 deletions Sources/GoogleAI/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,12 @@ public final class GenerativeModel {
/// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
/// prompts, see ``generateContent(_:)-58rm0``.
///
/// - Parameter content: The input(s) given to the model as a prompt (see ``PartsRepresentable``
/// - Parameter content: The input(s) given to the model as a prompt (see
/// ``ThrowingPartsRepresentable``
/// for conforming types).
/// - Returns: The content generated by the model.
/// - Throws: A ``GenerateContentError`` if the request failed.
public func generateContent(_ parts: PartsRepresentable...)
public func generateContent(_ parts: any ThrowingPartsRepresentable...)
async throws -> GenerateContentResponse {
return try await generateContent([ModelContent(parts: parts)])
}
Expand All @@ -110,18 +111,21 @@ public final class GenerativeModel {
/// - Parameter content: The input(s) given to the model as a prompt.
/// - Returns: The generated content response from the model.
/// - Throws: A ``GenerateContentError`` if the request failed.
public func generateContent(_ content: [ModelContent]) async throws
public func generateContent(_ content: @autoclosure () throws -> [ModelContent]) async throws
-> GenerateContentResponse {
let generateContentRequest = GenerateContentRequest(model: modelResourceName,
contents: content,
generationConfig: generationConfig,
safetySettings: safetySettings,
isStreaming: false,
options: requestOptions)
let response: GenerateContentResponse
do {
let generateContentRequest = try GenerateContentRequest(model: modelResourceName,
contents: content(),
generationConfig: generationConfig,
safetySettings: safetySettings,
isStreaming: false,
options: requestOptions)
response = try await generativeAIService.loadRequest(request: generateContentRequest)
} catch {
if let imageError = error as? ImageConversionError {
throw GenerateContentError.promptImageContentError(underlying: imageError)
}
throw GenerativeModel.generateContentError(from: error)
}

Expand All @@ -148,14 +152,15 @@ public final class GenerativeModel {
/// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
/// prompts, see ``generateContent(_:)-58rm0``.
///
/// - Parameter content: The input(s) given to the model as a prompt (see ``PartsRepresentable``
/// - Parameter content: The input(s) given to the model as a prompt (see
/// ``ThrowingPartsRepresentable``
/// for conforming types).
/// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
/// error if an error occurred.
@available(macOS 12.0, *)
public func generateContentStream(_ parts: PartsRepresentable...)
public func generateContentStream(_ parts: any ThrowingPartsRepresentable...)
-> AsyncThrowingStream<GenerateContentResponse, Error> {
return generateContentStream([ModelContent(parts: parts)])
return try generateContentStream([ModelContent(parts: parts)])
}

/// Generates new content from input content given to the model as a prompt.
Expand All @@ -164,10 +169,25 @@ public final class GenerativeModel {
/// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
/// error if an error occurred.
@available(macOS 12.0, *)
public func generateContentStream(_ content: [ModelContent])
public func generateContentStream(_ content: @autoclosure () throws -> [ModelContent])
-> AsyncThrowingStream<GenerateContentResponse, Error> {
let evaluatedContent: [ModelContent]
do {
evaluatedContent = try content()
} catch let underlying {
return AsyncThrowingStream { continuation in
let error: Error
if let contentError = underlying as? ImageConversionError {
error = GenerateContentError.promptImageContentError(underlying: contentError)
} else {
error = GenerateContentError.internalError(underlying: underlying)
}
continuation.finish(throwing: error)
}
}

let generateContentRequest = GenerateContentRequest(model: modelResourceName,
contents: content,
contents: evaluatedContent,
generationConfig: generationConfig,
safetySettings: safetySettings,
isStreaming: true,
Expand Down Expand Up @@ -218,12 +238,14 @@ public final class GenerativeModel {
/// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
/// input, see ``countTokens(_:)-9spwl``.
///
/// - Parameter content: The input(s) given to the model as a prompt (see ``PartsRepresentable``
/// - Parameter content: The input(s) given to the model as a prompt (see
/// ``ThrowingPartsRepresentable``
/// for conforming types).
/// - Returns: The results of running the model's tokenizer on the input; contains
/// ``CountTokensResponse/totalTokens``.
/// - Throws: A ``CountTokensError`` if the tokenization request failed.
public func countTokens(_ parts: PartsRepresentable...) async throws -> CountTokensResponse {
public func countTokens(_ parts: any ThrowingPartsRepresentable...) async throws
-> CountTokensResponse {
return try await countTokens([ModelContent(parts: parts)])
}

Expand All @@ -232,16 +254,16 @@ public final class GenerativeModel {
/// - Parameter content: The input given to the model as a prompt.
/// - Returns: The results of running the model's tokenizer on the input; contains
/// ``CountTokensResponse/totalTokens``.
/// - Throws: A ``CountTokensError`` if the tokenization request failed.
public func countTokens(_ content: [ModelContent]) async throws
/// - Throws: A ``CountTokensError`` if the tokenization request failed or the input content was
/// invalid.
public func countTokens(_ content: @autoclosure () throws -> [ModelContent]) async throws
-> CountTokensResponse {
let countTokensRequest = CountTokensRequest(
model: modelResourceName,
contents: content,
options: requestOptions
)

do {
let countTokensRequest = try CountTokensRequest(
model: modelResourceName,
contents: content(),
options: requestOptions
)
return try await generativeAIService.loadRequest(request: countTokensRequest)
} catch {
throw CountTokensError.internalError(underlying: error)
Expand Down
25 changes: 21 additions & 4 deletions Sources/GoogleAI/ModelContent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,14 @@ public struct ModelContent: Codable, Equatable {
public let parts: [Part]

/// Creates a new value from any data or `Array` of data interpretable as a
/// ``Part``. See ``PartsRepresentable`` for types that can be interpreted as `Part`s.
/// ``Part``. See ``ThrowingPartsRepresentable`` for types that can be interpreted as `Part`s.
public init(role: String? = "user", parts: some ThrowingPartsRepresentable) throws {
self.role = role
try self.parts = parts.tryPartsValue()
}

/// Creates a new value from any data or `Array` of data interpretable as a
/// ``Part``. See ``ThrowingPartsRepresentable`` for types that can be interpreted as `Part`s.
public init(role: String? = "user", parts: some PartsRepresentable) {
self.role = role
self.parts = parts.partsValue
Expand All @@ -116,9 +123,19 @@ public struct ModelContent: Codable, Equatable {
self.parts = parts
}

/// Creates a new value from any data interpretable as a ``Part``. See ``PartsRepresentable``
/// Creates a new value from any data interpretable as a ``Part``. See
/// ``ThrowingPartsRepresentable``
/// for types that can be interpreted as `Part`s.
public init(role: String? = "user", _ parts: any ThrowingPartsRepresentable...) throws {
let content = try parts.flatMap { try $0.tryPartsValue() }
self.init(role: role, parts: content)
}

/// Creates a new value from any data interpretable as a ``Part``. See
/// ``ThrowingPartsRepresentable``
/// for types that can be interpreted as `Part`s.
public init(role: String? = "user", _ parts: PartsRepresentable...) {
self.init(role: role, parts: parts)
public init(role: String? = "user", _ parts: [PartsRepresentable]) {
let content = parts.flatMap { $0.partsValue }
self.init(role: role, parts: content)
}
}
107 changes: 107 additions & 0 deletions Sources/GoogleAI/PartsRepresentable+Image.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import UniformTypeIdentifiers
#if canImport(UIKit)
import UIKit // For UIImage extensions.
#elseif canImport(AppKit)
import AppKit // For NSImage extensions.
#endif

private let imageCompressionQuality: CGFloat = 0.8

/// An enum describing failures that can occur when converting image types to model content data.
/// For some image types like `CIImage`, creating valid model content requires creating a JPEG
/// representation of the image that may not yet exist, which may be computationally expensive.
public enum ImageConversionError: Error {
/// The image (the receiver of the call `toModelContentParts()`) was invalid.
case invalidUnderlyingImage

/// A valid image destination could not be allocated.
case couldNotAllocateDestination

/// JPEG image data conversion failed, accompanied by the original image, which may be an
/// instance of `NSImageRep`, `UIImage`, `CGImage`, or `CIImage`.
case couldNotConvertToJPEG(Any)
}

#if canImport(UIKit)
/// Enables images to be representable as ``ThrowingPartsRepresentable``.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
extension UIImage: ThrowingPartsRepresentable {
public func tryPartsValue() throws -> [ModelContent.Part] {
guard let data = jpegData(compressionQuality: imageCompressionQuality) else {
throw ImageConversionError.couldNotConvertToJPEG(self)
}
return [ModelContent.Part.data(mimetype: "image/jpeg", data)]
}
}

#elseif canImport(AppKit)
/// Enables images to be representable as ``ThrowingPartsRepresentable``.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
extension NSImage: ThrowingPartsRepresentable {
public func tryPartsValue() throws -> [ModelContent.Part] {
guard let cgImage = cgImage(forProposedRect: nil, context: nil, hints: nil) else {
throw ImageConversionError.invalidUnderlyingImage
}
let bmp = NSBitmapImageRep(cgImage: cgImage)
guard let data = bmp.representation(using: .jpeg, properties: [.compressionFactor: 0.8])
else {
throw ImageConversionError.couldNotConvertToJPEG(bmp)
}
return [ModelContent.Part.data(mimetype: "image/jpeg", data)]
}
}
#endif

/// Enables `CGImages` to be representable as model content.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
extension CGImage: ThrowingPartsRepresentable {
public func tryPartsValue() throws -> [ModelContent.Part] {
let output = NSMutableData()
guard let imageDestination = CGImageDestinationCreateWithData(
output, UTType.jpeg.identifier as CFString, 1, nil
) else {
throw ImageConversionError.couldNotAllocateDestination
}
CGImageDestinationAddImage(imageDestination, self, nil)
CGImageDestinationSetProperties(imageDestination, [
kCGImageDestinationLossyCompressionQuality: imageCompressionQuality,
] as CFDictionary)
if CGImageDestinationFinalize(imageDestination) {
return [.data(mimetype: "image/jpeg", output as Data)]
}
throw ImageConversionError.couldNotConvertToJPEG(self)
}
}

/// Enables `CIImages` to be representable as model content.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
extension CIImage: ThrowingPartsRepresentable {
public func tryPartsValue() throws -> [ModelContent.Part] {
let context = CIContext()
let jpegData = (colorSpace ?? CGColorSpace(name: CGColorSpace.sRGB))
.flatMap {
// The docs specify kCGImageDestinationLossyCompressionQuality as a supported option, but
// Swift's type system does not allow this.
// [kCGImageDestinationLossyCompressionQuality: imageCompressionQuality]
context.jpegRepresentation(of: self, colorSpace: $0, options: [:])
}
if let jpegData = jpegData {
return [.data(mimetype: "image/jpeg", jpegData)]
}
throw ImageConversionError.couldNotConvertToJPEG(self)
}
}
Loading
Loading