This page aims to complement the Profiling JAX programs page in the main JAX documentation with advice specific to profiling JAX programs running on NVIDIA GPUs.
As mentioned on that page, the NVIDIA Nsight tools can be used to profile JAX programs on GPU.
The two tools that are most likely to be relevant are Nsight Systems and Nsight Compute.
Nsight Systems provides a high level overview of activity on the CPU and GPU, and is the best place to start investigating the performance of your program. It has small overheads and should not significantly affect the execution time of your program.
Nsight Compute, on the other hand, enables detailed performance analysis of individual GPU kernels. It repeatedly re-runs the kernel(s) in question to collect different metrics, resulting in an overall program execution time that is much slower. This is a powerful tool to use if you have identified specific GPU kernels that are executing surprisingly slowly. This document does not currently describe its use in any detail; more information is available in the documentation.
The JAX-Toolbox containers already contain the most recent version of Nsight Systems.
You can also install it yourself from here, or use the
package repositories here.
To collect a profile, simply launch your program inside nsys
, for example:
$ nsys profile --cuda-graph-trace=node python my_script.py
This will produce an .nsys-rep
file, by default report1.nsys-rep
.
When collecting profiles from a multi-process program, the simplest approach is to collect one report per process and start by analysing only one them. Simple JAX programs follow an SPMD model, meaning that the reports should contain similar data. Nsight Systems also supports multi-report analysis, if you need to drill into differences in performance between ranks.
A good starting point is to open the report file in the Nsight Systems GUI. This can be done in a few different ways.
A common workflow is to collect profiles on a remote system that has attached GPUs, and then download the report files to your local machine to view them. The Nsight Systems GUI supports Linux, macOS and Windows. This is a good option if your network connection to the remote system is slow or high latency, or if you can only allocate GPU resources for a short time.
If you want to run your JAX program and the GUI on the same system, it is possible to launch it directly from inside the Nsight Systems GUI, on a GPU attached to the same machine as documented here.
Some other permutations are available of using VNC or WebRTC to stream the GUI from a remote machine. This avoids having to download the report files by hand. Documentation is available here.
If your JAX Python program is structured in a way that leads to deep Python call stacks, for example because you have a lot of wrapper layers and indirection, or because you use a framework that adds similar layers, the default number of call stack frames recorded in the metadata by JAX may be too small. You can remove this limit by setting:
import jax
# Make sure NVTX annotations include full Python stack traces
jax.config.update("jax_traceback_in_locations_limit", -1)
or the JAX_TRACEBACK_IN_LOCATIONS_LIMIT
environment variable. At the time of writing, the default limit is 10 frames.
If the limit is reached, the text formatting of merged stack traces will not work as expected.
While it is possible to record profiles of the entire application (as above), this is often not the best choice. Because the execution of JAX programs is often quite repetitive, and there is non-trivial JIT compilation time and one-off initialisation cost, it may be that it is only worth recording a few iterations, and that these are very fast compared to the JIT overhead. In this case, only enabling profile collection for the iterations of interest is more efficient.
To illustrate this, consider the following JAX example (mnist_vae.py):
opt_state = opt_init(init_params)
for epoch in range(num_epochs):
tic = time.time()
opt_state = run_epoch(random.PRNGKey(epoch), opt_state, train_images)
test_elbo, sampled_images = evaluate(opt_state, test_images)
print(f"{epoch: 3d} {test_elbo} ({time.time() - tic:.3f} sec)")
where by default we have num_epochs = 100
(link).
Running this example prints something like
0 -124.1731185913086 (1.472 sec)
1 -116.52528381347656 (0.382 sec)
2 -113.37870025634766 (0.382 sec)
3 -110.11742401123047 (0.381 sec)
4 -110.05367279052734 (0.382 sec)
...
so as a minimum we should skip the first iteration, which contains the JIT overhead, to get representative performance numbers.
One way of doing this is to use the CUDA profiler API:
from ctypes import cdll
libcudart = cdll.LoadLibrary('libcudart.so')
for epoch in range(num_epochs):
if epoch == 2: libcudart.cudaProfilerStart()
tic = time.time()
...
libcudart.cudaProfilerStop()
and reduce the number of epochs profiled, for example num_epochs = 5
.
If we then tell nsys
to listen to the CUDA profiler API, with a command like:
$ PYTHONPATH=/opt/jax nsys profile --capture-range=cudaProfilerApi --cuda-graph-trace=node --capture-range-end=stop python /opt/jax/examples/mnist_vae.py
then the resulting profile will only contain 3 iterations of the loop (5 total - 2 skipped).
With --capture-range-end=stop
, nsys
will stop collecting profile data at cudaProfilerStop()
and ignore later calls to cudaProfilerStart()
, but it will not kill the application.
The default value, stop-shutdown
, will kill the application after cudaProfilerStop()
; in this case, buffered output is sometimes not flushed to the console.
If you need to start and stop profiling multiple times in your application, you can pass repeat
; in this case, a different report file will be written for each start-stop pair.
Documentation can be found here.
The example in the previous section yields a profile like:
The lower part of the screen (under "Threads (9)") shows the CPU timeline, while the upper part (under "CUDA HW") shows the GPU timeline. The "TSL" (CPU) and "NVTX (TSL)" (GPU) rows show annotations generated by JAX via XLA. Each "XlaModule" range corresponds to a call of a JITed JAX function, with the nestest "Thunk" ranges providing more granular detail.
Zooming in on the profile, we can clearly see the latency between kernel launches and their execution. These correlations are shown by the light blue highlighted regions when you select a kernel or NVTX marker:
We can also see that JAX is using CUDA graphs, both from the cuGraph*
calls in the CUDA API row, and from the
coloured outlines of kernels in the CUDA HW rows.
JAX's (XLA's) usage of CUDA graphs is not currently fully supported by the Nsight Systems UI, which leads to some
missing detail in the annotations for CUDA graph nodes.
This is shown by the magenta region in the figure above, and will be fixed in a future version of Nsight Systems.
More complete annotations can be obtained by adding --xla_gpu_enable_command_buffer=
to the XLA_FLAGS
environment
variable when collecting the profile, which will disable the use of CUDA graphs.
Depending on the JAX program, you will probably see a small slowdown when graphs are disabled; it's worth keeping in
mind the scale of this effect for your program.
Without CUDA graphs, metadata should be available for all kernels in the GPU timeline:
The tooltip contains information about the lines of your JAX program's Python source code that led to this kernel being emitted, as well as the relevant HLO code. This page may help to understand the HLO code. Note that there are two different HLO fields in the tooltip: "HLO" and "Called HLO", where in this example the latter is empty. In the case of fused kernels, the "Called HLO" field shows the body of the fused computation.
If you double-click on an NVTX region in the timeline it will open in the Events View in the lower part of the screen, with the tooltip content shown in the bottom right:
If you have previously opened a different row from the timeline in the Events View then double-clicking on a new row may show a message "A selected event does not exist in the current Events View..."; follow the instructions in the message to get the view shown in the screenshot.
The annotations described above are NVTX ranges (in the "TSL" domain) emitted by JAX via XLA.
You can also add your own custom NVTX ranges using the nvtx
Python bindings.
If these are not already installed, pip install nvtx
will install them.
A simple way of using these bindings is as a Python context manager:
for _ in range(3):
with nvtx.annotate("MyRange"):
call_some_jax_code()
which will produce three ranges called MyRange under the default NVTX domain in the NSight Systems GUI. Complete documentation can be found here.
Using nvtx
functions inside JITed JAX code is not supported and will not yield the expected results, so this only
makes sense for high-level annotations outside JIT regions.
Inside JIT regions you can use jax.named_scope
and jax.named_call
.
These will not generate NVTX ranges, but they do allow you to add custom levels to the name stack show in the metadata
emitted by XLA, i.e. the names like while/body/transpose[permutation=(1, 0)]
shown in the screenshot above.
The containers published from this repository (ghcr.io/nvidia/jax:XXX
) now include an additional wrapper to help with
collecting Nsight Systems profiles of JAX programs.
Loosely this corresponds to nsys profile
above, i.e. simply run nsys-jax python my_program.py
. If you want to pass
additional options to nsys profile
, the syntax is nsys-jax [nsys profile options] -- python my_program.py
; the --
is compulsory.
It is usually a good idea to set the profile names to something meaningful using nsys profile
's --output=..
option.
nsys-jax
will read the value of this option and save extra metadata under the same prefix, with the restriction that
only %q{ENV_VAR}
expansions are supported. An example when using the Slurm job orchestrator is:
nsys-jax -o /out/job%q{SLURM_JOB_ID}/step%q{SLURM_STEP_ID}/rank%q{SLURM_PROCID} -- python my_program.py
which will result in an output archive /out/job42/step7/rank0.zip
that contains rank0.nsys-rep
and other metadata.
As well as running nsys profile
, this automatically sets some configuration variables mentioned above, such as
JAX_TRACEBACK_IN_LOCATIONS_LIMIT
, and sets XLA flags requesting that metadata be saved in Protobuf format.
Important: because
nsys-jax
manipulates theXLA_FLAGS
environment variable, you must make sure that this is not overwritten inside the executable that you pass. For examplensys-jax python my_program.py
is fine, butnsys-jax my_script_to_overwrite_xla_flags_and_run_my_program.sh
may not be.
The only XLA flag that nsys-jax
will overwrite is --xla_dump_to
, which sets the output directory for the
Protobuf metadata. nsys-jax
additionally changes the default value of --xla_dump_hlo_as_proto
(true
), but will
not modify this if it has been set explicitly.
Note: because the Protobuf metadata is written at compilation time, using the JAX persistent compilation cache prevents it from being written reliably. Because of this
nsys-jax
setsJAX_ENABLE_COMPILATION_CACHE
tofalse
if it is not explicitly set.
After collecting the Nsight Systems profile, nsys-jax
triggers two extra processing steps:
- the
.nsys-rep
file is converted into a.parquet
and a.csv.xz
file for offline analysis - the metadata dumped by XLA is scanned for references to Python source code files -- i.e. your JAX program and the Python libraries on which it depends. Those files are copied to the output archive.
Finally, a compressed .zip
archive is generated. The post-processing uses a local, temporary directory. Only the
final archive is written to the given output location, which is likely to be on slower, shared storage.
Copy an nsys-jax
archive to an interactive system, and extract it. At the top level, there is an install.sh
script
that will create a Python virtual environment containing Jupyter Lab and the dependencies of the Analysis.ipynb
notebook that is also distributed in the archive. Run this and the suggested launch command for Jupyter Lab.
The included notebook is intended to be a template for programmatic analysis of the profile data in conjunction with the metadata from XLA. Out of the box it will provide some basic summaries and visualisations:
Examples include summaries of compilation time, heap memory usage, and straggler analysis of multi-GPU jobs.
You can see a rendered example of this notebook, as generated from the main
branch of this repository, here:
https://gist.github.com/nvjax/e2cd3520201caab6b67385ed36fad3c1#file-analysis-ipynb.
Note: this code should be considered unstable, the bundled notebook and its input data format may change considerably, but it should provide a useful playground in which to experiment with your own profile data.