From 36b99e893dfe60335a555386b22dcbb40c99bd34 Mon Sep 17 00:00:00 2001 From: Tianqi Xu Date: Sun, 4 Aug 2024 15:00:18 +0300 Subject: [PATCH] Revert "Tiny fix on crab-benchmark-v0 main" This reverts commit 044474530a671626e5cec0753e55b8e0f5c5e81d. --- crab-benchmark-v0/main.py | 39 ++++++--------------------------------- 1 file changed, 6 insertions(+), 33 deletions(-) diff --git a/crab-benchmark-v0/main.py b/crab-benchmark-v0/main.py index 2f8ccac..adefbc8 100644 --- a/crab-benchmark-v0/main.py +++ b/crab-benchmark-v0/main.py @@ -15,7 +15,6 @@ import warnings from pathlib import Path from typing import Literal -from uuid import uuid4 from crab import ( BenchmarkConfig, @@ -169,44 +168,18 @@ def get_benchmark(env: str, ubuntu_url: str): help="ubuntu, android or cross", default="cross", ) - parser.add_argument("--task-id", type=str, help="task id", default=None) - parser.add_argument( - "--task-description", - type=str, - help="task description. If provided, will overwrite the task id.", - default=None, - ) + parser.add_argument("--task-id", type=str, help="task id") args = parser.parse_args() benchmark = get_benchmark(args.env, args.remote_url) - if args.task_description is not None: - task_id = str(uuid4()) - benchmark.tasks = [ - Task( - id=task_id, - description=args.task_description, - evaluator=empty_evaluator, - ) - ] - else: - task_id = args.task_id - - history_messages_len = 2 - if args.model == "gpt4o": - model = OpenAIModel(model="gpt-4o", history_messages_len=history_messages_len) + model = OpenAIModel(model="gpt-4o") elif args.policy == "gpt4turbo": - model = OpenAIModel( - model="gpt-4-turbo", history_messages_len=history_messages_len - ) + model = OpenAIModel(model="gpt-4-turbo") elif args.policy == "gemini": - model = GeminiModel( - model="gemini-1.5-pro-latest", history_messages_len=history_messages_len - ) + model = GeminiModel(model="gemini-1.5-pro-latest") elif args.policy == "claude": - model = ClaudeModel( - model="claude-3-opus-20240229", history_messages_len=history_messages_len - ) + model = ClaudeModel(model="claude-3-opus-20240229") else: print("Unsupported model: ", args.model) exit() @@ -228,7 +201,7 @@ def get_benchmark(env: str, ubuntu_url: str): log_dir = (Path(__file__).parent / "logs").resolve() expeirment = CrabBenchmarkV0( benchmark=benchmark, - task_id=task_id, + task_id=args.task_id, agent_policy=agent_policy, log_dir=log_dir, )