Skip to content

Commit

Permalink
Add env var to make flop formula return worst case
Browse files Browse the repository at this point in the history
ghstack-source-id: 9e5f6ae8be4890f7e635f1f0c12a534e5725d0f5
Pull Request resolved: fairinternal/xformers#1299

__original_commit__ = fairinternal/xformers@e13a77f
  • Loading branch information
lw authored and xFormers Bot committed Feb 10, 2025
1 parent ad4345a commit 144ca64
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions xformers/ops/fmha/flash3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.


import os
from typing import Any, Iterable, List, Optional, Sequence, Set, Tuple

import torch
Expand Down Expand Up @@ -195,6 +196,18 @@ def mha_fwd_flops(
assert 3 <= query.ndim <= 4
assert 3 <= key.ndim <= 4
assert 3 <= value.ndim <= 4
# This FLOP formula is used by torch.compile's partitioner "automatic
# activation checkpointing" (AutoAC) to decide which ops to preserve
# for backward or to recompute. However, this formula is data-dependent!
# This makes all invocations reuse the choices made based on the first
# inputs, which may be sub-optimal but also lead to inconsistent
# behavior across runs. In the presence of tensor parallelism it might
# also lead to deadlocks if AutoAC recomputes different collectives
# on different ranks. For distributed jobs it seems more robust to have
# all ranks always use the "worst case" FLOP estimate. Ranks are in
# lockstep anyways and will be going as fast as the slowest one.
if os.environ.get("XFORMERS_FLOP_FORMULA_WORST_CASE", "0") == "1":
cu_seqlens_q = cu_seqlens_k = max_seqlen_q = max_seqlen_k = None # type: ignore[assignment]
sizes = _unpack_flash_attention_nested_shapes(
query=query.transpose(-2, -3) if query.ndim == 4 else query,
key=key.transpose(-2, -3) if key.ndim == 4 else key,
Expand All @@ -209,7 +222,7 @@ def mha_fwd_flops(
for query_shape, key_shape, value_shape, _ in sizes
)
if is_causal:
res /= 2
res //= 2
return res

def _create_dq_dk_dv(
Expand Down Expand Up @@ -336,6 +349,9 @@ def mha_bwd_flops(
assert 3 <= query.ndim <= 4
assert 3 <= key.ndim <= 4
assert 3 <= value.ndim <= 4
# See the fwd FLOP formula above for reasoning behind this.
if os.environ.get("XFORMERS_FLOP_FORMULA_WORST_CASE", "0") == "1":
cu_seqlens_q = cu_seqlens_k = max_seqlen_q = max_seqlen_k = None # type: ignore[assignment]
res = _flash_attention_backward_flop(
dout.transpose(-2, -3) if dout.ndim == 4 else dout,
query.transpose(-2, -3) if query.ndim == 4 else query,
Expand All @@ -349,7 +365,7 @@ def mha_bwd_flops(
max_seqlen_k,
)
if is_causal:
res /= 2
res //= 2
return res


Expand Down

0 comments on commit 144ca64

Please sign in to comment.