From 448aba97201ba42297282d859e6064b7f89537ae Mon Sep 17 00:00:00 2001 From: Jason Parraga Date: Wed, 29 Jan 2025 07:01:05 -0800 Subject: [PATCH] Configure pod runtime class based on custom pod specs (#6199) Signed-off-by: Jason Parraga --- flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 5 +- .../go/tasks/plugins/k8s/ray/ray_test.go | 58 ++++++++++++++----- 2 files changed, 49 insertions(+), 14 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index 76e595b006..03b4da1e90 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -540,12 +540,15 @@ func mergeCustomPodSpec(primaryContainer *v1.Container, podSpec *v1.PodSpec, k8s continue } - // Just handle resources for now if len(container.Resources.Requests) > 0 || len(container.Resources.Limits) > 0 { primaryContainer.Resources = container.Resources } } + if customPodSpec.RuntimeClassName != nil { + podSpec.RuntimeClassName = customPodSpec.RuntimeClassName + } + return podSpec, nil } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 42978cac81..0818df8f3e 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -454,7 +454,9 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) { expectedWorkerResources, err := flytek8s.ToK8sResourceRequirements(workerResources) require.NoError(t, err) - headPodSpec := &corev1.PodSpec{ + nvidiaRuntimeClassName := "nvidia-cdi" + + headPodSpecCustomResources := &corev1.PodSpec{ Containers: []corev1.Container{ { Name: "ray-head", @@ -462,7 +464,7 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) { }, }, } - workerPodSpec := &corev1.PodSpec{ + workerPodSpecCustomResources := &corev1.PodSpec{ Containers: []corev1.Container{ { Name: "ray-worker", @@ -471,14 +473,24 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) { }, } + headPodSpecCustomRuntimeClass := &corev1.PodSpec{ + RuntimeClassName: &nvidiaRuntimeClassName, + } + workerPodSpecCustomRuntimeClass := &corev1.PodSpec{ + RuntimeClassName: &nvidiaRuntimeClassName, + } + params := []struct { - name string - taskResources *corev1.ResourceRequirements - headK8SPod *core.K8SPod - workerK8SPod *core.K8SPod - expectedSubmitterResources *corev1.ResourceRequirements - expectedHeadResources *corev1.ResourceRequirements - expectedWorkerResources *corev1.ResourceRequirements + name string + taskResources *corev1.ResourceRequirements + headK8SPod *core.K8SPod + workerK8SPod *core.K8SPod + expectedSubmitterResources *corev1.ResourceRequirements + expectedHeadResources *corev1.ResourceRequirements + expectedWorkerResources *corev1.ResourceRequirements + expectedSubmitterRuntimeClassName *string + expectedHeadRuntimeClassName *string + expectedWorkerRuntimeClassName *string }{ { name: "task resources", @@ -491,15 +503,30 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) { name: "custom worker and head resources", taskResources: resourceRequirements, headK8SPod: &core.K8SPod{ - PodSpec: transformStructToStructPB(t, headPodSpec), + PodSpec: transformStructToStructPB(t, headPodSpecCustomResources), }, workerK8SPod: &core.K8SPod{ - PodSpec: transformStructToStructPB(t, workerPodSpec), + PodSpec: transformStructToStructPB(t, workerPodSpecCustomResources), }, expectedSubmitterResources: resourceRequirements, expectedHeadResources: expectedHeadResources, expectedWorkerResources: expectedWorkerResources, }, + { + name: "custom runtime class name", + taskResources: resourceRequirements, + expectedSubmitterResources: resourceRequirements, + expectedHeadResources: resourceRequirements, + expectedWorkerResources: resourceRequirements, + headK8SPod: &core.K8SPod{ + PodSpec: transformStructToStructPB(t, headPodSpecCustomRuntimeClass), + }, + workerK8SPod: &core.K8SPod{ + PodSpec: transformStructToStructPB(t, workerPodSpecCustomRuntimeClass), + }, + expectedHeadRuntimeClassName: &nvidiaRuntimeClassName, + expectedWorkerRuntimeClassName: &nvidiaRuntimeClassName, + }, } for _, p := range params { @@ -531,18 +558,23 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) { &submitterPodResources, ) - headPodResources := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Containers[0].Resources + headPodSpec := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec + headPodResources := headPodSpec.Containers[0].Resources assert.EqualValues(t, p.expectedHeadResources, &headPodResources, ) + assert.EqualValues(t, p.expectedHeadRuntimeClassName, headPodSpec.RuntimeClassName) + for _, workerGroupSpec := range rayJob.Spec.RayClusterSpec.WorkerGroupSpecs { - workerPodResources := workerGroupSpec.Template.Spec.Containers[0].Resources + workerPodSpec := workerGroupSpec.Template.Spec + workerPodResources := workerPodSpec.Containers[0].Resources assert.EqualValues(t, p.expectedWorkerResources, &workerPodResources, ) + assert.EqualValues(t, p.expectedWorkerRuntimeClassName, workerPodSpec.RuntimeClassName) } }) }