Skip to content

Commit

Permalink
Added image test
Browse files Browse the repository at this point in the history
  • Loading branch information
Archetapp committed Dec 12, 2024
1 parent 68b9fde commit a1fdad7
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,33 @@ extension _Gemini.APISpecification {

// Add metadata as JSON
let metadata = ["file": ["display_name": displayName]]
let metadataData = try JSONSerialization.data(withJSONObject: metadata)

let fileExtension = mimeType.split(separator: "/").last == "quicktime" ? "mov" :
String(mimeType.split(separator: "/").last ?? "bin")
let fileExtension: String = {
guard let subtype = mimeType.split(separator: "/").last else {
return "bin"
}

switch subtype {
case "quicktime":
return "mov"
case "x-m4a":
return "m4a"
case "mp4":
return "mp4"
case "jpeg", "jpg":
return "jpg"
case "png":
return "png"
case "gif":
return "gif"
case "webp":
return "webp"
case "pdf":
return "pdf"
default:
return String(subtype)
}
}()

result.append(
.file(
Expand Down
138 changes: 107 additions & 31 deletions Sources/_Gemini/Intramodular/_Gemini.Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,44 +43,88 @@ extension _Gemini.Client {
delaySeconds: Double = 1.0
) async throws -> _Gemini.File {
guard !name.isEmpty else {
throw _Gemini.APIError.unknown(message: "Invalid file name")
throw FileProcessingError.invalidFileName
}

for attempt in 1...maxAttempts {
let input = _Gemini.APISpecification.RequestBodies.FileStatusInput(name: name)
let fileStatus = try await run(\.getFile, with: input)

print("File status attempt \(attempt): \(fileStatus.state)")

if fileStatus.state == .active {
return fileStatus
}

if attempt < maxAttempts {
try await Task.sleep(nanoseconds: UInt64(delaySeconds * 1_000_000_000))
do {
let input = _Gemini.APISpecification.RequestBodies.FileStatusInput(name: name)
let fileStatus = try await run(\.getFile, with: input)

print("File status attempt \(attempt): \(fileStatus.state)")

switch fileStatus.state {
case .active:
return fileStatus
case .processing:
break
}

if attempt < maxAttempts {
try await Task.sleep(nanoseconds: UInt64(delaySeconds * 1_000_000_000))
}
} catch {
if attempt == maxAttempts {
throw error
}
continue
}
}

throw _Gemini.APIError.unknown(message: "File processing timeout")
throw FileProcessingError.processingTimeout(fileName: name)
}

public func generateContent(
file: _Gemini.File,
url: URL,
type: HTTPMediaType,
prompt: String,
model: _Gemini.Model
) async throws -> _Gemini.APISpecification.ResponseBodies.GenerateContent {

do {
let data = try Data(contentsOf: url)

let uploadedFile = try await uploadFile(
fileData: data,
mimeType: type,
displayName: "Test"
)

return try await self.generateContent(
file: uploadedFile,
prompt: prompt,
model: model
)
} catch let error as NSError where error.domain == NSCocoaErrorDomain {
throw _Gemini.APIError.unknown(message: "Failed to read file: \(error.localizedDescription)")
} catch {
throw error
}
}

public func generateContent(
file: _Gemini.File,
prompt: String,
model: _Gemini.Model
) async throws -> _Gemini.APISpecification.ResponseBodies.GenerateContent {
guard let fileName = file.name else {
throw FileProcessingError.invalidFileName
}

do {
print("Waiting for file processing...")
let processedFile = try await waitForFileProcessing(name: file.name ?? "")
let processedFile = try await waitForFileProcessing(name: fileName)
print("File processing complete: \(processedFile)")

// Create content request matching the expected format
guard let mimeType = file.mimeType else {
throw _Gemini.APIError.unknown(message: "Invalid MIME type")
}

let fileUri = processedFile.uri

let fileContent = _Gemini.APISpecification.RequestBodies.Content(
role: "user",
parts: [
.file(url: processedFile.uri, mimeType: file.mimeType ?? ""),
.file(url: fileUri, mimeType: mimeType),
]
)

Expand All @@ -105,9 +149,13 @@ extension _Gemini.Client {
)
)

print(input)

return try await run(\.generateContent, with: input)
} catch {
} catch let error as FileProcessingError {
throw error
} catch {
throw _Gemini.APIError.unknown(message: "Content generation failed: \(error.localizedDescription)")
}
}

Expand All @@ -116,28 +164,56 @@ extension _Gemini.Client {
mimeType: HTTPMediaType,
displayName: String
) async throws -> _Gemini.File {
let input = _Gemini.APISpecification.RequestBodies.FileUploadInput(
fileData: fileData,
mimeType: mimeType.rawValue,
displayName: displayName
)
guard !displayName.isEmpty else {
throw FileProcessingError.invalidFileName
}

let response = try await run(\.uploadFile, with: input)
return response.file
do {
let input = _Gemini.APISpecification.RequestBodies.FileUploadInput(
fileData: fileData,
mimeType: mimeType.rawValue,
displayName: displayName
)

let response = try await run(\.uploadFile, with: input)
return response.file
} catch {
throw _Gemini.APIError.unknown(message: "File upload failed: \(error.localizedDescription)")
}
}

public func getFile(
name: String
) async throws -> _Gemini.File {
let input = _Gemini.APISpecification.RequestBodies.FileStatusInput(name: name)
let file = try await run(\.getFile, with: input)
return file
guard !name.isEmpty else {
throw FileProcessingError.invalidFileName
}

do {
let input = _Gemini.APISpecification.RequestBodies.FileStatusInput(name: name)
return try await run(\.getFile, with: input)
} catch {
throw _Gemini.APIError.unknown(message: "Failed to get file status: \(error.localizedDescription)")
}
}

public func deleteFile(
fileURL: URL
) async throws {
let input = _Gemini.APISpecification.RequestBodies.DeleteFileInput(fileURL: fileURL)
try await run(\.deleteFile, with: input)
do {
let input = _Gemini.APISpecification.RequestBodies.DeleteFileInput(fileURL: fileURL)
try await run(\.deleteFile, with: input)
} catch {
throw _Gemini.APIError.unknown(message: "Failed to delete file: \(error.localizedDescription)")
}
}
}

// Error Handling

fileprivate enum FileProcessingError: Error {
case invalidFileName
case processingTimeout(fileName: String)
case invalidFileState(state: String)
case fileNotFound(name: String)
}
40 changes: 38 additions & 2 deletions Tests/_Gemini/Intramodular/_GeminiTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//

import Testing
import SwiftUIX
import Foundation
import _Gemini

Expand Down Expand Up @@ -70,6 +71,30 @@ private final class BundleHelper {}
model: .gemini_1_5_flash
)

print(response)

#expect(response.candidates != nil)
#expect(!response.candidates!.isEmpty)
} catch let error as GeminiTestError {
print("Detailed error: \(error.localizedDescription)")
#expect(false, "Audio content generation failed: \(error)")
} catch {
throw GeminiTestError.audioProcessingError(error)
}
}

