Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
add norm_ffn_norm to profile script (#282)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #282

This PR adds an example FFN with the preceding and subsequent norms
to the profile script. It also adds a couple of automatic data exctration QOL
items:
1. extract GPU time and aggregate it per CPU kernel name
2. attribute the kernel GPU time to gemms, float8 overhead or other
3. approximate the time spent syncing scales/amaxes and display as pct of total time
4. if TORCHINDUCTOR_PROFILE env variable is set, also parses its output for triton kernel memory bandwidth

I hope for this to speed up debugging of kernel performance on various models, as this automates a lot of high level metrics which take more time to get from visualizing the traces.

Example output when testing `norm_ffn_norm` with dynamic scaling and compile, and bandwidth measurements on

```
Summary of GPU time by CPU kernel

    experiment                                                                                      kernel       category  time_ms  pct_gpu_time  bw_gpbs
1       0_ref                                                                                    aten::mm         0_gemm   15.625         0.826     None
9       0_ref                                                     triton_red_fused__to_copy_add_mul_sum_2        2_other    1.375         0.073   241.12
11      0_ref                                                                                  aten::add_        2_other    0.566         0.030     None
8       0_ref                                            triton_poi_fused_add_fill_mul_sigmoid_silu_sub_1        2_other    0.520         0.027  2203.88
2       0_ref                                                                 triton_poi_fused_mul_silu_1        2_other    0.302         0.016  2207.31
10      0_ref                                             triton_red_fused__to_copy_add_div_mul_pow_sum_3        2_other    0.150         0.008  1375.62
7       0_ref                                             triton_red_fused__to_copy_add_div_mul_pow_sum_0        2_other    0.122         0.006  1963.47
3       0_ref                                          triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2        2_other    0.084         0.004  1935.35
6       0_ref                                                                                 aten::copy_        2_other    0.060         0.003     None
0       0_ref                                          triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0        2_other    0.052         0.003  1861.85
4       0_ref                                                                                   aten::sum        2_other    0.051         0.003     None
5       0_ref                                                                                 aten::fill_        2_other    0.002         0.000     None
19   1_float8                                                                            aten::_scaled_mm         0_gemm    8.148         0.623     None
36   1_float8  triton_poi_fused__to_copy_add_clamp_copy_empty_like_fill_mul_reciprocal_sigmoid_silu_sub_6  1_f8_overhead    0.408         0.031  2220.22
34   1_float8                                    triton_red_fused_abs_add_fill_max_mul_sigmoid_silu_sub_4  1_f8_overhead    0.314         0.024  2090.00
18   1_float8                                                         triton_poi_fused__scaled_mm_clone_6  1_f8_overhead    0.294         0.022  1439.36
37   1_float8                               triton_poi_fused__scaled_mm_clamp_clone_copy_mul_reciprocal_7  1_f8_overhead    0.292         0.022  1527.26
22   1_float8                                  triton_poi_fused__to_copy_clamp_copy_mul_reciprocal_silu_9  1_f8_overhead    0.255         0.020  2148.34
20   1_float8                                                         triton_red_fused_abs_max_mul_silu_7  1_f8_overhead    0.244         0.019  1834.33
23   1_float8                                                        triton_poi_fused__scaled_mm_clone_10  1_f8_overhead    0.146         0.011  1462.59
15   1_float8                                                                  triton_red_fused_abs_max_3  1_f8_overhead    0.127         0.010  1498.30
31   1_float8                                     triton_red_fused__to_copy_abs_add_div_max_mul_pow_sum_1  1_f8_overhead    0.091         0.007  1971.11
33   1_float8                               triton_poi_fused__scaled_mm_clamp_clone_copy_mul_reciprocal_3  1_f8_overhead    0.090         0.007  1388.84
21   1_float8                                                                  triton_red_fused_abs_max_8  1_f8_overhead    0.062         0.005  1487.67
17   1_float8                                       triton_poi_fused__to_copy_clamp_copy_mul_reciprocal_5  1_f8_overhead    0.046         0.003  1833.92
12   1_float8                                  triton_red_fused__to_copy_abs_add_max_mean_mul_pow_rsqrt_0  1_f8_overhead    0.044         0.003  1246.58
16   1_float8                                        triton_per_fused_abs_clamp_copy_max_mul_reciprocal_4  1_f8_overhead    0.010         0.001     0.32
35   1_float8                  triton_per_fused__scaled_mm_abs_clamp_clone_copy_max_mul_reciprocal_silu_5  1_f8_overhead    0.005         0.000     0.32
32   1_float8                       triton_red_fused__scaled_mm_abs_clamp_clone_copy_max_mul_reciprocal_2  1_f8_overhead    0.003         0.000     4.51
13   1_float8                               triton_red_fused__to_copy_abs_clamp_copy_max_mul_reciprocal_1  1_f8_overhead    0.003         0.000     4.61
38   1_float8                                                     triton_red_fused__to_copy_add_mul_sum_8        2_other    0.804         0.061   244.46
40   1_float8                                                                                  aten::add_        2_other    0.567         0.043     None
30   1_float8                                                         triton_red_fused__to_copy_mul_sum_0        2_other    0.562         0.043   236.73
39   1_float8                                             triton_red_fused__to_copy_add_div_mul_pow_sum_9        2_other    0.149         0.011  1377.16
25   1_float8                                                                   triton_poi_fused_clone_12        2_other    0.146         0.011  1532.05
24   1_float8                                         triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_11        2_other    0.115         0.009  1293.20
29   1_float8                                                                                 aten::copy_        2_other    0.060         0.005     None
27   1_float8                                                                                   aten::sum        2_other    0.050         0.004     None
26   1_float8                                                                   triton_poi_fused_clone_13        2_other    0.043         0.003  1333.22
14   1_float8                                                               triton_poi_fused_empty_like_2        2_other    0.002         0.000     0.00
28   1_float8                                                                                 aten::fill_        2_other    0.002         0.000     None

Float8 amax/scale sync approx ratio of total time: 0.000

Summary of time (ms) by kernel category

 experiment     0_ref  1_float8  f8_div_ref  ref_div_f8
category
0_gemm        15.625     8.148       0.521       1.918
1_f8_overhead  0.000     2.436         inf       0.000
2_other        3.283     2.498       0.761       1.314
All           18.908    13.082       0.692       1.445

```

Reviewed By: drisspg

Differential Revision: D59163495

fbshipit-source-id: 9aacfc996d7c66e204ee7d4460889bd7fca9f48c
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jun 28, 2024
1 parent 9ba2a8e commit 0b60496
Show file tree
Hide file tree
Showing 2 changed files with 372 additions and 52 deletions.
Loading

0 comments on commit 0b60496

Please sign in to comment.