From da6b356c0088c05a34906cc3b5da8e22117be00a Mon Sep 17 00:00:00 2001 From: Kai-Hsun Chen Date: Sun, 19 Jan 2025 18:00:00 -0800 Subject: [PATCH] [Refactor] Add a util function IsAutoscalingEnabled and refactor validations of RayJob deletion policy (#2775) Signed-off-by: kaihsun --- .../volcano/volcano_scheduler.go | 2 +- ray-operator/controllers/ray/common/pod.go | 6 +-- .../controllers/ray/common/pod_test.go | 20 ++++----- .../controllers/ray/raycluster_controller.go | 15 ++++--- .../controllers/ray/rayjob_controller.go | 41 +++++++++++-------- .../ray/rayjob_controller_unit_test.go | 2 +- ray-operator/controllers/ray/utils/util.go | 13 ++++++ .../controllers/ray/utils/util_test.go | 39 ++++++++++++++++++ 8 files changed, 98 insertions(+), 40 deletions(-) diff --git a/ray-operator/controllers/ray/batchscheduler/volcano/volcano_scheduler.go b/ray-operator/controllers/ray/batchscheduler/volcano/volcano_scheduler.go index a58ea2c83c9..6c7763ab9db 100644 --- a/ray-operator/controllers/ray/batchscheduler/volcano/volcano_scheduler.go +++ b/ray-operator/controllers/ray/batchscheduler/volcano/volcano_scheduler.go @@ -51,7 +51,7 @@ func (v *VolcanoBatchScheduler) Name() string { func (v *VolcanoBatchScheduler) DoBatchSchedulingOnSubmission(ctx context.Context, app *rayv1.RayCluster) error { var minMember int32 var totalResource corev1.ResourceList - if app.Spec.EnableInTreeAutoscaling == nil || !*app.Spec.EnableInTreeAutoscaling { + if !utils.IsAutoscalingEnabled(app) { minMember = utils.CalculateDesiredReplicas(ctx, app) + 1 totalResource = utils.CalculateDesiredResources(app) } else { diff --git a/ray-operator/controllers/ray/common/pod.go b/ray-operator/controllers/ray/common/pod.go index 1967b2b557a..3ab5cc6fad8 100644 --- a/ray-operator/controllers/ray/common/pod.go +++ b/ray-operator/controllers/ray/common/pod.go @@ -173,7 +173,7 @@ func DefaultHeadPodTemplate(ctx context.Context, instance rayv1.RayCluster, head initTemplateAnnotations(instance, &podTemplate) // if in-tree autoscaling is enabled, then autoscaler container should be injected into head pod. - if instance.Spec.EnableInTreeAutoscaling != nil && *instance.Spec.EnableInTreeAutoscaling { + if utils.IsAutoscalingEnabled(&instance) { // The default autoscaler is not compatible with Kubernetes. As a result, we disable // the monitor process by default and inject a KubeRay autoscaler side container into the head pod. headSpec.RayStartParams["no-monitor"] = "true" @@ -380,7 +380,7 @@ func initLivenessAndReadinessProbe(rayContainer *corev1.Container, rayNodeType r } // BuildPod a pod config -func BuildPod(ctx context.Context, podTemplateSpec corev1.PodTemplateSpec, rayNodeType rayv1.RayNodeType, rayStartParams map[string]string, headPort string, enableRayAutoscaler *bool, creatorCRDType utils.CRDType, fqdnRayIP string) (aPod corev1.Pod) { +func BuildPod(ctx context.Context, podTemplateSpec corev1.PodTemplateSpec, rayNodeType rayv1.RayNodeType, rayStartParams map[string]string, headPort string, enableRayAutoscaler bool, creatorCRDType utils.CRDType, fqdnRayIP string) (aPod corev1.Pod) { log := ctrl.LoggerFrom(ctx) // For Worker Pod: Traffic readiness is determined by the readiness probe. @@ -405,7 +405,7 @@ func BuildPod(ctx context.Context, podTemplateSpec corev1.PodTemplateSpec, rayNo // Add /dev/shm volumeMount for the object store to avoid performance degradation. addEmptyDir(ctx, &pod.Spec.Containers[utils.RayContainerIndex], &pod, SharedMemoryVolumeName, SharedMemoryVolumeMountPath, corev1.StorageMediumMemory) - if rayNodeType == rayv1.HeadNode && enableRayAutoscaler != nil && *enableRayAutoscaler { + if rayNodeType == rayv1.HeadNode && enableRayAutoscaler { // The Ray autoscaler writes logs which are read by the Ray head. // We need a shared log volume to enable this information flow. // Specifically, this is required for the event-logging functionality diff --git a/ray-operator/controllers/ray/common/pod_test.go b/ray-operator/controllers/ray/common/pod_test.go index 1e38fb80586..e1e08655bcb 100644 --- a/ray-operator/controllers/ray/common/pod_test.go +++ b/ray-operator/controllers/ray/common/pod_test.go @@ -576,7 +576,7 @@ func TestBuildPod(t *testing.T) { // Test head pod podName := strings.ToLower(cluster.Name + utils.DashSymbol + string(rayv1.HeadNode) + utils.DashSymbol + utils.FormatInt32(0)) podTemplateSpec := DefaultHeadPodTemplate(ctx, *cluster, cluster.Spec.HeadGroupSpec, podName, "6379") - pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", nil, utils.GetCRDType(""), "") + pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", false, utils.GetCRDType(""), "") // Check environment variables rayContainer := pod.Spec.Containers[utils.RayContainerIndex] @@ -631,7 +631,7 @@ func TestBuildPod(t *testing.T) { podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0) fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace) podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379") - pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", nil, utils.GetCRDType(""), fqdnRayIP) + pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP) // Check resources rayContainer = pod.Spec.Containers[utils.RayContainerIndex] @@ -694,7 +694,7 @@ func TestBuildPod_WithNoCPULimits(t *testing.T) { // Test head pod podName := strings.ToLower(cluster.Name + utils.DashSymbol + string(rayv1.HeadNode) + utils.DashSymbol + utils.FormatInt32(0)) podTemplateSpec := DefaultHeadPodTemplate(ctx, *cluster, cluster.Spec.HeadGroupSpec, podName, "6379") - pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", nil, utils.GetCRDType(""), "") + pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", false, utils.GetCRDType(""), "") expectedCommandArg := splitAndSort("ulimit -n 65536; ray start --head --block --dashboard-agent-listen-port=52365 --memory=1073741824 --num-cpus=2 --metrics-export-port=8080 --dashboard-host=0.0.0.0") actualCommandArg := splitAndSort(pod.Spec.Containers[0].Args[0]) if !reflect.DeepEqual(expectedCommandArg, actualCommandArg) { @@ -706,7 +706,7 @@ func TestBuildPod_WithNoCPULimits(t *testing.T) { podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0) fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace) podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379") - pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", nil, utils.GetCRDType(""), fqdnRayIP) + pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP) expectedCommandArg = splitAndSort("ulimit -n 65536; ray start --block --dashboard-agent-listen-port=52365 --memory=1073741824 --num-cpus=2 --num-gpus=3 --address=raycluster-sample-head-svc.default.svc.cluster.local:6379 --port=6379 --metrics-export-port=8080") actualCommandArg = splitAndSort(pod.Spec.Containers[0].Args[0]) if !reflect.DeepEqual(expectedCommandArg, actualCommandArg) { @@ -730,7 +730,7 @@ func TestBuildPod_WithOverwriteCommand(t *testing.T) { podName := strings.ToLower(cluster.Name + utils.DashSymbol + string(rayv1.HeadNode) + utils.DashSymbol + utils.FormatInt32(0)) podTemplateSpec := DefaultHeadPodTemplate(ctx, *cluster, cluster.Spec.HeadGroupSpec, podName, "6379") - headPod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", nil, utils.GetCRDType(""), "") + headPod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", false, utils.GetCRDType(""), "") headContainer := headPod.Spec.Containers[utils.RayContainerIndex] assert.Equal(t, headContainer.Command, []string{"I am head"}) assert.Equal(t, headContainer.Args, []string{"I am head again"}) @@ -739,7 +739,7 @@ func TestBuildPod_WithOverwriteCommand(t *testing.T) { podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0) fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace) podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379") - workerPod := BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", nil, utils.GetCRDType(""), fqdnRayIP) + workerPod := BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP) workerContainer := workerPod.Spec.Containers[utils.RayContainerIndex] assert.Equal(t, workerContainer.Command, []string{"I am worker"}) assert.Equal(t, workerContainer.Args, []string{"I am worker again"}) @@ -751,7 +751,7 @@ func TestBuildPod_WithAutoscalerEnabled(t *testing.T) { cluster.Spec.EnableInTreeAutoscaling = &trueFlag podName := strings.ToLower(cluster.Name + utils.DashSymbol + string(rayv1.HeadNode) + utils.DashSymbol + utils.FormatInt32(0)) podTemplateSpec := DefaultHeadPodTemplate(ctx, *cluster, cluster.Spec.HeadGroupSpec, podName, "6379") - pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", &trueFlag, utils.GetCRDType(""), "") + pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", true, utils.GetCRDType(""), "") actualResult := pod.Labels[utils.RayClusterLabelKey] expectedResult := cluster.Name @@ -808,7 +808,7 @@ func TestBuildPod_WithCreatedByRayService(t *testing.T) { cluster.Spec.EnableInTreeAutoscaling = &trueFlag podName := strings.ToLower(cluster.Name + utils.DashSymbol + string(rayv1.HeadNode) + utils.DashSymbol + utils.FormatInt32(0)) podTemplateSpec := DefaultHeadPodTemplate(ctx, *cluster, cluster.Spec.HeadGroupSpec, podName, "6379") - pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", &trueFlag, utils.RayServiceCRD, "") + pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", true, utils.RayServiceCRD, "") val, ok := pod.Labels[utils.RayClusterServingServiceLabelKey] assert.True(t, ok, "Expected serve label is not present") @@ -819,7 +819,7 @@ func TestBuildPod_WithCreatedByRayService(t *testing.T) { podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0) fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace) podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379") - pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", nil, utils.RayServiceCRD, fqdnRayIP) + pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.RayServiceCRD, fqdnRayIP) val, ok = pod.Labels[utils.RayClusterServingServiceLabelKey] assert.True(t, ok, "Expected serve label is not present") @@ -891,7 +891,7 @@ func TestBuildPodWithAutoscalerOptions(t *testing.T) { SecurityContext: &customSecurityContext, } podTemplateSpec := DefaultHeadPodTemplate(ctx, *cluster, cluster.Spec.HeadGroupSpec, podName, "6379") - pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", &trueFlag, utils.GetCRDType(""), "") + pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", true, utils.GetCRDType(""), "") expectedContainer := *autoscalerContainer.DeepCopy() expectedContainer.Image = customAutoscalerImage expectedContainer.ImagePullPolicy = customPullPolicy diff --git a/ray-operator/controllers/ray/raycluster_controller.go b/ray-operator/controllers/ray/raycluster_controller.go index 3f7f5533916..812045e5a0b 100644 --- a/ray-operator/controllers/ray/raycluster_controller.go +++ b/ray-operator/controllers/ray/raycluster_controller.go @@ -276,8 +276,7 @@ func validateRayClusterSpec(instance *rayv1.RayCluster) error { } } - enableInTreeAutoscaling := (instance.Spec.EnableInTreeAutoscaling != nil) && (*instance.Spec.EnableInTreeAutoscaling) - if enableInTreeAutoscaling { + if utils.IsAutoscalingEnabled(instance) { for _, workerGroup := range instance.Spec.WorkerGroupSpecs { if workerGroup.Suspend != nil && *workerGroup.Suspend { // TODO (rueian): This can be supported in future Ray. We should check the RayVersion once we know the version. @@ -943,7 +942,7 @@ func (r *RayClusterReconciler) reconcilePods(ctx context.Context, instance *rayv // diff < 0 indicates the need to delete some Pods to match the desired number of replicas. However, // randomly deleting Pods is certainly not ideal. So, if autoscaling is enabled for the cluster, we // will disable random Pod deletion, making Autoscaler the sole decision-maker for Pod deletions. - enableInTreeAutoscaling := (instance.Spec.EnableInTreeAutoscaling != nil) && (*instance.Spec.EnableInTreeAutoscaling) + enableInTreeAutoscaling := utils.IsAutoscalingEnabled(instance) // TODO (kevin85421): `enableRandomPodDelete` is a feature flag for KubeRay v0.6.0. If users want to use // the old behavior, they can set the environment variable `ENABLE_RANDOM_POD_DELETE` to `true`. When the @@ -1174,7 +1173,7 @@ func (r *RayClusterReconciler) buildHeadPod(ctx context.Context, instance rayv1. fqdnRayIP := utils.GenerateFQDNServiceName(ctx, instance, instance.Namespace) // Fully Qualified Domain Name // The Ray head port used by workers to connect to the cluster (GCS server port for Ray >= 1.11.0, Redis port for older Ray.) headPort := common.GetHeadPort(instance.Spec.HeadGroupSpec.RayStartParams) - autoscalingEnabled := instance.Spec.EnableInTreeAutoscaling + autoscalingEnabled := utils.IsAutoscalingEnabled(&instance) podConf := common.DefaultHeadPodTemplate(ctx, instance, instance.Spec.HeadGroupSpec, podName, headPort) if len(r.headSidecarContainers) > 0 { podConf.Spec.Containers = append(podConf.Spec.Containers, r.headSidecarContainers...) @@ -1202,7 +1201,7 @@ func (r *RayClusterReconciler) buildWorkerPod(ctx context.Context, instance rayv // The Ray head port used by workers to connect to the cluster (GCS server port for Ray >= 1.11.0, Redis port for older Ray.) headPort := common.GetHeadPort(instance.Spec.HeadGroupSpec.RayStartParams) - autoscalingEnabled := instance.Spec.EnableInTreeAutoscaling + autoscalingEnabled := utils.IsAutoscalingEnabled(&instance) podTemplateSpec := common.DefaultWorkerPodTemplate(ctx, instance, worker, podName, fqdnRayIP, headPort) if len(r.workerSidecarContainers) > 0 { podTemplateSpec.Spec.Containers = append(podTemplateSpec.Spec.Containers, r.workerSidecarContainers...) @@ -1580,7 +1579,7 @@ func (r *RayClusterReconciler) updateHeadInfo(ctx context.Context, instance *ray func (r *RayClusterReconciler) reconcileAutoscalerServiceAccount(ctx context.Context, instance *rayv1.RayCluster) error { logger := ctrl.LoggerFrom(ctx) - if instance.Spec.EnableInTreeAutoscaling == nil || !*instance.Spec.EnableInTreeAutoscaling { + if !utils.IsAutoscalingEnabled(instance) { return nil } @@ -1637,7 +1636,7 @@ func (r *RayClusterReconciler) reconcileAutoscalerServiceAccount(ctx context.Con func (r *RayClusterReconciler) reconcileAutoscalerRole(ctx context.Context, instance *rayv1.RayCluster) error { logger := ctrl.LoggerFrom(ctx) - if instance.Spec.EnableInTreeAutoscaling == nil || !*instance.Spec.EnableInTreeAutoscaling { + if !utils.IsAutoscalingEnabled(instance) { return nil } @@ -1679,7 +1678,7 @@ func (r *RayClusterReconciler) reconcileAutoscalerRole(ctx context.Context, inst func (r *RayClusterReconciler) reconcileAutoscalerRoleBinding(ctx context.Context, instance *rayv1.RayCluster) error { logger := ctrl.LoggerFrom(ctx) - if instance.Spec.EnableInTreeAutoscaling == nil || !*instance.Spec.EnableInTreeAutoscaling { + if !utils.IsAutoscalingEnabled(instance) { return nil } diff --git a/ray-operator/controllers/ray/rayjob_controller.go b/ray-operator/controllers/ray/rayjob_controller.go index 4e54f7d2a73..d5344c71e69 100644 --- a/ray-operator/controllers/ray/rayjob_controller.go +++ b/ray-operator/controllers/ray/rayjob_controller.go @@ -885,10 +885,12 @@ func validateRayJobSpec(rayJob *rayv1.RayJob) error { if rayJob.Spec.Suspend && !rayJob.Spec.ShutdownAfterJobFinishes { return fmt.Errorf("a RayJob with shutdownAfterJobFinishes set to false is not allowed to be suspended") } - if rayJob.Spec.Suspend && len(rayJob.Spec.ClusterSelector) != 0 { + + isClusterSelectorMode := len(rayJob.Spec.ClusterSelector) != 0 + if rayJob.Spec.Suspend && isClusterSelectorMode { return fmt.Errorf("the ClusterSelector mode doesn't support the suspend operation") } - if rayJob.Spec.RayClusterSpec == nil && len(rayJob.Spec.ClusterSelector) == 0 { + if rayJob.Spec.RayClusterSpec == nil && !isClusterSelectorMode { return fmt.Errorf("one of RayClusterSpec or ClusterSelector must be set") } // Validate whether RuntimeEnvYAML is a valid YAML string. Note that this only checks its validity @@ -905,21 +907,26 @@ func validateRayJobSpec(rayJob *rayv1.RayJob) error { if !features.Enabled(features.RayJobDeletionPolicy) && rayJob.Spec.DeletionPolicy != nil { return fmt.Errorf("RayJobDeletionPolicy feature gate must be enabled to use the DeletionPolicy feature") } - if rayJob.Spec.ClusterSelector != nil && - rayJob.Spec.DeletionPolicy != nil && *rayJob.Spec.DeletionPolicy == rayv1.DeleteClusterDeletionPolicy { - return fmt.Errorf("the ClusterSelector mode doesn't support DeletionPolicy=DeleteCluster") - } - if rayJob.Spec.ClusterSelector != nil && - rayJob.Spec.DeletionPolicy != nil && *rayJob.Spec.DeletionPolicy == rayv1.DeleteWorkersDeletionPolicy { - return fmt.Errorf("the ClusterSelector mode doesn't support DeletionPolicy=DeleteWorkers") - } - if rayJob.Spec.DeletionPolicy != nil && *rayJob.Spec.DeletionPolicy == rayv1.DeleteWorkersDeletionPolicy && - rayJob.Spec.RayClusterSpec.EnableInTreeAutoscaling != nil && *rayJob.Spec.RayClusterSpec.EnableInTreeAutoscaling { - // TODO (rueian): This can be supported in future Ray. We should check the RayVersion once we know the version. - return fmt.Errorf("DeletionPolicy=DeleteWorkers currently does not support RayClusterSpec.EnableInTreeAutoscaling") - } - if rayJob.Spec.ShutdownAfterJobFinishes && rayJob.Spec.DeletionPolicy != nil && *rayJob.Spec.DeletionPolicy == rayv1.DeleteNoneDeletionPolicy { - return fmt.Errorf("shutdownAfterJobFinshes is set to 'true' while deletion policy is 'DeleteNone'") + + if rayJob.Spec.DeletionPolicy != nil { + policy := *rayJob.Spec.DeletionPolicy + if isClusterSelectorMode { + switch policy { + case rayv1.DeleteClusterDeletionPolicy: + return fmt.Errorf("the ClusterSelector mode doesn't support DeletionPolicy=DeleteCluster") + case rayv1.DeleteWorkersDeletionPolicy: + return fmt.Errorf("the ClusterSelector mode doesn't support DeletionPolicy=DeleteWorkers") + } + } + + if policy == rayv1.DeleteWorkersDeletionPolicy && utils.IsAutoscalingEnabled(rayJob) { + // TODO (rueian): This can be supported in a future Ray version. We should check the RayVersion once we know it. + return fmt.Errorf("DeletionPolicy=DeleteWorkers currently does not support RayCluster with autoscaling enabled") + } + + if rayJob.Spec.ShutdownAfterJobFinishes && policy == rayv1.DeleteNoneDeletionPolicy { + return fmt.Errorf("shutdownAfterJobFinshes is set to 'true' while deletion policy is 'DeleteNone'") + } } return nil } diff --git a/ray-operator/controllers/ray/rayjob_controller_unit_test.go b/ray-operator/controllers/ray/rayjob_controller_unit_test.go index 115896df13f..bb8c05f0423 100644 --- a/ray-operator/controllers/ray/rayjob_controller_unit_test.go +++ b/ray-operator/controllers/ray/rayjob_controller_unit_test.go @@ -403,7 +403,7 @@ func TestValidateRayJobSpec(t *testing.T) { }, }, }) - assert.ErrorContains(t, err, "DeletionPolicy=DeleteWorkers currently does not support RayClusterSpec.EnableInTreeAutoscaling") + assert.ErrorContains(t, err, "DeletionPolicy=DeleteWorkers currently does not support RayCluster with autoscaling enabled") err = validateRayJobSpec(&rayv1.RayJob{ Spec: rayv1.RayJobSpec{ diff --git a/ray-operator/controllers/ray/utils/util.go b/ray-operator/controllers/ray/utils/util.go index d902cacfb04..0631a12662c 100644 --- a/ray-operator/controllers/ray/utils/util.go +++ b/ray-operator/controllers/ray/utils/util.go @@ -620,3 +620,16 @@ func ManagedByExternalController(controllerName *string) *string { } return nil } + +func IsAutoscalingEnabled[T *rayv1.RayCluster | *rayv1.RayJob | *rayv1.RayService](obj T) bool { + switch obj := (interface{})(obj).(type) { + case *rayv1.RayCluster: + return obj.Spec.EnableInTreeAutoscaling != nil && *obj.Spec.EnableInTreeAutoscaling + case *rayv1.RayJob: + return obj.Spec.RayClusterSpec != nil && obj.Spec.RayClusterSpec.EnableInTreeAutoscaling != nil && *obj.Spec.RayClusterSpec.EnableInTreeAutoscaling + case *rayv1.RayService: + return obj.Spec.RayClusterSpec.EnableInTreeAutoscaling != nil && *obj.Spec.RayClusterSpec.EnableInTreeAutoscaling + default: + panic(fmt.Sprintf("unsupported type: %T", obj)) + } +} diff --git a/ray-operator/controllers/ray/utils/util_test.go b/ray-operator/controllers/ray/utils/util_test.go index db64998f483..91b367f8cce 100644 --- a/ray-operator/controllers/ray/utils/util_test.go +++ b/ray-operator/controllers/ray/utils/util_test.go @@ -713,3 +713,42 @@ func TestErrRayClusterReplicaFailureReason(t *testing.T) { assert.Equal(t, RayClusterReplicaFailureReason(errors.Join(ErrFailedCreateWorkerPod, errors.New("other error"))), "FailedCreateWorkerPod") assert.Equal(t, RayClusterReplicaFailureReason(errors.New("other error")), "") } + +func TestIsAutoscalingEnabled(t *testing.T) { + // Test: RayCluster + cluster := &rayv1.RayCluster{} + assert.False(t, IsAutoscalingEnabled(cluster)) + + cluster = &rayv1.RayCluster{ + Spec: rayv1.RayClusterSpec{ + EnableInTreeAutoscaling: ptr.To[bool](true), + }, + } + assert.True(t, IsAutoscalingEnabled(cluster)) + + // Test: RayJob + job := &rayv1.RayJob{} + assert.False(t, IsAutoscalingEnabled(job)) + + job = &rayv1.RayJob{ + Spec: rayv1.RayJobSpec{ + RayClusterSpec: &rayv1.RayClusterSpec{ + EnableInTreeAutoscaling: ptr.To[bool](true), + }, + }, + } + assert.True(t, IsAutoscalingEnabled(job)) + + // Test: RayService + service := &rayv1.RayService{} + assert.False(t, IsAutoscalingEnabled(service)) + + service = &rayv1.RayService{ + Spec: rayv1.RayServiceSpec{ + RayClusterSpec: rayv1.RayClusterSpec{ + EnableInTreeAutoscaling: ptr.To[bool](true), + }, + }, + } + assert.True(t, IsAutoscalingEnabled(service)) +}