Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mypy errors in security/invariant directory #6908

Merged
merged 13 commits into from
Feb 24, 2025
Merged
12 changes: 9 additions & 3 deletions openhands/security/invariant/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,17 @@ async def security_risk(self, event: Action) -> ActionSecurityRisk:
new_elements = parse_element(self.trace, event)
input = [e.model_dump(exclude_none=True) for e in new_elements] # type: ignore [call-overload]
self.trace.extend(new_elements)
result, err = self.monitor.check(self.input, input)
check_result = self.monitor.check(self.input, input)
self.input.extend(input)
risk = ActionSecurityRisk.UNKNOWN
if err:
logger.warning(f'Error checking policy: {err}')

if isinstance(check_result, tuple):
result, err = check_result
if err:
logger.warning(f'Error checking policy: {err}')
return risk
else:
logger.warning(f'Error checking policy: {check_result}')
return risk

risk = self.get_risk(result)
Expand Down
8 changes: 4 additions & 4 deletions openhands/security/invariant/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def close_session(self) -> Union[None, Exception]:
return None

class _Policy:
def __init__(self, invariant):
def __init__(self, invariant: 'InvariantClient') -> None:
self.server = invariant.server
self.session_id = invariant.session_id

Expand All @@ -77,7 +77,7 @@ def get_template(self) -> tuple[str | None, Exception | None]:
except (ConnectionError, Timeout, HTTPError) as err:
return None, err

def from_string(self, rule: str):
def from_string(self, rule: str) -> 'InvariantClient._Policy':
policy_id, err = self._create_policy(rule)
if err:
raise err
Expand All @@ -97,7 +97,7 @@ def analyze(self, trace: list[dict]) -> Union[Any, Exception]:
return None, err

class _Monitor:
def __init__(self, invariant):
def __init__(self, invariant: 'InvariantClient') -> None:
self.server = invariant.server
self.session_id = invariant.session_id
self.policy = ''
Expand All @@ -114,7 +114,7 @@ def _create_monitor(self, rule: str) -> tuple[str | None, Exception | None]:
except (ConnectionError, Timeout, HTTPError) as err:
return None, err

def from_string(self, rule: str):
def from_string(self, rule: str) -> 'InvariantClient._Monitor':
monitor_id, err = self._create_monitor(rule)
if err:
raise err
Expand Down
5 changes: 3 additions & 2 deletions openhands/security/invariant/nodes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any, Iterable, Tuple
from pydantic import BaseModel, Field
from pydantic.dataclasses import dataclass

Expand All @@ -10,7 +11,7 @@ class LLM:

class Event(BaseModel):
metadata: dict | None = Field(
default_factory=dict, description='Metadata associated with the event'
default_factory=lambda: dict(), description='Metadata associated with the event'
)


Expand All @@ -30,7 +31,7 @@ class Message(Event):
content: str | None
tool_calls: list[ToolCall] | None = None

def __rich_repr__(self):
def __rich_repr__(self) -> Iterable[Any | tuple[Any] | tuple[str, Any] | tuple[str, Any, Any]]:
# Print on separate line
yield 'role', self.role
yield 'content', self.content
Expand Down
Loading