Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add usageMetadata to GenerateContentResponse in Vertex AI #12777

Merged
merged 3 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,7 @@ struct ErrorDetailsView: View {
],
finishReason: FinishReason.maxTokens,
citationMetadata: nil),
],
promptFeedback: nil)
])
)

return ErrorDetailsView(error: error)
Expand All @@ -200,8 +199,7 @@ struct ErrorDetailsView: View {
],
finishReason: FinishReason.other,
citationMetadata: nil),
],
promptFeedback: nil)
])
)

return ErrorDetailsView(error: error)
Expand Down
3 changes: 1 addition & 2 deletions FirebaseVertexAI/Sample/ChatSample/Views/ErrorView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ struct ErrorView: View {
],
finishReason: FinishReason.other,
citationMetadata: nil),
],
promptFeedback: nil)
])
)
List {
MessageView(message: ChatMessage.samples[0])
Expand Down
38 changes: 37 additions & 1 deletion FirebaseVertexAI/Sources/GenerateContentResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,28 @@ import Foundation
/// The model's response to a generate content request.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
public struct GenerateContentResponse {
/// Token usage metadata for processing the generate content request.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Shouldn't internal line-based comments be two slashes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't seem to be very consistent about // vs. /// for internal variables / constants (e.g.,

/// The resource name of the model in the backend; has the format "models/model-name".
let modelResourceName: String
/// The backing service responsible for sending and receiving model requests to the backend.
let generativeAIService: GenerativeAIService
/// Configuration parameters used for the MultiModalModel.
let generationConfig: GenerationConfig?
/// The safety settings to be used for prompts.
let safetySettings: [SafetySetting]?
/// A list of tools the model may use to generate the next response.
let tools: [Tool]?
/// Tool configuration for any `Tool` specified in the request.
let toolConfig: ToolConfig?
/// Instructions that direct the model to behave a certain way.
let systemInstruction: ModelContent?
/// Configuration parameters for sending requests to the backend.
let requestOptions: RequestOptions
) but you pointed out a very important issue... these should be public (wasn't caught because the test uses @testable import) and should remain /// for DocC. Adding public now.

public struct UsageMetadata {
/// The number of tokens in the request prompt.
public let promptTokenCount: Int

/// The total number of tokens across the generated response candidates.
public let candidatesTokenCount: Int

/// The total number of tokens in both the request and response.
public let totalTokenCount: Int
}

/// A list of candidate response content, ordered from best to worst.
public let candidates: [CandidateResponse]

/// A value containing the safety ratings for the response, or, if the request was blocked, a
/// reason for blocking the request.
public let promptFeedback: PromptFeedback?

/// Token usage metadata for processing the generate content request.
public let usageMetadata: UsageMetadata?

/// The response's content as text, if it exists.
public var text: String? {
guard let candidate = candidates.first else {
Expand Down Expand Up @@ -51,9 +66,11 @@ public struct GenerateContentResponse {
}

/// Initializer for SwiftUI previews or tests.
public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback?) {
public init(candidates: [CandidateResponse], promptFeedback: PromptFeedback? = nil,
usageMetadata: UsageMetadata? = nil) {
self.candidates = candidates
self.promptFeedback = promptFeedback
self.usageMetadata = usageMetadata
}
}

Expand All @@ -62,6 +79,7 @@ extension GenerateContentResponse: Decodable {
enum CodingKeys: CodingKey {
case candidates
case promptFeedback
case usageMetadata
}

public init(from decoder: Decoder) throws {
Expand All @@ -86,6 +104,7 @@ extension GenerateContentResponse: Decodable {
candidates = []
}
promptFeedback = try container.decodeIfPresent(PromptFeedback.self, forKey: .promptFeedback)
usageMetadata = try container.decodeIfPresent(UsageMetadata.self, forKey: .usageMetadata)
}
}

Expand Down Expand Up @@ -301,3 +320,20 @@ extension PromptFeedback: Decodable {
}
}
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
extension GenerateContentResponse.UsageMetadata: Decodable {
enum CodingKeys: CodingKey {
case promptTokenCount
case candidatesTokenCount
case totalTokenCount
}

public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
promptTokenCount = try container.decodeIfPresent(Int.self, forKey: .promptTokenCount) ?? 0
candidatesTokenCount = try container
.decodeIfPresent(Int.self, forKey: .candidatesTokenCount) ?? 0
totalTokenCount = try container.decodeIfPresent(Int.self, forKey: .totalTokenCount) ?? 0
}
}
41 changes: 41 additions & 0 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,21 @@ final class GenerativeModelTests: XCTestCase {
_ = try await model.generateContent(testPrompt)
}

func testGenerateContent_usageMetadata() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-success-basic-reply-short",
withExtension: "json"
)

let response = try await model.generateContent(testPrompt)

let usageMetadata = try XCTUnwrap(response.usageMetadata)
XCTAssertEqual(usageMetadata.promptTokenCount, 6)
XCTAssertEqual(usageMetadata.candidatesTokenCount, 7)
XCTAssertEqual(usageMetadata.totalTokenCount, 13)
}

func testGenerateContent_failure_invalidAPIKey() async throws {
let expectedStatusCode = 400
MockURLProtocol
Expand Down Expand Up @@ -814,6 +829,32 @@ final class GenerativeModelTests: XCTestCase {
for try await _ in stream {}
}

func testGenerateContentStream_usageMetadata() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "streaming-success-basic-reply-short",
withExtension: "txt"
)
var responses = [GenerateContentResponse]()

let stream = model.generateContentStream(testPrompt)
for try await response in stream {
responses.append(response)
}

for (index, response) in responses.enumerated() {
if index == responses.endIndex - 1 {
let usageMetadata = try XCTUnwrap(response.usageMetadata)
XCTAssertEqual(usageMetadata.promptTokenCount, 6)
XCTAssertEqual(usageMetadata.candidatesTokenCount, 4)
XCTAssertEqual(usageMetadata.totalTokenCount, 10)
} else {
// Only the last streamed response contains usage metadata
XCTAssertNil(response.usageMetadata)
}
}
}

func testGenerateContentStream_errorMidStream() async throws {
MockURLProtocol.requestHandler = try httpRequestHandler(
forResource: "streaming-failure-error-mid-stream",
Expand Down
Loading