Skip to content

Commit

Permalink
add configs
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Sep 13, 2024
1 parent 811221f commit ad620f8
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
59 changes: 54 additions & 5 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""

import argparse
import subprocess
import pytest
import sys
import torch
Expand Down Expand Up @@ -318,12 +319,61 @@ def get_MI_autotune_configs():


def get_NAVI_autotune_configs():
return [], []

return [ triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2), triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),
# Fall-back config.
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=2),], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK']

def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"

def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')

def get_gfx_version():
try:
# Run the rocminfo command
result = subprocess.run(['rocminfo'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
output = result.stdout

# Parse the output to find the gfx version
for line in output.splitlines():
line = line.strip()
if line.startswith("Name: gfx"):
gfx_version = line.split("Name:")[1].strip()
return gfx_version
except Exception as e:
print(f"Error: {e}")
return None

def is_navi():
target = triton.runtime.driver.active.get_current_target()
return target.backend == 'hip' and target.arch in ["gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"]
try:
# Attempt to get the GPU architecture using Triton
target = triton.runtime.driver.active.get_current_target()
backend = target.backend
arch = target.arch
if backend == 'hip' and arch in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"):
return True
else:
return False
except Exception as e:
# Fallback to using rocminfo if Triton method fails
gfx_version = get_gfx_version()
if gfx_version in ("gfx1030", "gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201"):
return True
else:
return False


def get_autotune_configs():
Expand Down Expand Up @@ -846,7 +896,6 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D,
tl.store(DQ_block_ptr, dq.to(q.dtype))


empty = torch.empty(128, device="cuda")


def get_shape_from_layout(q, k, metadata):
Expand Down
1 change: 1 addition & 0 deletions scripts
Submodule scripts added at 072766

0 comments on commit ad620f8

Please sign in to comment.