Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for qk hidden dim different from v hidden dim #1166

Open
wants to merge 54 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
5eeef1e
intermediate save
xiayuqing0622 Aug 12, 2024
331a601
support var dim
xiayuqing0622 Aug 13, 2024
02da101
modify readme
xiayuqing0622 Aug 13, 2024
2bce87c
compatible
xiayuqing0622 Aug 15, 2024
51a8bcb
Merge branch 'main' into dim
xiayuqing0622 Aug 19, 2024
e8b4082
test_head_dim
smallscientist1 Aug 19, 2024
ebf0b16
add test headdim
smallscientist1 Aug 20, 2024
ab35fc2
fix some config bug
smallscientist1 Aug 20, 2024
4e94c20
update test headdim
smallscientist1 Aug 20, 2024
e31b6a4
Merge branch 'Dao-AILab:main' into dim
xiayuqing0622 Aug 20, 2024
89dbe52
update test headdim splitkv
smallscientist1 Aug 20, 2024
fc094a4
Merge commit '89dbe521b48000ee4f3d942d7c3498c698817159' into dim
smallscientist1 Aug 20, 2024
d11b7ae
update ReadMe.md
smallscientist1 Aug 20, 2024
21ca4bc
remove unused file
smallscientist1 Aug 20, 2024
4c3462a
revert Readme
smallscientist1 Aug 20, 2024
f63411d
create bench headdim
smallscientist1 Aug 21, 2024
3e0c7c4
update bench result
smallscientist1 Aug 22, 2024
3caa059
update Readme
smallscientist1 Aug 22, 2024
493a430
reorg code to reduce compile time
smallscientist1 Aug 22, 2024
0607e6c
update (128,256) config
smallscientist1 Aug 22, 2024
fd6fc29
add (192,128)
smallscientist1 Aug 26, 2024
b6d7493
add config (192,128)
smallscientist1 Aug 26, 2024
85fb8d2
fix bug
smallscientist1 Aug 26, 2024
f0644c2
fix bug backward
smallscientist1 Aug 27, 2024
0092285
fix bug
smallscientist1 Aug 27, 2024
6e88a4d
Add support for dim(192,128) (#1)
smallscientist1 Aug 27, 2024
255cd5a
add optional dim compile
smallscientist1 Aug 28, 2024
e666f96
Merge branch 'Dao-AILab:main' into dim
xiayuqing0622 Sep 3, 2024
00979f5
support different head kv
smallscientist1 Sep 4, 2024
feeab17
add test_head
smallscientist1 Sep 4, 2024
18b309d
update flash api head
smallscientist1 Sep 4, 2024
6909ab4
fix interface bug
smallscientist1 Sep 4, 2024
3c8bb2b
Merge pull request #2 from xiayuqing0622/head
smallscientist1 Sep 4, 2024
5f26eb0
update README
smallscientist1 Sep 4, 2024
536a8cc
benchmark head_headdim
smallscientist1 Sep 4, 2024
ca6335d
fix bench bug
smallscientist1 Sep 4, 2024
def41c0
fix bug for numhead
smallscientist1 Sep 5, 2024
6e8d537
add autotuner
smallscientist1 Sep 6, 2024
83fd7a5
basetuner fwd
smallscientist1 Sep 6, 2024
7cf4858
update autotuner FLashFwd
smallscientist1 Sep 10, 2024
1ca8397
autotuner fwd
smallscientist1 Sep 10, 2024
1e5c49d
update code
smallscientist1 Sep 10, 2024
409bdde
update autotuner log
smallscientist1 Sep 12, 2024
d4b620a
update tunner
smallscientist1 Sep 12, 2024
be21a0a
fix bug kernel launch
smallscientist1 Sep 12, 2024
90fa651
update autotuner tile space
smallscientist1 Sep 18, 2024
1ba39eb
update cutlass bugfix
smallscientist1 Sep 18, 2024
c5fa3c9
add autotuner doc
smallscientist1 Sep 18, 2024
31ea0bb
Merge pull request #3 from xiayuqing0622/dim_autotuner
smallscientist1 Sep 18, 2024
b09eaee
update readme
smallscientist1 Sep 18, 2024
cd9fee4
update autotuner
smallscientist1 Sep 19, 2024
014c349
update readme
smallscientist1 Sep 19, 2024
cd91625
Merge branch 'dim_pr' into dim_pr1
smallscientist1 Sep 19, 2024
d578cff
Merge pull request #4 from xiayuqing0622/dim_pr1
smallscientist1 Sep 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added assets/Customflash2_a100_fwd_bwd_benchmark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
15 changes: 15 additions & 0 deletions autotuner/arch/A100.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .arch_base import Arch
class A100(Arch):
def __init__(self):
self.reg_cap = 65536 # 32768
self.smem_cap = 163*1024 # 164*1024
self.compute_max_core = 108
self.warp_size = 32
self.sm_partition = 4
self.transaction_size = [32, 128] # in bytes
self.max_smem_usage = 164 * 1024
self.bandwidth = [1319, 16308]
self.platform = "CUDA"
self.compute_capability = "80"
self.cutlass_mma = [16, 8, 16]
self.register_per_thread = 255
15 changes: 15 additions & 0 deletions autotuner/arch/RTX4090.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .arch_base import Arch
class RTX4090(Arch):
def __init__(self):
self.reg_cap = 65536 # 32768
self.smem_cap = 100*1024 # 164*1024
self.compute_max_core = 128
self.warp_size = 32
self.sm_partition = 4
self.transaction_size = [32, 128] # in bytes
self.max_smem_usage = 100 * 1024
self.bandwidth = [1008, 0] # TODO: 1
self.platform = "CUDA"
self.compute_capability = "89"
self.cutlass_mma = [16, 8, 16]
self.register_per_thread = 255
3 changes: 3 additions & 0 deletions autotuner/arch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .arch_base import Arch
from .A100 import *
from .RTX4090 import *
13 changes: 13 additions & 0 deletions autotuner/arch/arch_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
class Arch:
def __init__(self) -> None:
self.reg_cap = 0
self.smem_cap = 0
self.compute_max_core = 0
self.warp_size = 0
self.sm_partition = 0
self.transaction_size = [0, 0]
self.max_smem_usage = 0
self.bandwidth = [0, 0]
self.platform = "unknown"
self.compute_capability = "unknown"
self.register_per_thread = 0
247 changes: 247 additions & 0 deletions autotuner/base_tunner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
import ctypes
import os
from concurrent.futures import ThreadPoolExecutor
# import multiprocessing
# from functools import partial
import tempfile
import subprocess
import importlib.util

import ctypes
import torch
from configs import BaseConfig, supported_configs

import pprint
import json

import time

from code_emitter import CodeEmitter, ShapeConfig, ProfileConfig
from profile_attn import profile_fwd





class CompileResult:
def __init__(self, config: BaseConfig, lib_name: str) -> None:
self.config = config
self.lib_name = lib_name

def _create_code_for_profiling(config):
profile_code_path = os.path.join(config.template_dir , config.operation, "profile_code.py")

spec = importlib.util.spec_from_file_location("ProfileCode", profile_code_path)
foo = importlib.util.module_from_spec(spec)
spec.loader.exec_module(foo)
# from template.flash_kernels.retnet.regfuse.profile_code import profile_code
# return profile_code.format(Br=config.Br, Bc=config.Bc, Kd=config.Kd, D=config.D, unrollLastIter=int(config.unrollLastIter), BlockKSmem=config.BlockKSmem, num_stages_qk=config.num_stages_qk, num_stages_mask=config.num_stages_mask, BlockKSmem2=config.BlockKSmem2, num_stages_v=config.num_stages_v, Nthreads=config.Nthreads)
# from template.flash_kernels.retnet.smemfuse.profile_code import profile_code
# return profile_code.format(Br=config.Br, Bc=config.Bc, Kd=config.Kd, D=config.D, unrollLastIter=int(config.unrollLastIter), BlockKSmem=config.BlockKSmem, num_stages_qk=config.num_stages_qk, num_stages_mask=config.num_stages_mask, BlockKSmem2=config.BlockKSmem2, num_stages_v=config.num_stages_v, Nthreads=config.Nthreads, warps_mma1_n=config.warps_mma1_n, warps_mma_n=config.warps_mma_n)
return foo.profile_code.format_map(config.__dict__)

# def _compile(config, arch, temp_dir:str, timeout: float = None):
# ## compile

# profiling_code = _create_code_for_profiling(config)
# src = tempfile.NamedTemporaryFile(mode="w",suffix=".cu", delete=True, dir=temp_dir)
# lib_name = src.name.replace(".cu", ".so")
# compute_version = arch.compute_capability
# cutlass_dir = os.path.join(os.path.dirname(__file__), "../../third_party/cutlass/include")
# csrc_dir = os.path.join(os.path.dirname(__file__), "../../csrc")
# if config.fuse_type == "register":
# template_dir = os.path.join(config.template_dir , "regfuse/")
# elif config.fuse_type == "shared":
# template_dir = os.path.join(config.template_dir , "smemfuse/")
# else: # bwd
# template_dir = config.template_dir
# command = ["nvcc","-std=c++17","-O3","--use_fast_math","--expt-relaxed-constexpr","--disable-warnings", "--compiler-options", "'-fPIC'", "--shared", src.name, "-lcuda",
# f"-gencode=arch=compute_{compute_version},code=sm_{compute_version}",
# f"-I{cutlass_dir}",f"-I{template_dir}",f"-I{csrc_dir}", "-o", lib_name]
# src.write(profiling_code)
# src.flush()
# try:
# ret = subprocess.run(command, timeout=timeout)
# except subprocess.TimeoutExpired:
# return None
# if ret.returncode != 0:
# return None
# return CompileResult(config,lib_name)

class BaseTunner:
def __init__(self, arch, torch_array: list, op_name, shape_config: ShapeConfig, profile_config: ProfileConfig, tempdir):
self.arch = arch
self.torch_array = torch_array
self.Br_list = [32, 64, 96, 128, 160, 192, 224, 256] # [32, 64, 128, 256]
self.Bc_list = [32, 64, 96, 128, 160, 192, 224, 256] # [32, 64, 128, 256]

self.template_dir = "autotuner/template"
self.op_name = op_name
# TODO: workaround for dropout_p
self.cache_path = os.path.join(os.path.dirname(__file__), "./cache/", str(profile_config.dropout_p!=0))
self.problem_key = {
"dim_qk": torch_array[0].shape[-1],
"dim_v": torch_array[2].shape[-1]
}
assert torch_array[0].shape[-1] == shape_config.Kd
assert torch_array[2].shape[-1] == shape_config.D
self.shape_config = shape_config
self.profile_config = profile_config
self.tempdir = tempdir

def compile(self, configs:list, timeout: float = None):
temp_dir = self.tempdir
code_emitter = CodeEmitter(self.template_dir, temp_dir)
code_emitter.generate_code(self.shape_config, configs)


def profile(self, config:BaseConfig, repeat=30, load_only=False) -> float:
spec = importlib.util.spec_from_file_location("flash_attn_func", self.tempdir+"/"+config.output_dir+"/flash_attn_profile_interface.py")
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
flash_attn_func = mod.flash_attn_func
if load_only:
return None
latency = profile_fwd(flash_attn_func, self.shape_config.Kd, self.shape_config.D, batch_size=self.profile_config.batch_size, seqlen=self.profile_config.seqlen_q, nheads=self.profile_config.nheads, dropout_p=self.profile_config.dropout_p,is_bf16=self.shape_config.is_bf16, causal=self.shape_config.is_causal, device=self.profile_config.device, repeats=repeat)
if latency < 0:
latency = 1e8
# remove lib
# subprocess.run(["rm", libname], check=True)
return latency

def get_tuned_configs(self):
dim_qk = self.problem_key["dim_qk"]
dim_v = self.problem_key["dim_v"]
configs = []
for Br in self.Br_list:
for Bc in self.Bc_list:
cur_configs = self.generate_configs(Br,Bc,dim_qk,dim_v)
for cur_config in cur_configs:
if self.op_name == "flash_fwd" and self.validate_register_fuse(cur_config):
configs.append(cur_config)
else: # BWD
if self.validate_kernel(cur_config):
configs.append(cur_config)
return configs

def tune(self, log_path="./logs/"):
st = time.time()

dim_qk = self.problem_key["dim_qk"]
dim_v = self.problem_key["dim_v"]

best_config = self.check_cache()
if best_config is not None:
# print("Best config found in cache: ")
# pprint.pprint(best_config)
return best_config

configs = self.get_tuned_configs()

# print configs
print("Configs to be tuned: ")
for config in configs:
# print(config)
pprint.pprint(config)


# cresults = self.compile(configs,src_dir.name,timeout=1200)
# cresults = self.compile_parallel(configs,src_dir.name,timeout=120)
self.compile(configs,timeout=120)

# warm up (parallel compile module)
# module name must be different in api.py
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
latencys = executor.map(self.profile, configs, [1 for _ in range(len(configs))], [True for _ in range(len(configs))])
# with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
# latencys = executor.map(_profile,[self.tempdir for _ in range(len(configs))],[self.shape_config for _ in range(len(configs))], configs, ["cuda:0" for _ in range(len(configs))], [1 for _ in range(len(configs))])
# multiprocessing.set_start_method('spawn', force=True)
# pool = multiprocessing.Pool(os.cpu_count())
# outs = pool.map(partial(self.profile, repeat=1), configs)

profile_dict = {}
latency = 1e8
best_config = None
for config in configs:
lib_latency = self.profile(config)
if lib_latency == 1e8:
# print(cresult.config)
pprint.pprint(config)
print("profile runtime error")
if lib_latency < latency:
latency = lib_latency
best_config = config
profile_dict[config] = lib_latency

end = time.time()

print("##########################################################")
print("Operation type: ", best_config.operation)
print("Best config: ")# , best_config)
pprint.pprint(best_config)
print("Latency: ", latency)

file_name = "profile_result_{}_{}_{}_p{}_{}_{}_{}_c{}.txt".format(best_config.operation,dim_qk, dim_v, self.profile_config.batch_size, self.profile_config.seqlen_q, self.profile_config.nheads, self.profile_config.dropout_p,self.shape_config.is_causal)
os.makedirs(log_path,exist_ok=True)
with open(os.path.join(log_path,file_name),"a") as f:
for config in profile_dict:
f.write(repr(config)+"\n")
f.write(str(profile_dict[config])+"\n")
f.write("\n")
f.write("best config: \n")
f.write(repr(best_config)+"\n")
f.write(str(latency)+"\n")
f.write("\nsearch time: "+str(end-st)+"s" + "\n\n")

cache_path = self.cache_path
os.makedirs(cache_path,exist_ok=True)
with open(os.path.join(cache_path,"best_config_{}_{}_{}.json".format(self.op_name,dim_qk, dim_v)),"w") as f:
json.dump(best_config.__dict__,f)

return best_config

def check_cache(self):
cache_path = self.cache_path
op_name = self.op_name
dim_qk = self.problem_key["dim_qk"]
dim_v = self.problem_key["dim_v"]
if os.path.exists(os.path.join(cache_path, "best_config_{}_{}_{}.json".format(op_name,dim_qk, dim_v))):
with open(os.path.join(cache_path,"best_config_{}_{}_{}.json".format(op_name,dim_qk, dim_v)),"r") as f:
best_config_dict = json.load(f)
best_config = supported_configs[best_config_dict["operation"]].from_dict(best_config_dict)
return best_config

return None


def validate_shared_fuse(self, config):
return False
def validate_register_fuse(self, config):
return False
def validate_kernel(self, config):
return False
def generate_configs(self,Br:int,Bc:int,dim_qk:int,dim_v:int):
configs = []
return configs

if __name__=="__main__":
import torch
from configs.fwd_config import FlashFwdConfig
batch_size = 4
seqlen = 2048
nheads = 8
headdim = 192
v_headdim = 128
device = 'cuda'
dtype = torch.bfloat16
q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
k = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
requires_grad=True)
v = torch.randn(batch_size, seqlen, nheads, v_headdim, device=device, dtype=dtype,
requires_grad=True)
base_tunner = BaseTunner(arch=None, torch_array=[q,k,v], op_name="flash_fwd", shape_config=ShapeConfig(headdim,v_headdim), profle_config=ProfileConfig(batch_size,seqlen,seqlen,nheads,nheads,nheads,device,dtype,0), tempdir="autotuner/temp")

config = FlashFwdConfig(headdim,v_headdim,64,64)
base_tunner.compile([config])
base_tunner.profile(config)
Loading