diff --git a/FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift b/FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift index 7b3b78753db..dc5ce8f9561 100644 --- a/FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift +++ b/FirebaseVertexAI/Sample/ChatSample/Views/ErrorDetailsView.swift @@ -176,8 +176,7 @@ struct ErrorDetailsView: View { ], finishReason: FinishReason.maxTokens, citationMetadata: nil), - ], - promptFeedback: nil) + ]) ) return ErrorDetailsView(error: error) @@ -200,8 +199,7 @@ struct ErrorDetailsView: View { ], finishReason: FinishReason.other, citationMetadata: nil), - ], - promptFeedback: nil) + ]) ) return ErrorDetailsView(error: error) diff --git a/FirebaseVertexAI/Sample/ChatSample/Views/ErrorView.swift b/FirebaseVertexAI/Sample/ChatSample/Views/ErrorView.swift index 1307eee62d4..d4db2d67dc5 100644 --- a/FirebaseVertexAI/Sample/ChatSample/Views/ErrorView.swift +++ b/FirebaseVertexAI/Sample/ChatSample/Views/ErrorView.swift @@ -51,8 +51,7 @@ struct ErrorView: View { ], finishReason: FinishReason.other, citationMetadata: nil), - ], - promptFeedback: nil) + ]) ) List { MessageView(message: ChatMessage.samples[0]) diff --git a/FirebaseVertexAI/Sources/GenerateContentResponse.swift b/FirebaseVertexAI/Sources/GenerateContentResponse.swift index f43da308b13..29c5b14f552 100644 --- a/FirebaseVertexAI/Sources/GenerateContentResponse.swift +++ b/FirebaseVertexAI/Sources/GenerateContentResponse.swift @@ -17,6 +17,18 @@ import Foundation /// The model's response to a generate content request. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) public struct GenerateContentResponse { + /// Token usage metadata for processing the generate content request. + public struct UsageMetadata { + /// The number of tokens in the request prompt. + public let promptTokenCount: Int + + /// The total number of tokens across the generated response candidates. + public let candidatesTokenCount: Int + + /// The total number of tokens in both the request and response. + public let totalTokenCount: Int + } + /// A list of candidate response content, ordered from best to worst. public let candidates: [CandidateResponse] @@ -24,6 +36,9 @@ public struct GenerateContentResponse { /// reason for blocking the request. public let promptFeedback: PromptFeedback? + /// Token usage metadata for processing the generate content request. + public let usageMetadata: UsageMetadata? + /// The response's content as text, if it exists. public var text: String? { guard let candidate = candidates.first else { @@ -51,9 +66,11 @@ public struct GenerateContentResponse { } /// Initializer for SwiftUI previews or tests. - public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback?) { + public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback? = nil, + usageMetadata: UsageMetadata? = nil) { self.candidates = candidates self.promptFeedback = promptFeedback + self.usageMetadata = usageMetadata } } @@ -62,6 +79,7 @@ extension GenerateContentResponse: Decodable { enum CodingKeys: CodingKey { case candidates case promptFeedback + case usageMetadata } public init(from decoder: Decoder) throws { @@ -86,6 +104,7 @@ extension GenerateContentResponse: Decodable { candidates = [] } promptFeedback = try container.decodeIfPresent(PromptFeedback.self, forKey: .promptFeedback) + usageMetadata = try container.decodeIfPresent(UsageMetadata.self, forKey: .usageMetadata) } } @@ -301,3 +320,20 @@ extension PromptFeedback: Decodable { } } } + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +extension GenerateContentResponse.UsageMetadata: Decodable { + enum CodingKeys: CodingKey { + case promptTokenCount + case candidatesTokenCount + case totalTokenCount + } + + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + promptTokenCount = try container.decodeIfPresent(Int.self, forKey: .promptTokenCount) ?? 0 + candidatesTokenCount = try container + .decodeIfPresent(Int.self, forKey: .candidatesTokenCount) ?? 0 + totalTokenCount = try container.decodeIfPresent(Int.self, forKey: .totalTokenCount) ?? 0 + } +} diff --git a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift index 92585266891..e7ea54f5dd9 100644 --- a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift +++ b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift @@ -287,6 +287,21 @@ final class GenerativeModelTests: XCTestCase { _ = try await model.generateContent(testPrompt) } + func testGenerateContent_usageMetadata() async throws { + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "unary-success-basic-reply-short", + withExtension: "json" + ) + + let response = try await model.generateContent(testPrompt) + + let usageMetadata = try XCTUnwrap(response.usageMetadata) + XCTAssertEqual(usageMetadata.promptTokenCount, 6) + XCTAssertEqual(usageMetadata.candidatesTokenCount, 7) + XCTAssertEqual(usageMetadata.totalTokenCount, 13) + } + func testGenerateContent_failure_invalidAPIKey() async throws { let expectedStatusCode = 400 MockURLProtocol @@ -814,6 +829,32 @@ final class GenerativeModelTests: XCTestCase { for try await _ in stream {} } + func testGenerateContentStream_usageMetadata() async throws { + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "streaming-success-basic-reply-short", + withExtension: "txt" + ) + var responses = [GenerateContentResponse]() + + let stream = model.generateContentStream(testPrompt) + for try await response in stream { + responses.append(response) + } + + for (index, response) in responses.enumerated() { + if index == responses.endIndex - 1 { + let usageMetadata = try XCTUnwrap(response.usageMetadata) + XCTAssertEqual(usageMetadata.promptTokenCount, 6) + XCTAssertEqual(usageMetadata.candidatesTokenCount, 4) + XCTAssertEqual(usageMetadata.totalTokenCount, 10) + } else { + // Only the last streamed response contains usage metadata + XCTAssertNil(response.usageMetadata) + } + } + } + func testGenerateContentStream_errorMidStream() async throws { MockURLProtocol.requestHandler = try httpRequestHandler( forResource: "streaming-failure-error-mid-stream",