Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Allow Q K has different sequence len #219

Open
vanhowe opened this issue Mar 10, 2025 · 3 comments
Open

[Feature Request] Allow Q K has different sequence len #219

vanhowe opened this issue Mar 10, 2025 · 3 comments
Labels
enhancement New feature or request

Comments

@vanhowe
Copy link

vanhowe commented Mar 10, 2025

Feature Request

Allow Q K has different sequence len in which I'd like to do a cross modality alignment on GLA?

Motivation

I tried to change it on my own but find it has a lot of knowledge and module tangle, so diffcult to fix

Your Contribution

If I can get some info on where to fix I think I can help to enable the seq_q and seq_kv

@vanhowe vanhowe added the enhancement New feature or request label Mar 10, 2025
@vanhowe
Copy link
Author

vanhowe commented Mar 10, 2025

just have a dumb question, if Q and KV has different T, T_q and T_kv, then different chunk_num, I guess it should be like looping over KV on [j] where j in range[0, T_kv // C ] and then on i in range[0, T_q //C] , but I am new to Triton not sure about the pointer and the i loop,

i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)

o_i = tl.arange(0, BT)

# make block pointers
p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (0, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, 0), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (0, i_v * BV), (BT, BV), (1, 0))
p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_v_h, (T, V), (s_v_t, s_v_d), (0, i_v * BV), (BT, BV), (1, 0))

for i in range(0, tl.cdiv(T, BT)):
 .....
    b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
    if CHECK and i == 0:
        b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
        b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)
    tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))

@sustcsonglin
Copy link
Collaborator

Hi @vanhowe I am wondering what is the usage case of allowing different QK length?

@vanhowe
Copy link
Author

vanhowe commented Mar 11, 2025

Hi, I was researching, wondier if your efficient gated attention - could extending it to asymmetric Q/KV lengths help bridge modalities too? This small architectural shift could enable time series/LLM alignment research or even other cross modalities areas. And honestly I've tried with GLA from my own but found diffuclt for me to debug, that why I would like to request it. LOL

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants