Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] "wait" is ambiguous when building STA #192

Closed
amogkam opened this issue Feb 18, 2025 · 7 comments
Closed

[Bug] "wait" is ambiguous when building STA #192

amogkam opened this issue Feb 18, 2025 · 7 comments

Comments

@amogkam
Copy link

amogkam commented Feb 18, 2025

Environment

FastVideo master
Python 3.10.14
Cuda version 12.4

Describe the bug

I'm trying to build STA kernel following the instructions but running into this issue

/home/amogkamsetty/FastVideo/csrc/sliding_tile_attention/st_attn/st_attn_h100.cu(313): error: "wait" is ambiguous
       wait(v_smem_arrived[(kv_idx)%K::stages], (kv_idx/K::stages)%2);
       ^

Reproduction

python setup.py install inside the sliding_tile_attention directory

@amogkam amogkam changed the title [Bug] Cannot build STA [Bug] "wait" is ambiguous when building STA Feb 18, 2025
@jzhang38
Copy link
Collaborator

Could you paste the full error message?

@jzhang38
Copy link
Collaborator

I just pushed a commit: f9482d1 Can you try again

@amogkam
Copy link
Author

amogkam commented Feb 19, 2025

still same issue unfortunately. seems like there is some conflicting wait methods with the modules that we include?

/home/amogkamsetty/FastVideo/csrc/sliding_tile_attention/st_attn/st_attn_h100.cu(159): error: "wait" is ambiguous
        wait(compute_done[(kv_idx)%K::stages], (kv_idx/K::stages)%2);
        ^

/home/amogkamsetty/FastVideo/csrc/sliding_tile_attention/st_attn/st_attn_h100.cu(186): error: "wait" is ambiguous
          wait(compute_done[(count - 1)%K::stages], ((count - 1)/K::stages)%2);
          ^

/home/amogkamsetty/FastVideo/csrc/sliding_tile_attention/st_attn/st_attn_h100.cu(202): error: "wait" is ambiguous
        wait(compute_done[(count - 1)%K::stages], ((count - 1)/K::stages)%2);
        ^

/home/amogkamsetty/FastVideo/csrc/sliding_tile_attention/st_attn/st_attn_h100.cu(235): error: "wait" is ambiguous
   wait(qsmem_semaphore, 0);
   ^

/home/amogkamsetty/FastVideo/csrc/sliding_tile_attention/st_attn/st_attn_h100.cu(238): error: "wait" is ambiguous
       wait(k_smem_arrived[(kv_idx)%K::stages], (kv_idx/K::stages)%2);
       ^

/home/amogkamsetty/FastVideo/csrc/sliding_tile_attention/st_attn/st_attn_h100.cu(268): error: "wait" is ambiguous
       wait(v_smem_arrived[(kv_idx)%K::stages], (kv_idx/K::stages)%2);
       ^

/home/amogkamsetty/FastVideo/csrc/sliding_tile_attention/st_attn/st_attn_h100.cu(278): error: "wait" is ambiguous
       wait(k_smem_arrived[(kv_idx)%K::stages], (kv_idx/K::stages)%2);
       ^

/home/amogkamsetty/FastVideo/csrc/sliding_tile_attention/st_attn/st_attn_h100.cu(313): error: "wait" is ambiguous
       wait(v_smem_arrived[(kv_idx)%K::stages], (kv_idx/K::stages)%2);

@amogkam
Copy link
Author

amogkam commented Feb 19, 2025

seems to be this same issue as in ThunderKittens

HazyResearch/ThunderKittens#54

@jzhang38
Copy link
Collaborator

jzhang38 commented Feb 19, 2025

Remember to git pull --recurse-submodules first.
I reinstalled everything on my end and it works OK.

Have you installed C++20?

https://github.com/HazyResearch/ThunderKittens?tab=readme-ov-file#library-requirements

@jfischoff
Copy link

These changes resulted in it compiling successfull for me

