From 138bedd685b4c91c1db3c7e40090a16a3ae65b56 Mon Sep 17 00:00:00 2001 From: yauheni Date: Sat, 14 Sep 2024 20:53:23 +0200 Subject: [PATCH 1/2] Added callback_file & callback_name to default_args DAG level and tests --- dagfactory/dagbuilder.py | 20 ++++++++ tests/test_dagbuilder.py | 103 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+) diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index 689b3e8..4a9509b 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -254,6 +254,26 @@ def get_dag_params(self) -> Dict[str, Any]: dag_params["on_failure_callback_file"], ) + if utils.check_dict_key( + dag_params["default_args"], "on_success_callback_name" + ) and utils.check_dict_key( + dag_params["default_args"], "on_success_callback_file"): + + dag_params["default_args"]["on_success_callback"]: Callable = utils.get_python_callable( + dag_params["default_args"]["on_success_callback_name"], + dag_params["default_args"]["on_success_callback_file"], + ) + + if utils.check_dict_key( + dag_params["default_args"], "on_failure_callback_name" + ) and utils.check_dict_key( + dag_params["default_args"], "on_failure_callback_file"): + + dag_params["default_args"]["on_failure_callback"]: Callable = utils.get_python_callable( + dag_params["default_args"]["on_failure_callback_name"], + dag_params["default_args"]["on_failure_callback_file"], + ) + if utils.check_dict_key(dag_params, "template_searchpath"): if isinstance( dag_params["template_searchpath"], (list, str) diff --git a/tests/test_dagbuilder.py b/tests/test_dagbuilder.py index 119f806..a43fd94 100644 --- a/tests/test_dagbuilder.py +++ b/tests/test_dagbuilder.py @@ -200,6 +200,69 @@ }, }, } + +DAG_CONFIG_CALLBACK_NAME_AND_FILE = { + "doc_md": "##here is a doc md string", + "default_args": { + "owner": "custom_owner", + }, + "description": "this is an example dag", + "schedule_interval": "0 3 * * *", + "tags": ["tag1", "tag2"], + "on_failure_callback_name": "print_context_callback", + "on_failure_callback_file": __file__, + "on_success_callback_name": "print_context_callback", + "on_success_callback_file": __file__, + "tasks": { + "task_1": { + "operator": "airflow.operators.bash_operator.BashOperator", + "bash_command": "echo 1", + "execution_timeout_secs": 5, + }, + "task_2": { + "operator": "airflow.operators.bash_operator.BashOperator", + "bash_command": "echo 2", + "dependencies": ["task_1"], + }, + "task_3": { + "operator": "airflow.operators.bash_operator.BashOperator", + "bash_command": "echo 3", + "dependencies": ["task_1"], + }, + }, +} + +DAG_CONFIG_CALLBACK_NAME_AND_FILE_DEFAULT_ARGS = { + "doc_md": "##here is a doc md string", + "default_args": { + "owner": "custom_owner", + "on_failure_callback_name": "print_context_callback", + "on_failure_callback_file": __file__, + "on_success_callback_name": "print_context_callback", + "on_success_callback_file": __file__, + }, + "description": "this is an example dag", + "schedule_interval": "0 3 * * *", + "tags": ["tag1", "tag2"], + "tasks": { + "task_1": { + "operator": "airflow.operators.bash_operator.BashOperator", + "bash_command": "echo 1", + "execution_timeout_secs": 5, + }, + "task_2": { + "operator": "airflow.operators.bash_operator.BashOperator", + "bash_command": "echo 2", + "dependencies": ["task_1"], + }, + "task_3": { + "operator": "airflow.operators.bash_operator.BashOperator", + "bash_command": "echo 3", + "dependencies": ["task_1"], + }, + }, +} + UTC = pendulum.timezone("UTC") @@ -607,6 +670,46 @@ def test_make_task_with_callback(): assert callable(actual.on_retry_callback) +def test_dag_with_callback_name_and_file(): + td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_CALLBACK_NAME_AND_FILE, DEFAULT_CONFIG) + dag = td.build().get("dag") + + # Verify that the callbacks have been set up properly per DAG after specifying: + # - 'on_success_callback_file' & 'on_success_callback_name' for 'on_success_callback' + # - 'on_failure_callback_file' & 'on_failure_callback_name' for 'on_failure_callback' + assert "on_success_callback" in td.dag_config + assert "on_failure_callback" in td.dag_config + assert callable(td.dag_config["on_success_callback"]) + assert callable(td.dag_config["on_failure_callback"]) + assert td.dag_config["on_success_callback"].__name__ == "print_context_callback" + assert td.dag_config["on_success_callback"].__name__ == "print_context_callback" + + # Ensure that no callbacks were directly provided at the task level. + for td_task_id, td_task in dag.task_dict.items(): + assert not callable(td_task.on_success_callback) + assert not callable(td_task.on_failure_callback) + + +def test_dag_with_callback_name_and_file_default_args(): + td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG_CALLBACK_NAME_AND_FILE_DEFAULT_ARGS, DEFAULT_CONFIG) + dag = td.build().get("dag") + + # Verify that the callbacks have been set up properly per DAG and tasks after specifying through default_args: + # - 'on_success_callback_file' & 'on_success_callback_name' for 'on_success_callback' + # - 'on_failure_callback_file' & 'on_failure_callback_name' for 'on_failure_callback' + td_default_args = td.dag_config.get("default_args") + assert "on_success_callback" in td_default_args + assert "on_failure_callback" in td_default_args + assert callable(td_default_args["on_success_callback"]) + assert callable(td_default_args["on_failure_callback"]) + + for td_task_id, td_task in dag.task_dict.items(): + assert callable(td_task.on_success_callback) + assert callable(td_task.on_failure_callback) + assert td_task.on_success_callback.__name__ == "print_context_callback" + assert td_task.on_success_callback.__name__ == "print_context_callback" + + def test_make_timetable(): if version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"): td = dagbuilder.DagBuilder("test_dag", DAG_CONFIG, DEFAULT_CONFIG) From 5f448166b290cc990130dfdcce566fc2c623e661 Mon Sep 17 00:00:00 2001 From: yauheni Date: Fri, 20 Sep 2024 18:45:12 +0200 Subject: [PATCH 2/2] Added callback_file & callback_name to default_args DAG level and tests: fmt-check --- dagfactory/dagbuilder.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index 4a9509b..f9a12a7 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -257,21 +257,27 @@ def get_dag_params(self) -> Dict[str, Any]: if utils.check_dict_key( dag_params["default_args"], "on_success_callback_name" ) and utils.check_dict_key( - dag_params["default_args"], "on_success_callback_file"): + dag_params["default_args"], "on_success_callback_file" + ): - dag_params["default_args"]["on_success_callback"]: Callable = utils.get_python_callable( - dag_params["default_args"]["on_success_callback_name"], - dag_params["default_args"]["on_success_callback_file"], + dag_params["default_args"]["on_success_callback"]: Callable = ( + utils.get_python_callable( + dag_params["default_args"]["on_success_callback_name"], + dag_params["default_args"]["on_success_callback_file"], + ) ) if utils.check_dict_key( dag_params["default_args"], "on_failure_callback_name" ) and utils.check_dict_key( - dag_params["default_args"], "on_failure_callback_file"): + dag_params["default_args"], "on_failure_callback_file" + ): - dag_params["default_args"]["on_failure_callback"]: Callable = utils.get_python_callable( - dag_params["default_args"]["on_failure_callback_name"], - dag_params["default_args"]["on_failure_callback_file"], + dag_params["default_args"]["on_failure_callback"]: Callable = ( + utils.get_python_callable( + dag_params["default_args"]["on_failure_callback_name"], + dag_params["default_args"]["on_failure_callback_file"], + ) ) if utils.check_dict_key(dag_params, "template_searchpath"):