-
Notifications
You must be signed in to change notification settings - Fork 308
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Support PythonFunctionTask and reorganize agent structure
1. Add back `PythonFunctionTask` to support running user-defined functions on Slurm 2. Categorize task types into `script/` and `function/` Signed-off-by: JiaWei Jiang <[email protected]>
- Loading branch information
1 parent
e365dee
commit c1064d4
Showing
7 changed files
with
203 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
from .agent import SlurmAgent | ||
from .task import Slurm, SlurmRemoteScript, SlurmShellTask, SlurmTask | ||
from .function.agent import SlurmFunctionAgent | ||
from .function.task import SlurmFunction, SlurmFunctionTask | ||
from .script.agent import SlurmScriptAgent | ||
from .script.task import Slurm, SlurmRemoteScript, SlurmShellTask, SlurmTask |
115 changes: 115 additions & 0 deletions
115
plugins/flytekit-slurm/flytekitplugins/slurm/function/agent.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from dataclasses import dataclass | ||
from typing import Dict, Optional | ||
|
||
import asyncssh | ||
from asyncssh import SSHClientConnection | ||
|
||
from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta | ||
from flytekit.extend.backend.utils import convert_to_flyte_phase | ||
from flytekit.models.literals import LiteralMap | ||
from flytekit.models.task import TaskTemplate | ||
|
||
|
||
@dataclass | ||
class SlurmJobMetadata(ResourceMeta): | ||
"""Slurm job metadata. | ||
Args: | ||
job_id: Slurm job id. | ||
""" | ||
|
||
job_id: str | ||
slurm_host: str | ||
|
||
|
||
class SlurmFunctionAgent(AsyncAgentBase): | ||
name = "Slurm Function Agent" | ||
|
||
# SSH connection pool for multi-host environment | ||
_conn: Optional[SSHClientConnection] = None | ||
|
||
def __init__(self) -> None: | ||
super(SlurmFunctionAgent, self).__init__(task_type_name="slurm_fn", metadata_type=SlurmJobMetadata) | ||
|
||
async def create( | ||
self, | ||
task_template: TaskTemplate, | ||
inputs: Optional[LiteralMap] = None, | ||
**kwargs, | ||
) -> SlurmJobMetadata: | ||
# Retrieve task config | ||
slurm_host = task_template.custom["slurm_host"] | ||
srun_conf = task_template.custom["srun_conf"] | ||
|
||
# Construct srun command for Slurm cluster | ||
cmd = _get_srun_cmd(srun_conf=srun_conf, entrypoint=" ".join(task_template.container.args)) | ||
|
||
# Run Slurm job | ||
if self._conn is None: | ||
await self._connect(slurm_host) | ||
res = await self._conn.run(cmd, check=True) | ||
|
||
# Direct return for sbatch | ||
# job_id = res.stdout.split()[-1] | ||
# Use echo trick for srun | ||
job_id = res.stdout.strip() | ||
|
||
return SlurmJobMetadata(job_id=job_id, slurm_host=slurm_host) | ||
|
||
async def get(self, resource_meta: SlurmJobMetadata, **kwargs) -> Resource: | ||
await self._connect(resource_meta.slurm_host) | ||
res = await self._conn.run(f"scontrol show job {resource_meta.job_id}", check=True) | ||
|
||
# Determine the current flyte phase from Slurm job state | ||
job_state = "running" | ||
for o in res.stdout.split(" "): | ||
if "JobState" in o: | ||
job_state = o.split("=")[1].strip().lower() | ||
cur_phase = convert_to_flyte_phase(job_state) | ||
|
||
return Resource(phase=cur_phase) | ||
|
||
async def delete(self, resource_meta: SlurmJobMetadata, **kwargs) -> None: | ||
await self._connect(resource_meta.slurm_host) | ||
_ = await self._conn.run(f"scancel {resource_meta.job_id}", check=True) | ||
|
||
async def _connect(self, slurm_host: str) -> None: | ||
"""Make an SSH client connection.""" | ||
self._conn = await asyncssh.connect(host=slurm_host) | ||
|
||
|
||
def _get_srun_cmd(srun_conf: Dict[str, str], entrypoint: str) -> str: | ||
"""Construct Slurm srun command. | ||
Flyte entrypoint, pyflyte-execute, is run within a bash shell process. | ||
Args: | ||
srun_conf: Options of srun command. | ||
entrypoint: Flyte entrypoint. | ||
Returns: | ||
cmd: Slurm srun command. | ||
""" | ||
# Setup srun options | ||
cmd = ["srun"] | ||
for opt, val in srun_conf.items(): | ||
cmd.extend([f"--{opt}", str(val)]) | ||
|
||
cmd.extend(["bash", "-c"]) | ||
cmd = " ".join(cmd) | ||
|
||
cmd += f""" '# Setup environment variables | ||
export PATH=$PATH:/opt/anaconda/anaconda3/bin; | ||
# Run pyflyte-execute in a pre-built conda env | ||
source activate dev; | ||
{entrypoint}; | ||
# A trick to show Slurm job id on stdout | ||
echo $SLURM_JOB_ID;' | ||
""" | ||
|
||
return cmd | ||
|
||
|
||
AgentRegistry.register(SlurmFunctionAgent()) |
72 changes: 72 additions & 0 deletions
72
plugins/flytekit-slurm/flytekitplugins/slurm/function/task.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
""" | ||
Slurm task. | ||
""" | ||
|
||
from dataclasses import dataclass | ||
from typing import Any, Callable, Dict, Optional, Union | ||
|
||
from flytekit import FlyteContextManager, PythonFunctionTask | ||
from flytekit.configuration import SerializationSettings | ||
from flytekit.extend import TaskPlugins | ||
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin | ||
from flytekit.image_spec import ImageSpec | ||
|
||
|
||
@dataclass | ||
class SlurmFunction(object): | ||
"""Configure Slurm settings. Note that we focus on srun command now. | ||
Compared with spark, please refer to https://api-docs.databricks.com/python/pyspark/latest/api/pyspark.SparkContext.html. | ||
Args: | ||
slurm_host: Slurm host name. We assume there's no default Slurm host now. | ||
srun_conf: Options of srun command. | ||
""" | ||
|
||
slurm_host: str | ||
srun_conf: Optional[Dict[str, str]] = None | ||
|
||
def __post_init__(self): | ||
if self.srun_conf is None: | ||
self.srun_conf = {} | ||
|
||
|
||
class SlurmFunctionTask(AsyncAgentExecutorMixin, PythonFunctionTask[SlurmFunction]): | ||
""" | ||
Actual Plugin that transforms the local python code for execution within a slurm context... | ||
""" | ||
|
||
_TASK_TYPE = "slurm_fn" | ||
|
||
def __init__( | ||
self, | ||
task_config: SlurmFunction, | ||
task_function: Callable, | ||
container_image: Optional[Union[str, ImageSpec]] = None, | ||
**kwargs, | ||
): | ||
super(SlurmFunctionTask, self).__init__( | ||
task_config=task_config, | ||
task_type=self._TASK_TYPE, | ||
task_function=task_function, | ||
container_image=container_image, | ||
**kwargs, | ||
) | ||
|
||
def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: | ||
return { | ||
"slurm_host": self.task_config.slurm_host, | ||
"srun_conf": self.task_config.srun_conf, | ||
} | ||
|
||
def execute(self, **kwargs) -> Any: | ||
ctx = FlyteContextManager.current_context() | ||
if ctx.execution_state and ctx.execution_state.is_local_execution(): | ||
# Mimic the propeller's behavior in local agent test | ||
return AsyncAgentExecutorMixin.execute(self, **kwargs) | ||
else: | ||
# Execute the task with a direct python function call | ||
return PythonFunctionTask.execute(self, **kwargs) | ||
|
||
|
||
TaskPlugins.register_pythontask_plugin(SlurmFunction, SlurmFunctionTask) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,9 +13,13 @@ | |
version=__version__, | ||
author="flyteorg", | ||
author_email="[email protected]", | ||
description="Tiny slurm plugin for flytekit", | ||
description="This package holds the Slurm plugins for flytekit", | ||
namespace_packages=["flytekitplugins"], | ||
packages=[f"flytekitplugins.{PLUGIN_NAME}"], | ||
packages=[ | ||
f"flytekitplugins.{PLUGIN_NAME}", | ||
f"flytekitplugins.{PLUGIN_NAME}.script", | ||
f"flytekitplugins.{PLUGIN_NAME}.function", | ||
], | ||
install_requires=plugin_requires, | ||
license="apache2", | ||
python_requires=">=3.9", | ||
|