-
Notifications
You must be signed in to change notification settings - Fork 148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: base k_aux on d_in instead of d_sae in topk aux loss #432
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
|
||
import pytest | ||
import torch | ||
from sparsify import SparseCoder, SparseCoderConfig | ||
|
||
from sae_lens.sae import SAE | ||
from sae_lens.training.training_sae import ( | ||
|
@@ -101,28 +102,68 @@ def test_calculate_topk_aux_acts_k_less_than_dead(): | |
assert torch.allclose(result, expected) | ||
|
||
|
||
def test_TrainingSAE_calculate_topk_aux_loss(): | ||
# Create a small test SAE with d_sae=4, d_in=3 | ||
def test_TrainingSAE_topk_aux_loss_matches_unnormalized_sparsify_implementation(): | ||
cfg = build_sae_cfg( | ||
d_in=3, | ||
d_sae=4, | ||
d_in=128, | ||
d_sae=192, | ||
Comment on lines
+107
to
+108
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Want to extract all the |
||
architecture="topk", | ||
activation_fn_kwargs={"k": 26}, | ||
normalize_sae_decoder=False, | ||
) | ||
|
||
sae = TrainingSAE(TrainingSAEConfig.from_sae_runner_config(cfg)) | ||
comparison_sae = SparseCoder(d_in=128, cfg=SparseCoderConfig(num_latents=192, k=26)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||
|
||
with torch.no_grad(): | ||
# increase b_enc so all features are likely above 0 | ||
# sparsify includes a relu() in their pre_acts, but | ||
# this is not something we need to try to replicate. | ||
sae.b_enc.data = sae.b_enc + 100.0 | ||
# make sure all params are the same | ||
comparison_sae.encoder.weight.data = sae.W_enc.T | ||
comparison_sae.encoder.bias.data = sae.b_enc | ||
comparison_sae.b_dec.data = sae.b_dec | ||
comparison_sae.W_dec.data = sae.W_dec # type: ignore | ||
|
||
dead_neuron_mask = torch.randn(192) > 0.1 | ||
input_acts = torch.randn(200, 128) | ||
input_var = (input_acts - input_acts.mean(0)).pow(2).sum() | ||
|
||
sae_out = sae.training_forward_pass( | ||
sae_in=input_acts, | ||
current_l1_coefficient=0.0, | ||
dead_neuron_mask=dead_neuron_mask, | ||
) | ||
comparison_sae_out = comparison_sae.forward(input_acts, dead_mask=dead_neuron_mask) | ||
comparison_aux_loss = comparison_sae_out.auxk_loss.detach().item() | ||
|
||
normalization = input_var / input_acts.shape[0] | ||
raw_aux_loss = sae_out.losses["auxiliary_reconstruction_loss"].item() # type: ignore | ||
norm_aux_loss = raw_aux_loss / normalization | ||
assert norm_aux_loss == pytest.approx(comparison_aux_loss, abs=1e-4) | ||
|
||
|
||
def test_TrainingSAE_calculate_topk_aux_loss(): | ||
# Create a small test SAE with d_sae=3, d_in=4 | ||
cfg = build_sae_cfg( | ||
d_in=4, | ||
d_sae=3, | ||
architecture="topk", | ||
normalize_sae_decoder=False, | ||
) | ||
sae = TrainingSAE(TrainingSAEConfig.from_sae_runner_config(cfg)) | ||
|
||
# Set up test inputs | ||
hidden_pre = torch.tensor( | ||
[[1.0, -2.0, 3.0, -4.0], [1.0, 0.0, -3.0, -4.0]] # batch size 2 | ||
[[1.0, -2.0, 3.0], [1.0, 0.0, -3.0]] # batch size 2 | ||
) | ||
sae.W_dec.data = torch.tensor(2 * torch.ones((4, 3))) | ||
sae.b_dec.data = torch.tensor(torch.zeros(3)) | ||
sae.W_dec.data = torch.tensor(2 * torch.ones((3, 4))) | ||
sae.b_dec.data = torch.tensor(torch.zeros(4)) | ||
|
||
sae_out = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) | ||
sae_in = torch.tensor([[2.0, 1.0, 3.0], [5.0, 4.0, 6.0]]) | ||
# Mark neurons 1 and 3 as dead | ||
dead_neuron_mask = torch.tensor([False, True, False, True]) | ||
sae_out = torch.tensor([[1.0, 2.0, 3.0, 4.0], [4.0, 5.0, 6.0, 7.0]]) | ||
sae_in = torch.tensor([[2.0, 1.0, 3.0, 4.0], [5.0, 4.0, 6.0, 7.0]]) | ||
# Mark neurons 1 and 2 as dead | ||
dead_neuron_mask = torch.tensor([False, True, True]) | ||
|
||
# Calculate loss | ||
loss = sae.calculate_topk_aux_loss( | ||
|
@@ -133,15 +174,17 @@ def test_TrainingSAE_calculate_topk_aux_loss(): | |
) | ||
|
||
# The loss should: | ||
# 1. Select top k_aux=2 (half of d_sae) dead neurons | ||
# 2. Decode their activations (should be 2x the sum of the activations of the dead neurons) | ||
# thus, (-12, -12, -12), (-8, -8, -8) | ||
# and the residual is (1, -1, 0), (1, -1, 0) | ||
# Thus, squared errors are (169, 121, 144), (81, 49, 64) | ||
# and the sums are (434, 194) | ||
# and the mean of these is 314 | ||
|
||
assert loss == 314 | ||
# 1. Select top k_aux=2 (half of d_in=4) dead neurons | ||
# 2. Decode their activations (should be 2x the activations of the dead neurons) | ||
# For batch 1: dead neurons are [-2.0, 3.0] -> activations [-4.0, 6.0] -> sum 2.0 for each output dim | ||
# For batch 2: dead neurons are [0.0, -3.0] -> activations [0.0, -6.0] -> sum -6.0 for each output dim | ||
# Residuals are: [1.0, -1.0, 0.0, 0.0], [1.0, -1.0, 0.0, 0.0] | ||
# errors are: [1.0, 3.0, 2.0, 2.0], [-7., -5., -6., -6.] | ||
# Squared errors are: [1.0, 9.0, 4.0, 4.0], [49.0, 25.0, 36.0, 36.0] | ||
# Sum over features: 18.0, 146.0 | ||
# Mean over batch: 82.0 | ||
|
||
assert loss == 82 | ||
|
||
|
||
def test_TrainingSAE_forward_includes_topk_loss_with_topk_architecture(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're changing
calculate_topk_aux_loss()
, want to also refactor it to use an early return?It'd be
, which has a lot less indentation/is a lot easier to read.