You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Intuitively, it feels more correct to detach since this means the aux loss can only be reduced by pulling dead latents towards the SAE error, not by pulling the SAE error closer to the dead latents.
Should Sparsify also detach the error before calculating aux loss, or is there a reason why it's not ideal to do this?
The text was updated successfully, but these errors were encountered:
I will look at this later today, I'm not sure, but FWIW I haven't used AuxK loss in a very long time so it is sort of a neglected feature. I don't think it's really ever necessary. MultiTopK would be a better way to get rid of dead latents. Also, our new Signum optimizer seems to reduce dead latents significantly over Adam.
Is the signum optimizer whats in sign_sgd.py? Is there a paper describing the multi topk technique? It looks like it adds a loss that's equivalent to increasing k by 4, sort of like a matryoshka SAE?
In dictionary_learning's topk aux loss, the SAE error
e
is detached before the aux loss is calculated, but in sparsify it's not detached.Intuitively, it feels more correct to detach since this means the aux loss can only be reduced by pulling dead latents towards the SAE error, not by pulling the SAE error closer to the dead latents.
Should Sparsify also detach the error before calculating aux loss, or is there a reason why it's not ideal to do this?
The text was updated successfully, but these errors were encountered: