From 8c26deacd0397752d0ff13173ef0f618a1ebdecf Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Fri, 13 Sep 2024 17:08:48 -0500 Subject: [PATCH] fix nit --- python/perf-kernels/flash-attention.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 8b534b81bc84..89348a412fe5 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -319,10 +319,12 @@ def get_MI_autotune_configs(): def get_NAVI_autotune_configs(): - return [ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + return [ + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=2), @@ -332,15 +334,19 @@ def get_NAVI_autotune_configs(): num_warps=2), # Fall-back config. triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2),], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] + num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] + def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" + def is_cdna(): return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942', 'gfx90a', 'gfx908') + def get_gfx_version(): try: # Run the rocminfo command @@ -357,6 +363,7 @@ def get_gfx_version(): print(f"Error: {e}") return None + def is_navi(): try: # Attempt to get the GPU architecture using Triton @@ -367,7 +374,7 @@ def is_navi(): return True else: return False - except Exception as e: + except Exception: # Fallback to using rocminfo if Triton method fails gfx_version = get_gfx_version() if gfx_version in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"): @@ -896,8 +903,6 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, tl.store(DQ_block_ptr, dq.to(q.dtype)) - - def get_shape_from_layout(q, k, metadata): if metadata.layout == 'thd': nheads_q, nheads_k = q.shape[1], k.shape[1]