Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: base k_aux on d_in instead of d_sae in topk aux loss #432

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

chanind
Copy link
Collaborator

@chanind chanind commented Feb 22, 2025

Description

This PR fixes a minor issue with our topk aux loss implementation, where we calculate k_aux using d_sae instead of the correct d_in. This likely doesn't make a huge difference in practice, but can't hurt to fix. This PR also calls detach() on the residual error before calculating topk aux loss, similar to dictionary_learning's implementation. This should help ensure that the aux loss only pulls dead latents towards the SAE error, and doesn't accidentally pull live latents towards dead latents.

This PR also adds a test asserting that our topk aux loss matches the sparsity implementation aside from a normalization factor.

As a side note, we should probably add an optional aux_coefficient to the config to further customize this, but this is something that probably makes sense to hold off on until the refactor is done.

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

You have tested formatting, typing and tests

  • I have run make check-ci to check format and linting. (you can run make format to format code if needed.)

Comment on lines +476 to +479
residual = (sae_in - sae_out).detach()

# Heuristic from Appendix B.1 in the paper
k_aux = hidden_pre.shape[-1] // 2
k_aux = sae_in.shape[-1] // 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're changing calculate_topk_aux_loss(), want to also refactor it to use an early return?

It'd be

    def calculate_topk_aux_loss(
        self,
        sae_in: torch.Tensor,
        sae_out: torch.Tensor,
        hidden_pre: torch.Tensor,
        dead_neuron_mask: torch.Tensor | None,
    ) -> torch.Tensor:
        # Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
        # NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
        if dead_neuron_mask is None or int(dead_neuron_mask.sum()) == 0:
            return sae_out.new_tensor(0.0)

        num_dead = int(dead_neuron_mask.sum())
        residual = (sae_in - sae_out).detach()

        # Heuristic from Appendix B.1 in the paper
        k_aux = sae_in.shape[-1] // 2

        # Reduce the scale of the loss if there are a small number of dead latents
        scale = min(num_dead / k_aux, 1.0)
        k_aux = min(k_aux, num_dead)

        auxk_acts = _calculate_topk_aux_acts(
            k_aux=k_aux,
            hidden_pre=hidden_pre,
            dead_neuron_mask=dead_neuron_mask,
        )

        # Encourage the top ~50% of dead latents to predict the residual of the
        # top k living latents
        recons = self.decode(auxk_acts)
        auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
        return scale * auxk_loss

, which has a lot less indentation/is a lot easier to read.

Comment on lines +107 to +108
d_in=128,
d_sae=192,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Want to extract all the 128s and 192s to variables d_in and d_sae, respectively?

normalize_sae_decoder=False,
)

sae = TrainingSAE(TrainingSAEConfig.from_sae_runner_config(cfg))
comparison_sae = SparseCoder(d_in=128, cfg=SparseCoderConfig(num_latents=192, k=26))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think sparse_coder might be a better name than comparison_sae as someone reading the test could quickly tell it's an instance of SparseCoder.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants