Skip to content

Commit

Permalink
Merge pull request #82 from seung-lab/seggraph
Browse files Browse the repository at this point in the history
Improve synaptor integration
  • Loading branch information
ranlu authored Mar 21, 2024
2 parents d65426e + 6d9aaef commit 02417d6
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 48 deletions.
17 changes: 12 additions & 5 deletions dags/synaptor_dags.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""DAG definition for synaptor workflows."""
from typing import Optional
from datetime import datetime
from datetime import datetime, timedelta
from dataclasses import dataclass

from airflow import DAG
Expand Down Expand Up @@ -28,6 +28,9 @@
"start_date": datetime(2022, 2, 22),
"catchup": False,
"retries": 0,
'retry_delay': timedelta(seconds=10),
'retry_exponential_backoff': True,
'max_retry_delay': timedelta(seconds=600),
}


Expand Down Expand Up @@ -86,6 +89,10 @@ def fill_dag(dag: DAG, tasklist: list[Task], collect_metrics: bool = True) -> DA

drain >> init_cloudvols

if WORKFLOW_PARAMS.get("workspacetype", "File") == "Database":
init_db = manager_op(dag, "init_db", image=SYNAPTOR_IMAGE)
init_cloudvols >> init_db

if collect_metrics:
metrics = collect_metrics_op(dag)
metrics >> drain
Expand Down Expand Up @@ -142,7 +149,7 @@ def change_cluster_if_required(
MAX_CLUSTER_SIZE if next_task.cluster_key != "synaptor-seggraph" else 1
)
scale_up = scale_up_cluster_op(
dag, new_tag, next_task.cluster_key, 1, cluster_size, "cluster"
dag, new_tag, next_task.cluster_key, min(10, cluster_size), cluster_size, "cluster"
)

workers = [
Expand Down Expand Up @@ -178,8 +185,9 @@ def scale_down_cluster(
# cluster sub-dag
cluster_key = cluster_key_from_tag(prev_cluster_tag)
scale_down = scale_down_cluster_op(dag, prev_cluster_tag, cluster_key, 0, "cluster")
cluster_size = 1 if prev_cluster_tag.startswith("synaptor-seggraph") else MAX_CLUSTER_SIZE
prev_workers = [
dag.get_task(f"worker_{prev_cluster_tag}_{i}") for i in range(MAX_CLUSTER_SIZE)
dag.get_task(f"worker_{prev_cluster_tag}_{i}") for i in range(cluster_size)
]

prev_workers >> scale_down
Expand Down Expand Up @@ -246,10 +254,9 @@ def add_task(
db_assignment = [
CPUTask("chunk_ccs"),
CPUTask("match_contins"),
CPUTask("seg_graph_ccs"),
GraphTask("seg_graph_ccs"),
CPUTask("chunk_seg_map"),
CPUTask("merge_seginfo"),
GraphTask("seg_graph_ccs"),
GPUTask("chunk_edges"),
CPUTask("pick_edge"),
CPUTask("merge_dups"),
Expand Down
65 changes: 23 additions & 42 deletions dags/synaptor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,23 @@
from worker_op import worker_op
from param_default import default_synaptor_image
from igneous_and_cloudvolume import check_queue, upload_json, read_single_file
from slack_message import task_failure_alert, task_done_alert, slack_message
from slack_message import task_failure_alert, task_retry_alert, task_done_alert, slack_message
from nglinks import ImageLayer, SegLayer, generate_ng_payload, wrap_payload
from kombu_helper import drain_messages

from airflow import configuration as conf

airflow_broker_url = conf.get("celery", "broker_url")

maybe_aws = Variable.get("aws-secret.json", None)
maybe_gcp = Variable.get("google-secret.json", None)

mount_variables = ["synaptor_param.json"]

if maybe_aws is not None:
mount_variables.append("aws-secret.json")
if maybe_gcp is not None:
mount_variables.append("google-secret.json")

# hard-coding these for now
MOUNT_POINT = "/root/.cloudvolume/secrets/"
Expand Down Expand Up @@ -177,15 +190,12 @@ def drain_op(
queue: Optional[str] = "manager",
) -> PythonOperator:
"""Drains leftover messages from the RabbitMQ."""
from airflow import configuration as conf

broker_url = conf.get("celery", "broker_url")

return PythonOperator(
task_id="drain_messages",
python_callable=drain_messages,
priority_weight=100_000,
op_args=(broker_url, task_queue_name),
op_args=(airflow_broker_url, task_queue_name),
weight_rule=WeightRule.ABSOLUTE,
on_failure_callback=task_failure_alert,
on_success_callback=task_done_alert,
Expand All @@ -204,11 +214,8 @@ def manager_op(
config_path = os.path.join(MOUNT_POINT, "synaptor_param.json")
command = f"{synaptor_task_name} {config_path}"

# these variables will be mounted in the containers
variables = ["synaptor_param.json"]

return worker_op(
variables=variables,
variables=mount_variables,
mount_point=MOUNT_POINT,
task_id=synaptor_task_name,
command=command,
Expand All @@ -232,24 +239,21 @@ def generate_op(
image: str = default_synaptor_image,
) -> BaseOperator:
"""Generates tasks to run and adds them to the RabbitMQ."""
from airflow import configuration as conf

broker_url = conf.get("celery", "broker_url")
config_path = os.path.join(MOUNT_POINT, "synaptor_param.json")

command = (
f"generate {taskname} {config_path}"
f" --queueurl {broker_url}"
f" --queueurl {airflow_broker_url}"
f" --queuename {task_queue_name}"
)
if taskname == "self_destruct":
command += f" --clusterkey {tag}"

# these variables will be mounted in the containers
variables = add_secrets_if_defined(["synaptor_param.json"])

task_id = f"generate_{taskname}" if tag is None else f"generate_{taskname}_{tag}"

return worker_op(
variables=variables,
variables=mount_variables,
mount_point=MOUNT_POINT,
task_id=task_id,
command=command,
Expand All @@ -274,25 +278,19 @@ def synaptor_op(
image: str = default_synaptor_image,
) -> BaseOperator:
"""Runs a synaptor worker until it receives a self-destruct task."""
from airflow import configuration as conf

broker_url = conf.get("celery", "broker_url")
config_path = os.path.join(MOUNT_POINT, "synaptor_param.json")

command = (
f"worker --configfilename {config_path}"
f" --queueurl {broker_url} "
f" --queueurl {airflow_broker_url} "
f" --queuename {task_queue_name}"
" --lease_seconds 300"
)

# these variables will be mounted in the containers
variables = add_secrets_if_defined(["synaptor_param.json"])

task_id = f"worker_{i}" if tag is None else f"worker_{tag}_{i}"

return worker_op(
variables=variables,
variables=mount_variables,
mount_point=MOUNT_POINT,
task_id=task_id,
command=command,
Expand All @@ -301,6 +299,7 @@ def synaptor_op(
image=image,
priority_weight=100_000,
weight_rule=WeightRule.ABSOLUTE,
on_retry_callback=task_retry_alert,
queue=op_queue_name,
dag=dag,
# qos='quality of service'
Expand All @@ -324,21 +323,3 @@ def wait_op(dag: DAG, taskname: str) -> PythonOperator:
queue="manager",
dag=dag,
)


# Helper functions
def add_secrets_if_defined(variables: list[str]) -> list[str]:
"""Adds CloudVolume secret files to the mounted variables if defined.
Synaptor still needs to store the google-secret.json file sometimes
bc it currently uses an old version of gsutil.
"""
maybe_aws = Variable.get("aws-secret.json", None)
maybe_gcp = Variable.get("google-secret.json", None)

if maybe_aws is not None:
variables.append("aws-secret.json")
if maybe_gcp is not None:
variables.append("google-secret.json")

return variables
2 changes: 1 addition & 1 deletion slackbot/airflow_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def wait_for_dag_refresh(dag_id):


def run_dag(dag_id, wait_for_completion=False):
dags_need_refresh = ["segmentation", "chunkflow_worker"]
dags_need_refresh = ["segmentation", "chunkflow_worker", "synaptor"]
if dag_id in dags_need_refresh:
wait_for_dag_refresh(dag_id)
dagrun = run_in_executor(__run_dag, dag_id)
Expand Down

0 comments on commit 02417d6

Please sign in to comment.