diff --git a/gui/main.py b/gui/main.py index 8510de7..804ec11 100644 --- a/gui/main.py +++ b/gui/main.py @@ -18,23 +18,24 @@ import customtkinter as ctk from crab import Experiment -from crab.agents.backend_models import OpenAIModel, ClaudeModel, GeminiModel +from crab.agents.backend_models import ClaudeModel, GeminiModel, OpenAIModel from crab.agents.policies import SingleAgentPolicy from gui.utils import get_benchmark warnings.filterwarnings("ignore") AVAILABLE_MODELS = { - "gpt-4o": ("OpenAIModel", "gpt-4o"), - "gpt-4turbo": ("OpenAIModel", "gpt-4-turbo"), - "gemini": ("GeminiModel", "gemini-1.5-pro-latest"), - "claude": ("ClaudeModel", "claude-3-opus-20240229"), + "GPT-4o": ("OpenAIModel", "gpt-4o"), + "GPT-4 Turbo": ("OpenAIModel", "gpt-4-turbo"), + "Gemini": ("GeminiModel", "gemini-1.5-pro-latest"), + "Claude": ("ClaudeModel", "claude-3-opus-20240229"), } + def get_model_instance(model_key: str): if model_key not in AVAILABLE_MODELS: raise ValueError(f"Model {model_key} not supported") - + model_config = AVAILABLE_MODELS[model_key] model_class_name = model_config[0] model_name = model_config[1] @@ -46,12 +47,13 @@ def get_model_instance(model_key: str): elif model_class_name == "ClaudeModel": return ClaudeModel(model=model_name, history_messages_len=2) + def assign_task(): task_description = input_entry.get() input_entry.delete(0, "end") display_message(task_description) - model = get_model_instance(selected_model.get()) + model = get_model_instance(model_dropdown.get()) agent_policy = SingleAgentPolicy(model_backend=model) task_id = str(uuid4()) @@ -80,7 +82,8 @@ def display_message(message, sender="user"): if __name__ == "__main__": - # TODO: Handle JSON decode error from environment action endpoint and display model response in GUI + # TODO: Handle JSON decode error from environment action endpoint and + # display model response in GUI log_dir = (Path(__file__).parent / "logs").resolve() ctk.set_appearance_mode("System") @@ -93,15 +96,14 @@ def display_message(message, sender="user"): model_frame = ctk.CTkFrame(app) model_frame.pack(pady=10, padx=10, fill="x") - model_label = ctk.CTkLabel(model_frame, text="Select Model:") + model_label = ctk.CTkLabel(model_frame, text="Model") model_label.pack(side="left", padx=(0, 10)) - selected_model = ctk.StringVar(value="gpt-4o") model_dropdown = ctk.CTkOptionMenu( model_frame, values=list(AVAILABLE_MODELS.keys()), - variable=selected_model, ) + model_dropdown.set(next(iter(AVAILABLE_MODELS))) model_dropdown.pack(side="left", fill="x", expand=True) chat_display_frame = ctk.CTkFrame(app, width=380, height=380)