Skip to content

Commit

Permalink
fix sequence parallel attention in splash attention (#738)
Browse files Browse the repository at this point in the history
* fix sequence parallel attention in splash attention

* revert head change
  • Loading branch information
dlwh committed Sep 22, 2024
1 parent 07b3f16 commit fe3e2f3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 20 deletions.
3 changes: 3 additions & 0 deletions config/gpt2_small_fast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ trainer:

train_batch_size: 256
num_train_steps: 20000

# tensor_parallel_axes: ["position", "key_position"]
# tensor_parallel_axes: ["heads", "mlp"]
optimizer:
learning_rate: 1E-3
weight_decay: 0.1
Expand Down
47 changes: 27 additions & 20 deletions src/levanter/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,10 @@ def flatten(axes):
check_rep=False,
)
def wrap_flash_attention(q, k, v):
# NB: inside the function, q, k, and v are partitioned, so in general the lengths of dims are not the same
Sq = q.shape[2]
Sk = k.shape[2]
Hq = q.shape[1]
block_sizes = splash_attention_kernel.BlockSizes(
block_q=min(block_size, Sq),
block_kv_compute=min(block_size, Sk),
Expand All @@ -848,21 +852,23 @@ def wrap_flash_attention(q, k, v):
)

if mask is None:
kernel_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk))
base_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk))
elif isinstance(mask, AttentionMask):
if mask.is_causal:
masks = [splash_attention_mask.CausalMask(shape=(Sq, Sq)) for i in range(Hq)]
kernel_mask = splash_attention_mask.MultiHeadMask(masks=masks)
base_mask = splash_attention_mask.CausalMask(shape=(Sq, Sk))
else:
kernel_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk))
base_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk))

# This is going to be a pain to support
if mask.explicit_mask is not None:
raise NotImplementedError("Explicit masks are not yet supported for splash attention")
elif isinstance(mask, NamedArray):
raise NotImplementedError("NamedArray masks are not yet supported for splash attention")
else:
raise ValueError(f"Unknown mask type: {mask}")

kernel_mask = splash_attention_mask.MultiHeadMask(masks=[base_mask for _ in range(Hq)])

# copied from MaxText
splash_kernel = splash_attention_kernel.make_splash_mha(
mask=kernel_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes
Expand All @@ -879,22 +885,23 @@ def wrap_flash_attention(q, k, v):
# the output shape is B, S_q, H_q, D_v. Right now we're requiring D_k == D_v
# we can reshape it to match our expected output
attn_output = _unflatten_bshd(attn_output, q_class, v_class)
reference_out_shape = eqx.filter_eval_shape(
simple_attention_with_dropout,
QPos,
KPos,
Key,
query,
key,
value,
mask,
bias,
inference,
dropout,
attention_dtype,
precision,
prng=prng,
)
with haliax.axis_mapping({}):
reference_out_shape = eqx.filter_eval_shape(
simple_attention_with_dropout,
QPos,
KPos,
Key,
query,
key,
value,
mask,
bias,
inference,
dropout,
attention_dtype,
precision,
prng=prng,
)
attn_output = attn_output.rearrange(reference_out_shape.axes).astype(reference_out_shape.dtype)

attn_output = haliax.shard(attn_output)
Expand Down

0 comments on commit fe3e2f3

Please sign in to comment.