diff --git a/config/gpt2_small_fast.yaml b/config/gpt2_small_fast.yaml index 6242a37bc..054977899 100644 --- a/config/gpt2_small_fast.yaml +++ b/config/gpt2_small_fast.yaml @@ -19,6 +19,9 @@ trainer: train_batch_size: 256 num_train_steps: 20000 + +# tensor_parallel_axes: ["position", "key_position"] +# tensor_parallel_axes: ["heads", "mlp"] optimizer: learning_rate: 1E-3 weight_decay: 0.1 diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index e7c94f50b..633feee68 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -836,6 +836,10 @@ def flatten(axes): check_rep=False, ) def wrap_flash_attention(q, k, v): + # NB: inside the function, q, k, and v are partitioned, so in general the lengths of dims are not the same + Sq = q.shape[2] + Sk = k.shape[2] + Hq = q.shape[1] block_sizes = splash_attention_kernel.BlockSizes( block_q=min(block_size, Sq), block_kv_compute=min(block_size, Sk), @@ -848,14 +852,14 @@ def wrap_flash_attention(q, k, v): ) if mask is None: - kernel_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk)) + base_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk)) elif isinstance(mask, AttentionMask): if mask.is_causal: - masks = [splash_attention_mask.CausalMask(shape=(Sq, Sq)) for i in range(Hq)] - kernel_mask = splash_attention_mask.MultiHeadMask(masks=masks) + base_mask = splash_attention_mask.CausalMask(shape=(Sq, Sk)) else: - kernel_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk)) + base_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk)) + # This is going to be a pain to support if mask.explicit_mask is not None: raise NotImplementedError("Explicit masks are not yet supported for splash attention") elif isinstance(mask, NamedArray): @@ -863,6 +867,8 @@ def wrap_flash_attention(q, k, v): else: raise ValueError(f"Unknown mask type: {mask}") + kernel_mask = splash_attention_mask.MultiHeadMask(masks=[base_mask for _ in range(Hq)]) + # copied from MaxText splash_kernel = splash_attention_kernel.make_splash_mha( mask=kernel_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes @@ -879,22 +885,23 @@ def wrap_flash_attention(q, k, v): # the output shape is B, S_q, H_q, D_v. Right now we're requiring D_k == D_v # we can reshape it to match our expected output attn_output = _unflatten_bshd(attn_output, q_class, v_class) - reference_out_shape = eqx.filter_eval_shape( - simple_attention_with_dropout, - QPos, - KPos, - Key, - query, - key, - value, - mask, - bias, - inference, - dropout, - attention_dtype, - precision, - prng=prng, - ) + with haliax.axis_mapping({}): + reference_out_shape = eqx.filter_eval_shape( + simple_attention_with_dropout, + QPos, + KPos, + Key, + query, + key, + value, + mask, + bias, + inference, + dropout, + attention_dtype, + precision, + prng=prng, + ) attn_output = attn_output.rearrange(reference_out_shape.axes).astype(reference_out_shape.dtype) attn_output = haliax.shard(attn_output)