Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add documentation for JAXJob #3877

Merged
merged 1 commit into from
Oct 22, 2024

Conversation

sandipanpanda
Copy link
Member

Copy link

Hi @sandipanpanda. Thanks for your PR.

I'm waiting for a kubeflow member to verify that this patch is reasonable to test. If it is, they should reply with /ok-to-test on its own line. Until that is done, I will not automatically test new commits in this PR, but the usual testing commands by org members will still work. Regular contributors should join the org to skip this step.

Once the patch is verified, the new status will be reflected by the ok-to-test label.

I understand the commands that are listed here.

Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository.

Copy link
Member

@Arhell Arhell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/ok-to-test

Copy link
Member

@andreyvelich andreyvelich left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for doing this @sandipanpanda!
Let's merge it once we complete the Jax implementation in Training Operator.
/hold
/assign @kubeflow/wg-training-leads @StefanoFioravanzo @hbelmiro for review

Copy link
Contributor

@hbelmiro hbelmiro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/lgtm

Copy link
Contributor

@hbelmiro hbelmiro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/lgtm

@google-oss-prow google-oss-prow bot added the lgtm label Sep 26, 2024
Copy link
Member

@andreyvelich andreyvelich left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding this @sandipanpanda!
I left a few comments.

to run JAX training jobs on Kubernetes. The Kubeflow implementation of
the `JAXJob` is in the [`training-operator`](https://github.com/kubeflow/training-operator).

The current custom resource for JAX has been tested to run multiple processes on CPUs using [gloo](https://github.com/facebookincubator/gloo) for communication between CPUs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kubeflow/wg-training-leads @sandipanpanda Do we want to mention that we are looking for user feedback to run JAXJob on TPUs ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User feedback would be a great idea. But, even if we get any feedback for the TPU, I'm wondering if we can not implement TPU support in the upstream training-operator because we do not have any verification infrastructure, right?

Copy link
Member

@andreyvelich andreyvelich Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that it is fine for now, since we don't have infra for GPUs today as well.
E.g. we say that you can run those examples on GPUs, but we don't validate them in our CI.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that it is fine for now, since we don't have infra for GPUs today as well.
E.g. we say that you can run those examples on GPUs, but we don't validate them in our CI.

Yeah, that's true, but GPUs are mostly generic devices, and there are many developers who can access them. So, we can improve GPU utilization mechanism based on the GPUs.

OTOH, the TPU is only available in Google Cloud, and there are developers who can access it less than GPU one.
So, my concern is the TPU-specific mechanisms will be abandoned and will not work soon.

## Creating a JAX training job

You can create a training job by defining a `JAXJob` config file. See the manifests for the [simple JAXJob example](https://github.com/kubeflow/training-operator/blob/master/examples/jax/cpu-demo/demo.yaml).
You may change the config file based on your requirements.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by changing the config file here ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant the Job config file, using the wording in the existing user guides. I'll make that clear.

kubectl get pods -n kubeflow -l training.kubeflow.org/job-name=jaxjob-simple
```

Training takes 5-10 minutes on a CPU cluster. Logs can be inspected to see its training progress.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by training if we just compute all-reduce sum across all JAX processes: https://github.com/kubeflow/training-operator/blob/master/examples/jax/cpu-demo/train.py#L39C15-L39C19 ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used it in the context of using JAXJob for distributed training. I'll reword it to computation as it makes for sense.


```
PODNAME=$(kubectl get pods -l training.kubeflow.org/job-name=jaxjob-simple,training.kubeflow.org/replica-type=worker,training.kubeflow.org/replica-index=0 -o name -n kubeflow)
kubectl logs -f ${PODNAME} -n kubeflow
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please can you add the output of the logs here.

Copy link
Member

@terrytangyuan terrytangyuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

/lgtm

Copy link

[APPROVALNOTIFIER] This PR is APPROVED

This pull-request has been approved by: terrytangyuan

The full list of commands accepted by this bot can be found here.

The pull request process is described here

Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

@andreyvelich
Copy link
Member

@sandipanpanda Did you get a chance to review the remaining comments ?

@@ -10,7 +10,7 @@ weight = 10

The Training Operator is a Kubernetes-native project for fine-tuning and scalable
distributed training of machine learning (ML) models created with different ML frameworks such as
PyTorch, TensorFlow, XGBoost, and others.
PyTorch, TensorFlow, XGBoost, [JAX](https://jax.readthedocs.io/en/latest/), and others.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
PyTorch, TensorFlow, XGBoost, [JAX](https://jax.readthedocs.io/en/latest/), and others.
PyTorch, TensorFlow, XGBoost, JAX, and others.

For consistently across all supported frameworks.

the `JAXJob` is in the [`training-operator`](https://github.com/kubeflow/training-operator).

The current custom resource for JAX has been tested to run multiple processes on CPUs using [gloo](https://github.com/facebookincubator/gloo) for communication between CPUs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, could you mention that the worker with replica 0 is recognized as a JAX coordinator?
IIUC, we do not deploy the dedicated JAX Coordinator replicas, right?

Signed-off-by: Sandipan Panda <[email protected]>
Copy link
Member

@andreyvelich andreyvelich left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks for the update @sandipanpanda 🎉
/lgtm
/hold for others
/assign @kubeflow/wg-training-leads

Copy link
Member

@tenzen-y tenzen-y left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/lgtm

@andreyvelich
Copy link
Member

Thanks for this @sandipanpanda!
/hold cancel

@google-oss-prow google-oss-prow bot merged commit 463843e into kubeflow:master Oct 22, 2024
6 checks passed
@sandipanpanda sandipanpanda deleted the add-jax-doc branch October 22, 2024 20:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants