Skip to content

Latest commit

 

History

History
240 lines (197 loc) · 15 KB

profiling.md

File metadata and controls

240 lines (197 loc) · 15 KB

Profiling JAX programs on GPU

This page aims to complement the Profiling JAX programs page in the main JAX documentation with advice specific to profiling JAX programs running on NVIDIA GPUs.

As mentioned on that page, the NVIDIA Nsight tools can be used to profile JAX programs on GPU.

The two tools that are most likely to be relevant are Nsight Systems and Nsight Compute.

Nsight Systems provides a high level overview of activity on the CPU and GPU, and is the best place to start investigating the performance of your program. It has small overheads and should not significantly affect the execution time of your program.

Nsight Compute, on the other hand, enables detailed performance analysis of individual GPU kernels. It repeatedly re-runs the kernel(s) in question to collect different metrics, resulting in an overall program execution time that is much slower. This is a powerful tool to use if you have identified specific GPU kernels that are executing surprisingly slowly. This document does not currently describe its use in any detail; more information is available in the documentation.

Nsight Systems

The JAX-Toolbox containers already contain the most recent version of Nsight Systems. You can also install it yourself from here, or use the package repositories here. To collect a profile, simply launch your program inside nsys, for example:

$ nsys profile --cuda-graph-trace=node python my_script.py

This will produce an .nsys-rep file, by default report1.nsys-rep.

When collecting profiles from a multi-process program, the simplest approach is to collect one report per process and start by analysing only one them. Simple JAX programs follow an SPMD model, meaning that the reports should contain similar data. Nsight Systems also supports multi-report analysis, if you need to drill into differences in performance between ranks.

Opening report files

A good starting point is to open the report file in the Nsight Systems GUI. This can be done in a few different ways.

Running the GUI on a local system

A common workflow is to collect profiles on a remote system that has attached GPUs, and then download the report files to your local machine to view them. The Nsight Systems GUI supports Linux, macOS and Windows. This is a good option if your network connection to the remote system is slow or high latency, or if you can only allocate GPU resources for a short time.

Running everything on a local system

If you want to run your JAX program and the GUI on the same system, it is possible to launch it directly from inside the Nsight Systems GUI, on a GPU attached to the same machine as documented here.

Other configurations

Some other permutations are available of using VNC or WebRTC to stream the GUI from a remote machine. This avoids having to download the report files by hand. Documentation is available here.

Tuning JAX configuration for profiling

If your JAX Python program is structured in a way that leads to deep Python call stacks, for example because you have a lot of wrapper layers and indirection, or because you use a framework that adds similar layers, the default number of call stack frames recorded in the metadata by JAX may be too small. You can remove this limit by setting:

import jax
# Make sure NVTX annotations include full Python stack traces
jax.config.update("jax_traceback_in_locations_limit", -1)

or the JAX_TRACEBACK_IN_LOCATIONS_LIMIT environment variable. At the time of writing, the default limit is 10 frames. If the limit is reached, the text formatting of merged stack traces will not work as expected.

Collecting targeted profiles

While it is possible to record profiles of the entire application (as above), this is often not the best choice. Because the execution of JAX programs is often quite repetitive, and there is non-trivial JIT compilation time and one-off initialisation cost, it may be that it is only worth recording a few iterations, and that these are very fast compared to the JIT overhead. In this case, only enabling profile collection for the iterations of interest is more efficient.

To illustrate this, consider the following JAX example (mnist_vae.py):

  opt_state = opt_init(init_params)
  for epoch in range(num_epochs):
    tic = time.time()
    opt_state = run_epoch(random.PRNGKey(epoch), opt_state, train_images)
    test_elbo, sampled_images = evaluate(opt_state, test_images)
    print(f"{epoch: 3d} {test_elbo} ({time.time() - tic:.3f} sec)")

where by default we have num_epochs = 100 (link).

Running this example prints something like

  0 -124.1731185913086 (1.472 sec)
  1 -116.52528381347656 (0.382 sec)
  2 -113.37870025634766 (0.382 sec)
  3 -110.11742401123047 (0.381 sec)
  4 -110.05367279052734 (0.382 sec)
