Skip to content

Commit

Permalink
Merge branch 'update-benchmark-main' into macos
Browse files Browse the repository at this point in the history
  • Loading branch information
dandansamax committed Sep 7, 2024
2 parents 766dce6 + 3eb5bc7 commit 04d9193
Showing 1 changed file with 39 additions and 6 deletions.
45 changes: 39 additions & 6 deletions crab-benchmark-v0/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import warnings
from pathlib import Path
from typing import Literal
from uuid import uuid4

from crab import (
BenchmarkConfig,
Expand All @@ -23,6 +24,7 @@
Task,
TaskGenerator,
create_benchmark,
evaluator,
)
from crab.actions.crab_actions import complete
from crab.agents.backend_models import ClaudeModel, GeminiModel, OpenAIModel
Expand Down Expand Up @@ -163,6 +165,11 @@ def get_benchmark(env: str, ubuntu_url: str):
return create_benchmark(benchmark_config)


@evaluator(env_name="ubuntu")
def empty_evaluator() -> bool:
return False


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Script for running benchmark with an agent."
Expand Down Expand Up @@ -191,18 +198,44 @@ def get_benchmark(env: str, ubuntu_url: str):
help="ubuntu, android or cross",
default="cross",
)
parser.add_argument("--task-id", type=str, help="task id")
parser.add_argument("--task-id", type=str, help="task id", default=None)
parser.add_argument(
"--task-description",
type=str,
help="task description. If provided, will overwrite the task id.",
default=None,
)
args = parser.parse_args()
benchmark = get_benchmark(args.env, args.remote_url)

if args.task_description is not None:
task_id = str(uuid4())
benchmark.tasks = [
Task(
id=task_id,
description=args.task_description,
evaluator=empty_evaluator,
)
]
else:
task_id = args.task_id

history_messages_len = 2

if args.model == "gpt4o":
model = OpenAIModel(model="gpt-4o")
model = OpenAIModel(model="gpt-4o", history_messages_len=history_messages_len)
elif args.policy == "gpt4turbo":
model = OpenAIModel(model="gpt-4-turbo")
model = OpenAIModel(
model="gpt-4-turbo", history_messages_len=history_messages_len
)
elif args.policy == "gemini":
model = GeminiModel(model="gemini-1.5-pro-latest")
model = GeminiModel(
model="gemini-1.5-pro-latest", history_messages_len=history_messages_len
)
elif args.policy == "claude":
model = ClaudeModel(model="claude-3-opus-20240229")
model = ClaudeModel(
model="claude-3-opus-20240229", history_messages_len=history_messages_len
)
else:
print("Unsupported model: ", args.model)
exit()
Expand All @@ -224,7 +257,7 @@ def get_benchmark(env: str, ubuntu_url: str):
log_dir = (Path(__file__).parent / "logs").resolve()
expeirment = CrabBenchmarkV0(
benchmark=benchmark,
task_id=args.task_id,
task_id=task_id,
agent_policy=agent_policy,
log_dir=log_dir,
)
Expand Down

0 comments on commit 04d9193

Please sign in to comment.