Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: Memory Efficient Flash Attention for gfx1100 (7900xtx) #16

Closed
supernovae opened this issue Apr 19, 2024 · 31 comments · Fixed by #39
Closed

[Feature]: Memory Efficient Flash Attention for gfx1100 (7900xtx) #16

supernovae opened this issue Apr 19, 2024 · 31 comments · Fixed by #39
Labels
enhancement New feature or request

Comments

@supernovae
Copy link

Suggestion Description

Started using torchlearn to train models in pytorch using my gfx1100 card but get a warning that 1toch was not compiled with memory efficient flash attention.

I see there is a recently merged patch pending nightly in pytorch for adding rocm flash attention support but it looks like it's only targeting the MI200 and 300 cards. Any plans to support consumer/workstation cards? Can i compile this in myself today?

torchlearn is getting about 1.6it/s compared to 3it/s some people get with a 4090... while my card costs 1/2 as much as a 4090, i should be within 80-85% the performance but i think the lack of memory efficient flash attention is hindering

Operating System

ubuntu 22.0.4

GPU

gfx1100

ROCm Component

No response

@xinyazhang xinyazhang added the enhancement New feature or request label Apr 19, 2024
@xinyazhang
Copy link
Collaborator

xinyazhang commented Apr 19, 2024

We are already working on the efficient attention.

However since AOTriton uses Triton as compiler, the actual landing of the gfx1100 support for efficient attention may take longer.

Any plans to support consumer/workstation cards?

Now AOTriton is only compiled with MI200/300 series DC GPUs (commonly known as CDNA2/3 architectures). We are going to add Navi targets once the Triton compiler supports them.

Can i compile this in myself today?

Unfortunately this does not help due to missing proper support of Navi3x in Triton, mainly about missing WMMA compiler support. We are actively working on it and you'll see its landing as soon as it gets supported.

@supernovae
Copy link
Author

Thank you for the detailed update! I'll be following closely as I'm super interested in maximizing the utility of the 7900s and seeing more options on the market to compete against Nvidia.

@sdli1995
Copy link

We are already working on the efficient attention.

However since AOTriton uses Triton as compiler, the actual landing of the gfx1100 support for efficient attention may take longer.

Any plans to support consumer/workstation cards?

Now AOTriton is only compiled with MI200/300 series DC GPUs (commonly known as CDNA2/3 architectures). We are going to add Navi targets once the Triton compiler supports them.

Can i compile this in myself today?

Unfortunately this does not help due to missing proper support of Navi3x in Triton, mainly about missing WMMA compiler support. We are actively working on it and you'll see its landing as soon as it gets supported.

it's seems gfx1100 triton support is upstreamed , ROCm/triton#250 (comment) accord this issue comment and it's show the triton flash-att performance not good

@supernovae
Copy link
Author

I was just going to reference that I saw this issue said compiler support for RDNA3 is available in triton but i can't find much info beyond that yet: triton-lang/triton#3704 but in this case, someone was on an older generation card not a 7900xtx

@xinyazhang
Copy link
Collaborator

RDNA support is a more complicated topic because it requires an upgrade to the Triton compiler, which causes quite a few compatibility problems.

@xinyazhang
Copy link
Collaborator

cc: @jayfurmanek what is the status of Navi support on both repos?
cc: @groenenboomj

@supernovae
Copy link
Author

Any update on flash attention being in released builds of aotriton? would love to give this a whirl soon.

@xinyazhang
Copy link
Collaborator

Any update on flash attention being in released builds of aotriton? would love to give this a whirl soon.

We are doing experiments on Navi and fixing Triton compiler problems now (See: ROCm/triton#596 for more details)

A newer compiler is necessary to support Navi.

@sdli1995
Copy link

sdli1995 commented Jun 6, 2024

Any update on flash attention being in released builds of aotriton? would love to give this a whirl soon.

We are doing experiments on Navi and fixing Triton compiler problems now (See: ROCm/triton#596 for more details)

A newer compiler is necessary to support Navi.

I use upstreaming triton and it's triton flash-att impl not faster as it's expected
here is my benchmark script and tool install commands

import time
import torch
import triton

from triton import cdiv, jit
from triton import language as tl
from torch.nn.functional import scaled_dot_product_attention

torch.cuda.set_device('cuda:0')

def is_hip():
   return triton.runtime.driver.active.get_current_target().backend == "hip"


@jit
def _fwd_kernel(Q, K, V, sm_scale,  #
               L,  #
               Out,  #
               stride_qz, stride_qh, stride_qm, stride_qk,  #
               stride_kz, stride_kh, stride_kn, stride_kk,  #
               stride_vz, stride_vh, stride_vn, stride_vk,  #
               stride_oz, stride_oh, stride_om, stride_on,  #
               Z, H, N_CTX,  #
               Z_H_N_CTX,  #
               BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,  #
               BLOCK_N: tl.constexpr,  #
               IS_CAUSAL: tl.constexpr  #
               ):
   start_m = tl.program_id(0)
   off_hz = tl.program_id(1)
   qvk_offset = off_hz * stride_qh
   vk_offset = qvk_offset // stride_qm

   K_block_ptr = tl.make_block_ptr(
       base=K,
       shape=(BLOCK_DMODEL, Z_H_N_CTX),
       strides=(stride_kk, stride_kn),
       offsets=(0, vk_offset),
       block_shape=(BLOCK_DMODEL, BLOCK_N),
       order=(0, 1),
   )
   V_block_ptr = tl.make_block_ptr(
       base=V,
       shape=(Z_H_N_CTX, BLOCK_DMODEL),
       strides=(stride_vn, stride_vk),
       offsets=(vk_offset, 0),
       block_shape=(BLOCK_N, BLOCK_DMODEL),
       order=(1, 0),
   )
   # initialize offsets
   offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
   offs_n = tl.arange(0, BLOCK_N)
   # initialize pointer to m and l
   m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
   l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
   acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
   # credits to: Adam P. Goucher (https://github.com/apgoucher):
   # scale sm_scale by 1/log_2(e) and use
   # 2^x instead of exp in the loop because CSE and LICM
   # don't work as expected with `exp` in the loop
   qk_scale = sm_scale * 1.44269504
   # load q: it will stay in SRAM throughout

   offs_k = tl.arange(0, BLOCK_DMODEL)
   Q_ptrs = Q + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
   q = tl.load(Q_ptrs)

   q = (q * qk_scale).to(K.dtype.element_ty)
   lo = 0
   hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
   for start_n in range(lo, hi, BLOCK_N):
       # -- load k, v --
       k = tl.load(K_block_ptr)
       v = tl.load(V_block_ptr)
       # -- compute qk ---
       qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
       if IS_CAUSAL:
           qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
       qk += tl.dot(q, k)
       # -- compute scaling constant ---
       m_i_new = tl.maximum(m_i, tl.max(qk, 1))
       alpha = tl.math.exp2(m_i - m_i_new)
       p = tl.math.exp2(qk - m_i_new[:, None])
       # -- scale and update acc --
       acc *= alpha[:, None]
       acc += tl.dot(p.to(V.dtype.element_ty), v)
       # -- update m_i and l_i --
       l_i = l_i * alpha + tl.sum(p, 1)
       m_i = m_i_new
       # update pointers
       K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
       V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
   # write back l and m
   acc = acc / l_i[:, None]
   l_ptrs = L + off_hz * N_CTX + offs_m
   tl.store(l_ptrs, m_i + tl.math.log2(l_i))
   # write back O
   O_block_ptr = tl.make_block_ptr(
       base=Out,
       shape=(Z_H_N_CTX, BLOCK_DMODEL),
       strides=(stride_om, stride_on),
       offsets=(vk_offset + start_m * BLOCK_M, 0),
       block_shape=(BLOCK_M, BLOCK_DMODEL),
       order=(1, 0),
   )
   # O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
   tl.store(O_block_ptr, acc.to(K.dtype.element_ty))


@jit
def _bwd_preprocess(
   Out,
   DO,
   Delta,
   BLOCK_M: tl.constexpr,
   D_HEAD: tl.constexpr,
):
   off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
   off_n = tl.arange(0, D_HEAD)
   # load
   o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
   do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
   # compute
   delta = tl.sum(o * do, axis=1)
   # write-back
   tl.store(Delta + off_m, delta)


@jit
def _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale,  #
                             Out, DO,  #
                             DQ, DK, DV,  #
                             L,  #
                             D,  #
                             Q_block_ptr, K_block_ptr, V_block_ptr,  #
                             DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr,  #
                             stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,  #
                             stride_kz, stride_kh, stride_kn, stride_kk,  #
                             stride_vz, stride_vh, stride_vn, stride_vk,  #
                             Z, H, N_CTX,  #
                             off_h, off_z, off_hz, start_n, num_block,  #
                             BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,  #
                             BLOCK_N: tl.constexpr,  #
                             SEQUENCE_PARALLEL: tl.constexpr,  #
                             CAUSAL: tl.constexpr,  #
                             MMA_V3: tl.constexpr  #
                             ):
   if CAUSAL:
       lo = start_n * BLOCK_M
   else:
       lo = 0

   Q_offset = (off_z * stride_qz + off_h * stride_qh) // stride_qm
   DQ_offset = off_z * stride_qz + off_h * stride_qh
   K_offset = (off_z * stride_kz + off_h * stride_kh) // stride_kn
   V_offset = (off_z * stride_vz + off_h * stride_vh) // stride_vn
   if SEQUENCE_PARALLEL:
       DQ_offset += stride_dqa * start_n
   DQ_offset = DQ_offset // stride_qm

   Q_block_ptr = tl.advance(Q_block_ptr, (lo + Q_offset, 0))
   K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M + K_offset, 0))
   V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M + V_offset, 0))
   DO_block_ptr = tl.advance(DO_block_ptr, (lo + Q_offset, 0))
   DQ_block_ptr = tl.advance(DQ_block_ptr, (lo + DQ_offset, 0))
   DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M + K_offset, 0))
   DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M + V_offset, 0))

   # initialize row/col offsets
   offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
   offs_m = tl.arange(0, BLOCK_N)
   # pointer to row-wise quantities in value-like data
   D_ptrs = D + off_hz * N_CTX
   l_ptrs = L + off_hz * N_CTX
   # initialize dv amd dk
   dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
   dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
   # k and v stay in SRAM throughout
   k = tl.load(K_block_ptr)
   v = tl.load(V_block_ptr)
   # loop over rows
   for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
       offs_m_curr = start_m + offs_m
       # load q, k, v, do on-chip
       q = tl.load(Q_block_ptr)
       # recompute p = softmax(qk, dim=-1).T
       # NOTE: `do` is pre-divided by `l`; no normalization here
       if CAUSAL:
           qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.0), float("-inf"))
       else:
           qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
       qk += tl.dot(q, tl.trans(k))
       qk *= qk_scale
       l_i = tl.load(l_ptrs + offs_m_curr)
       p = tl.math.exp2(qk - l_i[:, None])
       # compute dv
       do = tl.load(DO_block_ptr)
       dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
       # compute dp = dot(v, do)
       Di = tl.load(D_ptrs + offs_m_curr)
       # dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
       dp = tl.dot(do, tl.trans(v))
       # compute ds = p * (dp - delta[:, None])
       ds = (p * (dp - Di[:, None]) * sm_scale).to(Q.dtype.element_ty)
       # compute dk = dot(ds.T, q)
       dk += tl.dot(tl.trans(ds), q)
       # compute dq
       if not SEQUENCE_PARALLEL:
           dq = tl.load(DQ_block_ptr)
           dq += tl.dot(ds, k)
           tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty))
       elif SEQUENCE_PARALLEL:
           if MMA_V3:
               dq = tl.dot(ds, k)
           else:
               # not work with mma v3, because M % 64 != 0
               dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds)))
           tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty))

       # increment pointers
       DQ_block_ptr = tl.advance(DQ_block_ptr, (BLOCK_M, 0))
       Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0))
       DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0))
   # write-back
   tl.store(DV_block_ptr, dv.to(V.dtype.element_ty))
   tl.store(DK_block_ptr, dk.to(K.dtype.element_ty))


