Skip to content

Commit

Permalink
Tiny fix on crab-benchmark-v0 main
Browse files Browse the repository at this point in the history
  • Loading branch information
dandansamax committed Aug 4, 2024
1 parent 36b99e8 commit 3eb5bc7
Showing 1 changed file with 40 additions and 6 deletions.
46 changes: 40 additions & 6 deletions crab-benchmark-v0/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
import warnings
from pathlib import Path
from typing import Literal
from uuid import uuid4

from crab import (
BenchmarkConfig,
Experiment,
MessageType,
Task,
TaskGenerator,
create_benchmark,
evaluator,
)
from crab.actions.crab_actions import complete
from crab.agents.backend_models import ClaudeModel, GeminiModel, OpenAIModel
Expand Down Expand Up @@ -140,6 +143,11 @@ def get_benchmark(env: str, ubuntu_url: str):
return create_benchmark(benchmark_config)


@evaluator(env_name="ubuntu")
def empty_evaluator() -> bool:
return False


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Script for running benchmark with an agent."
Expand Down Expand Up @@ -168,18 +176,44 @@ def get_benchmark(env: str, ubuntu_url: str):
help="ubuntu, android or cross",
default="cross",
)
parser.add_argument("--task-id", type=str, help="task id")
parser.add_argument("--task-id", type=str, help="task id", default=None)
parser.add_argument(
"--task-description",
type=str,
help="task description. If provided, will overwrite the task id.",
default=None,
)
args = parser.parse_args()
benchmark = get_benchmark(args.env, args.remote_url)

if args.task_description is not None:
task_id = str(uuid4())
benchmark.tasks = [
Task(
id=task_id,
description=args.task_description,
evaluator=empty_evaluator,
)
]
else:
task_id = args.task_id

history_messages_len = 2

if args.model == "gpt4o":
model = OpenAIModel(model="gpt-4o")
model = OpenAIModel(model="gpt-4o", history_messages_len=history_messages_len)
elif args.policy == "gpt4turbo":
model = OpenAIModel(model="gpt-4-turbo")
model = OpenAIModel(
model="gpt-4-turbo", history_messages_len=history_messages_len
)
elif args.policy == "gemini":
model = GeminiModel(model="gemini-1.5-pro-latest")
model = GeminiModel(
model="gemini-1.5-pro-latest", history_messages_len=history_messages_len
)
elif args.policy == "claude":
model = ClaudeModel(model="claude-3-opus-20240229")
model = ClaudeModel(
model="claude-3-opus-20240229", history_messages_len=history_messages_len
)
else:
print("Unsupported model: ", args.model)
exit()
Expand All @@ -201,7 +235,7 @@ def get_benchmark(env: str, ubuntu_url: str):
log_dir = (Path(__file__).parent / "logs").resolve()
expeirment = CrabBenchmarkV0(
benchmark=benchmark,
task_id=args.task_id,
task_id=task_id,
agent_policy=agent_policy,
log_dir=log_dir,
)
Expand Down

0 comments on commit 3eb5bc7

Please sign in to comment.