Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use LLM APIs responses in token counting #5604

Merged
merged 11 commits into from
Feb 23, 2025
47 changes: 47 additions & 0 deletions openhands/core/message_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from openhands.events.observation.error import ErrorObservation
from openhands.events.observation.observation import Observation
from openhands.events.serialization.event import truncate_content
from openhands.llm.metrics import Metrics, TokensUsage


def events_to_messages(
Expand Down Expand Up @@ -362,3 +363,49 @@ def apply_prompt_caching(messages: list[Message]) -> None:
-1
].cache_prompt = True # Last item inside the message content
break


def get_single_tokens_usage_for_event(
event: Event, metrics: Metrics
) -> TokensUsage | None:
"""
Returns at most one token usage record for the `model_response.id` in this event's
`tool_call_metadata`.

If no response_id is found, or none match in metrics.tokens_usages, returns [].
"""
if event.tool_call_metadata and event.tool_call_metadata.model_response:
response_id = event.tool_call_metadata.model_response.get('id')
if response_id:
return next(
(
usage
for usage in metrics.tokens_usages
if usage.response_id == response_id
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation ignores that model_response might have the tokens data and takes it from TokenUsage instead. We will likely always need response_id, but maybe not all the rest in model_response.

),
None,
)
return None


def get_tokens_usage_for_event_id(
events: list[Event], event_id: int, metrics: Metrics
) -> TokensUsage | None:
"""
Starting from the event with .id == event_id and moving backwards in `events`,
find the first TokensUsage record (if any) associated with a response_id from
tool_call_metadata.model_response.id.

Returns the first match found, or None if none is found.
"""
# find the index of the event with the given id
idx = next((i for i, e in enumerate(events) if e.id == event_id), None)
if idx is None:
return None

# search backward from idx down to 0
for i in range(idx, -1, -1):
usage = get_single_tokens_usage_for_event(events[i], metrics)
if usage is not None:
return usage
return None
29 changes: 20 additions & 9 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,20 +497,21 @@ def _post_completion(self, response: ModelResponse) -> float:
stats += 'Response Latency: %.3f seconds\n' % latest_latency.latency

usage: Usage | None = response.get('usage')
response_id = response.get('id', 'unknown')

if usage:
# keep track of the input and output tokens
input_tokens = usage.get('prompt_tokens')
output_tokens = usage.get('completion_tokens')
prompt_tokens = usage.get('prompt_tokens', 0)
completion_tokens = usage.get('completion_tokens', 0)

if input_tokens:
stats += 'Input tokens: ' + str(input_tokens)
if prompt_tokens:
stats += 'Input tokens: ' + str(prompt_tokens)

if output_tokens:
if completion_tokens:
stats += (
(' | ' if input_tokens else '')
(' | ' if prompt_tokens else '')
+ 'Output tokens: '
+ str(output_tokens)
+ str(completion_tokens)
+ '\n'
)

Expand All @@ -519,7 +520,7 @@ def _post_completion(self, response: ModelResponse) -> float:
'prompt_tokens_details'
)
cache_hit_tokens = (
prompt_tokens_details.cached_tokens if prompt_tokens_details else None
prompt_tokens_details.cached_tokens if prompt_tokens_details else 0
)
if cache_hit_tokens:
stats += 'Input tokens (cache hit): ' + str(cache_hit_tokens) + '\n'
Expand All @@ -528,10 +529,20 @@ def _post_completion(self, response: ModelResponse) -> float:
# but litellm doesn't separate them in the usage stats
# so we can read it from the provider-specific extra field
model_extra = usage.get('model_extra', {})
cache_write_tokens = model_extra.get('cache_creation_input_tokens')
cache_write_tokens = model_extra.get('cache_creation_input_tokens', 0)
if cache_write_tokens:
stats += 'Input tokens (cache write): ' + str(cache_write_tokens) + '\n'

# Record in metrics
# We'll treat cache_hit_tokens as "cache read" and cache_write_tokens as "cache write"
self.metrics.add_tokens_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cache_read_tokens=cache_hit_tokens,
cache_write_tokens=cache_write_tokens,
response_id=response_id,
)

# log the stats
if stats:
logger.debug(stats)
Expand Down
56 changes: 52 additions & 4 deletions openhands/llm/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,31 @@ class ResponseLatency(BaseModel):
response_id: str


class TokensUsage(BaseModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nit, but I'd expect this to be called TokenUsage instead of the pluralized form.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hah, I felt the same, it reads strange 😅 , o3-mini wanted it for some mysterious reason. Fixed!

They are going to align humans to their preferences, aren't they? 😂

"""Metric tracking detailed token usage per completion call."""

model: str
prompt_tokens: int
completion_tokens: int
cache_read_tokens: int
cache_write_tokens: int
response_id: str


class Metrics:
"""Metrics class can record various metrics during running and evaluation.
Currently, we define the following metrics:
accumulated_cost: the total cost (USD $) of the current LLM.
response_latency: the time taken for each LLM completion call.
We track:
- accumulated_cost and costs
- A list of ResponseLatency
- A list of TokensUsage (one per call).
"""

def __init__(self, model_name: str = 'default') -> None:
self._accumulated_cost: float = 0.0
self._costs: list[Cost] = []
self._response_latencies: list[ResponseLatency] = []
self.model_name = model_name
self._tokens_usages: list[TokensUsage] = []

@property
def accumulated_cost(self) -> float:
Expand All @@ -54,6 +67,16 @@ def response_latencies(self) -> list[ResponseLatency]:
def response_latencies(self, value: list[ResponseLatency]) -> None:
self._response_latencies = value

@property
def tokens_usages(self) -> list[TokensUsage]:
if not hasattr(self, '_tokens_usages'):
self._tokens_usages = []
return self._tokens_usages

@tokens_usages.setter
def tokens_usages(self, value: list[TokensUsage]) -> None:
self._tokens_usages = value

def add_cost(self, value: float) -> None:
if value < 0:
raise ValueError('Added cost cannot be negative.')
Expand All @@ -67,10 +90,33 @@ def add_response_latency(self, value: float, response_id: str) -> None:
)
)

def add_tokens_usage(
self,
prompt_tokens: int,
completion_tokens: int,
cache_read_tokens: int,
cache_write_tokens: int,
response_id: str,
) -> None:
"""Add a single usage record."""
self._tokens_usages.append(
TokensUsage(
model=self.model_name,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cache_read_tokens=cache_read_tokens,
cache_write_tokens=cache_write_tokens,
response_id=response_id,
)
)

def merge(self, other: 'Metrics') -> None:
"""Merge 'other' metrics into this one."""
self._accumulated_cost += other.accumulated_cost
self._costs += other._costs
self._response_latencies += other._response_latencies
# use the property so older picked objects that lack the field won't crash
self.tokens_usages += other.tokens_usages
self.response_latencies += other.response_latencies

def get(self) -> dict:
"""Return the metrics in a dictionary."""
Expand All @@ -80,12 +126,14 @@ def get(self) -> dict:
'response_latencies': [
latency.model_dump() for latency in self._response_latencies
],
'tokens_usages': [usage.model_dump() for usage in self._tokens_usages],
}

def reset(self):
self._accumulated_cost = 0.0
self._costs = []
self._response_latencies = []
self._tokens_usages = []

def log(self):
"""Log the metrics."""
Expand Down
60 changes: 60 additions & 0 deletions tests/unit/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import MagicMock, patch

import pytest
from litellm import PromptTokensDetails
from litellm.exceptions import (
RateLimitError,
)
Expand Down Expand Up @@ -429,3 +430,62 @@ def test_get_token_count_error_handling(
mock_logger.error.assert_called_once_with(
'Error getting token count for\n model gpt-4o\nToken counting failed'
)


@patch('openhands.llm.llm.litellm_completion')
def test_llm_token_usage(mock_litellm_completion, default_config):
# This mock response includes usage details with prompt_tokens,
# completion_tokens, prompt_tokens_details.cached_tokens, and model_extra.cache_creation_input_tokens
mock_response_1 = {
'id': 'test-response-usage',
'choices': [{'message': {'content': 'Usage test response'}}],
'usage': {
'prompt_tokens': 12,
'completion_tokens': 3,
'prompt_tokens_details': PromptTokensDetails(cached_tokens=2),
'model_extra': {'cache_creation_input_tokens': 5},
},
}

# Create a second usage scenario to test accumulation and a different response_id
mock_response_2 = {
'id': 'test-response-usage-2',
'choices': [{'message': {'content': 'Second usage test response'}}],
'usage': {
'prompt_tokens': 7,
'completion_tokens': 2,
'prompt_tokens_details': PromptTokensDetails(cached_tokens=1),
'model_extra': {'cache_creation_input_tokens': 3},
},
}

# We'll make mock_litellm_completion return these responses in sequence
mock_litellm_completion.side_effect = [mock_response_1, mock_response_2]

llm = LLM(config=default_config)

# First call
llm.completion(messages=[{'role': 'user', 'content': 'Hello usage!'}])

# Verify we have exactly one usage record after first call
tokens_usage_list = llm.metrics.get()['tokens_usages']
assert len(tokens_usage_list) == 1
usage_entry_1 = tokens_usage_list[0]
assert usage_entry_1['prompt_tokens'] == 12
assert usage_entry_1['completion_tokens'] == 3
assert usage_entry_1['cache_read_tokens'] == 2
assert usage_entry_1['cache_write_tokens'] == 5
assert usage_entry_1['response_id'] == 'test-response-usage'

# Second call
llm.completion(messages=[{'role': 'user', 'content': 'Hello again!'}])

# Now we expect two usage records total
tokens_usage_list = llm.metrics.get()['tokens_usages']
assert len(tokens_usage_list) == 2
usage_entry_2 = tokens_usage_list[-1]
assert usage_entry_2['prompt_tokens'] == 7
assert usage_entry_2['completion_tokens'] == 2
assert usage_entry_2['cache_read_tokens'] == 1
assert usage_entry_2['cache_write_tokens'] == 3
assert usage_entry_2['response_id'] == 'test-response-usage-2'
Loading
Loading