From 788650147c35489cb02d3bbe533fa516f89956c9 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Fri, 21 Feb 2025 20:08:59 -0800 Subject: [PATCH 1/2] fix: base k_aux on d_in instead of d_sae in topk aux loss --- pyproject.toml | 1 + sae_lens/training/training_sae.py | 2 +- tests/training/test_training_sae.py | 41 +++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3be7adca..f8de674a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ mkdocstrings = "^0.25.2" mkdocstrings-python = "^1.10.9" tabulate = "^0.9.0" ruff = "^0.7.4" +sparsify = {git = "https://github.com/EleutherAI/sparsify"} [tool.poetry.extras] mamba = ["mamba-lens"] diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py index ad2bda01..d38c3ccb 100644 --- a/sae_lens/training/training_sae.py +++ b/sae_lens/training/training_sae.py @@ -476,7 +476,7 @@ def calculate_topk_aux_loss( residual = sae_in - sae_out # Heuristic from Appendix B.1 in the paper - k_aux = hidden_pre.shape[-1] // 2 + k_aux = sae_in.shape[-1] // 2 # Reduce the scale of the loss if there are a small number of dead latents scale = min(num_dead / k_aux, 1.0) diff --git a/tests/training/test_training_sae.py b/tests/training/test_training_sae.py index 82e0b9ac..9629a39a 100644 --- a/tests/training/test_training_sae.py +++ b/tests/training/test_training_sae.py @@ -3,6 +3,7 @@ import pytest import torch +from sparsify import SparseCoder, SparseCoderConfig from sae_lens.sae import SAE from sae_lens.training.training_sae import ( @@ -101,6 +102,46 @@ def test_calculate_topk_aux_acts_k_less_than_dead(): assert torch.allclose(result, expected) +def test_TrainingSAE_topk_aux_loss_matches_unnormalized_sparsify_implementation(): + cfg = build_sae_cfg( + d_in=128, + d_sae=192, + architecture="topk", + activation_fn_kwargs={"k": 26}, + normalize_sae_decoder=False, + ) + + sae = TrainingSAE(TrainingSAEConfig.from_sae_runner_config(cfg)) + comparison_sae = SparseCoder(d_in=128, cfg=SparseCoderConfig(num_latents=192, k=26)) + + with torch.no_grad(): + # increase b_enc so all features are likely above 0 + # sparsify includes a relu() in their pre_acts, but + # this is not something we need to try to replicate. + sae.b_enc.data = sae.b_enc + 10.0 + # make sure all params are the same + comparison_sae.encoder.weight.data = sae.W_enc.T + comparison_sae.encoder.bias.data = sae.b_enc + comparison_sae.b_dec.data = sae.b_dec + comparison_sae.W_dec.data = sae.W_dec # type: ignore + + dead_neuron_mask = torch.randn(192) > 0.1 + input_acts = torch.randn(200, 128) + input_var = (input_acts - input_acts.mean(0)).pow(2).sum() + + sae_out = sae.training_forward_pass( + sae_in=input_acts, + current_l1_coefficient=0.0, + dead_neuron_mask=dead_neuron_mask, + ) + comparison_sae_out = comparison_sae.forward(input_acts, dead_mask=dead_neuron_mask) + + normalization = input_var / input_acts.shape[0] + raw_aux_loss = sae_out.losses["auxiliary_reconstruction_loss"].item() # type: ignore + norm_aux_loss = raw_aux_loss / normalization + assert norm_aux_loss == pytest.approx(comparison_sae_out.auxk_loss, abs=1e-4) + + def test_TrainingSAE_calculate_topk_aux_loss(): # Create a small test SAE with d_sae=4, d_in=3 cfg = build_sae_cfg( From 3a6b7506a5a2d0fa32bb87ebf10d9a476e1495ba Mon Sep 17 00:00:00 2001 From: David Chanin Date: Sat, 22 Feb 2025 20:51:57 -0800 Subject: [PATCH 2/2] detaching error before aux loss and fixing tests --- sae_lens/training/training_sae.py | 2 +- tests/training/test_training_sae.py | 46 +++++++++++++++-------------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py index d38c3ccb..ea6ee428 100644 --- a/sae_lens/training/training_sae.py +++ b/sae_lens/training/training_sae.py @@ -473,7 +473,7 @@ def calculate_topk_aux_loss( dead_neuron_mask is not None and (num_dead := int(dead_neuron_mask.sum())) > 0 ): - residual = sae_in - sae_out + residual = (sae_in - sae_out).detach() # Heuristic from Appendix B.1 in the paper k_aux = sae_in.shape[-1] // 2 diff --git a/tests/training/test_training_sae.py b/tests/training/test_training_sae.py index 9629a39a..6a8c1186 100644 --- a/tests/training/test_training_sae.py +++ b/tests/training/test_training_sae.py @@ -118,7 +118,7 @@ def test_TrainingSAE_topk_aux_loss_matches_unnormalized_sparsify_implementation( # increase b_enc so all features are likely above 0 # sparsify includes a relu() in their pre_acts, but # this is not something we need to try to replicate. - sae.b_enc.data = sae.b_enc + 10.0 + sae.b_enc.data = sae.b_enc + 100.0 # make sure all params are the same comparison_sae.encoder.weight.data = sae.W_enc.T comparison_sae.encoder.bias.data = sae.b_enc @@ -135,35 +135,35 @@ def test_TrainingSAE_topk_aux_loss_matches_unnormalized_sparsify_implementation( dead_neuron_mask=dead_neuron_mask, ) comparison_sae_out = comparison_sae.forward(input_acts, dead_mask=dead_neuron_mask) + comparison_aux_loss = comparison_sae_out.auxk_loss.detach().item() normalization = input_var / input_acts.shape[0] raw_aux_loss = sae_out.losses["auxiliary_reconstruction_loss"].item() # type: ignore norm_aux_loss = raw_aux_loss / normalization - assert norm_aux_loss == pytest.approx(comparison_sae_out.auxk_loss, abs=1e-4) + assert norm_aux_loss == pytest.approx(comparison_aux_loss, abs=1e-4) def test_TrainingSAE_calculate_topk_aux_loss(): - # Create a small test SAE with d_sae=4, d_in=3 + # Create a small test SAE with d_sae=3, d_in=4 cfg = build_sae_cfg( - d_in=3, - d_sae=4, + d_in=4, + d_sae=3, architecture="topk", normalize_sae_decoder=False, ) - sae = TrainingSAE(TrainingSAEConfig.from_sae_runner_config(cfg)) # Set up test inputs hidden_pre = torch.tensor( - [[1.0, -2.0, 3.0, -4.0], [1.0, 0.0, -3.0, -4.0]] # batch size 2 + [[1.0, -2.0, 3.0], [1.0, 0.0, -3.0]] # batch size 2 ) - sae.W_dec.data = torch.tensor(2 * torch.ones((4, 3))) - sae.b_dec.data = torch.tensor(torch.zeros(3)) + sae.W_dec.data = torch.tensor(2 * torch.ones((3, 4))) + sae.b_dec.data = torch.tensor(torch.zeros(4)) - sae_out = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - sae_in = torch.tensor([[2.0, 1.0, 3.0], [5.0, 4.0, 6.0]]) - # Mark neurons 1 and 3 as dead - dead_neuron_mask = torch.tensor([False, True, False, True]) + sae_out = torch.tensor([[1.0, 2.0, 3.0, 4.0], [4.0, 5.0, 6.0, 7.0]]) + sae_in = torch.tensor([[2.0, 1.0, 3.0, 4.0], [5.0, 4.0, 6.0, 7.0]]) + # Mark neurons 1 and 2 as dead + dead_neuron_mask = torch.tensor([False, True, True]) # Calculate loss loss = sae.calculate_topk_aux_loss( @@ -174,15 +174,17 @@ def test_TrainingSAE_calculate_topk_aux_loss(): ) # The loss should: - # 1. Select top k_aux=2 (half of d_sae) dead neurons - # 2. Decode their activations (should be 2x the sum of the activations of the dead neurons) - # thus, (-12, -12, -12), (-8, -8, -8) - # and the residual is (1, -1, 0), (1, -1, 0) - # Thus, squared errors are (169, 121, 144), (81, 49, 64) - # and the sums are (434, 194) - # and the mean of these is 314 - - assert loss == 314 + # 1. Select top k_aux=2 (half of d_in=4) dead neurons + # 2. Decode their activations (should be 2x the activations of the dead neurons) + # For batch 1: dead neurons are [-2.0, 3.0] -> activations [-4.0, 6.0] -> sum 2.0 for each output dim + # For batch 2: dead neurons are [0.0, -3.0] -> activations [0.0, -6.0] -> sum -6.0 for each output dim + # Residuals are: [1.0, -1.0, 0.0, 0.0], [1.0, -1.0, 0.0, 0.0] + # errors are: [1.0, 3.0, 2.0, 2.0], [-7., -5., -6., -6.] + # Squared errors are: [1.0, 9.0, 4.0, 4.0], [49.0, 25.0, 36.0, 36.0] + # Sum over features: 18.0, 146.0 + # Mean over batch: 82.0 + + assert loss == 82 def test_TrainingSAE_forward_includes_topk_loss_with_topk_architecture():