From 6b4b4fc91ddca9c921de30fffe71c50299c490c3 Mon Sep 17 00:00:00 2001 From: you-n-g Date: Fri, 26 Jul 2024 14:29:19 +0800 Subject: [PATCH] Support special step customization (#123) * Support special step customization * lint --- rdagent/app/qlib_rd_loop/factor_w_sc.py | 12 ++++++++++++ rdagent/utils/workflow.py | 5 ++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/rdagent/app/qlib_rd_loop/factor_w_sc.py b/rdagent/app/qlib_rd_loop/factor_w_sc.py index 10c9d19e4..e6434f9f0 100644 --- a/rdagent/app/qlib_rd_loop/factor_w_sc.py +++ b/rdagent/app/qlib_rd_loop/factor_w_sc.py @@ -2,16 +2,28 @@ Factor workflow with session control """ +from typing import Any + import fire from rdagent.app.qlib_rd_loop.conf import FACTOR_PROP_SETTING from rdagent.components.workflow.rd_loop import RDLoop from rdagent.core.exception import FactorEmptyError +from rdagent.log import rdagent_logger as logger class FactorRDLoop(RDLoop): skip_loop_error = (FactorEmptyError,) + def exp_gen(self, prev_out: dict[str, Any]): + with logger.tag("r"): # research + exp = self.hypothesis2experiment.convert(prev_out["propose"], self.trace) + if exp is None: + logger.error(f"Factor extraction failed.") + raise FactorEmptyError("Factor extraction failed.") + logger.log_object(exp.sub_tasks, tag="experiment generation") + return exp + def main(path=None, step_n=None): """ diff --git a/rdagent/utils/workflow.py b/rdagent/utils/workflow.py index f89759049..f4e568cda 100644 --- a/rdagent/utils/workflow.py +++ b/rdagent/utils/workflow.py @@ -35,7 +35,10 @@ def __new__(cls, clsname, bases, attrs): steps = LoopMeta._get_steps(bases) # all the base classes of parents for name, attr in attrs.items(): if not name.startswith("__") and isinstance(attr, Callable): - steps.append(name) + if name not in steps: + # NOTE: if we override the step in the subclass + # Then it is not the new step. So we skip it. + steps.append(name) attrs["steps"] = steps return super().__new__(cls, clsname, bases, attrs)