Skip to content

Commit

Permalink
🔨 refactor: refactor the steps
Browse files Browse the repository at this point in the history
  • Loading branch information
huangyz0918 committed Jun 5, 2024
1 parent 62b6e12 commit f7b4941
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 147 deletions.
10 changes: 5 additions & 5 deletions agent/function/plan_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def gen_file_name(project, llm_agent):
console.log(f"Error parsing filenames: {e}")
file_name_candidates = [] # Fallback to an empty list if parsing fails
else:
file_name_candidates = (file_name_candidates, )
file_name_candidates = (file_name_candidates,)

file_path_candidates = [os.path.join(project.path, filename.strip("'")) for filename in file_name_candidates]

Expand Down Expand Up @@ -99,7 +99,7 @@ def pmpt_task_select():
Output Format:
Your response should include three entries formatted as a list of strings, where each string contains the task name followed by its description, e.g.:
["Task1, Description of Task1", "Task2, Description of Task2", "Task3, Description of Task3"]
["Task1: Description of Task1", "Task2: Description of Task2", "Task3: Description of Task3"]
Note: Return only the task names followed by a brief description, without any additional information or punctuation.
"""
Expand All @@ -124,9 +124,9 @@ def pmpt_model_select():
Please return a list of three strings, where each string includes the model's name followed by its summary.
Ensure the description highlights how the model meets one of the selection criteria (best, balanced, fastest).
Example format:
["ModelName1, Best for task X with high accuracy of Y%, suitable for complex data analysis.",
"ModelName2, Balanced model, offers moderate accuracy with better speed, good for real-time applications.",
"ModelName3, Fastest model with lower accuracy, best for quick processing where speed is prioritized over precision."]
["ModelName1: Best for task X with high accuracy of Y%, suitable for complex data analysis.",
"ModelName2: Balanced model, offers moderate accuracy with better speed, good for real-time applications.",
"ModelName3: Fastest model with lower accuracy, best for quick processing where speed is prioritized over precision."]
Note: Ensure that the architecture names with summary are returned without any additional punctuation.
"""
Expand Down
327 changes: 186 additions & 141 deletions agent/function/tech_leader.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,153 +50,198 @@ def __init__(self, project: Project, model):
if self.project.plan is None:
self.project.plan = Plan(current_task=0)

def start(self):
def user_requirement_understanding(self):
"""
Execute the chain.
:return: the result of the chain.
(STEP-0) User Requirement Understanding.
:return:
"""
try:
is_running = True
while is_running:
show_panel("STEP 1: User Requirements Understanding")
if self.project.requirement:
self.console.log(f"[cyan]User Requirement:[/cyan] {self.project.requirement}")
else:
self.requirement = questionary.text("Hi, what are your requirements?").ask()
self.project.requirement = self.requirement

if not self.requirement:
raise SystemExit("The user requirement is not provided.")

# Generate entry file name based on requirement
enhanced_requirement = self.requirement
if self.entry_file is None:
self.entry_file = gen_file_name(self.project, self.model)
self.console.log(f"The entry file is: {self.entry_file}")

