Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add utility function to benchmark performance of fusion region with nvfuser and torch.compile #1682

Merged
merged 4 commits into from
Jan 27, 2025

Conversation

kshitij12345
Copy link
Collaborator

@kshitij12345 kshitij12345 commented Jan 22, 2025

Adds a utility function to benchmark the performance of nvFuser and torch.compile for a given fusion region.

Example Usage -

import torch
import thunder
from thunder.dev_utils.utils import _benchmark_fusion_region_with_nvfuser_and_torch_compile

def fn(x):
    y = (x * x).sum()
    z = x @ x
    return z.sin() + y.cos()

jfn = thunder.jit(fn, nv_store_fusion_inputs=True)

jfn(torch.randn(16, 16, device="cuda"))

trc = thunder.last_traces(jfn)[-1]

for bsym in trc.bound_symbols:
    if bsym.sym.is_fusion and "nvFusion" in bsym.sym.name:
        benchmark_comparison_data = _benchmark_fusion_region_with_nvfuser_and_torch_compile(bsym)
        nvfuser_walltime = benchmark_comparison_data.nvfuser_walltime
        nvfuser_kerneltime = benchmark_comparison_data.nvfuser_kernel_time
        nvfuser_prof_data = benchmark_comparison_data.nvfuser_profiler_data

        torch_compile_walltime = benchmark_comparison_data.torch_compile_walltime
        torch_compile_kerneltime = benchmark_comparison_data.torch_compile_kernel_time

        print(bsym)
        print(nvfuser_walltime)
        print(torch_compile_walltime)
        print(nvfuser_kerneltime)
        print(torch_compile_kerneltime)
        print(nvfuser_prof_data)
        print(nvfuser_prof_data.kernel_profiles[0].percentage_peak_bandwidth)
        print("----"*32)
Output
[y] = nvFusion0(x)
  # t0 = prims.mul(x, x)  # t0: "cuda:0 f32[16, 16]"
  # y = prims.sum(t0, (0, 1))  # y: "cuda:0 f32[]"
<torch.utils.benchmark.utils.common.Measurement object at 0x7e5bccfce1d0>
nvfuser_callable(*inputs)
  Median: 31.15 us
  IQR:    0.13 us (31.10 to 31.23)
  7 measurements, 10000 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7e5bccfcde70>
torch_compile_callable(*inputs)
  Median: 21.16 us
  IQR:    0.29 us (21.03 to 21.32)
  93578 measurements, 1 runs per measurement, 1 thread
0.007135999854654074
0.004095999989658594
<nvfuser._C.FusionProfile object at 0x7e5be85318f0>
0.04780027949585994
--------------------------------------------------------------------------------------------------------------------------------
[t6] = nvFusion1(t8, y)
  # t3 = prims.sin(t8)  # t3: "cuda:0 f32[16, 16]"
  # t4 = prims.cos(y)  # t4: "cuda:0 f32[]"
  # t5 = prims.broadcast_in_dim(t4, (16, 16), ())  # t5: "cuda:0 f32[16, 16]"
  # t6 = prims.add(t3, t5)  # t6: "cuda:0 f32[16, 16]"
<torch.utils.benchmark.utils.common.Measurement object at 0x7e5bb0b43f10>
nvfuser_callable(*inputs)
  Median: 25.20 us
  IQR:    0.11 us (25.15 to 25.25)
  8 measurements, 10000 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7e5bccfcd330>
torch_compile_callable(*inputs)
  Median: 24.49 us
  IQR:    0.35 us (24.34 to 24.69)
  80906 measurements, 1 runs per measurement, 1 thread
0.004608000162988901
0.003967999946326017
<nvfuser._C.FusionProfile object at 0x7e5be85318f0>
0.151795899955459
-------------------------------------------------------------------------------------------------------------------------------

@kshitij12345
Copy link
Collaborator Author

cc: @kiya00 I think this could be useful for automated report generation, would be great to have your review on this!

@kiya00
Copy link
Collaborator

kiya00 commented Jan 22, 2025

cc: @kiya00 I think this could be useful for automated report generation, would be great to have your review on this!

I think that's very great idea! I've prepared an initial version of the report, and we can gradually enrich its content #1636

Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks useful though I have some questions that might be orthogonal to this PR itself

thunder/dev_utils/utils.py Show resolved Hide resolved
thunder/dev_utils/utils.py Show resolved Hide resolved
@mruberry
Copy link
Collaborator

This looks ready to take out of draft to me, @kshitij12345, but maybe you have some more ideas

Running this and translating the slower fusions into bugs was incredibly interesting and helpful

fyi @kiya00, @riccardofelluga, I expect we'll want to develop the ability to have torch.compile or thunderfx run different fxgraphs, and maybe also the ability to autobenchmark torch.compile vs nvfuser on each fusion to select the faster of the two. We can talk about this more as we develop the new reporting tools

@kshitij12345 kshitij12345 marked this pull request as ready for review January 24, 2025 18:27
@kshitij12345
Copy link
Collaborator Author

This looks ready to take out of draft to me, @kshitij12345, but maybe you have some more ideas

I just wanted to update the description with an example usage, which is done. Have marked PR as ready for review.

@kshitij12345 kshitij12345 changed the title [WIP] Add utility function to benchmark performance of fusion region with nvfuser and torch.compile Add utility function to benchmark performance of fusion region with nvfuser and torch.compile Jan 27, 2025
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry mruberry enabled auto-merge (squash) January 27, 2025 17:52
@mruberry mruberry merged commit 980b63c into Lightning-AI:main Jan 27, 2025
52 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants