diff --git a/dags/synaptor_dags.py b/dags/synaptor_dags.py index beb4ba3a..13753f19 100644 --- a/dags/synaptor_dags.py +++ b/dags/synaptor_dags.py @@ -82,6 +82,14 @@ def __init__(self, name): self.use_gpus = False +class ManagerTask(Task): + + def __init__(self, name): + self.name = name + self.cluster_key = "manager" + self.use_gpus = False + + def fill_dag(dag: DAG, tasklist: list[Task], collect_metrics: bool = True) -> DAG: """Fills a synaptor DAG from a list of Tasks.""" drain = drain_op(dag) @@ -126,7 +134,7 @@ def fill_dag(dag: DAG, tasklist: list[Task], collect_metrics: bool = True) -> DA def change_cluster_if_required( dag: DAG, prev_cluster_tag: str, curr_operator: BaseOperator, next_task: Task, ) -> tuple[str, BaseOperator]: - if next_task.cluster_key in prev_cluster_tag: + if next_task.cluster_key == "manager" or next_task.cluster_key in prev_cluster_tag: # don't need to change the cluster return prev_cluster_tag, curr_operator @@ -220,6 +228,8 @@ def add_task( if task.name == "self_destruct": cluster_key = cluster_key_from_tag(tag) generate = self_destruct_op(dag, queue=cluster_key, tag=tag) + elif task.cluster_key == "manager": + generate = manager_op(dag, task.name, image=SYNAPTOR_IMAGE) else: generate = generate_op(dag, task.name, image=SYNAPTOR_IMAGE, tag=tag) @@ -262,8 +272,10 @@ def add_task( CPUTask("chunk_ccs"), CPUTask("match_contins"), GraphTask("seg_graph_ccs"), - CPUTask("chunk_seg_map"), + ManagerTask("index_seg_map"), + ManagerTask("chunk_seg_map"), CPUTask("merge_seginfo"), + ManagerTask("index_chunked_seg_map"), GPUTask("chunk_edges"), CPUTask("pick_edge"), CPUTask("merge_dups"),