Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Proposal] Optimize decoding in training using torch.nn.functional.embedding_bag #428

Open
1 task done
chanind opened this issue Feb 16, 2025 · 1 comment
Open
1 task done

Comments

@chanind
Copy link
Collaborator

chanind commented Feb 16, 2025

Proposal

Due to the sparsity of SAE activations, we don't need to fully multiply the SAE hidden activations by the decoder weights during decoding since most of the activations are zero. It may be a significant performance improvement to use torch.nn.functional.embedding_bag during decoding to replace acts @ sae.W_dec. We should benchmark this to see if it is in-fact a performance improvment. It could be that finding all the non-zero activation locations is more expensive than just running the standard decoding, for example. Likely this will be an improvement for topk SAEs at the very least.

Checklist

  • I have checked that there is no similar issue in the repo (required)
@chanind
Copy link
Collaborator Author

chanind commented Feb 21, 2025

I tried getting this working here: https://colab.research.google.com/drive/1uOc0ggPBV9VIRlw8pEDvOl9hnNauxqit

embedding_bag takes up a ton of memory when the SAE isn't sparse, so doesn't seem suitable for non-topk SAEs. Even with topk SAEs, I get a memory access violation after it runs for a few minutes on a backwards pass, so not sure what's up with that 🤔

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant