You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Due to the sparsity of SAE activations, we don't need to fully multiply the SAE hidden activations by the decoder weights during decoding since most of the activations are zero. It may be a significant performance improvement to use torch.nn.functional.embedding_bag during decoding to replace acts @ sae.W_dec. We should benchmark this to see if it is in-fact a performance improvment. It could be that finding all the non-zero activation locations is more expensive than just running the standard decoding, for example. Likely this will be an improvement for topk SAEs at the very least.
Checklist
I have checked that there is no similar issue in the repo (required)
The text was updated successfully, but these errors were encountered:
embedding_bag takes up a ton of memory when the SAE isn't sparse, so doesn't seem suitable for non-topk SAEs. Even with topk SAEs, I get a memory access violation after it runs for a few minutes on a backwards pass, so not sure what's up with that 🤔
Proposal
Due to the sparsity of SAE activations, we don't need to fully multiply the SAE hidden activations by the decoder weights during decoding since most of the activations are zero. It may be a significant performance improvement to use torch.nn.functional.embedding_bag during decoding to replace
acts @ sae.W_dec
. We should benchmark this to see if it is in-fact a performance improvment. It could be that finding all the non-zero activation locations is more expensive than just running the standard decoding, for example. Likely this will be an improvement fortopk
SAEs at the very least.Checklist
The text was updated successfully, but these errors were encountered: