Skip to content

Commit

Permalink
Structured Output
Browse files Browse the repository at this point in the history
  • Loading branch information
Archetapp committed Dec 18, 2024
1 parent b2009e1 commit d6a6b37
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ extension _Gemini.APISpecification {
from response: Request.Response,
context: DecodeOutputContext
) throws -> Output {
print(response)
try response.validate()

return try response.decode(
Expand Down
77 changes: 51 additions & 26 deletions Sources/_Gemini/Intramodular/Models/_Gemini.GenerationConfig.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
// Created by Jared Davidson on 12/18/24.
//

import Foundation

extension _Gemini {
public struct GenerationConfig: Codable {
public let maxOutputTokens: Int?
Expand All @@ -14,7 +16,7 @@ extension _Gemini {
public let presencePenalty: Double?
public let frequencyPenalty: Double?
public let responseMimeType: String?
public let responseSchema: ResponseSchema?
public let responseSchema: SchemaObject?

public init(
maxOutputTokens: Int? = nil,
Expand All @@ -24,7 +26,7 @@ extension _Gemini {
presencePenalty: Double? = nil,
frequencyPenalty: Double? = nil,
responseMimeType: String? = nil,
responseSchema: ResponseSchema? = nil
responseSchema: SchemaObject? = nil
) {
self.maxOutputTokens = maxOutputTokens
self.temperature = temperature
Expand All @@ -37,38 +39,61 @@ extension _Gemini {
}
}

public struct ResponseSchema: Codable {
public let type: SchemaType
public let items: SchemaObject?
public let properties: [String: SchemaObject]?
public indirect enum SchemaObject: Codable {
case object(properties: [String: SchemaObject])
case array(items: SchemaObject)
case string
case number
case boolean

public init(
type: SchemaType,
items: SchemaObject? = nil,
properties: [String: SchemaObject]? = nil
) {
self.type = type
self.items = items
self.properties = properties
public var type: SchemaType {
switch self {
case .object: return .object
case .array: return .array
case .string: return .string
case .number: return .number
case .boolean: return .boolean
}
}

private enum CodingKeys: String, CodingKey {
case type
case items
case properties
case items
}
}

public struct SchemaObject: Codable {
public let type: SchemaType
public let properties: [String: SchemaObject]?

public init(
type: SchemaType,
properties: [String: SchemaObject]? = nil
) {
self.type = type
self.properties = properties
public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(type, forKey: .type)

switch self {
case .object(let properties):
try container.encode(properties, forKey: .properties)
case .array(let items):
try container.encode(items, forKey: .items)
case .string, .number, .boolean:
break
}
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let type = try container.decode(SchemaType.self, forKey: .type)

switch type {
case .object:
let properties = try container.decode([String: SchemaObject].self, forKey: .properties)
self = .object(properties: properties)
case .array:
let items = try container.decode(SchemaObject.self, forKey: .items)
self = .array(items: items)
case .string:
self = .string
case .number:
self = .number
case .boolean:
self = .boolean
}
}
}

Expand Down
141 changes: 83 additions & 58 deletions Sources/_Gemini/Intramodular/_Gemini.Client+ContentGeneration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import FoundationX
import Swallow

extension _Gemini.Client {

// FIXME: - I'm not sure where/how a default should be properly placed.
static public let configDefault: _Gemini.GenerationConfig = .init(
maxOutputTokens: 8192,
temperature: 1,
Expand All @@ -24,88 +26,111 @@ extension _Gemini.Client {
)

public func generateContent(
url: URL,
type: HTTPMediaType,
prompt: String,
messages: [_Gemini.Message] = [],
file: _Gemini.File? = nil,
fileURL: URL? = nil,
mimeType: HTTPMediaType? = nil,
model: _Gemini.Model,
config: _Gemini.GenerationConfig = configDefault
) async throws -> _Gemini.Content {
do {
let data = try Data(contentsOf: url)

let uploadedFile = try await uploadFile(
fileData: data,
mimeType: type,
displayName: UUID().uuidString
)
// Handle file URL if provided
if let fileURL = fileURL {
guard let mimeType = mimeType else {
throw _Gemini.APIError.unknown(message: "MIME type is required when using fileURL")
}

return try await self.generateContent(
file: uploadedFile,
prompt: prompt,
model: model,
config: config
)
} catch let error as NSError where error.domain == NSCocoaErrorDomain {
throw _Gemini.APIError.unknown(message: "Failed to read file: \(error.localizedDescription)")
} catch {
throw error
}
}

public func generateContent(
file: _Gemini.File,
prompt: String,
model: _Gemini.Model,
config: _Gemini.GenerationConfig = configDefault
) async throws -> _Gemini.Content {
guard let fileName = file.name else {
throw ContentGenerationError.invalidFileName
do {
let data = try Data(contentsOf: fileURL)
let uploadedFile = try await uploadFile(
fileData: data,
mimeType: mimeType,
displayName: UUID().uuidString
)
return try await generateContent(
messages: messages,
file: uploadedFile,
model: model,
config: config
)
} catch let error as NSError where error.domain == NSCocoaErrorDomain {
throw _Gemini.APIError.unknown(message: "Failed to read file: \(error.localizedDescription)")
}
}

do {
print("Waiting for file processing...")
// Handle file if provided
if let file = file {
guard let fileName = file.name else {
throw ContentGenerationError.invalidFileName
}

let processedFile = try await waitForFileProcessing(name: fileName)
print("File processing complete: \(processedFile)")

guard let mimeType = file.mimeType else {
throw _Gemini.APIError.unknown(message: "Invalid MIME type")
}

let fileUri = processedFile.uri
var contents: [_Gemini.APISpecification.RequestBodies.Content] = []

let fileContent = _Gemini.APISpecification.RequestBodies.Content(
// Add file content first if present
contents.append(_Gemini.APISpecification.RequestBodies.Content(
role: "user",
parts: [
.file(url: fileUri, mimeType: mimeType),
]
)
parts: [.file(url: processedFile.uri, mimeType: mimeType)]
))

let promptContent = _Gemini.APISpecification.RequestBodies.Content(
role: "user",
parts: [
.text(prompt)
]
)
// Add regular messages
contents.append(contentsOf: messages.filter { $0.role != .system }.map { message in
_Gemini.APISpecification.RequestBodies.Content(
role: message.role.rawValue,
parts: [.text(message.content)]
)
})

let systemInstruction = messages.first { $0.role == .system }.map { message in
_Gemini.APISpecification.RequestBodies.Content(
role: message.role.rawValue,
parts: [.text(message.content)]
)
}

let input = _Gemini.APISpecification.RequestBodies.GenerateContentInput(
model: model,
requestBody: .init(
contents: [fileContent, promptContent],
generationConfig: config
contents: contents,
generationConfig: config,
systemInstruction: systemInstruction
)
)

print(input)

let response = try await run(\.generateContent, with: input)

return try _Gemini.Content.init(apiResponse: response)

} catch let error as ContentGenerationError {
throw error
} catch {
throw _Gemini.APIError.unknown(message: "Content generation failed: \(error.localizedDescription)")
return try _Gemini.Content(apiResponse: response)
}

// Handle text-only messages
let contents = messages.filter { $0.role != .system }.map { message in
_Gemini.APISpecification.RequestBodies.Content(
role: message.role.rawValue,
parts: [.text(message.content)]
)
}

let systemInstruction = messages.first { $0.role == .system }.map { message in
_Gemini.APISpecification.RequestBodies.Content(
role: message.role.rawValue,
parts: [.text(message.content)]
)
}

let input = _Gemini.APISpecification.RequestBodies.GenerateContentInput(
model: model,
requestBody: .init(
contents: contents,
generationConfig: config,
systemInstruction: systemInstruction
)
)

let response = try await run(\.generateContent, with: input)
return try _Gemini.Content(apiResponse: response)
}
}

Expand Down
92 changes: 92 additions & 0 deletions Tests/_Gemini/Intramodular/_GeminiTests+StructuredOutput.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
//
// _GeminiTests+StructuredOutput.swift
// AI
//
// Created by Jared Davidson on 12/18/24.
//

import Testing
import Foundation
import _Gemini
import AI

@Suite struct _GeminiStructuredOutputTests {
@Test func testStructuredMovieReview() async throws {
let reviewSchema = _Gemini.SchemaObject.object(properties: [
"title": .string,
"rating": .number,
"genres": .array(items: .string),
"review": .string
])

let config = _Gemini.GenerationConfig(
temperature: 0.7,
responseMimeType: "application/json",
responseSchema: .object(properties: [
"review": reviewSchema
])
)

let messages = [
_Gemini.Message(
role: .user,
content: "Write a review for the movie 'Inception' with a rating from 1-10. Return it as a JSON object."
)
]

let response = try await client.generateContent(
messages: messages,
model: .gemini_1_5_pro_latest,
config: config
)

dump(response)

// Validate the response
#expect(!response.text.isEmpty, "Response should not be empty")

// Attempt to parse the response as JSON
if let jsonData = response.text.data(using: String.Encoding.utf8) {
do {
let wrapper = try JSONDecoder().decode(MovieReviewWrapper.self, from: jsonData)
let review = wrapper.review

// Validate the structured output
#expect(!review.title.isEmpty, "Movie title should not be empty")
#expect(review.rating >= 1 && review.rating <= 10, "Rating should be between 1 and 10")
#expect(!review.genres.isEmpty, "Genres array should not be empty")
#expect(!review.review.isEmpty, "Review text should not be empty")

print("Parsed review:", review)
} catch {
print("JSON parsing error:", error)
print("Response text:", response.text)
#expect(false, "Failed to parse JSON response: \(error)")
}
} else {
#expect(false, "Failed to convert response to data")
}

// Check token usage
if let usage = response.tokenUsage {
#expect(usage.prompt > 0, "Prompt tokens should be greater than 0")
#expect(usage.response > 0, "Response tokens should be greater than 0")
#expect(usage.total == usage.prompt + usage.response, "Total tokens should equal prompt + response")
}

// Check finish reason
#expect(response.finishReason == .stop, "Response should have completed normally")
}
}

// Response structures
private struct MovieReviewWrapper: Codable {
let review: MovieReview
}

private struct MovieReview: Codable {
let title: String
let rating: Double
let genres: [String]
let review: String
}
Loading

0 comments on commit d6a6b37

Please sign in to comment.