Skip to content

Commit

Permalink
Revert "Tiny fix on crab-benchmark-v0 main"
Browse files Browse the repository at this point in the history
This reverts commit 0444745.
  • Loading branch information
dandansamax committed Aug 4, 2024
1 parent 071e168 commit 36b99e8
Showing 1 changed file with 6 additions and 33 deletions.
39 changes: 6 additions & 33 deletions crab-benchmark-v0/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import warnings
from pathlib import Path
from typing import Literal
from uuid import uuid4

from crab import (
BenchmarkConfig,
Expand Down Expand Up @@ -169,44 +168,18 @@ def get_benchmark(env: str, ubuntu_url: str):
help="ubuntu, android or cross",
default="cross",
)
parser.add_argument("--task-id", type=str, help="task id", default=None)
parser.add_argument(
"--task-description",
type=str,
help="task description. If provided, will overwrite the task id.",
default=None,
)
parser.add_argument("--task-id", type=str, help="task id")
args = parser.parse_args()
benchmark = get_benchmark(args.env, args.remote_url)

if args.task_description is not None:
task_id = str(uuid4())
benchmark.tasks = [
Task(
id=task_id,
description=args.task_description,
evaluator=empty_evaluator,
)
]
else:
task_id = args.task_id

history_messages_len = 2

if args.model == "gpt4o":
model = OpenAIModel(model="gpt-4o", history_messages_len=history_messages_len)
model = OpenAIModel(model="gpt-4o")
elif args.policy == "gpt4turbo":
model = OpenAIModel(
model="gpt-4-turbo", history_messages_len=history_messages_len
)
model = OpenAIModel(model="gpt-4-turbo")
elif args.policy == "gemini":
model = GeminiModel(
model="gemini-1.5-pro-latest", history_messages_len=history_messages_len
)
model = GeminiModel(model="gemini-1.5-pro-latest")
elif args.policy == "claude":
model = ClaudeModel(
model="claude-3-opus-20240229", history_messages_len=history_messages_len
)
model = ClaudeModel(model="claude-3-opus-20240229")
else:
print("Unsupported model: ", args.model)
exit()
Expand All @@ -228,7 +201,7 @@ def get_benchmark(env: str, ubuntu_url: str):
log_dir = (Path(__file__).parent / "logs").resolve()
expeirment = CrabBenchmarkV0(
benchmark=benchmark,
task_id=task_id,
task_id=args.task_id,
agent_policy=agent_policy,
log_dir=log_dir,
)
Expand Down

0 comments on commit 36b99e8

Please sign in to comment.