Skip to content

Commit

Permalink
fix(eval): support setting hard timeout per evaluation instance
Browse files Browse the repository at this point in the history
  • Loading branch information
xingyaoww committed Nov 18, 2024
1 parent a87b859 commit 42afc43
Showing 1 changed file with 55 additions and 6 deletions.
61 changes: 55 additions & 6 deletions evaluation/utils/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import multiprocessing as mp
import os
import pathlib
import signal
import subprocess
import time
import traceback
from contextlib import contextmanager
from typing import Any, Awaitable, Callable, TextIO

import pandas as pd
Expand Down Expand Up @@ -92,6 +94,27 @@ class EvalException(Exception):
pass


class EvalTimeoutException(Exception):
pass


@contextmanager
def timeout(seconds: int):
def timeout_handler(signum, frame):
raise EvalTimeoutException(f'Function timed out after {seconds} seconds')

# Set up the signal handler
original_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(seconds)

try:
yield
finally:
# Restore the original handler and disable the alarm
signal.alarm(0)
signal.signal(signal.SIGALRM, original_handler)


def codeact_user_response(
state: State,
encapsulate_solution: bool = False,
Expand Down Expand Up @@ -280,15 +303,33 @@ def _process_instance_wrapper(
metadata: EvalMetadata,
use_mp: bool,
max_retries: int = 5,
timeout_seconds: int | None = None,
) -> EvalOutput:
"""Wrap the process_instance_func to handle retries and errors.
Retry an instance up to max_retries times if it fails (e.g., due to transient network/runtime issues).
"""
"""Wrap the process_instance_func to handle retries and errors."""
for attempt in range(max_retries + 1):
try:
result = process_instance_func(instance, metadata, use_mp)
if timeout_seconds is not None:
with timeout(timeout_seconds):
result = process_instance_func(instance, metadata, use_mp)
else:
result = process_instance_func(instance, metadata, use_mp)
return result
except EvalTimeoutException as e:
error = f'Timeout after {timeout_seconds} seconds'
stacktrace = traceback.format_exc()
msg = (
'-' * 10
+ '\n'
+ f'Timeout ({timeout_seconds} seconds) in instance [{instance.instance_id}], Stopped evaluation for this instance.'
+ '\n'
+ '-' * 10
)
logger.exception(e)
return EvalOutput(
instance_id=instance.instance_id,
test_result={},
error=error,
)
except Exception as e:
error = str(e)
stacktrace = traceback.format_exc()
Expand Down Expand Up @@ -337,6 +378,7 @@ def run_evaluation(
[pd.Series, EvalMetadata, bool], Awaitable[EvalOutput]
],
max_retries: int = 5, # number of retries for each instance
timeout_seconds: int | None = None,
):
use_multiprocessing = num_workers > 1

Expand All @@ -357,7 +399,14 @@ def run_evaluation(
if use_multiprocessing:
with mp.Pool(num_workers) as pool:
args_iter = (
(process_instance_func, instance, metadata, True, max_retries)
(
process_instance_func,
instance,
metadata,
True,
max_retries,
timeout_seconds,
)
for _, instance in dataset.iterrows()
)
results = pool.imap_unordered(_process_instance_wrapper_mp, args_iter)
Expand Down

0 comments on commit 42afc43

Please sign in to comment.