Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: base k_aux on d_in instead of d_sae in topk aux loss #432

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ mkdocstrings = "^0.25.2"
mkdocstrings-python = "^1.10.9"
tabulate = "^0.9.0"
ruff = "^0.7.4"
sparsify = {git = "https://github.com/EleutherAI/sparsify"}

[tool.poetry.extras]
mamba = ["mamba-lens"]
Expand Down
4 changes: 2 additions & 2 deletions sae_lens/training/training_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,10 +473,10 @@ def calculate_topk_aux_loss(
dead_neuron_mask is not None
and (num_dead := int(dead_neuron_mask.sum())) > 0
):
residual = sae_in - sae_out
residual = (sae_in - sae_out).detach()

# Heuristic from Appendix B.1 in the paper
k_aux = hidden_pre.shape[-1] // 2
k_aux = sae_in.shape[-1] // 2
Comment on lines +476 to +479
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're changing calculate_topk_aux_loss(), want to also refactor it to use an early return?

It'd be

    def calculate_topk_aux_loss(
        self,
        sae_in: torch.Tensor,
        sae_out: torch.Tensor,
        hidden_pre: torch.Tensor,
        dead_neuron_mask: torch.Tensor | None,
    ) -> torch.Tensor:
        # Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
        # NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
        if dead_neuron_mask is None or int(dead_neuron_mask.sum()) == 0:
            return sae_out.new_tensor(0.0)

        num_dead = int(dead_neuron_mask.sum())
        residual = (sae_in - sae_out).detach()

        # Heuristic from Appendix B.1 in the paper
        k_aux = sae_in.shape[-1] // 2

        # Reduce the scale of the loss if there are a small number of dead latents
        scale = min(num_dead / k_aux, 1.0)
        k_aux = min(k_aux, num_dead)

        auxk_acts = _calculate_topk_aux_acts(
            k_aux=k_aux,
            hidden_pre=hidden_pre,
            dead_neuron_mask=dead_neuron_mask,
        )

        # Encourage the top ~50% of dead latents to predict the residual of the
        # top k living latents
        recons = self.decode(auxk_acts)
        auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
        return scale * auxk_loss

, which has a lot less indentation/is a lot easier to read.


# Reduce the scale of the loss if there are a small number of dead latents
scale = min(num_dead / k_aux, 1.0)
Expand Down
83 changes: 63 additions & 20 deletions tests/training/test_training_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
import torch
from sparsify import SparseCoder, SparseCoderConfig