@jit
def _bwd_kernel(Q, K, V, sm_scale,  #
               Out, DO,  #
               DQ, DK, DV,  #
               L,  #
               D,  #
               stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,  #
               stride_kz, stride_kh, stride_kn, stride_kk,  #
               stride_vz, stride_vh, stride_vn, stride_vk,  #
               Z, H, N_CTX,  #
               Z_H_N_CTX,  #
               SQ_Z_H_N_CTX,  #
               BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,  #
               BLOCK_N: tl.constexpr,  #
               SEQUENCE_PARALLEL: tl.constexpr,  #
               CAUSAL: tl.constexpr,  #
               MMA_V3: tl.constexpr  #
               ):
   qk_scale = sm_scale * 1.44269504
   off_hz = tl.program_id(0)
   off_z = off_hz // H
   off_h = off_hz % H

   Q_block_ptr = tl.make_block_ptr(
       base=Q,
       shape=(Z_H_N_CTX, BLOCK_DMODEL),
       strides=(stride_qm, stride_qk),
       offsets=(0, 0),
       block_shape=(BLOCK_M, BLOCK_DMODEL),
       order=(1, 0),
   )
   K_block_ptr = tl.make_block_ptr(
       base=K,
       shape=(Z_H_N_CTX, BLOCK_DMODEL),
       strides=(stride_kn, stride_kk),
       offsets=(0, 0),
       block_shape=(BLOCK_M, BLOCK_DMODEL),
       order=(1, 0),
   )
   V_block_ptr = tl.make_block_ptr(
       base=V,
       shape=(Z_H_N_CTX, BLOCK_DMODEL),
       strides=(stride_vn, stride_vk),
       offsets=(0, 0),
       block_shape=(BLOCK_M, BLOCK_DMODEL),
       order=(1, 0),
   )
   DO_block_ptr = tl.make_block_ptr(
       base=DO,
       shape=(Z_H_N_CTX, BLOCK_DMODEL),
       strides=(stride_qm, stride_qk),
       offsets=(0, 0),
       block_shape=(BLOCK_M, BLOCK_DMODEL),
       order=(1, 0),
   )
   if SEQUENCE_PARALLEL:
       DQ_block_ptr = tl.make_block_ptr(
           base=DQ,
           shape=(SQ_Z_H_N_CTX, BLOCK_DMODEL),
           strides=(stride_qm, stride_qk),
           offsets=(0, 0),
           block_shape=(BLOCK_M, BLOCK_DMODEL),
           order=(1, 0),
       )
   else:
       DQ_block_ptr = tl.make_block_ptr(
           base=DQ,
           shape=(Z_H_N_CTX, BLOCK_DMODEL),
           strides=(stride_qm, stride_qk),
           offsets=(0, 0),
           block_shape=(BLOCK_M, BLOCK_DMODEL),
           order=(1, 0),
       )

   DK_block_ptr = tl.make_block_ptr(
       base=DK,
       shape=(Z_H_N_CTX, BLOCK_DMODEL),
       strides=(stride_kn, stride_kk),
       offsets=(0, 0),
       block_shape=(BLOCK_M, BLOCK_DMODEL),
       order=(1, 0),
   )
   DV_block_ptr = tl.make_block_ptr(
       base=DV,
       shape=(Z_H_N_CTX, BLOCK_DMODEL),
       strides=(stride_vn, stride_vk),
       offsets=(0, 0),
       block_shape=(BLOCK_M, BLOCK_DMODEL),
       order=(1, 0),
   )

   num_block_n = tl.cdiv(N_CTX, BLOCK_N)
   if not SEQUENCE_PARALLEL:
       for start_n in range(0, num_block_n):
           _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO,  #
                                     DQ, DK, DV,  #
                                     L,  #
                                     D,  #
                                     Q_block_ptr, K_block_ptr, V_block_ptr,  #
                                     DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr,  #
                                     stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,  #
                                     stride_kz, stride_kh, stride_kn, stride_kk,  #
                                     stride_vz, stride_vh, stride_vn, stride_vk,  #
                                     Z, H, N_CTX,  #
                                     off_h, off_z, off_hz, start_n, num_block_n,  #
                                     BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,  #
                                     BLOCK_N=BLOCK_N,  #
                                     SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,  #
                                     CAUSAL=CAUSAL,  #
                                     MMA_V3=MMA_V3  #
                                     )
   else:
       start_n = tl.program_id(1)
       _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO,  #
                                 DQ, DK, DV,  #
                                 L,  #
                                 D,  #
                                 Q_block_ptr, K_block_ptr, V_block_ptr,  #
                                 DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr,  #
                                 stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,  #
                                 stride_kz, stride_kh, stride_kn, stride_kk,  #
                                 stride_vz, stride_vh, stride_vn, stride_vk,  #
                                 Z, H, N_CTX,  #
                                 off_h, off_z, off_hz, start_n, num_block_n,  #
                                 BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,  #
                                 BLOCK_N=BLOCK_N,  #
                                 SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,  #
                                 CAUSAL=CAUSAL,  #
                                 MMA_V3=MMA_V3  #
                                 )


class _attention(torch.autograd.Function):

   @staticmethod
   def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False):
       # only support for Ampere now
       capability = torch.cuda.get_device_capability()
       if capability[0] < 8:
           raise RuntimeError("Flash attention currently only supported for compute capability >= 80")
       BLOCK_M = 128
       BLOCK_N = 64
       # shape constraints
       Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
       assert Lq == Lk and Lk == Lv
       assert Lk in {16, 32, 64, 128}
       o = torch.empty_like(q)
       grid = (cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
       L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
       num_warps = 4 if Lk <= 64 else 8
       _fwd_kernel[grid](
           q, k, v, sm_scale,  #
           L,  #
           o,  #
           q.stride(0), q.stride(1), q.stride(2), q.stride(3),  #
           k.stride(0), k.stride(1), k.stride(2), k.stride(3),  #
           v.stride(0), v.stride(1), v.stride(2), v.stride(3),  #
           o.stride(0), o.stride(1), o.stride(2), o.stride(3),  #
           q.shape[0], q.shape[1], q.shape[2],  #
           q.shape[0] * q.shape[1] * q.shape[2],  #
           BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,  #
           IS_CAUSAL=causal,  #
           num_warps=num_warps,  #
           num_stages=4  #
       )

       ctx.save_for_backward(q, k, v, o, L)
       ctx.grid = grid
       ctx.sm_scale = sm_scale
       ctx.BLOCK_DMODEL = Lk
       ctx.causal = causal
       ctx.sequence_parallel = sequence_parallel
       return o

   @staticmethod
   def backward(ctx, do):
       capability = torch.cuda.get_device_capability()
       MMA_V3 = capability[0] >= 9
       BLOCK = 128

       if is_hip():
           # Bwd pass runs out of shared memory on HIP with larger block size.
           BLOCK = 64

       q, k, v, o, L = ctx.saved_tensors
       sequence_parallel = ctx.sequence_parallel
       seq_len_kv = k.shape[2]
       do = do.contiguous()
       if sequence_parallel:
           replicas = cdiv(seq_len_kv, BLOCK)
           new_dq_shape = (replicas, ) + q.shape
           dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype)
       else:
           dq = torch.zeros_like(q, dtype=q.dtype)
       dk = torch.empty_like(k)
       dv = torch.empty_like(v)
       delta = torch.empty_like(L)
       _bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1], )](
           o,
           do,
           delta,
           BLOCK_M=BLOCK,
           D_HEAD=ctx.BLOCK_DMODEL,
       )
       _bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)](
           q, k, v, ctx.sm_scale,  #
           o, do,  #
           dq, dk, dv,  #
           L,  #
           delta,  #
           o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3),  #
           k.stride(0), k.stride(1), k.stride(2), k.stride(3),  #
           v.stride(0), v.stride(1), v.stride(2), v.stride(3),  #
           q.shape[0], q.shape[1], q.shape[2],  #
           q.shape[0] * q.shape[1] * q.shape[2],  #
           cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2],  #
           BLOCK_M=BLOCK, BLOCK_N=BLOCK,  #
           BLOCK_DMODEL=ctx.BLOCK_DMODEL,  #
           SEQUENCE_PARALLEL=sequence_parallel,  #
           CAUSAL=ctx.causal,  #
           MMA_V3=MMA_V3,  #
           num_warps=8,  #
           num_stages=1  #
       )

       if len(dq.shape) == 5:
           dq = dq.sum(dim=0)
       return dq, dk, dv, None, None, None


attention = _attention.apply





def pytorch_att(q,k,v,drop=0.0,scale=1.3,casual=True):

   return scaled_dot_product_attention(q, k, v,dropout_p=drop,is_causal=casual,scale=scale)

def triton_att(q,k,v,drop=0.0,scale=1.3,casual=True):

   return attention(q, k, v, casual, scale)



