Skip to content

Commit

Permalink
Remove workers when inference is done
Browse files Browse the repository at this point in the history
Set TASK_NUM to 1 so the dag can update with one worker. Use a static
tag for the scaling operators so they won't rerun after the parameter
change.
  • Loading branch information
ranlu committed Mar 24, 2024
1 parent 6649b63 commit 0cbba42
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions dags/chunkflow_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,13 @@ def setup_env_op(dag, param, queue):
dag=dag
)

def remove_workers():
from time import sleep
param = Variable.get("inference_param", deserialize_json=True)
if param["TASK_NUM"] > 1:
param["TASK_NUM"] = 1
Variable.set("inference_param", param, serialize_json=True)
sleep(60)


def inference_op(dag, param, queue, wid):
Expand Down Expand Up @@ -513,8 +520,8 @@ def process_output(**kwargs):


try:
scale_up_cluster_task = scale_up_cluster_op(dag_worker, "chunkflow", "gpu", min(param.get("TASK_NUM",1), 20), estimate_worker_instances(param.get("TASK_NUM",1), cluster_info["gpu"]), "cluster")
scale_down_cluster_task = scale_down_cluster_op(dag_worker, "chunkflow", "gpu", 0, "cluster")
scale_up_cluster_task = scale_up_cluster_op(dag_worker, "chunkflow", "gpu", min(param.get("TASK_NUM",1), 20), estimate_worker_instances(param.get("TASK_NUM",1), cluster_info["gpu"]), "cluster", tag="up")
scale_down_cluster_task = scale_down_cluster_op(dag_worker, "chunkflow", "gpu", 0, "cluster", tag="down")
except:
scale_up_cluster_task = placeholder_op(dag_worker, "chunkflow_gpu_scale_up_dummy")
scale_down_cluster_task = placeholder_op(dag_worker, "chunkflow_gpu_scale_down_dummy")
Expand All @@ -529,6 +536,15 @@ def process_output(**kwargs):
dag=dag_worker
)

remove_workers_op = PythonOperator(
task_id="remove_extra_workers",
python_callable=remove_workers,
priority_weight=100000,
weight_rule=WeightRule.ABSOLUTE,
queue="manager",
dag=dag_worker
)

generate_ng_link_task = PythonOperator(
task_id="generate_ng_link",
python_callable=generate_ng_link,
Expand Down Expand Up @@ -562,4 +578,4 @@ def process_output(**kwargs):

[setup_redis_task, update_mount_secrets_op] >> sanity_check_task >> image_parameters >> set_env_task >> process_output_task

scale_up_cluster_task >> wait_for_chunkflow_task >> mark_done_task >> generate_ng_link_task >> scale_down_cluster_task
scale_up_cluster_task >> wait_for_chunkflow_task >> remove_workers_op >> mark_done_task >> generate_ng_link_task >> scale_down_cluster_task

0 comments on commit 0cbba42

Please sign in to comment.