diff --git a/csrc/sliding_tile_attention/st_attn/st_attn_h100.cu b/csrc/sliding_tile_attention/st_attn/st_attn_h100.cu
index e0e8c40..8da971e 100644
--- a/csrc/sliding_tile_attention/st_attn/st_attn_h100.cu
+++ b/csrc/sliding_tile_attention/st_attn/st_attn_h100.cu
@@ -161,7 +161,7 @@ void fwd_attend_ker(const __grid_constant__ fwd_globals<D> g) {
                     tma::load_async(k_smem[(kv_idx+1)%K::stages], g.k, kv_tile_idx, k_smem_arrived[(kv_idx+1)%K::stages]);
                     tma::expect_bytes(v_smem_arrived[(kv_idx+1)%K::stages], sizeof(v_tile));
                     tma::load_async(v_smem[(kv_idx+1)%K::stages], g.v, kv_tile_idx, v_smem_arrived[(kv_idx+1)%K::stages]);
-                    wait(compute_done[(kv_idx)%K::stages], (kv_idx/K::stages)%2);
+                    kittens::wait(compute_done[(kv_idx)%K::stages], (kv_idx/K::stages)%2);
                 }
             } else {
                 int qt = seq_idx / 6 / (CH * CW);
@@ -188,7 +188,7 @@ void fwd_attend_ker(const __grid_constant__ fwd_globals<D> g) {
                                     tma::load_async(k_smem[count%K::stages], g.k, kv_tile_idx, k_smem_arrived[count%K::stages]);
                                     tma::expect_bytes(v_smem_arrived[count%K::stages], sizeof(v_tile));
                                     tma::load_async(v_smem[count%K::stages], g.v, kv_tile_idx, v_smem_arrived[count%K::stages]);
-                                    wait(compute_done[(count - 1)%K::stages], ((count - 1)/K::stages)%2);
+                                    kittens::wait(compute_done[(count - 1)%K::stages], ((count - 1)/K::stages)%2);
                                     count += 1;
                                 } else {
                                     count += 1;
@@ -204,7 +204,7 @@ void fwd_attend_ker(const __grid_constant__ fwd_globals<D> g) {
                     tma::load_async(k_smem[count%K::stages], g.k, kv_tile_idx, k_smem_arrived[count%K::stages]);
                     tma::expect_bytes(v_smem_arrived[count%K::stages], sizeof(v_tile));
                     tma::load_async(v_smem[count%K::stages], g.v, kv_tile_idx, v_smem_arrived[count%K::stages]);
-                    wait(compute_done[(count - 1)%K::stages], ((count - 1)/K::stages)%2);
+                    kittens::wait(compute_done[(count - 1)%K::stages], ((count - 1)/K::stages)%2);
                     count += 1;
                 }
             }
@@ -237,10 +237,10 @@ void fwd_attend_ker(const __grid_constant__ fwd_globals<D> g) {
             kv_iters = CLAMP(DT*2+1, 1, CT) * CLAMP(DH*2+1, 1, CH) * CLAMP(DW*2+1, 1, CW) * 3 - 1 ; 
         }
 
-        wait(qsmem_semaphore, 0);
+        kittens::wait(qsmem_semaphore, 0);
         for (auto kv_idx = 0; kv_idx <= kv_iters; kv_idx++) {
 
-            wait(k_smem_arrived[(kv_idx)%K::stages], (kv_idx/K::stages)%2);
+            kittens::wait(k_smem_arrived[(kv_idx)%K::stages], (kv_idx/K::stages)%2);
             warpgroup::mm_ABt(att_block, q_smem[warpgroupid], k_smem[(kv_idx)%K::stages]);
             
             copy(max_vec_last_scaled, max_vec);
@@ -270,7 +270,7 @@ void fwd_attend_ker(const __grid_constant__ fwd_globals<D> g) {
             copy(att_block_mma, att_block); 
             mul_row(o_reg, o_reg, max_vec_last_scaled); 
 
-            wait(v_smem_arrived[(kv_idx)%K::stages], (kv_idx/K::stages)%2); 
+            kittens::wait(v_smem_arrived[(kv_idx)%K::stages], (kv_idx/K::stages)%2); 
 
             warpgroup::mma_AB(o_reg, att_block_mma, v_smem[(kv_idx)%K::stages]);
             warpgroup::mma_async_wait();
@@ -281,7 +281,7 @@ void fwd_attend_ker(const __grid_constant__ fwd_globals<D> g) {
         if constexpr(text_kv) {
             for (auto kv_idx = kv_iters + 1; kv_idx <= kv_iters + 3; kv_idx++) {
 
-                wait(k_smem_arrived[(kv_idx)%K::stages], (kv_idx/K::stages)%2);
+                kittens::wait(k_smem_arrived[(kv_idx)%K::stages], (kv_idx/K::stages)%2);
                 warpgroup::mm_ABt(att_block, q_smem[warpgroupid], k_smem[(kv_idx)%K::stages]);

@amogkam
Copy link
Author

amogkam commented Feb 24, 2025

Thanks I can get this built now!

@amogkam amogkam closed this as completed Feb 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants