Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Slurm jobs #20

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions flows/conda/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ class CondaParams(BaseModel):
description="Python file to run", default="src/train.py"
)
params: Optional[dict] = {}

class Config:
extra = "forbid"
9 changes: 8 additions & 1 deletion flows/parent_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,20 @@
from flows.conda.schema import CondaParams
from flows.podman.podman_flows import launch_podman
from flows.podman.schema import PodmanParams
from flows.slurm.schema import SlurmParams
from flows.slurm.slurm_flows import launch_slurm


class FlowType(str, Enum):
podman = "podman"
conda = "conda"
slurm = "slurm"


@flow(name="Parent flow")
async def launch_parent_flow(
flow_type: FlowType,
params_list: list[Union[PodmanParams, CondaParams]],
params_list: list[Union[PodmanParams, CondaParams, SlurmParams]],
):
prefect_logger = get_run_logger()

Expand All @@ -31,6 +34,10 @@ async def launch_parent_flow(
flow_run_id = await launch_conda(
conda_params=params, prev_flow_run_id=flow_run_id
)
elif flow_type == FlowType.slurm:
flow_run_id = await launch_slurm(
slurm_params=params, prev_flow_run_id=flow_run_id
)
else:
prefect_logger.error("Flow type not supported")
raise ValueError("Flow type not supported")
Expand Down
104 changes: 104 additions & 0 deletions flows/slurm/run_slurm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#!/bin/bash

# Check if all arguments are provided
if [ $# -ne 10 ]; then
echo "Usage: $0 <job_name> <num_nodes> <partitions> <reservations> <max_time> <conda_env> <forward_ports> <submission_ssh_key> <python_file> <yaml_file>"
fi

# Assign arguments to variables
job_name=$1
num_nodes=$2
partitions=$3
reservations=$4
max_time=$5
conda_env=$6
forward_ports=$7
submission_ssh_key=$8
python_file=$9
yaml_file=${10}

# Create a temporary Slurm batch script
BATCH_SCRIPT=$(mktemp ./tmp_script.XXXXXX)

# Write the Slurm batch script
echo "#!/bin/bash" > $BATCH_SCRIPT

if [ -z "$partitions" ]
then
echo "Partitions is unset or empty"
else
echo "#SBATCH --partition=$partitions" >> $BATCH_SCRIPT
fi

if [ -z "$reservations" ]
then
echo "Reservations is unset or empty"
else
echo "#SBATCH --reservation=$reservations" >> $BATCH_SCRIPT
fi

JOB_SUBMISSION_HOST=$(hostname)

echo "#SBATCH --nodes=$num_nodes" >> $BATCH_SCRIPT
echo "#SBATCH --job-name=$job_name" >> $BATCH_SCRIPT
echo "#SBATCH --time=$max_time" >> $BATCH_SCRIPT
echo "unset LD_PRELOAD" >> $BATCH_SCRIPT
echo "source /etc/profile.d/modules.sh" >> $BATCH_SCRIPT
echo "module load maxwell mamba" >> $BATCH_SCRIPT
echo ". mamba-init" >> $BATCH_SCRIPT
echo "mamba activate $conda_env" >> $BATCH_SCRIPT
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
echo "mamba activate $conda_env" >> $BATCH_SCRIPT
echo "mamba activate $conda_env" >> $BATCH_SCRIPT
if [ ! -z "$forward_ports" ]; then
echo -n "ssh " >> $BATCH_SCRIPT
# If an ssh key was specified, pass it as an additional argument
if [ ! -z "$submission_ssh_key" ]; then
echo -n "-i $submission_ssh_key " >> $BATCH_SCRIPT
fi
# Loop over forward_ports and append each forwarding rule
for port in $forward_ports; do
echo -n "-L $port:localhost:8000 " >> $BATCH_SCRIPT
done
# -N tells SSH to not execute a remote command
echo -n "-N " >> $BATCH_SCRIPT
# Finally, specify host to tunnel to
# & at the end of the line makes the command run in the background
echo -n "$USER@$JOB_SUBMISSION_HOST &" >> $BATCH_SCRIPT
fi

if [ ! -z "$forward_ports" ]; then
echo -n "ssh " >> $BATCH_SCRIPT
# If an ssh key was specified, pass it as an additional argument
if [ ! -z "$submission_ssh_key" ]; then
echo -n "-i $submission_ssh_key " >> $BATCH_SCRIPT
fi
# Loop over forward_ports and append each forwarding rule
IFS=',' read -ra ADDR <<< "$forward_ports"
for port_pair in "${ADDR[@]}"; do
echo $port_pair
local_port=${port_pair%:*}
remote_port=${port_pair#*:}
echo -n "-L $local_port:localhost:$remote_port " >> $BATCH_SCRIPT
done
# -N tells SSH to not execute a remote command
echo -n "-N " >> $BATCH_SCRIPT
# Finally, specify host to tunnel to
# & at the end of the line makes the command run in the background
echo "$USER@$JOB_SUBMISSION_HOST &" >> $BATCH_SCRIPT
fi
echo "srun python $python_file $yaml_file" >> $BATCH_SCRIPT
echo $BATCH_SCRIPT

# Submit the Slurm batch script and capture the job ID
JOB_ID=$(sbatch --export=ALL,JOB_SUBMISSION_HOST=$JOB_SUBMISSION_HOST,JOB_SUBMISSION_SSH_KEY="$submission_ssh_key" $BATCH_SCRIPT | awk '{print $4}')

# Print the job ID
echo "Submitted job with ID $JOB_ID"

# Track the progress of the job
while true; do
# Get the job status
JOB_STATUS=$(sacct -j $JOB_ID --format=State --noheader | head -1 | awk '{print $1}')

if [[ $JOB_STATUS == *"COMPLETED"* ]]; then
# If the job is completed, break the loop
echo "Job $JOB_ID has completed"
break
elif [[ $JOB_STATUS == *"FAILED"* ]]; then
# If the job has failed, print an error message and exit with a non-zero status
echo "Job $JOB_ID has failed" >&2

# Remove the temporary Slurm batch script
rm $BATCH_SCRIPT

exit 1
else
# If the job is neither completed nor failed, print its status
echo "Job $JOB_ID is $JOB_STATUS"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something that currently needs to more work is to retrieve the slurm job logs while it is running. We could do this by reading the slurm-{job_id}.out file, but I wonder if there are alternative ways to do this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wiebke I think you already added this. If you did, can you add it to this PR please?

fi
sleep 10
done

# Remove the temporary Slurm batch script
rm $BATCH_SCRIPT
19 changes: 19 additions & 0 deletions flows/slurm/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import List, Optional

from pydantic import BaseModel, Field


class SlurmParams(BaseModel):
job_name: str
num_nodes: int
partitions: Optional[List[str]] = []
reservations: Optional[List[str]] = []
max_time: str = Field(pattern=r"^([0-1]?[0-9]|2[0-3]):[0-5][0-9]:[0-5][0-9]$")
conda_env_name: str
Wiebke marked this conversation as resolved.
Show resolved Hide resolved
# TODO: Enforce port pair format
forward_ports: Optional[List[str]] = []
submission_ssh_key: Optional[str] = None
python_file_name: str = Field(
description="Python file to run", default="src/train.py"
)
params: Optional[dict] = {}
76 changes: 75 additions & 1 deletion flows/slurm/slurm_flows.py
Original file line number Diff line number Diff line change
@@ -1 +1,75 @@
# TODO
# TODO: Check pyslurm: https://github.com/PySlurm/pyslurm/tree/main
import sys
import tempfile

import yaml
from prefect import context, flow, get_run_logger
from prefect.states import Failed
from prefect.utilities.processutils import run_process

from flows.slurm.schema import SlurmParams


class Logger:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be refactored into something common across the flows?

def __init__(self, logger, level="info"):
self.logger = getattr(logger, level)

def write(self, message):
if message != "\n":
self.logger(message)

def flush(self):
pass


def setup_logger():
"""
Adopt stdout and stderr to prefect logger
"""
prefect_logger = get_run_logger()
sys.stdout = Logger(prefect_logger, level="info")
sys.stderr = Logger(prefect_logger, level="error")
return prefect_logger


@flow(name="launch_slurm")
async def launch_slurm(
slurm_params: SlurmParams,
prev_flow_run_id: str = None,
):
logger = setup_logger()

if prev_flow_run_id:
# Append the previous flow run id to parameters if provided
slurm_params.params["io_parameters"]["uid_retrieve"] = prev_flow_run_id

current_flow_run_id = str(context.get_run_context().flow_run.id)

# Append current flow run id
slurm_params.params["io_parameters"]["uid_save"] = current_flow_run_id

# Create temporary file for parameters
with tempfile.NamedTemporaryFile(mode="w+t", dir=".") as temp_file:
yaml.dump(slurm_params.params, temp_file)
# Define conda command
cmd = [
"flows/slurm/run_slurm.sh",
slurm_params.job_name,
str(slurm_params.num_nodes),
",".join(slurm_params.partitions),
",".join(slurm_params.reservations),
slurm_params.max_time,
slurm_params.conda_env_name,
",".join(slurm_params.forward_ports),
slurm_params.submission_ssh_key,
slurm_params.python_file_name,
temp_file.name,
]
logger.info(f"Launching with command: {cmd}")
process = await run_process(cmd, stream_output=True)

if process.returncode != 0:
return Failed(message="Slurm command failed")
pass

return current_flow_run_id
14 changes: 13 additions & 1 deletion prefect.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,21 @@ deployments:
job_variables: {}
schedule:
is_schedule_active: true
- name: launch_parent_flow
- name: launch_slurm
version: 0.0.1
tags: []
description: Launch slurm job
entrypoint: flows/slurm/slurm_flows.py:launch_slurm
parameters: {}
work_pool:
name: mlex_pool
work_queue_name: default-queue
job_variables: {}
schedule:
is_schedule_active: true
- name: launch_parent_flow
version: 0.0.2
tags: []
description: Launch parent flow
entrypoint: flows/parent_flow.py:launch_parent_flow
parameters: {}
Expand Down