for BATCH in [1,2,4]:
   for N_HEADS in [16, 32, 40]:
       for N_CTX in [512, 1024, 2048]:
           for D_HEAD in [128]:
               for casual in [True]:

                   def get_flops(ms,causal=True):

                       flops_per_matmul = 2. * BATCH * N_HEADS * N_CTX * N_CTX * D_HEAD
                       
                       total_flops = 2 * flops_per_matmul
                       if causal:
                           total_flops *= 0.5
                       return total_flops / ms * 1e-9
                   print(f"Test-self-attention BATCH:{BATCH} N_HEADS:{N_HEADS} N_HEADS:{N_CTX} N_HEADS:{D_HEAD} Casual:{casual}")
                   query = torch.rand(BATCH, N_HEADS, N_CTX, D_HEAD, device="cuda:0", dtype=torch.float16)
                   key = torch.rand(BATCH, N_HEADS, N_CTX, D_HEAD, device="cuda:0", dtype=torch.float16)
                   value = torch.rand(BATCH, N_HEADS, N_CTX, D_HEAD, device="cuda:0", dtype=torch.float16)
                   sm_scale = 1.3

                   for i in range(10):
                       y = pytorch_att(query,key,value,casual)
                   torch.cuda.synchronize()
                   delta=0
                   for i in range(10):
                       torch.cuda.synchronize()
                       t = time.time()
                       y = pytorch_att(query,key,value,casual)
                       torch.cuda.synchronize()
                       delta += time.time() - t
                   ms = delta*100
                   print(f"SDPA   thought is in {get_flops(ms, casual):.3f} TFlops, {ms:.3f} ms/fwd")
                   
                   for i in range(10):
                       y = triton_att(query,key,value,casual)
                   torch.cuda.synchronize()
                   delta=0
                   for i in range(10):
                       torch.cuda.synchronize()
                       t = time.time()
                       y = triton_att(query,key,value,casual)
                       torch.cuda.synchronize()
                       delta += time.time() - t
                   ms = delta*100
                   print(f"Triton thought is in {get_flops(ms, casual):.3f} TFlops, {ms:.3f} ms/fwd")

rlst is

Test-self-attention BATCH:1 N_HEADS:16 N_HEADS:512 N_HEADS:128 Casual:True
/home/work/nsp/nsp/triton_nn/attn.py:467: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:505.)
  return scaled_dot_product_attention(q, k, v,dropout_p=drop,is_causal=casual,scale=scale)