show_panel("STEP 2: Dataset Selection")
if self.project.plan.data_kind is None and self.project.plan.dataset is None:
self.project.plan.data_kind = analyze_requirement(
enhanced_requirement,
pmpt_dataset_detect(),
if self.project.requirement:
self.console.log(f"[cyan]User Requirement:[/cyan] {self.project.requirement}")
else:
show_panel("STEP 0: User Requirements Understanding")
self.requirement = questionary.text("Hi, what are your requirements?").ask()
self.project.requirement = self.requirement

if not self.requirement:
raise SystemExit("The user requirement is not provided.")

# Generate entry file name based on requirement
self.project.enhanced_requirement = self.requirement
if self.entry_file is None:
self.entry_file = gen_file_name(self.project, self.model)
self.console.log(f"The entry file is: {self.entry_file}")
update_project_state(self.project)

def dataset_selection(self):
"""
(STEP-1) Dataset Selection.
:return:
"""

if self.project.plan.data_kind is None and self.project.plan.dataset is None:
show_panel("STEP 1: Dataset Selection")
self.project.plan.data_kind = analyze_requirement(
self.project.enhanced_requirement,
pmpt_dataset_detect(),
self.model
)

if self.project.plan.data_kind == 'no_data_information_provided':
public_dataset_list = analyze_requirement(
self.project.enhanced_requirement,
pmpt_public_dataset_guess(),
self.model
)
public_dataset_list = ast.literal_eval(public_dataset_list)
self.project.plan.dataset = questionary.select(
"Please select the dataset:",
choices=public_dataset_list
).ask()
elif self.project.plan.data_kind == 'csv_data':
self.project.plan.dataset = questionary.text("Please provide the CSV data path:").ask()
if os.path.exists(self.project.plan.dataset) is False:
public_dataset_list = analyze_requirement(
self.project.enhanced_requirement,
pmpt_public_dataset_guess(),
self.model
)
if self.project.plan.data_kind == 'no_data_information_provided':
public_dataset_list = analyze_requirement(
enhanced_requirement,
pmpt_public_dataset_guess(),
self.model
)
public_dataset_list = ast.literal_eval(public_dataset_list)
self.project.plan.dataset = questionary.select(
"Please select the dataset:",
choices=public_dataset_list
).ask()

elif self.project.plan.data_kind == 'csv_data':
self.project.plan.dataset = questionary.text("Please provide the CSV data path:").ask()
if os.path.exists(self.project.plan.dataset) is False:
public_dataset_list = analyze_requirement(
enhanced_requirement,
pmpt_public_dataset_guess(),
self.model
)
public_dataset_list = ast.literal_eval(public_dataset_list)
self.project.plan.dataset = questionary.select(
"Please select the dataset:",
choices=public_dataset_list
).ask()

if self.project.plan.dataset is None:
raise SystemExit("There is no dataset information. Aborted.")
else:
self.console.log(f"[cyan]Data source:[/cyan] {self.project.plan.dataset}")
if self.project.plan.data_kind == 'csv_data':
csv_data_sample = read_csv_file(self.project.plan.dataset)
self.console.log(f"[cyan]Dataset examples:[/cyan] {csv_data_sample}")
enhanced_requirement += f"\nDataset: {self.project.plan.dataset}"
enhanced_requirement += f"\nDataset Sample: {csv_data_sample}"
update_project_state(self.project)

show_panel("STEP 3: Task & Model Selection")
if self.project.plan.ml_task_type is None:
ml_task_list = analyze_requirement(enhanced_requirement, pmpt_task_select(), self.model)
ml_task_list = ast.literal_eval(ml_task_list)
ml_task_type = questionary.select(
"Please select the ML task type:",
choices=ml_task_list
public_dataset_list = ast.literal_eval(public_dataset_list)
self.project.plan.dataset = questionary.select(
"Please select the dataset:",
choices=public_dataset_list
).ask()

self.console.log(f"[cyan]ML task type detected:[/cyan] {ml_task_type}")
confirm_ml_task_type = questionary.confirm("Are you sure to use this ml task type?").ask()
if confirm_ml_task_type:
self.project.plan.ml_task_type = ml_task_type
update_project_state(self.project)
else:
self.console.log("Seems you are not satisfied with the task type. Aborting the chain.")
return

enhanced_requirement += f"\n\nML task type: {self.project.plan.ml_task_type}"
if self.project.plan.ml_model_arch is None:
ml_model_list = analyze_requirement(enhanced_requirement, pmpt_model_select(), self.model)
ml_model_list = ast.literal_eval(ml_model_list)
ml_model_arch = questionary.select(
"Please select the ML model architecture:",
choices=ml_model_list
).ask()
self.console.log(f"[cyan]Model architecture detected:[/cyan] {ml_model_arch}")
confirm_ml_model_arch = questionary.confirm("Are you sure to use this ml arch?").ask()
if confirm_ml_model_arch:
self.project.plan.ml_model_arch = ml_model_arch
update_project_state(self.project)
else:
self.console.log("Seems you are not satisfied with the model architecture. Aborting the chain.")
return

enhanced_requirement += f"\nModel architecture: {self.project.plan.ml_model_arch}"

show_panel("STEP 4: Planning")
if self.project.plan.tasks is None:
self.console.log(
f"The project [cyan]{self.project.name}[/cyan] has no existing plans. Start planning..."
)
enhanced_requirement += f"\nDataset: {self.project.plan.dataset}"
with self.console.status("Planning the tasks for you..."):
task_dicts = plan_generator(
enhanced_requirement,
self.model,
)
self.console.print(generate_plan_card_ascii(task_dicts), highlight=False)
self.project.plan.tasks = []
for task_dict in task_dicts.get('tasks'):
task = match_plan(task_dict)
if task:
self.project.plan.tasks.append(task)

# Confirm the plan
confirm_plan = questionary.confirm("Are you sure to use this plan?").ask()
if confirm_plan:
update_project_state(self.project)
else:
self.console.log("Seems you are not satisfied with the plan. Aborting the chain.")
return

task_num = len(self.project.plan.tasks)
# check if all tasks are completed.
# if self.project.plan.current_task == task_num:
# self.console.log(":tada: Looks like all tasks are completed.")
# return

# code generation
show_panel("STEP 5: Code Generation")
code_generation_agent = CodeAgent(self.model, self.project)
code_generation_agent.invoke(task_num, self.requirement)

# install the dependencies for this plan and code.
show_panel("STEP 6: Execution and Refection")
launch_agent = SetupAgent(self.model, self.project)
launch_agent.invoke()
is_running = False
if self.project.plan.dataset is None:
raise SystemExit("There is no dataset information. Aborted.")
else:
self.console.log(f"[cyan]Data source:[/cyan] {self.project.plan.dataset}")
if self.project.plan.data_kind == 'csv_data':
csv_data_sample = read_csv_file(self.project.plan.dataset)
self.console.log(f"[cyan]Dataset examples:[/cyan] {csv_data_sample}")
self.project.enhanced_requirement += f"\nDataset Sample: {csv_data_sample}"

self.project.enhanced_requirement += f"\nDataset: {self.project.plan.dataset}"
update_project_state(self.project)

def task_model_selection(self):
"""
(STEP-2) Task & Model Selection.
:return:
"""
if self.project.plan.ml_task_type is None or self.project.plan.ml_model_arch is None:
show_panel("STEP 2: Task & Model Selection")

# select the ml task type
if self.project.plan.ml_task_type is None:
ml_task_list = analyze_requirement(self.project.enhanced_requirement, pmpt_task_select(), self.model)
ml_task_list = ast.literal_eval(ml_task_list)
ml_task_type = questionary.select(
"Please select the ML task type:",
choices=ml_task_list
).ask()

self.console.log(f"[cyan]ML task type detected:[/cyan] {ml_task_type}")
confirm_ml_task_type = questionary.confirm("Are you sure to use this ml task type?").ask()
if confirm_ml_task_type:
self.project.plan.ml_task_type = ml_task_type
else:
self.console.log("Seems you are not satisfied with the task type. Aborting the chain.")
return
self.project.enhanced_requirement += f"\n\nML task type: {self.project.plan.ml_task_type}"

# select the mode architecture
if self.project.plan.ml_model_arch is None:
ml_model_list = analyze_requirement(self.project.enhanced_requirement, pmpt_model_select(), self.model)
ml_model_list = ast.literal_eval(ml_model_list)
ml_model_arch = questionary.select(
"Please select the ML model architecture:",
choices=ml_model_list
).ask()
self.console.log(f"[cyan]Model architecture detected:[/cyan] {ml_model_arch}")
confirm_ml_model_arch = questionary.confirm("Are you sure to use this ml arch?").ask()
if confirm_ml_model_arch:
self.project.plan.ml_model_arch = ml_model_arch
else:
self.console.log("Seems you are not satisfied with the model architecture. Aborting the chain.")
return

update_project_state(self.project)
self.project.enhanced_requirement += f"\nModel architecture: {self.project.plan.ml_model_arch}"

def task_planning(self):
"""
(STEP-3) Task Planning.
:return:
"""
if self.project.plan.tasks is None:
show_panel("STEP 3: Task Planning")
self.console.log(
f"The project [cyan]{self.project.name}[/cyan] has no existing plans. Start planning..."
)
self.project.enhanced_requirement += f"\nDataset: {self.project.plan.dataset}"
with self.console.status("Planning the tasks for you..."):
task_dicts = plan_generator(self.project.enhanced_requirement, self.model)
self.console.print(generate_plan_card_ascii(task_dicts), highlight=False)
self.project.plan.tasks = []
for task_dict in task_dicts.get('tasks'):
task = match_plan(task_dict)
if task:
self.project.plan.tasks.append(task)

# Confirm the plan
confirm_plan = questionary.confirm("Are you sure to use this plan?").ask()
if confirm_plan:
update_project_state(self.project)
else:
self.console.log("Seems you are not satisfied with the plan. Aborting the chain.")
return
else:
tasks = []
for t in self.project.plan.tasks:
tasks.append({'name': t.name, 'resources': [r.name for r in t.resources], 'description': t.description})
self.console.print(generate_plan_card_ascii({'tasks': tasks}), highlight=False)

def code_generation(self):
"""
(STEP-4) Code Generation.
:return:
"""
task_num = len(self.project.plan.tasks)
if self.project.plan.current_task < task_num:
show_panel("STEP 4: Code Generation")
code_generation_agent = CodeAgent(self.model, self.project)
code_generation_agent.invoke(task_num, self.requirement)
update_project_state(self.project)

def execution_and_reflection(self):
"""
(STEP-5) Execution and Reflection.
:return:
"""
show_panel("STEP 5: Execution and Reflection")
launch_agent = SetupAgent(self.model, self.project)
launch_agent.invoke()
update_project_state(self.project)

def start(self):
"""
Execute the chain.
:return: the result of the chain.
"""
try:
# STEP 0: User Requirement Understanding
self.user_requirement_understanding()
# STEP 1: Dataset Selection
self.dataset_selection()
# STEP 2: Task & Model Selection
self.task_model_selection()
# STEP 3: Task Planning
self.task_planning()
# STEP 4: Code Generation
self.code_generation()
# STEP 5: Execution and Reflection
self.execution_and_reflection()
self.console.log("The chain has been completed.")
except KeyboardInterrupt:
self.console.log("MLE Plan Agent has been interrupted.")
return
3 changes: 2 additions & 1 deletion agent/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ class Plan(BaseModel):

class Project(BaseModel):
name: str
path: Optional[str] = None
lang: str
llm: str
path: Optional[str] = None
plan: Optional[Plan] = None
entry_file: Optional[str] = None
debug_env: Optional[str] = DebugEnv.local
description: Optional[str] = None
requirement: Optional[str] = None
enhanced_requirement: Optional[str] = None

0 comments on commit f7b4941

Please sign in to comment.