-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
XLA? #24
Comments
@fredlarochelle great question :). we are fully aware the important of XLA ecosystem (JAX, TF-XLA, PT-XLA), and we are actively discussing with Google to bring Intel GPU to XLA, the direction we aligned is to use PjRT plugin interface to integrate Intel GPU in OpenXLA first, we are internally working on that, I expect we will open our XLA solution to support JAX in near future, please keep monitoring our repo, I will also update here when Intel GPU XLA solution is out. PS: you can check this RFC to understand more about integration design if you have interesting. |
Hi @fredlarochelle, Good news! Intel Extension for TensorFlow v1.2.0 adopted PJRT plugin interface to implement Intel GPU backend for OpenXLA experimental support, you can follow the instructions outlined here to build the necessary xla extension using bazel and get started with the provided JAX example.
Thank you, and let us know if you have any other questions! |
@yehudaorel Awesome! I will set aside some time over the next few days to try this out and definitely get back to you! |
@yehudaorel I get that the build is failing. Running For info, the system is a Intel Xeon E5-2695 v3 with 128gb of ram and an Intel Arc A770. Based on previous experience from building IPEX, I used the Do I need to build LLVM like IPEX? Or should I simply try to build XLA in the same conda env as IPEX with the already build LLVM? |
Hi, @fredlarochelle
From the build output:
Although the A770 is not mentioned specifically in the docs yet, its is based on the same DG2-512 chip variant as the A730M & Flex 170: Keep in mind, ARC A-series GPU's support is experimental thereby being highly sensitive to breaking, here is the list of GPUs that were tested & verified:
Also take a look at the ITEX provided docs for building from source procedure: https://github.com/intel/intel-extension-for-tensorflow/blob/main/docs/install/how_to_build.md |
Apologies for the delay. I have taken a closer look and it does look pretty good soo far! While I haven't conducted extensive testing, it seems to be working great! However, I have encountered three problems. First, it doesn't work properly in a Jupyter notebook. Whenever I attempt to run anything apart from the imports, the notebook crashes with a message similar to the following in the logs Additionally, I am experiencing OOM errors on the A770 even at relatively low memory usage (around 3GB). Finally, I need to take a deeper look at what is going on, but Jax test suite doesn't want to run at all. Regarding the |
Btw, my tests on the A770 were done with the device type set as |
Any reasons why float16 operations are about 1.3-1.4x faster than bfloat16? |
Thanks for your feedback.
Could you share some more details? The test case, test code, environment etc. |
What's you environment? Could you check them as following and put the output here?
The output should similar to======================== Check Python ======================== python3.9 is installed. ==================== Check Python Passed ===================== ========================== Check OS ========================== OS ubuntu:22.04 is Supported. ====================== Check OS Passed ======================= ====================== Check Tensorflow ====================== tensorflow2.10 is installed. ================== Check Tensorflow Passed =================== =================== Check Intel GPU Driver =================== Intel(R) graphics runtime intel-level-zero-gpu-1.3.25593.18-601 is installed, but is not recommended 1.3.24595.35+i538.
|
The script For my environment, I am on Ubuntu 22.04 with an A770, in a conda env with Python 3.10. Drivers nor oneAPI should be an issue here, I am building and running IPEX. As an alternative to your
As for more information, nothing runs in a Jupyter notebook, the Jupyter logs say the error is
For the OOM error, here is a quick matmul tflops test I wrote. Setting the
Finally, for the performance disparity between EDIT: I haven't taken a look yet at the issues I am having running Jax test suite. |
Hi, I encountered another weird bug where importing either |
Has Flax been tested? I am encountering some issues while testing it. For instance, just creating a small model:
I get the following error, while it works fine in Colab:
Also, to avoid dependency conflicts, I found that it is important to include Flax installation within the same pip command as Jax. For instance, I used the following command to install the required packages EDIT: Should I just start opening new issues for every bug I find? |
I solved the issue I was having with the Jax test suite by executing The tests can't be run in parallel without crashing from Linux killing the process due to OOM errors, even when it is set to only 2 workers on a system that has 64GB of RAM. Also, it kinda messes with the system and it needed a reboot after every try. I was able to start a test run with a single worker (took around 1h) after disabling Overall not too bad! |
@fredlarochelle Thank you for your feedback. Your arc 770 is 16G memory, it's a problem that "the OOM error randomly happens with error saying it failed trying to allocate 2.92 GiB". In the output of, there is There are some several allocate memory size but deallocate are 0. Not sure python reclaimed the memory in time.
Till now, flex-gpu are officially suppored and fully verified, but ARC are experimental. Here is my test environmnet: |
I appear to have discovered a potential solution to address the OOM errors encountered during the matmul test. Thus far, I haven't been able to reproduce the error using the following approach. Instead of utilizing the Upon this change, I conducted a brief retest of the previously identified problems, and unfortunately, they still persist. Regarding the Jax test suite, it now appears to have the ability to run in parallel. However, numerous tests that are passing with a single worker are failing. In terms of the operating system, I am currently running Ubuntu server. To clarify, should I switch to Ubuntu desktop or stay with Ubuntu server? Technically, the kernel remains largely the same between the two versions, with the main distinction being the absence of a desktop environnment in the server edition. |
@fredlarochelle Could you summary which issues are resovled and which are still open?
it works without any issue for the same scenario on Colab (colab jax with old version). Status: Open or resoved?
From the message, it seems the class has been changed. It's should be version mismatch issue. status: Open or closed? |
Here is the summary and I have added a couple:
Other than that, from my still limited testing, it's impressive, performance seems pretty good. |
Hi @fredlarochelle, Thanks for you feedback. This project focus on intel extension for tensorflow. Now there are lots of tensorflow feature to enabled on Flex/ARC GPU. We will try our best to help you resolve issues on platform/hardware. Here is my suggestion:
Could you check other python applications work in Jupyter notebook on you system?
Could this reproduced on Nvidia platform? I just want to make sure it's Intel extension for tensorflow issue or not?
Looks data structure changed, some memebers is not avabilbed. Version mismatch.
Intel Extension for Tensorflow (ITEX) depedens on Tensorflow. Tensorflow is a must. Wont' fix.
Keept it at Low priority for us now.
Keept it at Low priority unless specific issue about ITEX is identified.
@feng-intel, do you have plan to support newer jax release?
|
For these software configuration and jax test suite coverage, we have no resources to support it currently. If you have specific issue of ITEX, for example, look into an failed test Jax case that success on other platform but fail on ITEX, For this XLA support, it is supported now.Can I close it? |
Let's close it. |
As part of this extension is any work being done on adding support to XLA for Intel GPUs? With support for XLA, Intel GPUs could work with the whole Jax ecosystem.
The text was updated successfully, but these errors were encountered: