diff --git a/openhands/runtime/utils/bash.py b/openhands/runtime/utils/bash.py index b9c25978e50e..28a8dd00d814 100644 --- a/openhands/runtime/utils/bash.py +++ b/openhands/runtime/utils/bash.py @@ -157,12 +157,14 @@ def _get_command_output( raw_command_output: str, continue_prefix: str = '', suffix: str = '', + keep_prompt: bool = False, ) -> str: """Get the command output with the previous command output removed. Args: continue_prefix: The prefix to add to the command output if it's a continuation of the previous command. suffix: The suffix to add to the command output. + keep_prompt: If True, keep the prompt in the output. """ # remove the previous command output from the new output if any custom_prefix = '' @@ -172,11 +174,12 @@ def _get_command_output( else: command_output = raw_command_output self.prev_output = raw_command_output # update current command output anyway - command_output = _remove_command_prefix(command_output, command) + if not keep_prompt: + command_output = _remove_command_prefix(command_output, command) command_output = f'{custom_prefix}{command_output}{suffix}' return command_output - def _handle_completed_command(self, command: str) -> CmdOutputObservation: + def _handle_completed_command(self, command: str, keep_prompt: bool = False) -> CmdOutputObservation: full_output = self._get_pane_content(full=True) ps1_matches = CmdOutputMetadata.matches_ps1_metadata(full_output) @@ -195,6 +198,7 @@ def _handle_completed_command(self, command: str) -> CmdOutputObservation: if not is_special_key else f'\n\n[The command completed with exit code {metadata.exit_code}. CTRL+{command[-1].upper()} was sent.]' ), + keep_prompt=keep_prompt, ) self.prev_status = BashCommandStatus.COMPLETED self.prev_output = '' # Reset previous command output @@ -205,7 +209,7 @@ def _handle_completed_command(self, command: str) -> CmdOutputObservation: metadata=metadata, ) - def _handle_nochange_timeout_command(self, command: str) -> CmdOutputObservation: + def _handle_nochange_timeout_command(self, command: str, keep_prompt: bool = False) -> CmdOutputObservation: self.prev_status = BashCommandStatus.NO_CHANGE_TIMEOUT full_output = self._get_pane_content(full=True) @@ -223,6 +227,7 @@ def _handle_nochange_timeout_command(self, command: str) -> CmdOutputObservation 'send other commands to interact with the current process, ' 'or send keys to interrupt/kill the command.]' ), + keep_prompt=keep_prompt, ) return CmdOutputObservation( content=command_output, @@ -231,7 +236,7 @@ def _handle_nochange_timeout_command(self, command: str) -> CmdOutputObservation ) def _handle_hard_timeout_command( - self, command: str, timeout: float + self, command: str, timeout: float, keep_prompt: bool = False ) -> CmdOutputObservation: self.prev_status = BashCommandStatus.HARD_TIMEOUT full_output = self._get_pane_content(full=True) @@ -249,6 +254,7 @@ def _handle_hard_timeout_command( 'send other commands to interact with the current process, ' 'or send keys to interrupt/kill the command.]' ), + keep_prompt=keep_prompt, ) return CmdOutputObservation( @@ -308,7 +314,7 @@ def execute(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservati # 1) Execution completed # if the last command output contains the end marker if cur_pane_output.rstrip().endswith(CMD_OUTPUT_PS1_END.rstrip()): - return self._handle_completed_command(action.command) + return self._handle_completed_command(action.command, action.keep_prompt) # 2) Execution timed out since there's no change in output # for a while (self.NO_CHANGE_TIMEOUT_SECONDS) @@ -318,10 +324,10 @@ def execute(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservati not action.blocking and time_since_last_change >= self.NO_CHANGE_TIMEOUT_SECONDS ): - return self._handle_nochange_timeout_command(action.command) + return self._handle_nochange_timeout_command(action.command, action.keep_prompt) # 3) Execution timed out due to hard timeout if action.timeout and time.time() - start_time >= action.timeout: - return self._handle_hard_timeout_command(action.command, action.timeout) + return self._handle_hard_timeout_command(action.command, action.timeout, action.keep_prompt) time.sleep(self.POLL_INTERVAL) diff --git a/tests/unit/test_bash_session.py b/tests/unit/test_bash_session.py index 42058c2c1ad8..10344d441a61 100644 --- a/tests/unit/test_bash_session.py +++ b/tests/unit/test_bash_session.py @@ -259,6 +259,42 @@ def test_ansi_escape_codes(): session.close() +def test_keep_prompt(): + session = BashSession(work_dir=os.getcwd()) + + # Test command with keep_prompt=False (default) + obs = session.execute(CmdRunAction('echo "test"')) + logger.info(obs, extra={'msg_type': 'OBSERVATION'}) + assert 'test' in obs.content + assert 'echo "test"' not in obs.content + assert obs.metadata.exit_code == 0 + + # Test command with keep_prompt=True + obs = session.execute(CmdRunAction('echo "test"', keep_prompt=True)) + logger.info(obs, extra={'msg_type': 'OBSERVATION'}) + assert 'test' in obs.content + assert 'echo "test"' in obs.content + assert obs.metadata.exit_code == 0 + + # Test long-running command with keep_prompt=True + obs = session.execute( + CmdRunAction('for i in {1..2}; do echo $i; sleep 3; done', keep_prompt=True, blocking=False) + ) + logger.info(obs, extra={'msg_type': 'OBSERVATION'}) + assert 'for i in {1..2}; do echo $i; sleep 3; done' in obs.content + assert '1' in obs.content + assert obs.metadata.exit_code == -1 + + # Continue watching output with keep_prompt=True + obs = session.execute(CmdRunAction('', keep_prompt=True)) + logger.info(obs, extra={'msg_type': 'OBSERVATION'}) + assert '[Command output continued from previous command]' in obs.content + assert '2' in obs.content + assert obs.metadata.exit_code == -1 + + session.close() + + def test_long_output(): session = BashSession(work_dir=os.getcwd())