SDPA   thought is in 2.588 TFlops, 0.415 ms/fwd
Triton thought is in 3.219 TFlops, 0.334 ms/fwd
Test-self-attention BATCH:1 N_HEADS:16 N_HEADS:1024 N_HEADS:128 Casual:True
SDPA   thought is in 7.701 TFlops, 0.558 ms/fwd
Triton thought is in 5.790 TFlops, 0.742 ms/fwd
Test-self-attention BATCH:1 N_HEADS:16 N_HEADS:2048 N_HEADS:128 Casual:True
SDPA   thought is in 9.412 TFlops, 1.825 ms/fwd
Triton thought is in 8.539 TFlops, 2.012 ms/fwd
Test-self-attention BATCH:1 N_HEADS:32 N_HEADS:512 N_HEADS:128 Casual:True
SDPA   thought is in 6.659 TFlops, 0.322 ms/fwd
Triton thought is in 4.693 TFlops, 0.458 ms/fwd
Test-self-attention BATCH:1 N_HEADS:32 N_HEADS:1024 N_HEADS:128 Casual:True
SDPA   thought is in 8.998 TFlops, 0.955 ms/fwd
Triton thought is in 7.524 TFlops, 1.142 ms/fwd
Test-self-attention BATCH:1 N_HEADS:32 N_HEADS:2048 N_HEADS:128 Casual:True
SDPA   thought is in 10.267 TFlops, 3.347 ms/fwd
Triton thought is in 9.435 TFlops, 3.642 ms/fwd
Test-self-attention BATCH:1 N_HEADS:40 N_HEADS:512 N_HEADS:128 Casual:True
SDPA   thought is in 7.206 TFlops, 0.373 ms/fwd
Triton thought is in 4.938 TFlops, 0.544 ms/fwd
Test-self-attention BATCH:1 N_HEADS:40 N_HEADS:1024 N_HEADS:128 Casual:True
SDPA   thought is in 8.951 TFlops, 1.200 ms/fwd
Triton thought is in 7.983 TFlops, 1.345 ms/fwd
Test-self-attention BATCH:1 N_HEADS:40 N_HEADS:2048 N_HEADS:128 Casual:True
SDPA   thought is in 10.358 TFlops, 4.147 ms/fwd
Triton thought is in 9.599 TFlops, 4.474 ms/fwd
Test-self-attention BATCH:2 N_HEADS:16 N_HEADS:512 N_HEADS:128 Casual:True
SDPA   thought is in 6.307 TFlops, 0.341 ms/fwd
Triton thought is in 4.702 TFlops, 0.457 ms/fwd
Test-self-attention BATCH:2 N_HEADS:16 N_HEADS:1024 N_HEADS:128 Casual:True
SDPA   thought is in 8.618 TFlops, 0.997 ms/fwd
Triton thought is in 7.582 TFlops, 1.133 ms/fwd
Test-self-attention BATCH:2 N_HEADS:16 N_HEADS:2048 N_HEADS:128 Casual:True
SDPA   thought is in 10.156 TFlops, 3.383 ms/fwd
Triton thought is in 9.809 TFlops, 3.503 ms/fwd
Test-self-attention BATCH:2 N_HEADS:32 N_HEADS:512 N_HEADS:128 Casual:True
SDPA   thought is in 8.587 TFlops, 0.500 ms/fwd
Triton thought is in 6.019 TFlops, 0.714 ms/fwd
Test-self-attention BATCH:2 N_HEADS:32 N_HEADS:1024 N_HEADS:128 Casual:True
SDPA   thought is in 8.574 TFlops, 2.004 ms/fwd
Triton thought is in 8.392 TFlops, 2.047 ms/fwd
Test-self-attention BATCH:2 N_HEADS:32 N_HEADS:2048 N_HEADS:128 Casual:True
SDPA   thought is in 10.282 TFlops, 6.683 ms/fwd
Triton thought is in 9.928 TFlops, 6.922 ms/fwd
Test-self-attention BATCH:2 N_HEADS:40 N_HEADS:512 N_HEADS:128 Casual:True
SDPA   thought is in 8.708 TFlops, 0.617 ms/fwd
Triton thought is in 6.201 TFlops, 0.866 ms/fwd
Test-self-attention BATCH:2 N_HEADS:40 N_HEADS:1024 N_HEADS:128 Casual:True
SDPA   thought is in 8.790 TFlops, 2.443 ms/fwd
Triton thought is in 9.367 TFlops, 2.293 ms/fwd
Test-self-attention BATCH:2 N_HEADS:40 N_HEADS:2048 N_HEADS:128 Casual:True
SDPA   thought is in 10.356 TFlops, 8.295 ms/fwd
Triton thought is in 10.096 TFlops, 8.508 ms/fwd
Test-self-attention BATCH:4 N_HEADS:16 N_HEADS:512 N_HEADS:128 Casual:True
SDPA   thought is in 7.948 TFlops, 0.540 ms/fwd
Triton thought is in 5.756 TFlops, 0.746 ms/fwd
Test-self-attention BATCH:4 N_HEADS:16 N_HEADS:1024 N_HEADS:128 Casual:True
SDPA   thought is in 8.597 TFlops, 1.998 ms/fwd
Triton thought is in 8.349 TFlops, 2.058 ms/fwd
Test-self-attention BATCH:4 N_HEADS:16 N_HEADS:2048 N_HEADS:128 Casual:True
SDPA   thought is in 10.291 TFlops, 6.678 ms/fwd
Triton thought is in 9.929 TFlops, 6.921 ms/fwd
Test-self-attention BATCH:4 N_HEADS:32 N_HEADS:512 N_HEADS:128 Casual:True
SDPA   thought is in 8.320 TFlops, 1.032 ms/fwd
Triton thought is in 7.124 TFlops, 1.206 ms/fwd
Test-self-attention BATCH:4 N_HEADS:32 N_HEADS:1024 N_HEADS:128 Casual:True
SDPA   thought is in 8.700 TFlops, 3.949 ms/fwd
Triton thought is in 9.127 TFlops, 3.765 ms/fwd
Test-self-attention BATCH:4 N_HEADS:32 N_HEADS:2048 N_HEADS:128 Casual:True
SDPA   thought is in 10.468 TFlops, 13.129 ms/fwd
Triton thought is in 10.487 TFlops, 13.105 ms/fwd
Test-self-attention BATCH:4 N_HEADS:40 N_HEADS:512 N_HEADS:128 Casual:True
SDPA   thought is in 8.271 TFlops, 1.298 ms/fwd
Triton thought is in 7.448 TFlops, 1.442 ms/fwd
Test-self-attention BATCH:4 N_HEADS:40 N_HEADS:1024 N_HEADS:128 Casual:True
SDPA   thought is in 8.768 TFlops, 4.898 ms/fwd
Triton thought is in 9.324 TFlops, 4.606 ms/fwd
Test-self-attention BATCH:4 N_HEADS:40 N_HEADS:2048 N_HEADS:128 Casual:True
SDPA   thought is in 10.483 TFlops, 16.388 ms/fwd
Triton thought is in 10.606 TFlops, 16.199 ms/fwd

@evshiron
Copy link

I got the following numbers by running 06-fused-attention.py with upstream Triton with a RX 7900 XTX in WSL:

fused-attention-batch1-head8-d64-fwd-causal=True:
     N_CTX  Triton [FP16]  Triton [FP8]  Torch SDPA
0   1024.0       8.391028      5.927599    4.443374
1   2048.0      14.409957     10.447128    6.138015
2   4096.0      18.222575     13.817904    6.236508
3   8192.0      21.010927     16.646617    5.542175
4  16384.0      22.218493     19.198227    5.615378
fused-attention-batch1-head8-d64-fwd-causal=False:
     N_CTX  Triton [FP16]  Triton [FP8]  Torch SDPA
0   1024.0      15.979019     13.108997   12.844687
1   2048.0      20.875325     18.150148   16.195271
2   4096.0      23.844223     13.735497   19.157410
3   8192.0      24.475788     16.138601   19.658030
4  16384.0      25.516809     18.638798   20.117615
fused-attention-batch1-head8-d64-bwd-causal=True:
     N_CTX  Triton [FP16]  Triton [FP8]  Torch SDPA
0   1024.0       8.412188      8.576740    8.660512
1   2048.0      15.533114     15.031891    9.992787
2   4096.0      19.582960     19.624496   10.929170
3   8192.0      22.467926     22.662537   11.697946
4  16384.0      24.919043     24.939034   11.873013

I also managed to use aotriton's Flash Attention implementation with upstream Triton, and have it working in SD:Next. The performance is slightly slower (10it/s to 8it/s), but it uses lesser VRAM (upscale 512x768 images to 2.5x w/o OOM, which was 1.6x, w/ VAE Tiling disabled).

@Repeerc
Copy link

Repeerc commented Jul 9, 2024

I wrote a flash attention 2 implementaton by using rocWMMA library:
https://github.com/Repeerc/flash-attention-v2-RDNA3-minimal
now can work with stable diffusion, and sdxl lora training speed can approach 4090 GPU according to this benchmark:
https://www.pugetsystems.com/labs/articles/stable-diffusion-lora-training-consumer-gpu-analysis/#Performance_%E2%80%93_1024%C3%971024

@sancspro
Copy link

sancspro commented Jul 9, 2024

I wrote a flash attention 2 implementaton by using rocWMMA library: https://github.com/Repeerc/flash-attention-v2-RDNA3-minimal now can work with stable diffusion, and sdxl lora training speed can approach 4090 GPU according to this benchmark: https://www.pugetsystems.com/labs/articles/stable-diffusion-lora-training-consumer-gpu-analysis/#Performance_%E2%80%93_1024%C3%971024

Amazing work!!
Will it work with sd auto1111 in Linux?

@Repeerc
Copy link

Repeerc commented Jul 9, 2024

I wrote a flash attention 2 implementaton by using rocWMMA library: https://github.com/Repeerc/flash-attention-v2-RDNA3-minimal now can work with stable diffusion, and sdxl lora training speed can approach 4090 GPU according to this benchmark: https://www.pugetsystems.com/labs/articles/stable-diffusion-lora-training-consumer-gpu-analysis/#Performance_%E2%80%93_1024%C3%971024

Amazing work!! Will it work with sd auto1111 in Linux?

maybe, I packaged it as a extension and tested work in WSL:
https://github.com/Repeerc/sd-webui-flash-attention2-rdna3-rocm

@sancspro
Copy link

sancspro commented Jul 9, 2024

I wrote a flash attention 2 implementaton by using rocWMMA library: https://github.com/Repeerc/flash-attention-v2-RDNA3-minimal now can work with stable diffusion, and sdxl lora training speed can approach 4090 GPU according to this benchmark: https://www.pugetsystems.com/labs/articles/stable-diffusion-lora-training-consumer-gpu-analysis/#Performance_%E2%80%93_1024%C3%971024

Amazing work!! Will it work with sd auto1111 in Linux?

maybe, I packaged it as a extension and tested work in WSL: https://github.com/Repeerc/sd-webui-flash-attention2-rdna3-rocm

Thanks a ton! This works well. Had to install Ninja.
Minor improvement with iterations per sec. BIG reduction in VRAM usage with my 7800xt.

@farshadghodsian
Copy link

farshadghodsian commented Jul 10, 2024

I wrote a flash attention 2 implementaton by using rocWMMA library: https://github.com/Repeerc/flash-attention-v2-RDNA3-minimal now can work with stable diffusion, and sdxl lora training speed can approach 4090 GPU according to this benchmark: https://www.pugetsystems.com/labs/articles/stable-diffusion-lora-training-consumer-gpu-analysis/#Performance_%E2%80%93_1024%C3%971024

Does this bring us anywhere closer to getting Flash Attention forward and backward pass working on Radeon 7000 GPUs (gfx1100) using PyTorch or is this implementation only for Stable Diffusion?

@minzhezhou
Copy link

I wrote a flash attention 2 implementaton by using rocWMMA library: https://github.com/Repeerc/flash-attention-v2-RDNA3-minimal now can work with stable diffusion, and sdxl lora training speed can approach 4090 GPU according to this benchmark: https://www.pugetsystems.com/labs/articles/stable-diffusion-lora-training-consumer-gpu-analysis/#Performance_%E2%80%93_1024%C3%971024

Does this bring us anywhere closer to getting Flash Attention forward and backward pass working on Radeon 7000 GPUs (gfx1100) using PyTorch or is this implementation only for Stable Diffusion?

If you mean using it in transformer lib, I believe it's still quite far away from the offical flash attn v2, we have to implement all the flash_attn.flash_attn_interface and other utils like flash_attn.bert_padding.

@gel-crabs
Copy link

For anyone wondering, there's also a CK-based version for Navi3x (ROCm/flash-attention, howiejay/navi_support branch) described here:

ROCm/flash-attention#27 (comment)

It's fast, but it's also FA version 2.0.6 and only the forward pass works.

@Beinsezii
Copy link

The howiejay CK is a lot faster in forward than the Repeerc wmma going by his own numbers.

By monkey patching the torch SDPA function I can hit something like 3.8 it/s on low-power XTX for 1024² SDXL

The Repeerc version at a self-reported 3.5 is closer to what I get from the DaoLab Triton JIT version after adding a custom autotune config.

...All this and the base 3090 is still a good deal faster from what I can tell.

@feffy380
Copy link

The backward pass is what interests me about Repeerc's version, since training uses more memory and torch sdpa does not have a memory-efficient fallback. Unfortunately it still seems to have stability issues.

xinyazhang added a commit that referenced this issue Aug 14, 2024
## What's Changed
1. A whole new tuning system (referred as `cpptune`/`cpp_tune`/`cpptuning`) based on pre-compiling all GPU kernels with CMake option `AOTRITON_BUILD_FOR_TUNING` and kernel selection parameters provided by all AOTriton API
2. GPU kernel compiling can timeout (default limit is 8 minutes), to avoid excessive long Navi31 kernel build
3. Migrating the backward kernel away from block pointers
4. Improved backward kernel performance by using better tuning database generated from cpptune.
5. Add Navi31 to tuning database
6. Enable Navi31 by default
7. Default to AOTRITON_COMPRESS_KERNEL=ON and consequently requires zstd as runtime dependency
8. Use `pkg-config` to search zstd since `find_package(zstd)` is not supported officially.

## Known problems
1. No official Navi32 support. Users may want to duplicate Navi31 tuning database entries to accomplish Navi 32 support in AOTriton.

This fixes #16
@sancspro
Copy link

Do we have any of the FA implementations working for training (such as lora training via Kohya) ?

@feffy380
Copy link

@sancspro Wait for pytorch to update to aotriton v0.7b. The PR is pytorch/pytorch#134498

@sancspro
Copy link

@sancspro Wait for pytorch to update to aotriton v0.7b. The PR is pytorch/pytorch#134498

Ok, thanks for the info @feffy380

So after this PR is pushed, I just need to update Python and then enable SDPA in cross-attention setting in Kohya to make use of FA?

@sancspro
Copy link

Pytorch nightly out - torch: 2.5.0.dev20240912+rocm6.2

SDP - flash attention works out of the box. No tweaking, special configuration etc. Just enable SDP and use env var:
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1

7800XT unleashed! :)

@sdfasfsdfasfasafd
Copy link

Pytorch nightly out - torch: 2.5.0.dev20240912+rocm6.2

SDP - flash attention works out of the box. No tweaking, special configuration etc. Just enable SDP and use env var: TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1

7800XT unleashed! :)

which branch of flash_attention do you install after the nightly update?

@sancspro
Copy link

Hi. I didn't use any external FA. For example, for auto1111 I enabled SDP and then I saw VRAM reduction immediately. I think they've integrated AOTriton into Pytorch which works out of the box for some GPUs. Now, they added RDNA3 to the list of supported GPUs. Note that I also set an env variable which I mentioned above.

@Kademo15
Copy link

Hi. I didn't use any external FA. For example, for auto1111 I enabled SDP and then I saw VRAM reduction immediately. I think they've integrated AOTriton into Pytorch which works out of the box for some GPUs. Now, they added RDNA3 to the list of supported GPUs. Note that I also set an env variable which I mentioned above.

Hello @sancspro could I ask you for some numbers you have with and without FA. I tested it with comfyUI and saw a super small vram reduction in sd1.5[512x512] (about 2gb) and in sdxl [1024x1024] I saw no memory savings at all, only a speed increase.

@blood400cc
Copy link

blood400cc commented Oct 25, 2024

Pytorch nightly out - torch: 2.5.0.dev20240912+rocm6.2

SDP - flash attention works out of the box. No tweaking, special configuration etc. Just enable SDP and use env var: TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1

7800XT unleashed! :)

Yes this experimental function works, I tested in a pure matrix test.py with setting os env variable in python as your suggestion, worked well on my 7900!

Just how can you enable SDP in automatic111? I don't know where to set env var in automatic yet, I tried to simply enable SDP in automatic 111 settings, and it always came back with a failure notice as:
"OutOfMemoryError: HIP out of memory. Tried to allocate 9.00 GiB. GPU 0 has a total capacity of 23.98 GiB of which 924.00 MiB is free. Of the allocated memory 22.51 GiB is allocated by PyTorch, and 36.52 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_HIP_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)".

My environment, ubuntu2404+Rocm6.2+Torch2.5+Py3.10+Automatic111v1.10.1.

@blood400cc
Copy link

ok I managed to set env var in the launch.py directly, now SDP works. but the effect here is NOT faster, but saving more vram. it's about 50% slower than doggettx, but use only 1/3 vram of doggettx.

Not sure that's the way it should, or anywhere I haven't set right...

@xinyazhang
Copy link
Collaborator

xinyazhang commented Oct 25, 2024

ok I managed to set env var in the launch.py directly, now SDP works. but the effect here is NOT faster, but saving more vram. it's about 50% slower than doggettx, but use only 1/3 vram of doggettx.

This is a known problem, which actually motivates the following decisions

  1. Navi31 is marked as experimental in PyTorch 2.5's ROCM SDPA FA/ME backend
  2. Navi32 is not officially supported and there is no tuning database for it, when Navi31 is still experimental.

@blood400cc
Copy link

got it, so SDP should be faster but not there yet, hopefully in next iteration of ROCm/pyTorch can be solved ane this eperimental feature comes to formal

@Wintoplay
Copy link

Wintoplay commented Nov 29, 2024

@sancspro
I see no improvement on 7900xtx too.
I also get this
UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:296.)

Grad_norm becomes crazy when training, model doesn't converge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.