Skip to content

Commit

Permalink
Add request and limit to ray config (#3087)
Browse files Browse the repository at this point in the history
* Add request and limit to ray config

Signed-off-by: Kevin Su <[email protected]>

* lint

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* update flytekit version

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

---------

Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
pingsutw and eapolinario authored Jan 27, 2025
1 parent f0ba47f commit 4208a64
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 32 deletions.
12 changes: 6 additions & 6 deletions flytekit/core/resources.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, fields
from typing import Any, List, Optional, Union
from typing import List, Optional, Union

from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements
from mashumaro.mixins.json import DataClassJSONMixin
Expand Down Expand Up @@ -103,11 +103,11 @@ def convert_resources_to_resource_model(


def pod_spec_from_resources(
k8s_pod_name: str,
primary_container_name: Optional[str] = None,
requests: Optional[Resources] = None,
limits: Optional[Resources] = None,
k8s_gpu_resource_key: str = "nvidia.com/gpu",
) -> dict[str, Any]:
) -> V1PodSpec:
def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resource_key: str):
if resources is None:
return None
Expand All @@ -133,10 +133,10 @@ def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resour
requests = requests or limits
limits = limits or requests

k8s_pod = V1PodSpec(
pod_spec = V1PodSpec(
containers=[
V1Container(
name=k8s_pod_name,
name=primary_container_name,
resources=V1ResourceRequirements(
requests=requests,
limits=limits,
Expand All @@ -145,4 +145,4 @@ def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resour
]
)

return k8s_pod.to_dict()
return pod_spec
20 changes: 20 additions & 0 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from flyteidl.core import tasks_pb2 as _core_task
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
from kubernetes.client import ApiClient

from flytekit.models import common as _common
from flytekit.models import interface as _interface
Expand All @@ -16,6 +17,9 @@
from flytekit.models.core import identifier as _identifier
from flytekit.models.documentation import Documentation

if typing.TYPE_CHECKING:
from flytekit import PodTemplate


class Resources(_common.FlyteIdlEntity):
class ResourceName(object):
Expand Down Expand Up @@ -1042,6 +1046,22 @@ def from_flyte_idl(cls, pb2_object: _core_task.K8sPod):
else None,
)

def to_pod_template(self) -> "PodTemplate":
from flytekit import PodTemplate

return PodTemplate(
labels=self.metadata.labels,
annotations=self.metadata.annotations,
pod_spec=self.pod_spec,
)

@classmethod
def from_pod_template(cls, pod_template: "PodTemplate") -> "K8sPod":
return cls(
metadata=K8sObjectMetadata(labels=pod_template.labels, annotations=pod_template.annotations),
pod_spec=ApiClient().sanitize_for_serialization(pod_template.pod_spec),
)


class Sql(_common.FlyteIdlEntity):
class Dialect(object):
Expand Down
65 changes: 53 additions & 12 deletions plugins/flytekit-ray/flytekitplugins/ray/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,30 @@
)
from google.protobuf.json_format import MessageToDict

from flytekit import lazy_module
from flytekit import PodTemplate, Resources, lazy_module
from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.resources import pod_spec_from_resources
from flytekit.extend import TaskPlugins
from flytekit.models.task import K8sPod

ray = lazy_module("ray")
_RAY_HEAD_CONTAINER_NAME = "ray-head"
_RAY_WORKER_CONTAINER_NAME = "ray-worker"


@dataclass
class HeadNodeConfig:
ray_start_params: typing.Optional[typing.Dict[str, str]] = None
k8s_pod: typing.Optional[K8sPod] = None
pod_template: typing.Optional[PodTemplate] = None
requests: Optional[Resources] = None
limits: Optional[Resources] = None

def __post_init__(self):
if self.pod_template:
if self.requests and self.limits:
raise ValueError("Cannot specify both pod_template and requests/limits")


@dataclass
Expand All @@ -37,7 +47,14 @@ class WorkerNodeConfig:
min_replicas: typing.Optional[int] = None
max_replicas: typing.Optional[int] = None
ray_start_params: typing.Optional[typing.Dict[str, str]] = None
k8s_pod: typing.Optional[K8sPod] = None
pod_template: typing.Optional[PodTemplate] = None
requests: Optional[Resources] = None
limits: Optional[Resources] = None

def __post_init__(self):
if self.pod_template:
if self.requests and self.limits:
raise ValueError("Cannot specify both pod_template and requests/limits")


@dataclass
Expand Down Expand Up @@ -83,25 +100,49 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:

def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]:
cfg = self._task_config

# Deprecated: runtime_env is removed KubeRay >= 1.1.0. It is replaced by runtime_env_yaml
runtime_env = base64.b64encode(json.dumps(cfg.runtime_env).encode()).decode() if cfg.runtime_env else None

runtime_env_yaml = yaml.dump(cfg.runtime_env) if cfg.runtime_env else None

if cfg.head_node_config.requests or cfg.head_node_config.limits:
head_pod_template = PodTemplate(
pod_spec=pod_spec_from_resources(
primary_container_name=_RAY_HEAD_CONTAINER_NAME,
requests=cfg.head_node_config.requests,
limits=cfg.head_node_config.limits,
)
)
else:
head_pod_template = cfg.head_node_config.pod_template

worker_group_spec: typing.List[WorkerGroupSpec] = []
for c in cfg.worker_node_config:
if c.requests or c.limits:
worker_pod_template = PodTemplate(
pod_spec=pod_spec_from_resources(
primary_container_name=_RAY_WORKER_CONTAINER_NAME,
requests=c.requests,
limits=c.limits,
)
)
else:
worker_pod_template = c.pod_template
k8s_pod = K8sPod.from_pod_template(worker_pod_template) if worker_pod_template else None
worker_group_spec.append(
WorkerGroupSpec(c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params, k8s_pod)
)

