Skip to content

Commit 167f160

Browse files
mehdi-golimuhammad-tanvir-1211aacostadiaz
authored
correct flop calculation for causal mask when qk_seq_len!=kv_seq_len (#332)
This PR fixes the number of floating point operation and read/write bytes when masking applied for cases that `qk_seq_len` is not equal to `kv_seq_len`. --------- Co-authored-by: Muhammad Tanvir <[email protected]> Co-authored-by: Alejandro Acosta <[email protected]>
1 parent 9cc9df2 commit 167f160

File tree

2 files changed

+23
-15
lines changed

2 files changed

+23
-15
lines changed

benchmarks/pvc/flash_attention_v2/benchmark_runner.hpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -535,13 +535,19 @@ template <class FMHAConfiguration> struct BenchmarkRunnerFMHA {
535535
extra_label << "layoutV=RowMajor ";
536536

537537
state.SetLabel(extra_label.str());
538-
539-
double flops_qk = 2.0 * options.batch * options.num_heads_q * options.seq_len_qo * options.seq_len_kv * options.head_size_qk;
540-
double flops_pv = 2.0 * options.batch * options.num_heads_q * options.seq_len_qo * options.head_size_vo * options.seq_len_kv;
538+
// when seq_len_qo is not equal to seq_len_kv we use bottom up approach for the masking.
539+
// Following changes will adjust the effective_seq_len_kv when masking applied for such cases.
540+
auto offset = cute::min(options.seq_len_qo, options.seq_len_kv);
541+
auto discard_seq_coord = options.seq_len_qo - offset;
542+
auto full_tile_offset = options.seq_len_kv - offset;
543+
auto effective_seq_len_kv = Causal ? full_tile_offset + (options.seq_len_kv / 2.0): options.seq_len_kv;
544+
auto effective_seq_len_qo = Causal ? options.seq_len_qo - discard_seq_coord : options.seq_len_qo;
545+
546+
double flops_qk = 2.0 * options.batch * options.num_heads_q * effective_seq_len_qo * effective_seq_len_kv * options.head_size_qk;
547+
double flops_pv = 2.0 * options.batch * options.num_heads_q * effective_seq_len_qo * options.head_size_vo * effective_seq_len_kv;
541548
double gflops = (flops_qk + flops_pv) * 1e-9;
542-
543-
double gbps_qk = 2.0 * options.batch * options.num_heads_q * (options.seq_len_qo * options.head_size_qk + options.seq_len_kv * options.head_size_qk);
544-
double gbps_pv = 2.0 * options.batch * options.num_heads_q * (options.seq_len_kv * options.seq_len_qo + options.seq_len_qo * options.head_size_vo);
549+
double gbps_qk = 2.0 * options.batch * options.num_heads_q * (effective_seq_len_qo * options.head_size_qk + effective_seq_len_kv * options.head_size_qk);
550+
double gbps_pv = 2.0 * options.batch * options.num_heads_q * (effective_seq_len_kv * effective_seq_len_qo + effective_seq_len_qo * options.head_size_vo);
545551
double mega_bytes_transferred = (gbps_qk + gbps_pv) * (1e-6);
546552

547553
initialize_counters(state);

examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -530,17 +530,19 @@ template <class GemmKernel, bool isVarLen> struct ExampleRunner {
530530
run(params);
531531
}
532532
syclcompat::wait();
533-
534-
double effective_seq_len_kv = options.is_causal ?
535-
options.seq_len_kv / 2.0 :
536-
options.seq_len_kv;
537-
533+
// when seq_len_qo is not equal to seq_len_kv we use bottom up approach for the masking.
534+
// Following changes will adjust the effective_seq_len_kv when masking applied for such cases
535+
auto offset = cute::min(options.seq_len_qo, options.seq_len_kv);
536+
auto discard_seq_coord = options.seq_len_qo - offset;
537+
auto full_tile_offset = options.seq_len_kv - offset;
538+
auto effective_seq_len_kv = options.is_causal ? full_tile_offset + (options.seq_len_kv / 2.0): options.seq_len_kv;
539+
auto effective_seq_len_qo = options.is_causal ? options.seq_len_qo - discard_seq_coord : options.seq_len_qo;
538540
double cute_time = timer.seconds() / options.iterations;
539-
double flops_qk = 2.0 * options.batch * options.num_heads_q * options.seq_len_qo * effective_seq_len_kv * options.head_size_qk;
540-
double flops_pv = 2.0 * options.batch * options.num_heads_q * options.seq_len_qo * options.head_size_vo * effective_seq_len_kv;
541+
double flops_qk = 2.0 * options.batch * options.num_heads_q * effective_seq_len_qo * effective_seq_len_kv * options.head_size_qk;
542+
double flops_pv = 2.0 * options.batch * options.num_heads_q * effective_seq_len_qo * options.head_size_vo * effective_seq_len_kv;
541543
double tflops = ((flops_qk + flops_pv) * 1e-12) / cute_time;
542-
double gbps_qk = 2.0 * options.batch * options.num_heads_q * (options.seq_len_qo * options.head_size_qk + effective_seq_len_kv * options.head_size_qk);
543-
double gbps_pv = 2.0 * options.batch * options.num_heads_q * (effective_seq_len_kv * options.seq_len_qo + options.seq_len_qo * options.head_size_vo);
544+
double gbps_qk = 2.0 * options.batch * options.num_heads_q * (effective_seq_len_qo * options.head_size_qk + effective_seq_len_kv * options.head_size_qk);
545+
double gbps_pv = 2.0 * options.batch * options.num_heads_q * (effective_seq_len_kv * effective_seq_len_qo + effective_seq_len_qo * options.head_size_vo);
544546
double gbps = ((gbps_qk + gbps_pv) * 1e-9) / (cute_time);
545547
std::cout << "Batch: " << options.batch << "\tNumHeads_q: " << options.num_heads_q << "\tNumHeads_kv: " << options.num_heads_kv << "\tSeq Length QO: " << options.seq_len_qo
546548
<< "\tSeq Length KV: " << options.seq_len_kv << "\tHead Size QK: " << options.head_size_qk << "\tHead Size VO: " << options.head_size_vo

0 commit comments

Comments
 (0)