You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi. I am working on writing a Triton kernel for the backward pass of a sub-quadratic attention architecture. Currently, I'm receiving the following error when compiling the kernel:
The operations involved in the kernel are complex, and I have many loads and intermediate variables created during the derivation. I had a few questions on the SRAM usage inside the kernel:
Does the order of tl.load matter, or is Triton smart enough to compile it into the most memory optimal form. IE, can I tl.load all required variables at the beginning and expect the same memory usage as if were tl.load them right before the operation they were involved in?
Is there a way to forcibly evict a variable from shared memory after loading it, if I no longer need to use it?
If I use tl.store and tl.load in the same kernel, will this force triton to write it out to HBM and then reload it from HBM?
If I load x1 = tl.load(ptr) and then later load another variable into it x1 = tl.load(ptr2) will this overwrite the memory in SRAM?
Is there a way to understand memory usage breakdown in a compiled kernel?
Note: I'm using a simple grid of shape [Batch, Heads] (like Flash Attention). I don't think blocks or num stages is relevant.
I'm also happy to share the kernel code, if needed. Hopefully there's some way I can re-arrange operations and evict from SRAM to optimize usage.
The text was updated successfully, but these errors were encountered:
Hi. I am working on writing a Triton kernel for the backward pass of a sub-quadratic attention architecture. Currently, I'm receiving the following error when compiling the kernel:
The operations involved in the kernel are complex, and I have many loads and intermediate variables created during the derivation. I had a few questions on the SRAM usage inside the kernel:
tl.load
matter, or is Triton smart enough to compile it into the most memory optimal form. IE, can Itl.load
all required variables at the beginning and expect the same memory usage as if weretl.load
them right before the operation they were involved in?tl.store
andtl.load
in the same kernel, will this force triton to write it out to HBM and then reload it from HBM?x1 = tl.load(ptr)
and then later load another variable into itx1 = tl.load(ptr2)
will this overwrite the memory in SRAM?Note: I'm using a simple grid of shape [Batch, Heads] (like Flash Attention). I don't think blocks or num stages is relevant.
I'm also happy to share the kernel code, if needed. Hopefully there's some way I can re-arrange operations and evict from SRAM to optimize usage.
The text was updated successfully, but these errors were encountered: