Skip to content

Commit

Permalink
Fix mypy errors in security/invariant directory
Browse files Browse the repository at this point in the history
- Fix default_factory type in Event class
- Add missing type annotations to __rich_repr__ method
- Add missing type annotations to _Policy and _Monitor classes
  • Loading branch information
openhands-agent committed Feb 24, 2025
1 parent b147f6a commit 9db2531
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
8 changes: 4 additions & 4 deletions openhands/security/invariant/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def close_session(self) -> Union[None, Exception]:
return None

class _Policy:
def __init__(self, invariant):
def __init__(self, invariant: 'InvariantClient') -> None:
self.server = invariant.server
self.session_id = invariant.session_id

Expand All @@ -77,7 +77,7 @@ def get_template(self) -> tuple[str | None, Exception | None]:
except (ConnectionError, Timeout, HTTPError) as err:
return None, err

def from_string(self, rule: str):
def from_string(self, rule: str) -> 'InvariantClient._Policy':
policy_id, err = self._create_policy(rule)
if err:
raise err
Expand All @@ -97,7 +97,7 @@ def analyze(self, trace: list[dict]) -> Union[Any, Exception]:
return None, err

class _Monitor:
def __init__(self, invariant):
def __init__(self, invariant: 'InvariantClient') -> None:
self.server = invariant.server
self.session_id = invariant.session_id
self.policy = ''
Expand All @@ -114,7 +114,7 @@ def _create_monitor(self, rule: str) -> tuple[str | None, Exception | None]:
except (ConnectionError, Timeout, HTTPError) as err:
return None, err

def from_string(self, rule: str):
def from_string(self, rule: str) -> 'InvariantClient._Monitor':
monitor_id, err = self._create_monitor(rule)
if err:
raise err
Expand Down
5 changes: 3 additions & 2 deletions openhands/security/invariant/nodes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any, Iterable, Tuple
from pydantic import BaseModel, Field
from pydantic.dataclasses import dataclass

Expand All @@ -10,7 +11,7 @@ class LLM:

class Event(BaseModel):
metadata: dict | None = Field(
default_factory=dict, description='Metadata associated with the event'
default_factory=lambda: dict(), description='Metadata associated with the event'
)


Expand All @@ -30,7 +31,7 @@ class Message(Event):
content: str | None
tool_calls: list[ToolCall] | None = None

def __rich_repr__(self):
def __rich_repr__(self) -> Iterable[Any | tuple[Any] | tuple[str, Any] | tuple[str, Any, Any]]:
# Print on separate line
yield 'role', self.role
yield 'content', self.content
Expand Down

0 comments on commit 9db2531

Please sign in to comment.