Skip to content

Commit 844da1c

Browse files
committed
add assertion on len == 1
1 parent 9c3a519 commit 844da1c

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

torch/_inductor/kernel/flex_attention.py

+2
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,8 @@ def flex_attention(
814814
choices: List[Any] = []
815815
input_nodes = [query, key, value, kv_indices]
816816
if score_mod_other_buffers and mask_mod_other_buffers:
817+
assert len(score_mod_other_buffers) == 1
818+
assert len(mask_mod_other_buffers) == 1
817819
input_nodes += [score_mod_other_buffers[0], mask_mod_other_buffers[0]]
818820
CppMHATemplate.add_choices(
819821
choices=choices,

0 commit comments

Comments
 (0)