from sae_lens.sae import SAE
from sae_lens.training.training_sae import (
Expand Down Expand Up @@ -101,28 +102,68 @@ def test_calculate_topk_aux_acts_k_less_than_dead():
assert torch.allclose(result, expected)


def test_TrainingSAE_calculate_topk_aux_loss():
# Create a small test SAE with d_sae=4, d_in=3
def test_TrainingSAE_topk_aux_loss_matches_unnormalized_sparsify_implementation():
cfg = build_sae_cfg(
d_in=3,
d_sae=4,
d_in=128,
d_sae=192,
Comment on lines +107 to +108
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Want to extract all the 128s and 192s to variables d_in and d_sae, respectively?

architecture="topk",
activation_fn_kwargs={"k": 26},
normalize_sae_decoder=False,
)

sae = TrainingSAE(TrainingSAEConfig.from_sae_runner_config(cfg))
comparison_sae = SparseCoder(d_in=128, cfg=SparseCoderConfig(num_latents=192, k=26))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think sparse_coder might be a better name than comparison_sae as someone reading the test could quickly tell it's an instance of SparseCoder.


with torch.no_grad():
# increase b_enc so all features are likely above 0
# sparsify includes a relu() in their pre_acts, but
# this is not something we need to try to replicate.
sae.b_enc.data = sae.b_enc + 100.0
# make sure all params are the same
comparison_sae.encoder.weight.data = sae.W_enc.T
comparison_sae.encoder.bias.data = sae.b_enc
comparison_sae.b_dec.data = sae.b_dec
comparison_sae.W_dec.data = sae.W_dec # type: ignore

dead_neuron_mask = torch.randn(192) > 0.1
input_acts = torch.randn(200, 128)
input_var = (input_acts - input_acts.mean(0)).pow(2).sum()

sae_out = sae.training_forward_pass(
sae_in=input_acts,
current_l1_coefficient=0.0,
dead_neuron_mask=dead_neuron_mask,
)
comparison_sae_out = comparison_sae.forward(input_acts, dead_mask=dead_neuron_mask)
comparison_aux_loss = comparison_sae_out.auxk_loss.detach().item()

normalization = input_var / input_acts.shape[0]
raw_aux_loss = sae_out.losses["auxiliary_reconstruction_loss"].item() # type: ignore
norm_aux_loss = raw_aux_loss / normalization
assert norm_aux_loss == pytest.approx(comparison_aux_loss, abs=1e-4)


def test_TrainingSAE_calculate_topk_aux_loss():
# Create a small test SAE with d_sae=3, d_in=4
cfg = build_sae_cfg(
d_in=4,
d_sae=3,
architecture="topk",
normalize_sae_decoder=False,
)
sae = TrainingSAE(TrainingSAEConfig.from_sae_runner_config(cfg))

# Set up test inputs
hidden_pre = torch.tensor(
[[1.0, -2.0, 3.0, -4.0], [1.0, 0.0, -3.0, -4.0]] # batch size 2
[[1.0, -2.0, 3.0], [1.0, 0.0, -3.0]] # batch size 2
)
sae.W_dec.data = torch.tensor(2 * torch.ones((4, 3)))
sae.b_dec.data = torch.tensor(torch.zeros(3))
sae.W_dec.data = torch.tensor(2 * torch.ones((3, 4)))
sae.b_dec.data = torch.tensor(torch.zeros(4))

sae_out = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
sae_in = torch.tensor([[2.0, 1.0, 3.0], [5.0, 4.0, 6.0]])
# Mark neurons 1 and 3 as dead
dead_neuron_mask = torch.tensor([False, True, False, True])
sae_out = torch.tensor([[1.0, 2.0, 3.0, 4.0], [4.0, 5.0, 6.0, 7.0]])
sae_in = torch.tensor([[2.0, 1.0, 3.0, 4.0], [5.0, 4.0, 6.0, 7.0]])
# Mark neurons 1 and 2 as dead
dead_neuron_mask = torch.tensor([False, True, True])

# Calculate loss
loss = sae.calculate_topk_aux_loss(
Expand All @@ -133,15 +174,17 @@ def test_TrainingSAE_calculate_topk_aux_loss():
)

# The loss should:
# 1. Select top k_aux=2 (half of d_sae) dead neurons
# 2. Decode their activations (should be 2x the sum of the activations of the dead neurons)
# thus, (-12, -12, -12), (-8, -8, -8)
# and the residual is (1, -1, 0), (1, -1, 0)
# Thus, squared errors are (169, 121, 144), (81, 49, 64)
# and the sums are (434, 194)
# and the mean of these is 314

assert loss == 314
# 1. Select top k_aux=2 (half of d_in=4) dead neurons
# 2. Decode their activations (should be 2x the activations of the dead neurons)
# For batch 1: dead neurons are [-2.0, 3.0] -> activations [-4.0, 6.0] -> sum 2.0 for each output dim
# For batch 2: dead neurons are [0.0, -3.0] -> activations [0.0, -6.0] -> sum -6.0 for each output dim
# Residuals are: [1.0, -1.0, 0.0, 0.0], [1.0, -1.0, 0.0, 0.0]
# errors are: [1.0, 3.0, 2.0, 2.0], [-7., -5., -6., -6.]
# Squared errors are: [1.0, 9.0, 4.0, 4.0], [49.0, 25.0, 36.0, 36.0]
# Sum over features: 18.0, 146.0
# Mean over batch: 82.0

assert loss == 82


def test_TrainingSAE_forward_includes_topk_loss_with_topk_architecture():
Expand Down
Loading