diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index 57484669d152..8f614745b153 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -284,6 +284,8 @@ async def _handle_observation(self, observation: Observation) -> None: self.agent.llm.metrics.merge(observation.llm_metrics) if self._pending_action and self._pending_action.id == observation.cause: + if self.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION: + return self._pending_action = None if self.state.agent_state == AgentState.USER_CONFIRMED: await self.set_agent_state_to(AgentState.RUNNING) @@ -369,6 +371,7 @@ async def set_agent_state_to(self, new_state: AgentState) -> None: else: confirmation_state = ActionConfirmationStatus.REJECTED self._pending_action.confirmation_state = confirmation_state # type: ignore[attr-defined] + self._pending_action._id = None # type: ignore[attr-defined] self.event_stream.add_event(self._pending_action, EventSource.AGENT) self.state.agent_state = new_state diff --git a/openhands/core/cli.py b/openhands/core/cli.py index 53db5ca27747..660de7af37ab 100644 --- a/openhands/core/cli.py +++ b/openhands/core/cli.py @@ -11,6 +11,7 @@ from openhands.controller import AgentController from openhands.controller.agent import Agent from openhands.core.config import ( + AppConfig, get_parser, load_app_config, ) @@ -20,6 +21,7 @@ from openhands.events import EventSource, EventStream, EventStreamSubscriber from openhands.events.action import ( Action, + ActionConfirmationStatus, ChangeAgentStateAction, CmdRunAction, FileEditAction, @@ -30,10 +32,12 @@ AgentStateChangedObservation, CmdOutputObservation, FileEditObservation, + NullObservation, ) from openhands.llm.llm import LLM from openhands.runtime import get_runtime_cls from openhands.runtime.base import Runtime +from openhands.security import SecurityAnalyzer, options from openhands.storage import get_file_store @@ -45,6 +49,15 @@ def display_command(command: str): print('❯ ' + colored(command + '\n', 'green')) +def display_confirmation(confirmation_state: ActionConfirmationStatus): + if confirmation_state == ActionConfirmationStatus.CONFIRMED: + print(colored('✅ ' + confirmation_state + '\n', 'green')) + elif confirmation_state == ActionConfirmationStatus.REJECTED: + print(colored('❌ ' + confirmation_state + '\n', 'red')) + else: + print(colored('⏳ ' + confirmation_state + '\n', 'yellow')) + + def display_command_output(output: str): lines = output.split('\n') for line in lines: @@ -59,7 +72,7 @@ def display_file_edit(event: FileEditAction | FileEditObservation): print(colored(str(event), 'green')) -def display_event(event: Event): +def display_event(event: Event, config: AppConfig): if isinstance(event, Action): if hasattr(event, 'thought'): display_message(event.thought) @@ -74,6 +87,8 @@ def display_event(event: Event): display_file_edit(event) if isinstance(event, FileEditObservation): display_file_edit(event) + if hasattr(event, 'confirmation_state') and config.security.confirmation_mode: + display_confirmation(event.confirmation_state) async def main(): @@ -119,12 +134,18 @@ async def main(): headless_mode=True, ) + if config.security.security_analyzer: + options.SecurityAnalyzers.get( + config.security.security_analyzer, SecurityAnalyzer + )(event_stream) + controller = AgentController( agent=agent, max_iterations=config.max_iterations, max_budget_per_task=config.max_budget_per_task, agent_to_llm_config=config.get_agent_to_llm_config_map(), event_stream=event_stream, + confirmation_mode=config.security.confirmation_mode, ) async def prompt_for_next_task(): @@ -143,14 +164,34 @@ async def prompt_for_next_task(): action = MessageAction(content=next_message) event_stream.add_event(action, EventSource.USER) + async def prompt_for_user_confirmation(): + loop = asyncio.get_event_loop() + user_confirmation = await loop.run_in_executor( + None, lambda: input('Confirm action (possible security risk)? (y/n) >> ') + ) + return user_confirmation.lower() == 'y' + async def on_event(event: Event): - display_event(event) + display_event(event, config) if isinstance(event, AgentStateChangedObservation): if event.agent_state in [ AgentState.AWAITING_USER_INPUT, AgentState.FINISHED, ]: await prompt_for_next_task() + if ( + isinstance(event, NullObservation) + and controller.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION + ): + user_confirmed = await prompt_for_user_confirmation() + if user_confirmed: + event_stream.add_event( + ChangeAgentStateAction(AgentState.USER_CONFIRMED), EventSource.USER + ) + else: + event_stream.add_event( + ChangeAgentStateAction(AgentState.USER_REJECTED), EventSource.USER + ) event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, str(uuid4())) diff --git a/openhands/core/config/security_config.py b/openhands/core/config/security_config.py index a4c49c2b0cda..60645f305736 100644 --- a/openhands/core/config/security_config.py +++ b/openhands/core/config/security_config.py @@ -32,5 +32,9 @@ def __str__(self): return f"SecurityConfig({', '.join(attr_str)})" + @classmethod + def from_dict(cls, security_config_dict: dict) -> 'SecurityConfig': + return cls(**security_config_dict) + def __repr__(self): return self.__str__() diff --git a/openhands/core/config/utils.py b/openhands/core/config/utils.py index 437754ef22ae..00f41dc1da7b 100644 --- a/openhands/core/config/utils.py +++ b/openhands/core/config/utils.py @@ -18,6 +18,7 @@ ) from openhands.core.config.llm_config import LLMConfig from openhands.core.config.sandbox_config import SandboxConfig +from openhands.core.config.security_config import SecurityConfig load_dotenv() @@ -144,6 +145,12 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'): ) llm_config = LLMConfig.from_dict(nested_value) cfg.set_llm_config(llm_config, nested_key) + elif key is not None and key.lower() == 'security': + logger.openhands_logger.debug( + 'Attempt to load security config from config toml' + ) + security_config = SecurityConfig.from_dict(value) + cfg.security = security_config elif not key.startswith('sandbox') and key.lower() != 'core': logger.openhands_logger.warning( f'Unknown key in {toml_file}: "{key}"' diff --git a/openhands/security/invariant/analyzer.py b/openhands/security/invariant/analyzer.py index 52651876926d..f843e9304359 100644 --- a/openhands/security/invariant/analyzer.py +++ b/openhands/security/invariant/analyzer.py @@ -300,7 +300,7 @@ async def confirm(self, event: Event) -> None: ) # we should confirm only on agent actions event_source = event.source if event.source else EventSource.AGENT - await call_sync_from_async(self.event_stream.add_event, new_event, event_source) + self.event_stream.add_event(new_event, event_source) async def security_risk(self, event: Action) -> ActionSecurityRisk: logger.debug('Calling security_risk on InvariantAnalyzer')