Skip to content

Commit 344cda2

Browse files
committed
force is_causal for ndc
1 parent d3b0681 commit 344cda2

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

sharktank/sharktank/layers/paged_llama_attention_block.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
170170
attn_weights, values
171171
) # (bs, heads, slen, head_dim)
172172
else:
173-
is_causal = attention_mask is None and batch_seq_len == 1
173+
# Use the builtin attention mask when not decomposed
174+
is_causal = True
175+
attention_mask = None
174176
attn_output = torch.nn.functional.scaled_dot_product_attention(
175177
query=xq, # [bs, ..., sl, dim]
176178
key=keys, # [bs, ..., sl, dim]

0 commit comments

Comments
 (0)