...

so as a minimum we should skip the first iteration, which contains the JIT overhead, to get representative performance numbers.

One way of doing this is to use the CUDA profiler API:

from ctypes import cdll
libcudart = cdll.LoadLibrary('libcudart.so')
for epoch in range(num_epochs):
  if epoch == 2: libcudart.cudaProfilerStart()
  tic = time.time()
  ...
libcudart.cudaProfilerStop()

and reduce the number of epochs profiled, for example num_epochs = 5.

If we then tell nsys to listen to the CUDA profiler API, with a command like:

$ PYTHONPATH=/opt/jax nsys profile --capture-range=cudaProfilerApi --cuda-graph-trace=node --capture-range-end=stop python /opt/jax/examples/mnist_vae.py

then the resulting profile will only contain 3 iterations of the loop (5 total - 2 skipped).

With --capture-range-end=stop, nsys will stop collecting profile data at cudaProfilerStop() and ignore later calls to cudaProfilerStart(), but it will not kill the application. The default value, stop-shutdown, will kill the application after cudaProfilerStop(); in this case, buffered output is sometimes not flushed to the console. If you need to start and stop profiling multiple times in your application, you can pass repeat; in this case, a different report file will be written for each start-stop pair. Documentation can be found here.

Understanding the Nsight Systems timeline

The example in the previous section yields a profile like: Nsight Systems GUI showing 3 iterations of the mnist_vae.py JAX example

The lower part of the screen (under "Threads (9)") shows the CPU timeline, while the upper part (under "CUDA HW") shows the GPU timeline. The "TSL" (CPU) and "NVTX (TSL)" (GPU) rows show annotations generated by JAX via XLA. Each "XlaModule" range corresponds to a call of a JITed JAX function, with the nestest "Thunk" ranges providing more granular detail.

Zooming in on the profile, we can clearly see the latency between kernel launches and their execution. These correlations are shown by the light blue highlighted regions when you select a kernel or NVTX marker: Nsight Systems GUI showing the launch latency of a particular kernel

We can also see that JAX is using CUDA graphs, both from the cuGraph* calls in the CUDA API row, and from the coloured outlines of kernels in the CUDA HW rows. JAX's (XLA's) usage of CUDA graphs is not currently fully supported by the Nsight Systems UI, which leads to some missing detail in the annotations for CUDA graph nodes. This is shown by the magenta region in the figure above, and will be fixed in a future version of Nsight Systems.

More complete annotations can be obtained by adding --xla_gpu_enable_command_buffer= to the XLA_FLAGS environment variable when collecting the profile, which will disable the use of CUDA graphs. Depending on the JAX program, you will probably see a small slowdown when graphs are disabled; it's worth keeping in mind the scale of this effect for your program.

Without CUDA graphs, metadata should be available for all kernels in the GPU timeline: Nsight Systems GUI showing graph-free execution and a tool-tip

The tooltip contains information about the lines of your JAX program's Python source code that led to this kernel being emitted, as well as the relevant HLO code. This page may help to understand the HLO code. Note that there are two different HLO fields in the tooltip: "HLO" and "Called HLO", where in this example the latter is empty. In the case of fused kernels, the "Called HLO" field shows the body of the fused computation.

If you double-click on an NVTX region in the timeline it will open in the Events View in the lower part of the screen, with the tooltip content shown in the bottom right: Nsight Systems GUI showing a tooltip, events view, and description

If you have previously opened a different row from the timeline in the Events View then double-clicking on a new row may show a message "A selected event does not exist in the current Events View..."; follow the instructions in the message to get the view shown in the screenshot.

Custom NVTX annotations

The annotations described above are NVTX ranges (in the "TSL" domain) emitted by JAX via XLA. You can also add your own custom NVTX ranges using the nvtx Python bindings. If these are not already installed, pip install nvtx will install them. A simple way of using these bindings is as a Python context manager:

for _ in range(3):
  with nvtx.annotate("MyRange"):
    call_some_jax_code()

which will produce three ranges called MyRange under the default NVTX domain in the NSight Systems GUI. Complete documentation can be found here.

Using nvtx functions inside JITed JAX code is not supported and will not yield the expected results, so this only makes sense for high-level annotations outside JIT regions. Inside JIT regions you can use jax.named_scope and jax.named_call. These will not generate NVTX ranges, but they do allow you to add custom levels to the name stack show in the metadata emitted by XLA, i.e. the names like while/body/transpose[permutation=(1, 0)] shown in the screenshot above.

nsys-jax wrapper for Nsight Systems

The containers published from this repository (ghcr.io/nvidia/jax:XXX) now include an additional wrapper to help with collecting Nsight Systems profiles of JAX programs.

Loosely this corresponds to nsys profile above, i.e. simply run nsys-jax python my_program.py. If you want to pass additional options to nsys profile, the syntax is nsys-jax [nsys profile options] -- python my_program.py; the -- is compulsory.

It is usually a good idea to set the profile names to something meaningful using nsys profile's --output=.. option. nsys-jax will read the value of this option and save extra metadata under the same prefix, with the restriction that only %q{ENV_VAR} expansions are supported. An example when using the Slurm job orchestrator is: nsys-jax -o /out/job%q{SLURM_JOB_ID}/step%q{SLURM_STEP_ID}/rank%q{SLURM_PROCID} -- python my_program.py which will result in an output archive /out/job42/step7/rank0.zip that contains rank0.nsys-rep and other metadata.

As well as running nsys profile, this automatically sets some configuration variables mentioned above, such as JAX_TRACEBACK_IN_LOCATIONS_LIMIT, and sets XLA flags requesting that metadata be saved in Protobuf format.

Important: because nsys-jax manipulates the XLA_FLAGS environment variable, you must make sure that this is not overwritten inside the executable that you pass. For example nsys-jax python my_program.py is fine, but nsys-jax my_script_to_overwrite_xla_flags_and_run_my_program.sh may not be.

The only XLA flag that nsys-jax will overwrite is --xla_dump_to, which sets the output directory for the Protobuf metadata. nsys-jax additionally changes the default value of --xla_dump_hlo_as_proto (true), but will not modify this if it has been set explicitly.

Note: because the Protobuf metadata is written at compilation time, using the JAX persistent compilation cache prevents it from being written reliably. Because of this nsys-jax sets JAX_ENABLE_COMPILATION_CACHE to false if it is not explicitly set.

After collecting the Nsight Systems profile, nsys-jax triggers two extra processing steps:

  • the .nsys-rep file is converted into a .parquet and a .csv.xz file for offline analysis
  • the metadata dumped by XLA is scanned for references to Python source code files -- i.e. your JAX program and the Python libraries on which it depends. Those files are copied to the output archive.

Finally, a compressed .zip archive is generated. The post-processing uses a local, temporary directory. Only the final archive is written to the given output location, which is likely to be on slower, shared storage.

Offline analysis

Copy an nsys-jax archive to an interactive system, and extract it. At the top level, there is an install.sh script that will create a Python virtual environment containing Jupyter Lab and the dependencies of the Analysis.ipynb notebook that is also distributed in the archive. Run this and the suggested launch command for Jupyter Lab.

The included notebook is intended to be a template for programmatic analysis of the profile data in conjunction with the metadata from XLA. Out of the box it will provide some basic summaries and visualisations: Analysis notebook inside Jupyter Lab showing an interactive flame graph of JAX source code

Examples include summaries of compilation time, heap memory usage, and straggler analysis of multi-GPU jobs.

You can see a rendered example of this notebook, as generated from the main branch of this repository, here: https://gist.github.com/nvjax/e2cd3520201caab6b67385ed36fad3c1#file-analysis-ipynb.

Note: this code should be considered unstable, the bundled notebook and its input data format may change considerably, but it should provide a useful playground in which to experiment with your own profile data.