ray_job = RayJob(
ray_cluster=RayCluster(
head_group_spec=(
HeadGroupSpec(cfg.head_node_config.ray_start_params, cfg.head_node_config.k8s_pod)
HeadGroupSpec(
cfg.head_node_config.ray_start_params,
K8sPod.from_pod_template(head_pod_template) if head_pod_template else None,
)
if cfg.head_node_config
else None
),
worker_group_spec=[
WorkerGroupSpec(
c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params, c.k8s_pod
)
for c in cfg.worker_node_config
],
worker_group_spec=worker_group_spec,
enable_autoscaling=(cfg.enable_autoscaling if cfg.enable_autoscaling else False),
),
runtime_env=runtime_env,
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-ray/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["ray[default]", "flytekit>=1.3.0b2,<2.0.0", "flyteidl>=1.13.6"]
plugin_requires = ["ray[default]", "flytekit>1.14.5", "flyteidl>=1.13.6"]

__version__ = "0.0.0+develop"

Expand Down
26 changes: 21 additions & 5 deletions plugins/flytekit-ray/tests/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import ray
import yaml

from flytekit.core.resources import pod_spec_from_resources
from flytekitplugins.ray import HeadNodeConfig
from flytekitplugins.ray.models import (
HeadGroupSpec,
Expand All @@ -13,21 +15,28 @@
from flytekitplugins.ray.task import RayJobConfig, WorkerNodeConfig
from google.protobuf.json_format import MessageToDict

from flytekit import PythonFunctionTask, task
from flytekit import PythonFunctionTask, task, PodTemplate, Resources
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.models.task import K8sPod


pod_template=PodTemplate(
primary_container_name="primary",
labels={"lKeyA": "lValA"},
annotations={"aKeyA": "aValA"},
)

config = RayJobConfig(
worker_node_config=[
WorkerNodeConfig(
group_name="test_group",
replicas=3,
min_replicas=0,
max_replicas=10,
k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}),
pod_template=pod_template,
)
],
head_node_config=HeadNodeConfig(k8s_pod=K8sPod(pod_spec={"str": "head", "int": 2})),
head_node_config=HeadNodeConfig(requests=Resources(cpu="1", mem="1Gi"), limits=Resources(cpu="2", mem="2Gi")),
runtime_env={"pip": ["numpy"]},
enable_autoscaling=True,
shutdown_after_job_finishes=True,
Expand Down Expand Up @@ -55,6 +64,13 @@ def t1(a: int) -> str:
image_config=ImageConfig(default_image=default_img, images=[default_img]),
env={},
)
head_pod_template = PodTemplate(
pod_spec=pod_spec_from_resources(
primary_container_name="ray-head",
requests=Resources(cpu="1", mem="1Gi"),
limits=Resources(cpu="2", mem="2Gi"),
)
)

ray_job_pb = RayJob(
ray_cluster=RayCluster(
Expand All @@ -64,10 +80,10 @@ def t1(a: int) -> str:
replicas=3,
min_replicas=0,
max_replicas=10,
k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}),
k8s_pod=K8sPod.from_pod_template(pod_template),
)
],
head_group_spec=HeadGroupSpec(k8s_pod=K8sPod(pod_spec={"str": "head", "int": 2})),
head_group_spec=HeadGroupSpec(k8s_pod=K8sPod.from_pod_template(head_pod_template)),
enable_autoscaling=True,
),
runtime_env=base64.b64encode(json.dumps({"pip": ["numpy"]}).encode()).decode(),
Expand Down
16 changes: 8 additions & 8 deletions tests/flytekit/unit/core/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,12 @@ def test_resources_round_trip():
def test_pod_spec_from_resources_requests_limits_set():
requests = Resources(cpu="1", mem="1Gi", gpu="1", ephemeral_storage="1Gi")
limits = Resources(cpu="4", mem="2Gi", gpu="1", ephemeral_storage="1Gi")
k8s_pod_name = "foo"
primary_container_name = "foo"

expected_pod_spec = V1PodSpec(
containers=[
V1Container(
name=k8s_pod_name,
name=primary_container_name,
resources=V1ResourceRequirements(
requests={
"cpu": "1",
Expand All @@ -133,25 +133,25 @@ def test_pod_spec_from_resources_requests_limits_set():
)
]
)
pod_spec = pod_spec_from_resources(k8s_pod_name=k8s_pod_name, requests=requests, limits=limits)
assert expected_pod_spec == V1PodSpec(**pod_spec)
pod_spec = pod_spec_from_resources(primary_container_name=primary_container_name, requests=requests, limits=limits)
assert expected_pod_spec == pod_spec


def test_pod_spec_from_resources_requests_set():
requests = Resources(cpu="1", mem="1Gi")
limits = None
k8s_pod_name = "foo"
primary_container_name = "foo"

expected_pod_spec = V1PodSpec(
containers=[
V1Container(
name=k8s_pod_name,
name=primary_container_name,
resources=V1ResourceRequirements(
requests={"cpu": "1", "memory": "1Gi"},
limits={"cpu": "1", "memory": "1Gi"},
),
)
]
)
pod_spec = pod_spec_from_resources(k8s_pod_name=k8s_pod_name, requests=requests, limits=limits)
assert expected_pod_spec == V1PodSpec(**pod_spec)
pod_spec = pod_spec_from_resources(primary_container_name=primary_container_name, requests=requests, limits=limits)
assert expected_pod_spec == pod_spec

0 comments on commit 4208a64

Please sign in to comment.