Skip to content

Commit

Permalink
exp_gen bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
XianBW committed Dec 30, 2024
1 parent 593854c commit 77d8b8b
Showing 1 changed file with 9 additions and 19 deletions.
28 changes: 9 additions & 19 deletions rdagent/scenarios/data_science/proposal/exp_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,10 @@ def __init__(self, scen: DataScienceScen, knowledge_base: KnowledgeBase | None =
self.hist: list[tuple[DSHypothesis, DSExperiment, HypothesisFeedback]] = []
self.knowledge_base = knowledge_base

def get_sota_hypothesis_and_experiment(
self, component: COMPONENT | None = None
) -> tuple[DSHypothesis | None, Experiment | None]:
def get_sota_hypothesis_and_experiment(self) -> tuple[DSHypothesis | None, Experiment | None]:
"""Access the last experiment result, sub-task, and the corresponding hypothesis."""
for h, exp, hf in self.hist[::-1]:
if hf.decision:
if component and h.component != component:
continue
return h, exp
return None, None

Expand Down Expand Up @@ -174,7 +170,7 @@ def gen(self, trace: DSTrace) -> DSExperiment:
hypothesis_and_feedback=hypothesis_and_feedback,
)

dependency_exp = trace.get_sota_hypothesis_and_experiment("DataLoadSpec")[1]
dependency_exp = trace.get_sota_hypothesis_and_experiment()[1]
ft = FeatureTask(
name="Feature Engineering",
description=resp_dict.get("description", "Feature description not provided"),
Expand All @@ -192,7 +188,7 @@ def gen(self, trace: DSTrace) -> DSExperiment:
hypothesis_and_feedback=hypothesis_and_feedback,
)

dependency_exp = trace.get_sota_hypothesis_and_experiment("FeatureEng")[1]
dependency_exp = trace.get_sota_hypothesis_and_experiment()[1]
mt = ModelTask(
name=resp_dict.get("model_name", "Model name not provided"),
description=resp_dict.get("description", "Model description not provided"),
Expand All @@ -214,7 +210,7 @@ def gen(self, trace: DSTrace) -> DSExperiment:
hypothesis_and_feedback=hypothesis_and_feedback,
)

dependency_exp = trace.get_sota_hypothesis_and_experiment("Model")[1]
dependency_exp = trace.get_sota_hypothesis_and_experiment()[1]
et = EnsembleTask(
name="Ensemble",
description=resp_dict.get("description", "Ensemble description not provided"),
Expand All @@ -232,7 +228,7 @@ def gen(self, trace: DSTrace) -> DSExperiment:
hypothesis_and_feedback=hypothesis_and_feedback,
)

dependency_exp = trace.get_sota_hypothesis_and_experiment("Ensemble")[1]
dependency_exp = trace.get_sota_hypothesis_and_experiment()[1]
wt = WorkflowTask(
name="Workflow",
description=resp_dict.get("description", "Workflow description not provided"),
Expand Down Expand Up @@ -267,7 +263,7 @@ def gen(self, trace: DSTrace) -> DSExperiment:
scenario_desc=scenario_desc,
task_output_format=T(".prompts:output_format.feature").r(),
)
dependency_exp = trace.get_sota_hypothesis_and_experiment("DataLoadSpec")[1]
dependency_exp = trace.get_sota_hypothesis_and_experiment()[1]
ft = FeatureTask(
name="Feature Engineering",
description=resp_dict.get("description", "Factor description not provided"),
Expand All @@ -281,19 +277,13 @@ def gen(self, trace: DSTrace) -> DSExperiment:
scenario_desc=scenario_desc,
task_output_format=T(".prompts:output_format.model").r(),
)
dependency_exp = trace.get_sota_hypothesis_and_experiment("FeatureEng")[1]
if last_model_exp := trace.get_sota_hypothesis_and_experiment("Model")[1]:
# TODO: model only have one (named "model.py")?
base_code = last_model_exp.experiment_workspace.file_dict["model.py"]
else:
base_code = ""
dependency_exp = trace.get_sota_hypothesis_and_experiment()[1]
mt = ModelTask(
name=resp_dict.get("model_name", "Model name not provided"),
description=resp_dict.get("description", "Model description not provided"),
model_type=resp_dict.get("model_type", "Model type not provided"),
architecture=resp_dict.get("architecture", "Model architecture not provided"),
hyperparameters=resp_dict.get("hyperparameters", "Model hyperparameters not provided"),
base_code=base_code,
)
exp = DSExperiment(sub_tasks=[mt], hypothesis=DSHypothesis("Model"))
exp.experiment_workspace.inject_code_from_folder(dependency_exp.experiment_workspace.workspace_path)
Expand All @@ -304,7 +294,7 @@ def gen(self, trace: DSTrace) -> DSExperiment:
scenario_desc=scenario_desc,
task_output_format=T(".prompts:output_format.ensemble").r(),
)
dependency_exp = trace.get_sota_hypothesis_and_experiment("Model")[1]
dependency_exp = trace.get_sota_hypothesis_and_experiment()[1]
et = EnsembleTask(
name="Ensemble",
description=resp_dict.get("description", "Ensemble description not provided"),
Expand All @@ -318,7 +308,7 @@ def gen(self, trace: DSTrace) -> DSExperiment:
scenario_desc=scenario_desc,
task_output_format=T(".prompts:output_format.workflow").r(),
)
dependency_exp = trace.get_sota_hypothesis_and_experiment("Ensemble")[1]
dependency_exp = trace.get_sota_hypothesis_and_experiment()[1]
wt = WorkflowTask(
name="Workflow",
description=resp_dict.get("description", "Workflow description not provided"),
Expand Down

0 comments on commit 77d8b8b

Please sign in to comment.