We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I tried to imitate your educational coding style hehe
Here's a pure Pytorch implementation of Flash Attention, hope you like it @karpathy
def flash_attention(Q, K, V, is_causal=True, BLOCK_SIZE:int=64): NEG_INFINITY = -1e10 EPS = 1e-10 B, nh, T, H = Q.shape scale = H ** -0.5 assert Q.shape == K.shape and Q.shape == V.shape, "Some of Q,K,V are misshapen!" # TODO: Allow small sequences assert T >= BLOCK_SIZE, "For small sequences, use standard attention!" # initialize buffers outputs = torch.zeros_like(Q) maximums = torch.full((B, nh, T, 1), fill_value=NEG_INFINITY) denominators = torch.full((B, nh, T, 1), fill_value=EPS) # chop up matrices Q_blocks, K_blocks, V_blocks = map( lambda x: torch.split(x, BLOCK_SIZE, dim=2), (Q, K, V) ) O_blocks, M_blocks, D_blocks = map( lambda x: list(torch.split(x, BLOCK_SIZE, dim=2)), (outputs, maximums, denominators) ) # helper variables for causal mask positions = torch.arange(0, T) K_index_blocks = torch.split(positions[None, :], BLOCK_SIZE, dim=1) Q_index_blocks = torch.split(positions[:, None], BLOCK_SIZE, dim=0) for k_index in range(len(K_blocks)): k_block = K_blocks[k_index] v_block = V_blocks[k_index] for q_index in range(len(Q_blocks)): # create causal mask causal_mask = K_index_blocks[k_index] <= Q_index_blocks[q_index] # calculate masked attention scores q_block = Q_blocks[q_index] attn = q_block @ k_block.permute(0, 1, 3, 2) * scale attn = torch.where(causal_mask, attn, NEG_INFINITY) # calculate new maximum attention score per query vector old_maximum = M_blocks[q_index] local_maximum, _ = torch.max(attn, dim=-1, keepdim=True) new_maximum = torch.maximum(old_maximum, local_maximum) # Now that maximum is known, we can safely exponentiate attn scores attn = torch.exp(attn-new_maximum) # Adjust and update the softmax denominator. denominator_scaler = torch.exp(old_maximum-new_maximum) denominator_update = torch.sum(attn, dim=-1, keepdim=True) old_denominator = D_blocks[q_index]*denominator_scaler new_denominator = old_denominator + denominator_update # Adjust and update the output of attention. output_scaler = old_denominator / new_denominator output_update = attn @ v_block / new_denominator old_output = O_blocks[q_index]*output_scaler new_output = old_output + output_update # Store new maximums, new denominators and new attention output. M_blocks[q_index] = new_maximum D_blocks[q_index] = new_denominator O_blocks[q_index] = new_output # Patch together attention output into a single (B, nh, T, H) vector. return torch.cat(O_blocks, dim=2)
The text was updated successfully, but these errors were encountered:
Inspired by Shreyansh's implementation.
Sorry, something went wrong.
No branches or pull requests
I tried to imitate your educational coding style hehe
Here's a pure Pytorch implementation of Flash Attention, hope you like it @karpathy
The text was updated successfully, but these errors were encountered: