-
Notifications
You must be signed in to change notification settings - Fork 148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: base k_aux on d_in instead of d_sae in topk aux loss #432
base: main
Are you sure you want to change the base?
Conversation
residual = (sae_in - sae_out).detach() | ||
|
||
# Heuristic from Appendix B.1 in the paper | ||
k_aux = hidden_pre.shape[-1] // 2 | ||
k_aux = sae_in.shape[-1] // 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're changing calculate_topk_aux_loss()
, want to also refactor it to use an early return?
It'd be
def calculate_topk_aux_loss(
self,
sae_in: torch.Tensor,
sae_out: torch.Tensor,
hidden_pre: torch.Tensor,
dead_neuron_mask: torch.Tensor | None,
) -> torch.Tensor:
# Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
# NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
if dead_neuron_mask is None or int(dead_neuron_mask.sum()) == 0:
return sae_out.new_tensor(0.0)
num_dead = int(dead_neuron_mask.sum())
residual = (sae_in - sae_out).detach()
# Heuristic from Appendix B.1 in the paper
k_aux = sae_in.shape[-1] // 2
# Reduce the scale of the loss if there are a small number of dead latents
scale = min(num_dead / k_aux, 1.0)
k_aux = min(k_aux, num_dead)
auxk_acts = _calculate_topk_aux_acts(
k_aux=k_aux,
hidden_pre=hidden_pre,
dead_neuron_mask=dead_neuron_mask,
)
# Encourage the top ~50% of dead latents to predict the residual of the
# top k living latents
recons = self.decode(auxk_acts)
auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
return scale * auxk_loss
, which has a lot less indentation/is a lot easier to read.
d_in=128, | ||
d_sae=192, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Want to extract all the 128
s and 192
s to variables d_in
and d_sae
, respectively?
normalize_sae_decoder=False, | ||
) | ||
|
||
sae = TrainingSAE(TrainingSAEConfig.from_sae_runner_config(cfg)) | ||
comparison_sae = SparseCoder(d_in=128, cfg=SparseCoderConfig(num_latents=192, k=26)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think sparse_coder
might be a better name than comparison_sae
as someone reading the test could quickly tell it's an instance of SparseCoder
.
Description
This PR fixes a minor issue with our topk aux loss implementation, where we calculate
k_aux
usingd_sae
instead of the correctd_in
. This likely doesn't make a huge difference in practice, but can't hurt to fix. This PR also callsdetach()
on the residual error before calculating topk aux loss, similar to dictionary_learning's implementation. This should help ensure that the aux loss only pulls dead latents towards the SAE error, and doesn't accidentally pull live latents towards dead latents.This PR also adds a test asserting that our topk aux loss matches the sparsity implementation aside from a normalization factor.
As a side note, we should probably add an optional
aux_coefficient
to the config to further customize this, but this is something that probably makes sense to hold off on until the refactor is done.Type of change
Please delete options that are not relevant.
Checklist:
You have tested formatting, typing and tests
make check-ci
to check format and linting. (you can runmake format
to format code if needed.)