Skip to content

Commit

Permalink
Add usageMetadata to GenerateContentResponse in Vertex AI (#12777)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Apr 12, 2024
1 parent 113a100 commit 9d304a1
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,7 @@ struct ErrorDetailsView: View {
],
finishReason: FinishReason.maxTokens,
citationMetadata: nil),
],
promptFeedback: nil)
])
)

return ErrorDetailsView(error: error)
Expand All @@ -200,8 +199,7 @@ struct ErrorDetailsView: View {
],
finishReason: FinishReason.other,
citationMetadata: nil),
],
promptFeedback: nil)
])
)

return ErrorDetailsView(error: error)
Expand Down
3 changes: 1 addition & 2 deletions FirebaseVertexAI/Sample/ChatSample/Views/ErrorView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ struct ErrorView: View {
],
finishReason: FinishReason.other,
citationMetadata: nil),
],
promptFeedback: nil)
])
)
List {
MessageView(message: ChatMessage.samples[0])
Expand Down
38 changes: 37 additions & 1 deletion FirebaseVertexAI/Sources/GenerateContentResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,28 @@ import Foundation
/// The model's response to a generate content request.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
public struct GenerateContentResponse {
/// Token usage metadata for processing the generate content request.
public struct UsageMetadata {
/// The number of tokens in the request prompt.
public let promptTokenCount: Int

/// The total number of tokens across the generated response candidates.
public let candidatesTokenCount: Int

/// The total number of tokens in both the request and response.
public let totalTokenCount: Int
}

/// A list of candidate response content, ordered from best to worst.
public let candidates: [CandidateResponse]

/// A value containing the safety ratings for the response, or, if the request was blocked, a
/// reason for blocking the request.
public let promptFeedback: PromptFeedback?

/// Token usage metadata for processing the generate content request.
public let usageMetadata: UsageMetadata?

/// The response's content as text, if it exists.
public var text: String? {
guard let candidate = candidates.first else {
Expand Down Expand Up @@ -51,9 +66,11 @@ public struct GenerateContentResponse {
}

/// Initializer for SwiftUI previews or tests.
public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback?) {
public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback? = nil,
usageMetadata: UsageMetadata? = nil) {
self.candidates = candidates
self.promptFeedback = promptFeedback
self.usageMetadata = usageMetadata
}
}

Expand All @@ -62,6 +79,7 @@ extension GenerateContentResponse: Decodable {
enum CodingKeys: CodingKey {
case candidates
case promptFeedback
case usageMetadata
}

public init(from decoder: Decoder) throws {
Expand All @@ -86,6 +104,7 @@ extension GenerateContentResponse: Decodable {
candidates = []
}
promptFeedback = try container.decodeIfPresent(PromptFeedback.self, forKey: .promptFeedback)
usageMetadata = try container.decodeIfPresent(UsageMetadata.self, forKey: .usageMetadata)
}
}

Expand Down Expand Up @@ -301,3 +320,20 @@ extension PromptFeedback: Decodable {
}
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
extension GenerateContentResponse.UsageMetadata: Decodable {
enum CodingKeys: CodingKey {
case promptTokenCount
case candidatesTokenCount
case totalTokenCount
}

public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
promptTokenCount = try container.decodeIfPresent(Int.self, forKey: .promptTokenCount) ?? 0
candidatesTokenCount = try container
.decodeIfPresent(Int.self, forKey: .candidatesTokenCount) ?? 0
totalTokenCount = try container.decodeIfPresent(Int.self, forKey: .totalTokenCount) ?? 0
}
}
41 changes: 41 additions & 0 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,21 @@ final class GenerativeModelTests: XCTestCase {
_ = try await model.generateContent(testPrompt)
}

func testGenerateContent_usageMetadata() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-success-basic-reply-short",
withExtension: "json"
)

let response = try await model.generateContent(testPrompt)

let usageMetadata = try XCTUnwrap(response.usageMetadata)
XCTAssertEqual(usageMetadata.promptTokenCount, 6)
XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
XCTAssertEqual(usageMetadata.totalTokenCount, 13)
}

func testGenerateContent_failure_invalidAPIKey() async throws {
let expectedStatusCode = 400
MockURLProtocol
Expand Down Expand Up @@ -814,6 +829,32 @@ final class GenerativeModelTests: XCTestCase {
for try await _ in stream {}
}

func testGenerateContentStream_usageMetadata() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "streaming-success-basic-reply-short",
withExtension: "txt"
)
var responses = [GenerateContentResponse]()

let stream = model.generateContentStream(testPrompt)
for try await response in stream {
responses.append(response)
}

for (index, response) in responses.enumerated() {
if index == responses.endIndex - 1 {
let usageMetadata = try XCTUnwrap(response.usageMetadata)
XCTAssertEqual(usageMetadata.promptTokenCount, 6)
XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
XCTAssertEqual(usageMetadata.totalTokenCount, 10)
} else {
// Only the last streamed response contains usage metadata
XCTAssertNil(response.usageMetadata)
}
}
}

func testGenerateContentStream_errorMidStream() async throws {
MockURLProtocol.requestHandler = try httpRequestHandler(
forResource: "streaming-failure-error-mid-stream",
Expand Down

0 comments on commit 9d304a1

Please sign in to comment.