Skip to content

Commit 77ca02f

Browse files
authored
Enable flash attention by default (#690)
1 parent 3d8cad8 commit 77ca02f

File tree

3 files changed

+3
-5
lines changed

3 files changed

+3
-5
lines changed

sharktank/sharktank/layers/configs/llm_configs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ class LlamaModelConfig:
167167
tensor_parallelism_size: int = 1
168168

169169
# Which attention kernel to use.
170-
attention_kernel: str = "decomposed"
170+
attention_kernel: str = "torch"
171171

172172
# Indicates if running with HuggingFace implementation and ensures
173173
# numerical equivalency to HuggingFace's LLaMa if true (by modifying

sharktank/sharktank/layers/paged_llama_attention_block.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -216,14 +216,12 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
216216
attn_weights, values
217217
) # (bs, heads, slen, head_dim)
218218
else:
219-
is_causal = True
220-
attention_mask = None
221219
attn_output = ops.scaled_dot_product_attention(
222220
q=xq, # [bs, ..., sl, dim]
223221
k=keys, # [bs, ..., sl, dim]
224222
v=values, # [bs, ..., sl, dim]
225223
a=attention_mask, # [bs, ..., sl, sl]
226-
is_causal=is_causal, # assumes causal masking when true
224+
is_causal=False, # assumes causal masking when true
227225
scale=None, # defaults to 1/sqrt(dim)
228226
)
229227

sharktank/sharktank/utils/cli.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def add_model_options(parser: argparse.ArgumentParser):
6666
parser.add_argument(
6767
"--attention-kernel",
6868
type=str,
69-
default="decomposed",
69+
default="torch",
7070
choices=["decomposed", "torch"],
7171
)
7272
parser.add_argument(

0 commit comments

Comments
 (0)