Skip to content

Commit

Permalink
Enable flashinfer when group_size == 6 (#2124)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored Apr 12, 2024
1 parent 9b71443 commit 880c68a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def create_flashinfer_paged_kv_cache(
in self.metadata["model_type"]
)
# filter by attention group size
or kwargs["num_attention_heads"] // kwargs["num_key_value_heads"] not in [1, 4, 8]
or kwargs["num_attention_heads"] // kwargs["num_key_value_heads"] not in [1, 4, 6, 8]
):
return

Expand Down
4 changes: 2 additions & 2 deletions python/mlc_llm/op/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,12 @@ def _fallback():
and k.dtype == "float16"
and v.dtype == "float16"
):
if group_size not in [1, 4, 8]:
if group_size not in [1, 4, 6, 8]:
global WARN_FLASHINFER_GROUP_SIZE # pylint: disable=global-statement
if not WARN_FLASHINFER_GROUP_SIZE:
WARN_FLASHINFER_GROUP_SIZE = True
logger.warning(
"FlashInfer only supports group size in [1, 4, 8], but got %d. Skip and "
"FlashInfer only supports group size in [1, 4, 6, 8], but got %d. Skip and "
"fallback to default implementation.",
group_size,
)
Expand Down

0 comments on commit 880c68a

Please sign in to comment.