Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory-efficient attention #7

Open
justheuristic opened this issue Mar 15, 2022 · 6 comments
Open

Memory-efficient attention #7

justheuristic opened this issue Mar 15, 2022 · 6 comments

Comments

@justheuristic
Copy link
Member

justheuristic commented Mar 15, 2022

This is a discussion of how to minimize memory usage of attention.

Current state: investigating apex's scaled_masked_softmax to check how it operates

@krunt
Copy link
Contributor

krunt commented Mar 16, 2022

regarding scaled_masked_softmax_cuda
scaled_masked_softmax_cuda from apex/csrc/megatron
behaves the same on forward pass as pytorch softmax,
on backward it is inplace!!! (saving 2 buffers: tmp and return)

It supports seq_len<=2048 (I think easy to extend), float16

regarding next iteration for memory saving
Implemented here
https://github.com/krunt/mytorchcudamodules/blob/master/modules/mine_self_multihead_attn_func.py
loop by batch dimension based on python version of multihead attention from apex.

Need to commit this code & tests to this repo
The logic should be enabled by input argument flag

@justheuristic
Copy link
Member Author

Summary based on @krunt 's recent talk about FMHA design:

  • there are two ways to implement attention:
    -- the naive way:
    - 1. compute query-key dot-products by the tile, store dot-products in global memory,
    - 2. then compute attention weights: softmax of dot-products, store results in global memory
    - 3. then compute the weighted sum of values with attention weights store sums in shared memory
    -- the shmemory way:
    - load a subset of queries and all keys/values from global to shared memory
    - compute dot-products, maintain them in shared memory without offloading to global
    - compute attention weights via softmax in-place in shared memory
    - then compute weighted sum of values with attention weights and only then store results in global memory

The shmemory way is significantly faster (~10x on fmha benchmark #8), but requires that all keys/values fit into shared memory. As a result, both FMHA and FasterTransformer are limited by head dimension 64 and sequence length 512.

In turn, the naive way supports arbitrary head size and sequence length, but is significantly slower because it needs to store/load intermediate values in global memory.

@justheuristic
Copy link
Member Author

justheuristic commented Mar 22, 2022

Based on these two solutions, we can produce a middle-of-the-road implementation that the flexibility of naive strategy with most_of_the performance from shmemory-based strategy

Stage 1: compute log-sum-exps

for each query, compute a scalar log-sum-exp of dot products, i.e.
result[i] = log(sum_over_j(<query_i, key_j>))

Log-sum-exps can be partially computed in chunks of tile_size tokens.
Second, third, etc. tiles do the following:

# forall tile i = 0...num_queries/tile_size, j=0...num_keys/tile_size
logaddexp_accumulators_i = load_logsumexp_outputs_from_previous_part()  # initially 1d[tile_size] of -inf
new_log_add_exps_ij = compute_dotproduct_logsumexp(query_tiles[i], key_tiles[j])
logaddexp_accumulators_i [:]= safe_logaddexp_pair(logaddexp_accumulators_i, new_log_add_exps_ij)

Wherein compute_dotproduct_logsumexp stands for computing dot-product of queries to keys, followed by a reduce_logsumexp over all keys, parallel for each query.
, safe_logaddexp_pair is an element-wise log-sum of two exponents, equivalent to torch logaddexp

i/o: load queries and keys, 2x [tile_size x head size], store logsumexps: small [tile_size] vectors
flops: ~half of fusedMHA's forward pass, since we have no need

Stage 2: forward (given logsumexp)

Once we know log-sum-exps, we no longer need to load the entire set of queries into shared memory.

Instead, we can load one chunk at a time, compute partial attention outputs from that chunk, add them to the accumulator, then load the next chunk, etc.

# forall tile i = 0...num_queries/tile_size, j=0...num_keys/tile_size
query_tiles[i], key_tiles[j], value_tiles[j] = load_into_shmemory()
attention_accumulators_i = load_partial_results_from_previous_part()  # initially 2d[num_queries, head_dim] of zeros
logsumexp_accumulator_i = load_from_stage_1_for_queries_i()

dot_product_ij = dot_product(query_tiles[i], key_tiles[j])
softmax_tile_ij = exp(dot_product_ij - logsumexp_accumulator_i)
attention_output_tile_ij = dot_product(softmax_tile_ij, value_tiles[j])
attention_accumulators_i [:]= attention_accumulators_i + attention_output_tile_ij

i/o same as shmemory-based MHA, but with one extra tensor loaded
flops: a bit less than shmemory-based MHA since softmax denominator is pre-computed

Stage 3: backward

Use the same backward logic as in shmemory, but this time you reuse log-sum-exps saved from the forward pass and accumulate gradients by tiles.

Notes:

  • compute log-add-exps during forward pass and reuse for backward pass
  • if it works well enough, maybe contribute this to apex in order to avoid compilation here?

@krunt
Copy link
Contributor

krunt commented Apr 12, 2022

image

fwd fmha for longer sequences is implemented on this fork https://github.com/krunt/apex

k,v are in smem always (no offload (!!!) to gmem during iteration by Q)

  1. fwd 2x-2.5x faster than lean (and memory efficient too!).
  2. not optimal support (fmha does not support it) of head_dim > 64 (hope it is correct - the results say so)

TODO:

  1. support for initialization of cacc_max, cacc_sum, vacc
  2. test gmem offload slowdown of cacc_max, cacc_sum
  3. support different seqlen (via mask (currently not fixed in fmha - easy to do))
  4. bwd
  5. fwd accumulators to float (??? is needed)

@krunt
Copy link
Contributor

krunt commented Apr 18, 2022

bwd is ported:

image

@krunt
Copy link
Contributor

krunt commented Apr 18, 2022

fwd+bwd results:

image

image

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants