-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Slurm jobs #20
base: main
Are you sure you want to change the base?
Changes from all commits
60d3033
a2cd6a1
b788a7e
a582f8a
045b35c
147874c
e48334a
0a5c8ff
fe2be00
66101a7
dcbc4fd
02efa7d
4a9362d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
#!/bin/bash | ||
|
||
# Check if all arguments are provided | ||
if [ $# -ne 10 ]; then | ||
echo "Usage: $0 <job_name> <num_nodes> <partitions> <reservations> <max_time> <conda_env> <forward_ports> <submission_ssh_key> <python_file> <yaml_file>" | ||
fi | ||
|
||
# Assign arguments to variables | ||
job_name=$1 | ||
num_nodes=$2 | ||
partitions=$3 | ||
reservations=$4 | ||
max_time=$5 | ||
conda_env=$6 | ||
forward_ports=$7 | ||
submission_ssh_key=$8 | ||
python_file=$9 | ||
yaml_file=${10} | ||
|
||
# Create a temporary Slurm batch script | ||
BATCH_SCRIPT=$(mktemp ./tmp_script.XXXXXX) | ||
|
||
# Write the Slurm batch script | ||
echo "#!/bin/bash" > $BATCH_SCRIPT | ||
|
||
if [ -z "$partitions" ] | ||
then | ||
echo "Partitions is unset or empty" | ||
else | ||
echo "#SBATCH --partition=$partitions" >> $BATCH_SCRIPT | ||
fi | ||
|
||
if [ -z "$reservations" ] | ||
then | ||
echo "Reservations is unset or empty" | ||
else | ||
echo "#SBATCH --reservation=$reservations" >> $BATCH_SCRIPT | ||
fi | ||
|
||
JOB_SUBMISSION_HOST=$(hostname) | ||
|
||
echo "#SBATCH --nodes=$num_nodes" >> $BATCH_SCRIPT | ||
echo "#SBATCH --job-name=$job_name" >> $BATCH_SCRIPT | ||
echo "#SBATCH --time=$max_time" >> $BATCH_SCRIPT | ||
echo "unset LD_PRELOAD" >> $BATCH_SCRIPT | ||
echo "source /etc/profile.d/modules.sh" >> $BATCH_SCRIPT | ||
echo "module load maxwell mamba" >> $BATCH_SCRIPT | ||
echo ". mamba-init" >> $BATCH_SCRIPT | ||
echo "mamba activate $conda_env" >> $BATCH_SCRIPT | ||
if [ ! -z "$forward_ports" ]; then | ||
echo -n "ssh " >> $BATCH_SCRIPT | ||
# If an ssh key was specified, pass it as an additional argument | ||
if [ ! -z "$submission_ssh_key" ]; then | ||
echo -n "-i $submission_ssh_key " >> $BATCH_SCRIPT | ||
fi | ||
# Loop over forward_ports and append each forwarding rule | ||
IFS=',' read -ra ADDR <<< "$forward_ports" | ||
for port_pair in "${ADDR[@]}"; do | ||
echo $port_pair | ||
local_port=${port_pair%:*} | ||
remote_port=${port_pair#*:} | ||
echo -n "-L $local_port:localhost:$remote_port " >> $BATCH_SCRIPT | ||
done | ||
# -N tells SSH to not execute a remote command | ||
echo -n "-N " >> $BATCH_SCRIPT | ||
# Finally, specify host to tunnel to | ||
# & at the end of the line makes the command run in the background | ||
echo "$USER@$JOB_SUBMISSION_HOST &" >> $BATCH_SCRIPT | ||
fi | ||
echo "srun python $python_file $yaml_file" >> $BATCH_SCRIPT | ||
echo $BATCH_SCRIPT | ||
|
||
# Submit the Slurm batch script and capture the job ID | ||
JOB_ID=$(sbatch --export=ALL,JOB_SUBMISSION_HOST=$JOB_SUBMISSION_HOST,JOB_SUBMISSION_SSH_KEY="$submission_ssh_key" $BATCH_SCRIPT | awk '{print $4}') | ||
|
||
# Print the job ID | ||
echo "Submitted job with ID $JOB_ID" | ||
|
||
# Track the progress of the job | ||
while true; do | ||
# Get the job status | ||
JOB_STATUS=$(sacct -j $JOB_ID --format=State --noheader | head -1 | awk '{print $1}') | ||
|
||
if [[ $JOB_STATUS == *"COMPLETED"* ]]; then | ||
# If the job is completed, break the loop | ||
echo "Job $JOB_ID has completed" | ||
break | ||
elif [[ $JOB_STATUS == *"FAILED"* ]]; then | ||
# If the job has failed, print an error message and exit with a non-zero status | ||
echo "Job $JOB_ID has failed" >&2 | ||
|
||
# Remove the temporary Slurm batch script | ||
rm $BATCH_SCRIPT | ||
|
||
exit 1 | ||
else | ||
# If the job is neither completed nor failed, print its status | ||
echo "Job $JOB_ID is $JOB_STATUS" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Something that currently needs to more work is to retrieve the slurm job logs while it is running. We could do this by reading the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Wiebke I think you already added this. If you did, can you add it to this PR please? |
||
fi | ||
sleep 10 | ||
done | ||
|
||
# Remove the temporary Slurm batch script | ||
rm $BATCH_SCRIPT |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from typing import List, Optional | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
|
||
class SlurmParams(BaseModel): | ||
job_name: str | ||
num_nodes: int | ||
partitions: Optional[List[str]] = [] | ||
reservations: Optional[List[str]] = [] | ||
max_time: str = Field(pattern=r"^([0-1]?[0-9]|2[0-3]):[0-5][0-9]:[0-5][0-9]$") | ||
conda_env_name: str | ||
Wiebke marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# TODO: Enforce port pair format | ||
forward_ports: Optional[List[str]] = [] | ||
submission_ssh_key: Optional[str] = None | ||
python_file_name: str = Field( | ||
description="Python file to run", default="src/train.py" | ||
) | ||
params: Optional[dict] = {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,75 @@ | ||
# TODO | ||
# TODO: Check pyslurm: https://github.com/PySlurm/pyslurm/tree/main | ||
import sys | ||
import tempfile | ||
|
||
import yaml | ||
from prefect import context, flow, get_run_logger | ||
from prefect.states import Failed | ||
from prefect.utilities.processutils import run_process | ||
|
||
from flows.slurm.schema import SlurmParams | ||
|
||
|
||
class Logger: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this be refactored into something common across the flows? |
||
def __init__(self, logger, level="info"): | ||
self.logger = getattr(logger, level) | ||
|
||
def write(self, message): | ||
if message != "\n": | ||
self.logger(message) | ||
|
||
def flush(self): | ||
pass | ||
|
||
|
||
def setup_logger(): | ||
""" | ||
Adopt stdout and stderr to prefect logger | ||
""" | ||
prefect_logger = get_run_logger() | ||
sys.stdout = Logger(prefect_logger, level="info") | ||
sys.stderr = Logger(prefect_logger, level="error") | ||
return prefect_logger | ||
|
||
|
||
@flow(name="launch_slurm") | ||
async def launch_slurm( | ||
slurm_params: SlurmParams, | ||
prev_flow_run_id: str = None, | ||
): | ||
logger = setup_logger() | ||
|
||
if prev_flow_run_id: | ||
# Append the previous flow run id to parameters if provided | ||
slurm_params.params["io_parameters"]["uid_retrieve"] = prev_flow_run_id | ||
|
||
current_flow_run_id = str(context.get_run_context().flow_run.id) | ||
|
||
# Append current flow run id | ||
slurm_params.params["io_parameters"]["uid_save"] = current_flow_run_id | ||
|
||
# Create temporary file for parameters | ||
with tempfile.NamedTemporaryFile(mode="w+t", dir=".") as temp_file: | ||
yaml.dump(slurm_params.params, temp_file) | ||
# Define conda command | ||
cmd = [ | ||
"flows/slurm/run_slurm.sh", | ||
slurm_params.job_name, | ||
str(slurm_params.num_nodes), | ||
",".join(slurm_params.partitions), | ||
",".join(slurm_params.reservations), | ||
slurm_params.max_time, | ||
slurm_params.conda_env_name, | ||
",".join(slurm_params.forward_ports), | ||
slurm_params.submission_ssh_key, | ||
slurm_params.python_file_name, | ||
temp_file.name, | ||
] | ||
logger.info(f"Launching with command: {cmd}") | ||
process = await run_process(cmd, stream_output=True) | ||
|
||
if process.returncode != 0: | ||
return Failed(message="Slurm command failed") | ||
pass | ||
|
||
return current_flow_run_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.