Skip to content

Commit

Permalink
feat: Support PythonFunctionTask and reorganize agent structure
Browse files Browse the repository at this point in the history
1. Add back `PythonFunctionTask` to support running user-defined functions
    on Slurm
2. Categorize task types into `script/` and `function/`

Signed-off-by: JiaWei Jiang <[email protected]>
  • Loading branch information
JiangJiaWei1103 committed Jan 16, 2025
1 parent e365dee commit c1064d4
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 11 deletions.
5 changes: 2 additions & 3 deletions flytekit/extras/tasks/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,7 @@ def __init__(
fully_qualified_class_name = task_config.__module__ + "." + task_config.__class__.__name__
if fully_qualified_class_name not in [
"flytekitplugins.pod.task.Pod",
"flytekitplugins.slurm.task.Slurm",
"flytekitplugins.slurm.task.SlurmShell",
"flytekitplugins.slurm.script.task.Slurm",
]:
raise ValueError("TaskConfig can either be empty - indicating simple container task or a PodConfig.")

Expand All @@ -263,7 +262,7 @@ def __init__(
# errors.
# This seem like a hack. We should use a plugin_class that doesn't require a fake-function to make work.
plugin_class = TaskPlugins.find_pythontask_plugin(type(task_config))
if plugin_class.__name__ in ["SlurmTask", "SlurmShellTask"]:
if plugin_class.__name__ in ["SlurmShellTask"]:
self._config_task_instance = None
else:
self._config_task_instance = plugin_class(task_config=task_config, task_function=_dummy_task_func)
Expand Down
6 changes: 4 additions & 2 deletions plugins/flytekit-slurm/flytekitplugins/slurm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .agent import SlurmAgent
from .task import Slurm, SlurmRemoteScript, SlurmShellTask, SlurmTask
from .function.agent import SlurmFunctionAgent
from .function.task import SlurmFunction, SlurmFunctionTask
from .script.agent import SlurmScriptAgent
from .script.task import Slurm, SlurmRemoteScript, SlurmShellTask, SlurmTask
115 changes: 115 additions & 0 deletions plugins/flytekit-slurm/flytekitplugins/slurm/function/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from dataclasses import dataclass
from typing import Dict, Optional

import asyncssh
from asyncssh import SSHClientConnection

from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta
from flytekit.extend.backend.utils import convert_to_flyte_phase
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate


@dataclass
class SlurmJobMetadata(ResourceMeta):
"""Slurm job metadata.
Args:
job_id: Slurm job id.
"""

job_id: str
slurm_host: str


class SlurmFunctionAgent(AsyncAgentBase):
name = "Slurm Function Agent"

# SSH connection pool for multi-host environment
_conn: Optional[SSHClientConnection] = None

def __init__(self) -> None:
super(SlurmFunctionAgent, self).__init__(task_type_name="slurm_fn", metadata_type=SlurmJobMetadata)

async def create(
self,
task_template: TaskTemplate,
inputs: Optional[LiteralMap] = None,
**kwargs,
) -> SlurmJobMetadata:
# Retrieve task config
slurm_host = task_template.custom["slurm_host"]
srun_conf = task_template.custom["srun_conf"]

# Construct srun command for Slurm cluster
cmd = _get_srun_cmd(srun_conf=srun_conf, entrypoint=" ".join(task_template.container.args))

# Run Slurm job
if self._conn is None:
await self._connect(slurm_host)
res = await self._conn.run(cmd, check=True)

# Direct return for sbatch
# job_id = res.stdout.split()[-1]
# Use echo trick for srun
job_id = res.stdout.strip()

return SlurmJobMetadata(job_id=job_id, slurm_host=slurm_host)

async def get(self, resource_meta: SlurmJobMetadata, **kwargs) -> Resource:
await self._connect(resource_meta.slurm_host)
res = await self._conn.run(f"scontrol show job {resource_meta.job_id}", check=True)

# Determine the current flyte phase from Slurm job state
job_state = "running"
for o in res.stdout.split(" "):
if "JobState" in o:
job_state = o.split("=")[1].strip().lower()
cur_phase = convert_to_flyte_phase(job_state)

return Resource(phase=cur_phase)

async def delete(self, resource_meta: SlurmJobMetadata, **kwargs) -> None:
await self._connect(resource_meta.slurm_host)
_ = await self._conn.run(f"scancel {resource_meta.job_id}", check=True)

async def _connect(self, slurm_host: str) -> None:
"""Make an SSH client connection."""
self._conn = await asyncssh.connect(host=slurm_host)


def _get_srun_cmd(srun_conf: Dict[str, str], entrypoint: str) -> str:
"""Construct Slurm srun command.
Flyte entrypoint, pyflyte-execute, is run within a bash shell process.
Args:
srun_conf: Options of srun command.
entrypoint: Flyte entrypoint.
Returns:
cmd: Slurm srun command.
"""
# Setup srun options
cmd = ["srun"]
for opt, val in srun_conf.items():
cmd.extend([f"--{opt}", str(val)])

cmd.extend(["bash", "-c"])
cmd = " ".join(cmd)

cmd += f""" '# Setup environment variables
export PATH=$PATH:/opt/anaconda/anaconda3/bin;
# Run pyflyte-execute in a pre-built conda env
source activate dev;
{entrypoint};
# A trick to show Slurm job id on stdout
echo $SLURM_JOB_ID;'
"""

return cmd


AgentRegistry.register(SlurmFunctionAgent())
72 changes: 72 additions & 0 deletions plugins/flytekit-slurm/flytekitplugins/slurm/function/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Slurm task.
"""

from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Union

from flytekit import FlyteContextManager, PythonFunctionTask
from flytekit.configuration import SerializationSettings
from flytekit.extend import TaskPlugins
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
from flytekit.image_spec import ImageSpec


@dataclass
class SlurmFunction(object):
"""Configure Slurm settings. Note that we focus on srun command now.
Compared with spark, please refer to https://api-docs.databricks.com/python/pyspark/latest/api/pyspark.SparkContext.html.
Args:
slurm_host: Slurm host name. We assume there's no default Slurm host now.
srun_conf: Options of srun command.
"""

slurm_host: str
srun_conf: Optional[Dict[str, str]] = None

def __post_init__(self):
if self.srun_conf is None:
self.srun_conf = {}


class SlurmFunctionTask(AsyncAgentExecutorMixin, PythonFunctionTask[SlurmFunction]):
"""
Actual Plugin that transforms the local python code for execution within a slurm context...
"""

_TASK_TYPE = "slurm_fn"

def __init__(
self,
task_config: SlurmFunction,
task_function: Callable,
container_image: Optional[Union[str, ImageSpec]] = None,
**kwargs,
):
super(SlurmFunctionTask, self).__init__(
task_config=task_config,
task_type=self._TASK_TYPE,
task_function=task_function,
container_image=container_image,
**kwargs,
)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
return {
"slurm_host": self.task_config.slurm_host,
"srun_conf": self.task_config.srun_conf,
}

def execute(self, **kwargs) -> Any:
ctx = FlyteContextManager.current_context()
if ctx.execution_state and ctx.execution_state.is_local_execution():
# Mimic the propeller's behavior in local agent test
return AsyncAgentExecutorMixin.execute(self, **kwargs)
else:
# Execute the task with a direct python function call
return PythonFunctionTask.execute(self, **kwargs)


TaskPlugins.register_pythontask_plugin(SlurmFunction, SlurmFunctionTask)
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class SlurmJobMetadata(ResourceMeta):
slurm_host: str


class SlurmAgent(AsyncAgentBase):
name = "Slurm Agent"
class SlurmScriptAgent(AsyncAgentBase):
name = "Slurm Script Agent"

# SSH connection pool for multi-host environment
# _ssh_clients: Dict[str, SSHClientConnection]
Expand All @@ -37,7 +37,7 @@ class SlurmAgent(AsyncAgentBase):
DUMMY_SCRIPT = "#!/bin/bash"

def __init__(self) -> None:
super(SlurmAgent, self).__init__(task_type_name="slurm", metadata_type=SlurmJobMetadata)
super(SlurmScriptAgent, self).__init__(task_type_name="slurm", metadata_type=SlurmJobMetadata)

async def create(
self,
Expand Down Expand Up @@ -133,4 +133,4 @@ def _get_sbatch_cmd(sbatch_conf: Dict[str, str], batch_script_path: str, batch_s
return cmd


AgentRegistry.register(SlurmAgent())
AgentRegistry.register(SlurmScriptAgent())
8 changes: 6 additions & 2 deletions plugins/flytekit-slurm/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
version=__version__,
author="flyteorg",
author_email="[email protected]",
description="Tiny slurm plugin for flytekit",
description="This package holds the Slurm plugins for flytekit",
namespace_packages=["flytekitplugins"],
packages=[f"flytekitplugins.{PLUGIN_NAME}"],
packages=[
f"flytekitplugins.{PLUGIN_NAME}",
f"flytekitplugins.{PLUGIN_NAME}.script",
f"flytekitplugins.{PLUGIN_NAME}.function",
],
install_requires=plugin_requires,
license="apache2",
python_requires=">=3.9",
Expand Down

0 comments on commit c1064d4

Please sign in to comment.