diff --git a/dpgen2/superop/prep_run_dp_train.py b/dpgen2/superop/prep_run_dp_train.py index 29f5afb3..87bd74d0 100644 --- a/dpgen2/superop/prep_run_dp_train.py +++ b/dpgen2/superop/prep_run_dp_train.py @@ -232,6 +232,7 @@ def _prep_run_dp_train( run_template_config = run_config.pop("template_config") prep_executor = init_executor(prep_config.pop("executor")) run_executor = init_executor(run_config.pop("executor")) + template_slice_config = run_config.pop("template_slice_config", {}) prep_train = Step( "prep-train", @@ -261,6 +262,7 @@ def _prep_run_dp_train( input_parameter=["task_name"], input_artifact=["task_path", "init_model"], output_artifact=["model", "lcurve", "log", "script"], + **template_slice_config, ), python_packages=upload_python_packages, **run_template_config,