Skip to content

Commit

Permalink
fix nit
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Sep 17, 2024
1 parent b8d254a commit 8c26dea
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,12 @@ def get_MI_autotune_configs():


def get_NAVI_autotune_configs():
return [ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
return [
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2), triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
Expand All @@ -332,15 +334,19 @@ def get_NAVI_autotune_configs():
num_warps=2),
# Fall-back config.
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK']
num_warps=2),
], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK']


def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')


def get_gfx_version():
try:
# Run the rocminfo command
Expand All @@ -357,6 +363,7 @@ def get_gfx_version():
print(f"Error: {e}")
return None


def is_navi():
try:
# Attempt to get the GPU architecture using Triton
Expand All @@ -367,7 +374,7 @@ def is_navi():
return True
else:
return False
except Exception as e:
except Exception:
# Fallback to using rocminfo if Triton method fails
gfx_version = get_gfx_version()
if gfx_version in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"):
Expand Down Expand Up @@ -896,8 +903,6 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D,
tl.store(DQ_block_ptr, dq.to(q.dtype))




def get_shape_from_layout(q, k, metadata):
if metadata.layout == 'thd':
nheads_q, nheads_k = q.shape[1], k.shape[1]
Expand Down

0 comments on commit 8c26dea

Please sign in to comment.