@@ -530,17 +530,19 @@ template <class GemmKernel, bool isVarLen> struct ExampleRunner {
530
530
run (params);
531
531
}
532
532
syclcompat::wait ();
533
-
534
- double effective_seq_len_kv = options.is_causal ?
535
- options.seq_len_kv / 2.0 :
536
- options.seq_len_kv ;
537
-
533
+ // when seq_len_qo is not equal to seq_len_kv we use bottom up approach for the masking.
534
+ // Following changes will adjust the effective_seq_len_kv when masking applied for such cases
535
+ auto offset = cute::min (options.seq_len_qo , options.seq_len_kv );
536
+ auto discard_seq_coord = options.seq_len_qo - offset;
537
+ auto full_tile_offset = options.seq_len_kv - offset;
538
+ auto effective_seq_len_kv = options.is_causal ? full_tile_offset + (options.seq_len_kv / 2.0 ): options.seq_len_kv ;
539
+ auto effective_seq_len_qo = options.is_causal ? options.seq_len_qo - discard_seq_coord : options.seq_len_qo ;
538
540
double cute_time = timer.seconds () / options.iterations ;
539
- double flops_qk = 2.0 * options.batch * options.num_heads_q * options. seq_len_qo * effective_seq_len_kv * options.head_size_qk ;
540
- double flops_pv = 2.0 * options.batch * options.num_heads_q * options. seq_len_qo * options.head_size_vo * effective_seq_len_kv;
541
+ double flops_qk = 2.0 * options.batch * options.num_heads_q * effective_seq_len_qo * effective_seq_len_kv * options.head_size_qk ;
542
+ double flops_pv = 2.0 * options.batch * options.num_heads_q * effective_seq_len_qo * options.head_size_vo * effective_seq_len_kv;
541
543
double tflops = ((flops_qk + flops_pv) * 1e-12 ) / cute_time;
542
- double gbps_qk = 2.0 * options.batch * options.num_heads_q * (options. seq_len_qo * options.head_size_qk + effective_seq_len_kv * options.head_size_qk );
543
- double gbps_pv = 2.0 * options.batch * options.num_heads_q * (effective_seq_len_kv * options. seq_len_qo + options. seq_len_qo * options.head_size_vo );
544
+ double gbps_qk = 2.0 * options.batch * options.num_heads_q * (effective_seq_len_qo * options.head_size_qk + effective_seq_len_kv * options.head_size_qk );
545
+ double gbps_pv = 2.0 * options.batch * options.num_heads_q * (effective_seq_len_kv * effective_seq_len_qo + effective_seq_len_qo * options.head_size_vo );
544
546
double gbps = ((gbps_qk + gbps_pv) * 1e-9 ) / (cute_time);
545
547
std::cout << " Batch: " << options.batch << " \t NumHeads_q: " << options.num_heads_q << " \t NumHeads_kv: " << options.num_heads_kv << " \t Seq Length QO: " << options.seq_len_qo
546
548
<< " \t Seq Length KV: " << options.seq_len_kv << " \t Head Size QK: " << options.head_size_qk << " \t Head Size VO: " << options.head_size_vo
0 commit comments