Skip to content

Commit

Permalink
Merge commit '6d0d061f92f2410f872d62cea1e47d8f365537ec'
Browse files Browse the repository at this point in the history
  • Loading branch information
xingyaoww committed Dec 23, 2024
2 parents f3b3566 + 6d0d061 commit 5ae2c4d
Show file tree
Hide file tree
Showing 22 changed files with 620 additions and 129 deletions.
48 changes: 29 additions & 19 deletions evaluation/swe_bench/eval_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,23 +263,29 @@ def process_instance(
test_output_path = os.path.join(log_dir, 'test_output.txt')
with open(test_output_path, 'w') as f:
f.write(test_output)

_report = get_eval_report(
test_spec=test_spec,
prediction={
'model_patch': model_patch,
'instance_id': instance_id,
},
log_path=test_output_path,
include_tests_status=True,
)
report = _report[instance_id]
logger.info(
f"[{instance_id}] report: {report}\nResult for {instance_id}: resolved: {report['resolved']}"
)
instance['test_result']['report']['resolved'] = report[
'resolved'
]
try:
_report = get_eval_report(
test_spec=test_spec,
prediction={
'model_patch': model_patch,
'instance_id': instance_id,
},
log_path=test_output_path,
include_tests_status=True,
)
report = _report[instance_id]
logger.info(
f"[{instance_id}] report: {report}\nResult for {instance_id}: resolved: {report['resolved']}"
)
instance['test_result']['report']['resolved'] = report[
'resolved'
]
except Exception as e:
logger.error(
f'[{instance_id}] Error when getting eval report: {e}'
)
instance['test_result']['report']['resolved'] = False
instance['test_result']['report']['error_eval'] = True
else:
logger.info(f'[{instance_id}] Error when starting eval:\n{obs.content}')
instance['test_result']['report']['error_eval'] = True
Expand Down Expand Up @@ -355,7 +361,7 @@ def process_instance(

if 'model_patch' not in predictions.columns:
predictions['model_patch'] = predictions['test_result'].apply(
lambda x: x['git_patch']
lambda x: x.get('git_patch', '')
)
assert {'instance_id', 'model_patch'}.issubset(
set(predictions.columns)
Expand Down Expand Up @@ -401,7 +407,11 @@ def process_instance(
fields = ['resolved', 'failed_apply_patch', 'error_eval', 'empty_generation']

def count_report_field(row, field):
return row['test_result']['report'][field]
return (
row['test_result']['report'][field]
if 'report' in row['test_result']
else False
)

report = {}
for field in fields:
Expand Down
125 changes: 80 additions & 45 deletions evaluation/swe_bench/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.action import CmdRunAction, IPythonRunCellAction, MessageAction
from openhands.events.observation import CmdOutputObservation, ErrorObservation
from openhands.events.serialization.event import event_to_dict
from openhands.runtime.base import Runtime
Expand Down Expand Up @@ -145,7 +145,7 @@ def get_config(
platform='linux/amd64',
api_key=os.environ.get('ALLHANDS_API_KEY', None),
remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
keep_remote_runtime_alive=False,
keep_runtime_alive=False,
remote_runtime_init_timeout=3600,
),
# do not mount workspace
Expand Down Expand Up @@ -303,6 +303,7 @@ def initialize_runtime(
def complete_runtime(
runtime: Runtime,
instance: pd.Series, # this argument is not required, but it is used to get the workspace_dir_name
n_retries: int = 5,
) -> dict[str, Any]:
"""Complete the runtime for the agent.
Expand All @@ -321,55 +322,84 @@ def complete_runtime(
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(
isinstance(obs, CmdOutputObservation) and obs.exit_code == 0,
f'Failed to cd to /workspace/{workspace_dir_name}: {str(obs)}',
)

action = CmdRunAction(command='git config --global core.pager ""')
action.timeout = 600
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(
isinstance(obs, CmdOutputObservation) and obs.exit_code == 0,
f'Failed to git config --global core.pager "": {str(obs)}',
)
if isinstance(obs, CmdOutputObservation) and obs.exit_code == 0:
action = CmdRunAction(command='git config --global core.pager ""')
action.timeout = 600
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(
isinstance(obs, CmdOutputObservation) and obs.exit_code == 0,
f'Failed to git config --global core.pager "": {str(obs)}',
)

action = CmdRunAction(command='git add -A')
action.timeout = 600
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(
isinstance(obs, CmdOutputObservation) and obs.exit_code == 0,
f'Failed to git add -A: {str(obs)}',
)
action = CmdRunAction(command='git add -A')
action.timeout = 600
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(
isinstance(obs, CmdOutputObservation) and obs.exit_code == 0,
f'Failed to git add -A: {str(obs)}',
)

n_retries = 0
git_patch = None
while n_retries < 5:
action = CmdRunAction(
command=f'git diff --no-color --cached {instance["base_commit"]}',
keep_prompt=False,
n_retries = 0
git_patch = None
while n_retries < 5:
action = CmdRunAction(
command=f'git diff --no-color --cached {instance["base_commit"]}',
keep_prompt=False,
)
action.timeout = 600 + 100 * n_retries
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
n_retries += 1
if isinstance(obs, CmdOutputObservation):
if obs.exit_code == 0:
git_patch = obs.content.strip()
break
else:
logger.info('Failed to get git diff, retrying...')
sleep_if_should_continue(10)
elif isinstance(obs, ErrorObservation):
logger.error(f'Error occurred: {obs.content}. Retrying...')
sleep_if_should_continue(10)
else:
assert_and_raise(False, f'Unexpected observation type: {str(obs)}')
else:
logger.warning(
f'Failed to cd to /workspace/{workspace_dir_name}... Trying to use IPython to get git diff'
)
action.timeout = 600 + 100 * n_retries
# Git configuration and diff using IPython
cell_code = f"""
import subprocess
def run_git_cmd(cmd):
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, cwd='/workspace/{workspace_dir_name}')
return result.stdout, result.returncode
# Configure git
run_git_cmd('git config --global core.pager ""')
run_git_cmd('git add -A')
# Get the diff
stdout, exit_code = run_git_cmd('git diff --no-color --cached {instance["base_commit"]}')
git_patch = stdout.strip()
"""
action = IPythonRunCellAction(code=cell_code)
action.timeout = 600
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
n_retries += 1
if isinstance(obs, CmdOutputObservation):
if obs.exit_code == 0:
git_patch = obs.content.strip()
break
else:
logger.info('Failed to get git diff, retrying...')
sleep_if_should_continue(10)
elif isinstance(obs, ErrorObservation):
logger.error(f'Error occurred: {obs.content}. Retrying...')
sleep_if_should_continue(10)
else:
assert_and_raise(False, f'Unexpected observation type: {str(obs)}')

# Get the git_patch from IPython's namespace
cell_code = 'print(git_patch)'
action = IPythonRunCellAction(code=cell_code)
action.timeout = 600
obs = runtime.run_action(action)
git_patch = obs.content.strip()

assert_and_raise(git_patch is not None, 'Failed to get git diff (None)')

Expand Down Expand Up @@ -534,5 +564,10 @@ def filter_dataset(dataset: pd.DataFrame, filter_column: str) -> pd.DataFrame:
instances[col] = instances[col].apply(lambda x: str(x))

run_evaluation(
instances, metadata, output_file, args.eval_num_workers, process_instance
instances,
metadata,
output_file,
args.eval_num_workers,
process_instance,
timeout_seconds=120 * 60, # 2 hour PER instance should be more than enough
)
Loading

0 comments on commit 5ae2c4d

Please sign in to comment.