From 46ffa349fd51d3d64d82d77ffd70c8dda09d144f Mon Sep 17 00:00:00 2001 From: Sandipan Panda Date: Sat, 13 Jul 2024 01:20:16 +0530 Subject: [PATCH] Update JAX integration proposal Signed-off-by: Sandipan Panda --- docs/proposals/jax-integration.md | 66 ++++--------------------------- 1 file changed, 8 insertions(+), 58 deletions(-) diff --git a/docs/proposals/jax-integration.md b/docs/proposals/jax-integration.md index ee27bc639b..3fb169dc54 100644 --- a/docs/proposals/jax-integration.md +++ b/docs/proposals/jax-integration.md @@ -76,14 +76,12 @@ As a DevOps engineer, I want to manage JAX distributed training jobs using the K ##### Key Validations -1. **Coordinator Role Validation**: - - Ensure exactly one Coordinator role with `processId` set to `0` and its `replicas` is set to `1`. -2. **Worker Role Validation**: - - Ensure at least one Worker replica. - - Ensure the `replicas` field for each role is greater than `0`. -3. **JAX Parameters Validation**: +1. **Worker Role Validation**: + - Ensure at least one Worker replica with `processId` set to `0` that will work as coordinator. + - Ensure the `replicas` field is greater than `0`. +2. **JAX Parameters Validation**: - Ensure `coordinatorAddress`, `numProcesses`, and `processId` are set and valid across all roles. -4. **Pod Specification Validation**: +3. **Pod Specification Validation**: - Ensure necessary container specifications and `restartPolicy` are correctly set. - Validate `coordinatorAddress` follows the `host:port` format. @@ -100,14 +98,6 @@ metadata: name: example-jaxjob spec: jaxReplicaSpecs: - Coordinator: - replicas: 1 - restartPolicy: OnFailure - template: - spec: - containers: - - name: jax-coordinator - image: ghcr.io/kubeflow/jax:latest Worker: replicas: 1 restartPolicy: OnFailure @@ -158,8 +148,6 @@ const ( JAXJobSingular = "jaxjob" // JAXJobFrameworkName is the name of the ML Framework JAXJobFrameworkName = "jax" - // JAXJobReplicaTypeCoordinator is the type of Coordinator of distributed JAX - JAXJobReplicaTypeCoordinator ReplicaType = "Coordinator" // JAXJobReplicaTypeWorker is the type for workers of distributed JAX. JAXJobReplicaTypeWorker ReplicaType = "Worker" ) @@ -199,7 +187,6 @@ type JAXJobSpec struct { // A map of JAXReplicaType (type) to ReplicaSpec (value). Specifies the JAX cluster configuration. // For example, // { - // "Coordinator": JAXReplicaSpec, // "Worker": JAXReplicaSpec, // } JAXReplicaSpecs map[ReplicaType]*ReplicaSpec `json:"jaxReplicaSpecs"` @@ -223,48 +210,10 @@ type JAXJobList struct { func init() { SchemeBuilder.Register(&JAXJob{}, &JAXJobList{}) - SchemeBuilder.SchemeBuilder.Register(addJAXJobDefaultingFuncs) + SchemeBuilder.SchemeBuilder.Register(addJAXDefaultingFuncs) } ``` -##### Resulting Coordinator -```yaml -apiVersion: v1 -kind: Service -metadata: - name: jax-coordinator -spec: - selector: - app: jax-coordinator - ports: - - port: 6666 - targetPort: 6666 -``` -```yaml -apiVersion: v1 -kind: Pod -metadata: - name: jax-coordinator - labels: - app: jax-coordinator -spec: - containers: - - image: ghcr.io/kubeflow/jax:latest - imagePullPolicy: IfNotPresent - name: coordinator - env: - - name: JAX_COORDINATOR_ADDRESS - value: '127.0.0.1:6666' - - name: JAX_NUM_PROCESSES - value: 1 - - name: JAX_PROCESS_ID - value: 0 - # process 0 is coordinator - ports: - - name: coordinatorPort - containerPort: 6666 - restartPolicy: OnFailure -``` ##### Resulting Worker @@ -290,7 +239,8 @@ spec: - name: JAX_NUM_PROCESSES value: 1 - name: JAX_PROCESS_ID - value: 1 + value: 0 + # process 0 is coordinator restartPolicy: OnFailure ```