Skip to content

Commit

Permalink
Configure pod runtime class based on custom pod specs (#6199)
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Parraga <[email protected]>
  • Loading branch information
Sovietaced authored Jan 29, 2025
1 parent 45ce4c0 commit 448aba9
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 14 deletions.
5 changes: 4 additions & 1 deletion flyteplugins/go/tasks/plugins/k8s/ray/ray.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,12 +540,15 @@ func mergeCustomPodSpec(primaryContainer *v1.Container, podSpec *v1.PodSpec, k8s
continue
}

// Just handle resources for now
if len(container.Resources.Requests) > 0 || len(container.Resources.Limits) > 0 {
primaryContainer.Resources = container.Resources
}
}

if customPodSpec.RuntimeClassName != nil {
podSpec.RuntimeClassName = customPodSpec.RuntimeClassName
}

return podSpec, nil
}

Expand Down
58 changes: 45 additions & 13 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,15 +454,17 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) {
expectedWorkerResources, err := flytek8s.ToK8sResourceRequirements(workerResources)
require.NoError(t, err)

headPodSpec := &corev1.PodSpec{
nvidiaRuntimeClassName := "nvidia-cdi"

headPodSpecCustomResources := &corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "ray-head",
Resources: *expectedHeadResources,
},
},
}
workerPodSpec := &corev1.PodSpec{
workerPodSpecCustomResources := &corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "ray-worker",
Expand All @@ -471,14 +473,24 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) {
},
}

headPodSpecCustomRuntimeClass := &corev1.PodSpec{
RuntimeClassName: &nvidiaRuntimeClassName,
}
workerPodSpecCustomRuntimeClass := &corev1.PodSpec{
RuntimeClassName: &nvidiaRuntimeClassName,
}

params := []struct {
name string
taskResources *corev1.ResourceRequirements
headK8SPod *core.K8SPod
workerK8SPod *core.K8SPod
expectedSubmitterResources *corev1.ResourceRequirements
expectedHeadResources *corev1.ResourceRequirements
expectedWorkerResources *corev1.ResourceRequirements
name string
taskResources *corev1.ResourceRequirements
headK8SPod *core.K8SPod
workerK8SPod *core.K8SPod
expectedSubmitterResources *corev1.ResourceRequirements
expectedHeadResources *corev1.ResourceRequirements
expectedWorkerResources *corev1.ResourceRequirements
expectedSubmitterRuntimeClassName *string
expectedHeadRuntimeClassName *string
expectedWorkerRuntimeClassName *string
}{
{
name: "task resources",
Expand All @@ -491,15 +503,30 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) {
name: "custom worker and head resources",
taskResources: resourceRequirements,
headK8SPod: &core.K8SPod{
PodSpec: transformStructToStructPB(t, headPodSpec),
PodSpec: transformStructToStructPB(t, headPodSpecCustomResources),
},
workerK8SPod: &core.K8SPod{
PodSpec: transformStructToStructPB(t, workerPodSpec),
PodSpec: transformStructToStructPB(t, workerPodSpecCustomResources),
},
expectedSubmitterResources: resourceRequirements,
expectedHeadResources: expectedHeadResources,
expectedWorkerResources: expectedWorkerResources,
},
{
name: "custom runtime class name",
taskResources: resourceRequirements,
expectedSubmitterResources: resourceRequirements,
expectedHeadResources: resourceRequirements,
expectedWorkerResources: resourceRequirements,
headK8SPod: &core.K8SPod{
PodSpec: transformStructToStructPB(t, headPodSpecCustomRuntimeClass),
},
workerK8SPod: &core.K8SPod{
PodSpec: transformStructToStructPB(t, workerPodSpecCustomRuntimeClass),
},
expectedHeadRuntimeClassName: &nvidiaRuntimeClassName,
expectedWorkerRuntimeClassName: &nvidiaRuntimeClassName,
},
}

for _, p := range params {
Expand Down Expand Up @@ -531,18 +558,23 @@ func TestBuildResourceRayCustomK8SPod(t *testing.T) {
&submitterPodResources,
)

headPodResources := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Containers[0].Resources
headPodSpec := rayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec
headPodResources := headPodSpec.Containers[0].Resources
assert.EqualValues(t,
p.expectedHeadResources,
&headPodResources,
)

assert.EqualValues(t, p.expectedHeadRuntimeClassName, headPodSpec.RuntimeClassName)

for _, workerGroupSpec := range rayJob.Spec.RayClusterSpec.WorkerGroupSpecs {
workerPodResources := workerGroupSpec.Template.Spec.Containers[0].Resources
workerPodSpec := workerGroupSpec.Template.Spec
workerPodResources := workerPodSpec.Containers[0].Resources
assert.EqualValues(t,
p.expectedWorkerResources,
&workerPodResources,
)
assert.EqualValues(t, p.expectedWorkerRuntimeClassName, workerPodSpec.RuntimeClassName)
}
})
}
Expand Down

0 comments on commit 448aba9

Please sign in to comment.