Skip to content

Commit

Permalink
Fix mypy errors in openhands/llm directory
Browse files Browse the repository at this point in the history
  • Loading branch information
openhands-agent committed Feb 19, 2025
1 parent 592aca0 commit 2589b13
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 76 deletions.
20 changes: 10 additions & 10 deletions openhands/llm/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any

from litellm import acompletion as litellm_acompletion
from litellm.types.utils import ModelResponse

from openhands.core.exceptions import UserCancelledError
from openhands.core.logger import openhands_logger as logger
Expand All @@ -17,7 +18,7 @@
class AsyncLLM(LLM):
"""Asynchronous LLM class."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

self._async_completion = partial(
Expand Down Expand Up @@ -45,7 +46,7 @@ def __init__(self, *args, **kwargs):
retry_max_wait=self.config.retry_max_wait,
retry_multiplier=self.config.retry_multiplier,
)
async def async_completion_wrapper(*args, **kwargs):
async def async_completion_wrapper(*args: Any, **kwargs: Any) -> dict[str, Any]:
"""Wrapper for the litellm acompletion function that adds logging and cost tracking."""
messages: list[dict[str, Any]] | dict[str, Any] = []

Expand Down Expand Up @@ -76,7 +77,7 @@ async def async_completion_wrapper(*args, **kwargs):

self.log_prompt(messages)

async def check_stopped():
async def check_stopped() -> None:
while should_continue():
if (
hasattr(self.config, 'on_cancel_requested_fn')
Expand All @@ -96,10 +97,8 @@ async def check_stopped():
self.log_response(message_back)

# log costs and tokens used
self._post_completion(resp)

# We do not support streaming in this method, thus return resp
return resp
return dict(resp)

except UserCancelledError:
logger.debug('LLM request cancelled by user.')
Expand All @@ -116,14 +115,15 @@ async def check_stopped():
except asyncio.CancelledError:
pass

self._async_completion = async_completion_wrapper # type: ignore
self._async_completion = partial(async_completion_wrapper)

async def _call_acompletion(self, *args, **kwargs):
async def _call_acompletion(self, *args: Any, **kwargs: Any) -> ModelResponse:
"""Wrapper for the litellm acompletion function."""
# Used in testing?
return await litellm_acompletion(*args, **kwargs)
resp = await litellm_acompletion(*args, **kwargs)
return ModelResponse(**resp)

@property
def async_completion(self):
def async_completion(self) -> Any:
"""Decorator for the async litellm acompletion function."""
return self._async_completion
2 changes: 1 addition & 1 deletion openhands/llm/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ def list_foundation_models(
return []


def remove_error_modelId(model_list):
def remove_error_modelId(model_list: list[str]) -> list[str]:
return list(filter(lambda m: not m.startswith('bedrock'), model_list))
14 changes: 7 additions & 7 deletions openhands/llm/debug_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class DebugMixin:
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]):
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]) -> None:
if not messages:
logger.debug('No completion messages!')
return
Expand All @@ -24,30 +24,30 @@ def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]):
else:
logger.debug('No completion messages!')

def log_response(self, message_back: str):
def log_response(self, message_back: str) -> None:
if message_back:
llm_response_logger.debug(message_back)

def _format_message_content(self, message: dict[str, Any]):
def _format_message_content(self, message: dict[str, Any]) -> str:
content = message['content']
if isinstance(content, list):
return '\n'.join(
self._format_content_element(element) for element in content
)
return str(content)

def _format_content_element(self, element: dict[str, Any]):
def _format_content_element(self, element: dict[str, Any]) -> str:
if isinstance(element, dict):
if 'text' in element:
return element['text']
return str(element['text'])
if (
self.vision_is_active()
and 'image_url' in element
and 'url' in element['image_url']
):
return element['image_url']['url']
return str(element['image_url']['url'])
return str(element)

# This method should be implemented in the class that uses DebugMixin
def vision_is_active(self):
def vision_is_active(self) -> bool:
raise NotImplementedError
11 changes: 6 additions & 5 deletions openhands/llm/fn_call_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import copy
import json
import re
from typing import Iterable
from typing import Any, Iterable, TypedDict

from litellm import ChatCompletionToolParam
from litellm.types.completion_response import ChatCompletionToolParamFunctionChunk

from openhands.core.exceptions import (
FunctionCallConversionError,
Expand Down Expand Up @@ -265,7 +266,7 @@ def convert_tool_call_to_string(tool_call: dict) -> str:
return ret


def convert_tools_to_description(tools: list[dict]) -> str:
def convert_tools_to_description(tools: list[ChatCompletionToolParam]) -> str:
ret = ''
for i, tool in enumerate(tools):
assert tool['type'] == 'function'
Expand Down Expand Up @@ -474,8 +475,8 @@ def convert_fncall_messages_to_non_fncall_messages(


def _extract_and_validate_params(
matching_tool: dict, param_matches: Iterable[re.Match], fn_name: str
) -> dict:
matching_tool: dict[str, Any], param_matches: Iterable[re.Match], fn_name: str
) -> dict[str, Any]:
params = {}
# Parse and validate parameters
required_params = set()
Expand Down Expand Up @@ -712,7 +713,7 @@ def convert_non_fncall_messages_to_fncall_messages(
# Parse parameters
param_matches = re.finditer(FN_PARAM_REGEX_PATTERN, fn_body, re.DOTALL)
params = _extract_and_validate_params(
matching_tool, param_matches, fn_name
dict(matching_tool), param_matches, fn_name
)

# Create tool call with unique ID
Expand Down
67 changes: 36 additions & 31 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
warnings.simplefilter('ignore')
import litellm

from litellm import ChatCompletionMessageToolCall, ModelInfo, PromptTokensDetails
from litellm import ChatCompletionMessageToolCall, PromptTokensDetails
from litellm.types.router import ModelInfo as RouterModelInfo
from litellm.types.utils import ModelInfo as UtilsModelInfo
from litellm import Message as LiteLLMMessage
from litellm import completion as litellm_completion
from litellm import completion_cost as litellm_completion_cost
from litellm.exceptions import (
RateLimitError,
)
from litellm.types.completion_response import Choices, StreamingChoices
from litellm.types.utils import CostPerToken, ModelResponse, Usage
from litellm.utils import create_pretrained_tokenizer

Expand Down Expand Up @@ -104,7 +107,7 @@ def __init__(
self.cost_metric_supported: bool = True
self.config: LLMConfig = copy.deepcopy(config)

self.model_info: ModelInfo | None = None
self.model_info: RouterModelInfo | UtilsModelInfo | None = None
self.retry_listener = retry_listener
if self.config.log_completions:
if self.config.log_completions_folder is None:
Expand Down Expand Up @@ -170,7 +173,7 @@ def __init__(
retry_multiplier=self.config.retry_multiplier,
retry_listener=self.retry_listener,
)
def wrapper(*args, **kwargs):
def wrapper(*args: Any, **kwargs: Any) -> ModelResponse:
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
from openhands.core.utils import json

Expand Down Expand Up @@ -244,18 +247,16 @@ def wrapper(*args, **kwargs):
# if we mocked function calling, and we have tools, convert the response back to function calling format
if mock_function_calling and mock_fncall_tools is not None:
assert len(resp.choices) == 1
non_fncall_response_message = resp.choices[0].message
fn_call_messages_with_response = (
convert_non_fncall_messages_to_fncall_messages(
messages + [non_fncall_response_message], mock_fncall_tools
if isinstance(resp.choices[0], (Choices, StreamingChoices)):
non_fncall_response_message = resp.choices[0].message
fn_call_messages_with_response = (
convert_non_fncall_messages_to_fncall_messages(
messages + [dict(non_fncall_response_message)], mock_fncall_tools
)
)
)
fn_call_response_message = fn_call_messages_with_response[-1]
if not isinstance(fn_call_response_message, LiteLLMMessage):
fn_call_response_message = LiteLLMMessage(
**fn_call_response_message
)
resp.choices[0].message = fn_call_response_message
fn_call_response_message = fn_call_messages_with_response[-1]
fn_call_response_message = dict(fn_call_response_message)
resp.choices[0].message = fn_call_response_message

message_back: str = resp['choices'][0]['message']['content'] or ''
tool_calls: list[ChatCompletionMessageToolCall] = resp['choices'][0][
Expand Down Expand Up @@ -305,17 +306,17 @@ def wrapper(*args, **kwargs):

return resp

self._completion = wrapper
self._completion = partial(wrapper)

@property
def completion(self):
def completion(self) -> Callable[..., ModelResponse]:
"""Decorator for the litellm completion function.
Check the complete documentation at https://litellm.vercel.app/docs/completion
"""
return self._completion

def init_model_info(self):
def init_model_info(self) -> None:
if self._tried_model_info:
return
self._tried_model_info = True
Expand Down Expand Up @@ -443,11 +444,11 @@ def _supports_vision(self) -> bool:
# remove when litellm is updated to fix https://github.com/BerriAI/litellm/issues/5608
# Check both the full model name and the name after proxy prefix for vision support
return (
litellm.supports_vision(self.config.model)
or litellm.supports_vision(self.config.model.split('/')[-1])
bool(litellm.supports_vision(self.config.model))
or bool(litellm.supports_vision(self.config.model.split('/')[-1]))
or (
self.model_info is not None
and self.model_info.get('supports_vision', False)
and bool(self.model_info.get('supports_vision', False))
)
)

Expand Down Expand Up @@ -592,7 +593,7 @@ def _is_local(self) -> bool:
return True
return False

def _completion_cost(self, response) -> float:
def _completion_cost(self, response: ModelResponse) -> float:
"""Calculate completion cost and update metrics with running total.
Calculate the cost of a completion response based on the model. Local models are treated as free.
Expand Down Expand Up @@ -631,35 +632,39 @@ def _completion_cost(self, response) -> float:
try:
if cost is None:
try:
cost = litellm_completion_cost(
completion_response=response, **extra_kwargs
)
cost = float(litellm_completion_cost(
completion_response=response,
custom_cost_per_token=extra_kwargs.get('custom_cost_per_token'),
))
except Exception as e:
logger.error(f'Error getting cost from litellm: {e}')

if cost is None:
_model_name = '/'.join(self.config.model.split('/')[1:])
cost = litellm_completion_cost(
completion_response=response, model=_model_name, **extra_kwargs
)
cost = float(litellm_completion_cost(
completion_response=response,
model=_model_name,
custom_cost_per_token=extra_kwargs.get('custom_cost_per_token'),
))
logger.debug(
f'Using fallback model name {_model_name} to get cost: {cost}'
)
self.metrics.add_cost(cost)
return cost
cost_float = float(cost)
self.metrics.add_cost(cost_float)
return cost_float
except Exception:
self.cost_metric_supported = False
logger.debug('Cost calculation not supported for this model.')
return 0.0

def __str__(self):
def __str__(self) -> str:
if self.config.api_version:
return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
elif self.config.base_url:
return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
return f'LLM(model={self.config.model})'

def __repr__(self):
def __repr__(self) -> str:
return str(self)

def reset(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions openhands/llm/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,18 @@ def get(self) -> dict:
],
}

def reset(self):
def reset(self) -> None:
self._accumulated_cost = 0.0
self._costs = []
self._response_latencies = []

def log(self):
def log(self) -> str:
"""Log the metrics."""
metrics = self.get()
logs = ''
for key, value in metrics.items():
logs += f'{key}: {value}\n'
return logs

def __repr__(self):
def __repr__(self) -> str:
return f'Metrics({self.get()}'
Loading

0 comments on commit 2589b13

Please sign in to comment.