diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 8c639a4c0..4c6fb36fa 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -447,7 +447,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size } else { v_padded = v; } + // // Otherwise the kernel will be launched from cuda:0 device + // // Cast to char to avoid compiler warning about narrowing + // at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + // auto opts = q.options(); at::Tensor out; if (out_.has_value()) { out = out_.value(); @@ -459,11 +463,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size out = out.reshape({batch_size, num_heads_maxkv, ngroups, v_head_size_og}).transpose(1, 2); } if (v_head_size_og % 8 != 0) { - out = torch::empty({batch_size, seqlen_q, num_heads, v_head_size_og}, q.options()); + // out = torch::empty({batch_size, num_heads, seqlen_q, v_head_size_og}, q.options()).transpose(1, 2); out = torch::nn::functional::pad(out, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); } } else { - out = torch::empty({batch_size, seqlen_q, num_heads, v_head_size_og}, q.options()); + out = torch::empty({batch_size, num_heads, seqlen_q, v_head_size_og}, q.options()).transpose(1, 2); if (v_head_size_og % 8 != 0) { out = torch::nn::functional::pad(out, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); } @@ -623,7 +627,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s CHECK_CONTIGUOUS(cu_seqlens_k); const auto sizes = q.sizes(); - const int v_head_size_og = v.sizes()[2]; + const int v_head_size_og = paged_KV ? v.sizes()[3] : v.sizes()[2]; const int batch_size = cu_seqlens_q.numel() - 1; int num_heads = sizes[1]; const int head_size_og = sizes[2]; @@ -710,14 +714,19 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, sizes[0], sizes[1], v_head_size_og); if (seqlenq_ngroups_swapped) { - out = out.reshape({batch_size, num_heads_maxkv, ngroups, v_head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_maxkv, head_size_og}); + out = out.reshape({batch_size, num_heads_maxkv, ngroups, v_head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_maxkv, v_head_size_og}); } if (v_head_size_og % 8 != 0) { - out = torch::empty({total_q, num_heads, v_head_size_og}, q.options()); + // out = torch::empty({total_q, num_heads, v_head_size_og}, q.options()); out = torch::nn::functional::pad(out, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); } } else { + if (seqlenq_ngroups_swapped) { + out = torch::empty({batch_size, num_heads_maxkv, ngroups, v_head_size_og}, q.options()).transpose(1, 2).reshape({batch_size * ngroups, num_heads_maxkv, v_head_size_og}); + } + else { out = torch::empty({total_q, num_heads, v_head_size_og}, q.options()); + } if (v_head_size_og % 8 != 0) { out = torch::nn::functional::pad(out, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8})); } @@ -1024,7 +1033,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si } Flash_bwd_params params; - set_params_dgrad(params, batch_size, seqlen_q, seqlen_k, diff --git a/flex_head_fa/__init__.py b/flex_head_fa/__init__.py index 2f1a0b2d6..1de0756ab 100644 --- a/flex_head_fa/__init__.py +++ b/flex_head_fa/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.1.1" # flash attn __version__ = "2.6.3" +__version__ = "0.1.2" # flash attn __version__ = "2.6.3" from flex_head_fa.flash_attn_interface import ( flash_attn_func, diff --git a/setup.py b/setup.py index 57daee483..86ae86e48 100644 --- a/setup.py +++ b/setup.py @@ -455,6 +455,7 @@ def get_wheel_url(): wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename) + print(wheel_url) return wheel_url, wheel_filename @@ -484,7 +485,8 @@ def run(self): impl_tag, abi_tag, plat_tag = self.get_tag() archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" - wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") + # wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") + wheel_path = os.path.join(self.dist_dir, wheel_filename) print("Raw wheel path", wheel_path) os.rename(wheel_filename, wheel_path) except (urllib.error.HTTPError, urllib.error.URLError):