@Test func testImageContentGeneration() async throws {
do {
let file = try await createFile(type: .image)

let response = try await client.generateContent(
file: file,
prompt: "What is this the shape of this image?",
model: .gemini_1_5_flash
)

print(response)

#expect(response.candidates != nil)
#expect(!response.candidates!.isEmpty)
} catch let error as GeminiTestError {
Expand Down Expand Up @@ -127,16 +152,26 @@ private final class BundleHelper {}
return try await client.uploadFile(
fileData: audioData,
mimeType: .custom("audio/x-m4a"),
displayName: UUID().uuidString
displayName: "Test"
)
case .video:
let videoData = try loadTestFile(named: "LintMySwiftSmall", fileExtension: "mov")

return try await client.uploadFile(
fileData: videoData,
mimeType: .custom("video/quicktime"),
displayName: UUID().uuidString
displayName: "Test"
)
case .image:
let image = AppKitOrUIKitImage(_SwiftUIX_systemName: "arrow.up", withConfiguration: .init(pointSize: 50))
guard let imageData = image?.data(using: .png) else { throw GeminiTestError.fileNotFound("System Symbol")}

return try await client.uploadFile(
fileData: imageData,
mimeType: .custom("image/png"),
displayName: "Test"
)

}
} catch let error as GeminiTestError {
throw error
Expand All @@ -148,6 +183,7 @@ private final class BundleHelper {}
enum TestFileType {
case audio
case video
case image
}
}
// Error Handling
Expand Down

0 comments on commit a1fdad7

Please sign in to comment.