From 5f17c0e82af0519c07d4088d08cda229fb36a0dc Mon Sep 17 00:00:00 2001 From: callummcdougall Date: Wed, 16 Oct 2024 18:57:41 +0100 Subject: [PATCH 1/8] first commit --- sae_lens/sae.py | 54 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/sae_lens/sae.py b/sae_lens/sae.py index 1610090c..749dd1a5 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -6,16 +6,22 @@ import os import warnings from dataclasses import dataclass, field -from typing import Any, Callable, Literal, Optional, Tuple, TypeVar, Union, overload +from typing import ( + Any, + Callable, + Iterable, + Literal, + Optional, + Tuple, + TypeVar, + Union, + overload, +) T = TypeVar("T", bound="SAE") import einops import torch from jaxtyping import Float -from safetensors.torch import save_file -from torch import nn -from transformer_lens.hook_points import HookedRootModule, HookPoint - from sae_lens.config import DTYPE_MAP from sae_lens.toolkit.pretrained_sae_loaders import ( NAMED_PRETRAINED_SAE_LOADERS, @@ -27,6 +33,9 @@ get_norm_scaling_factor, get_pretrained_saes_directory, ) +from safetensors.torch import save_file +from torch import nn +from transformer_lens.hook_points import HookedRootModule, HookPoint SPARSITY_PATH = "sparsity.safetensors" SAE_WEIGHTS_PATH = "sae_weights.safetensors" @@ -67,7 +76,6 @@ class SAEConfig: @classmethod def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig": - # rename dict: rename_dict = { # old : new "hook_point": "hook_name", @@ -196,7 +204,6 @@ def __init__( # handle run time activation normalization if needed: if self.cfg.normalize_activations == "constant_norm_rescale": - # we need to scale the norm of the input and store the scaling factor def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor: self.x_norm_coeff = (self.cfg.d_in**0.5) / x.norm(dim=-1, keepdim=True) @@ -212,7 +219,6 @@ def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor: # self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out elif self.cfg.normalize_activations == "layer_norm": - # we need to scale the norm of the input and store the scaling factor def run_time_activation_ln_in( x: torch.Tensor, eps: float = 1e-5 @@ -237,7 +243,6 @@ def run_time_activation_ln_out(x: torch.Tensor, eps: float = 1e-5): self.setup() # Required for `HookedRootModule`s def initialize_weights_basic(self): - # no config changes encoder bias init for now. self.b_enc = nn.Parameter( torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device) @@ -492,8 +497,13 @@ def forward( return self.hook_sae_output(sae_out) def encode_gated( - self, x: Float[torch.Tensor, "... d_in"] + self, x: Float[torch.Tensor, "... d_in"], latents: Iterable[int] | None = None ) -> Float[torch.Tensor, "... d_sae"]: + """ + Calculate SAE features from inputs + """ + if latents is None: + latents = slice(None) x = x.to(self.dtype) x = self.reshape_fn_in(x) @@ -502,12 +512,13 @@ def encode_gated( sae_in = x - self.b_dec * self.cfg.apply_b_dec_to_input # Gating path - gating_pre_activation = sae_in @ self.W_enc + self.b_gate + gating_pre_activation = sae_in @ self.W_enc[:, latents] + self.b_gate[latents] active_features = (gating_pre_activation > 0).to(self.dtype) # Magnitude path with weight sharing magnitude_pre_activation = self.hook_sae_acts_pre( - sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag + sae_in @ (self.W_enc[:, latents] * self.r_mag[latents].exp()) + + self.b_mag[latents] ) feature_magnitudes = self.activation_fn(magnitude_pre_activation) @@ -516,11 +527,13 @@ def encode_gated( return feature_acts def encode_jumprelu( - self, x: Float[torch.Tensor, "... d_in"] + self, x: Float[torch.Tensor, "... d_in"], latents: Iterable[int] | None = None ) -> Float[torch.Tensor, "... d_sae"]: """ Calculate SAE features from inputs """ + if latents is None: + latents = slice(None) # move x to correct dtype x = x.to(self.dtype) @@ -535,7 +548,9 @@ def encode_jumprelu( sae_in = self.hook_sae_input(x - (self.b_dec * self.cfg.apply_b_dec_to_input)) # "... d_in, d_in d_sae -> ... d_sae", - hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) + hidden_pre = self.hook_sae_acts_pre( + sae_in @ self.W_enc[:, latents] + self.b_enc[latents] + ) feature_acts = self.hook_sae_acts_post( self.activation_fn(hidden_pre) * (hidden_pre > self.threshold) @@ -544,11 +559,13 @@ def encode_jumprelu( return feature_acts def encode_standard( - self, x: Float[torch.Tensor, "... d_in"] + self, x: Float[torch.Tensor, "... d_in"], latents: Iterable[int] | None = None ) -> Float[torch.Tensor, "... d_sae"]: """ Calculate SAE features from inputs """ + if latents is None: + latents = slice(None) x = x.to(self.dtype) x = self.reshape_fn_in(x) @@ -559,7 +576,9 @@ def encode_standard( sae_in = x - (self.b_dec * self.cfg.apply_b_dec_to_input) # "... d_in, d_in d_sae -> ... d_sae", - hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) + hidden_pre = self.hook_sae_acts_pre( + sae_in @ self.W_enc[:, latents] + self.b_enc[latents] + ) feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre)) return feature_acts @@ -606,7 +625,6 @@ def fold_activation_norm_scaling_factor( self.cfg.normalize_activations = "none" def save_model(self, path: str, sparsity: Optional[torch.Tensor] = None): - if not os.path.exists(path): os.mkdir(path) @@ -627,7 +645,6 @@ def save_model(self, path: str, sparsity: Optional[torch.Tensor] = None): def load_from_pretrained( cls, path: str, device: str = "cpu", dtype: str | None = None ) -> "SAE": - # get the config config_path = os.path.join(path, SAE_CFG_PATH) with open(config_path, "r") as f: @@ -752,7 +769,6 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "SAE": return cls(SAEConfig.from_dict(config_dict)) def turn_on_forward_pass_hook_z_reshaping(self): - assert self.cfg.hook_name.endswith( "_z" ), "This method should only be called for hook_z SAEs." From 27933f7e0834cd5808ee27eb0717066890431528 Mon Sep 17 00:00:00 2001 From: callummcdougall Date: Wed, 16 Oct 2024 19:06:56 +0100 Subject: [PATCH 2/8] fix formatting, add tests --- sae_lens/sae.py | 7 ++++--- tests/unit/training/test_sae_basic.py | 13 ++++++++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/sae_lens/sae.py b/sae_lens/sae.py index 749dd1a5..032303ae 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -22,6 +22,10 @@ import einops import torch from jaxtyping import Float +from safetensors.torch import save_file +from torch import nn +from transformer_lens.hook_points import HookedRootModule, HookPoint + from sae_lens.config import DTYPE_MAP from sae_lens.toolkit.pretrained_sae_loaders import ( NAMED_PRETRAINED_SAE_LOADERS, @@ -33,9 +37,6 @@ get_norm_scaling_factor, get_pretrained_saes_directory, ) -from safetensors.torch import save_file -from torch import nn -from transformer_lens.hook_points import HookedRootModule, HookPoint SPARSITY_PATH = "sparsity.safetensors" SAE_WEIGHTS_PATH = "sae_weights.safetensors" diff --git a/tests/unit/training/test_sae_basic.py b/tests/unit/training/test_sae_basic.py index 530fca2c..194fab64 100644 --- a/tests/unit/training/test_sae_basic.py +++ b/tests/unit/training/test_sae_basic.py @@ -70,6 +70,18 @@ def test_sae_init(cfg: LanguageModelSAERunnerConfig): assert sae.b_dec.shape == (cfg.d_in,) +def test_sae_encode_with_slicing(cfg: LanguageModelSAERunnerConfig): + sae = SAE.from_dict(cfg.get_base_sae_cfg_dict()) + + activations = torch.randn(10, 4, cfg.d_in, device=cfg.device) + latents = torch.randint(low=0, high=cfg.d_sae, size=(10,)) + feature_activations = sae.encode(activations) + feature_activations_slice = sae.encode(activations, latents=latents) + torch.testing.assert_close( + feature_activations[..., latents], feature_activations_slice + ) + + def test_sae_fold_w_dec_norm(cfg: LanguageModelSAERunnerConfig): sae = SAE.from_dict(cfg.get_base_sae_cfg_dict()) sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here. @@ -106,7 +118,6 @@ def test_sae_fold_w_dec_norm(cfg: LanguageModelSAERunnerConfig): def test_sae_fold_norm_scaling_factor(cfg: LanguageModelSAERunnerConfig): - norm_scaling_factor = 3.0 sae = SAE.from_dict(cfg.get_base_sae_cfg_dict()) From 64c6fc4b737d2ed48e65ee4d7a8effc1d9c6e4a9 Mon Sep 17 00:00:00 2001 From: callummcdougall Date: Wed, 16 Oct 2024 19:10:22 +0100 Subject: [PATCH 3/8] fix typing problem --- sae_lens/sae.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/sae_lens/sae.py b/sae_lens/sae.py index 032303ae..e13e0b30 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -503,8 +503,7 @@ def encode_gated( """ Calculate SAE features from inputs """ - if latents is None: - latents = slice(None) + latents_slice = slice(None) if latents is None else latents x = x.to(self.dtype) x = self.reshape_fn_in(x) @@ -513,13 +512,15 @@ def encode_gated( sae_in = x - self.b_dec * self.cfg.apply_b_dec_to_input # Gating path - gating_pre_activation = sae_in @ self.W_enc[:, latents] + self.b_gate[latents] + gating_pre_activation = ( + sae_in @ self.W_enc[:, latents_slice] + self.b_gate[latents_slice] + ) active_features = (gating_pre_activation > 0).to(self.dtype) # Magnitude path with weight sharing magnitude_pre_activation = self.hook_sae_acts_pre( - sae_in @ (self.W_enc[:, latents] * self.r_mag[latents].exp()) - + self.b_mag[latents] + sae_in @ (self.W_enc[:, latents_slice] * self.r_mag[latents_slice].exp()) + + self.b_mag[latents_slice] ) feature_magnitudes = self.activation_fn(magnitude_pre_activation) @@ -533,8 +534,7 @@ def encode_jumprelu( """ Calculate SAE features from inputs """ - if latents is None: - latents = slice(None) + latents_slice = slice(None) if latents is None else latents # move x to correct dtype x = x.to(self.dtype) @@ -550,7 +550,7 @@ def encode_jumprelu( # "... d_in, d_in d_sae -> ... d_sae", hidden_pre = self.hook_sae_acts_pre( - sae_in @ self.W_enc[:, latents] + self.b_enc[latents] + sae_in @ self.W_enc[:, latents_slice] + self.b_enc[latents_slice] ) feature_acts = self.hook_sae_acts_post( @@ -565,8 +565,7 @@ def encode_standard( """ Calculate SAE features from inputs """ - if latents is None: - latents = slice(None) + latents_slice = slice(None) if latents is None else latents x = x.to(self.dtype) x = self.reshape_fn_in(x) @@ -578,7 +577,7 @@ def encode_standard( # "... d_in, d_in d_sae -> ... d_sae", hidden_pre = self.hook_sae_acts_pre( - sae_in @ self.W_enc[:, latents] + self.b_enc[latents] + sae_in @ self.W_enc[:, latents_slice] + self.b_enc[latents_slice] ) feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre)) From cbda3a09ee7cdeb337e8fad7453bcc431818a53d Mon Sep 17 00:00:00 2001 From: callummcdougall Date: Wed, 16 Oct 2024 19:20:30 +0100 Subject: [PATCH 4/8] fix more typing problems --- sae_lens/sae.py | 6 +++--- sae_lens/training/training_sae.py | 18 +++++------------- tests/unit/training/test_sae_basic.py | 1 + 3 files changed, 9 insertions(+), 16 deletions(-) diff --git a/sae_lens/sae.py b/sae_lens/sae.py index e13e0b30..471bce41 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -503,7 +503,7 @@ def encode_gated( """ Calculate SAE features from inputs """ - latents_slice = slice(None) if latents is None else latents + latents_slice = slice(None) if latents is None else torch.tensor(latents) x = x.to(self.dtype) x = self.reshape_fn_in(x) @@ -534,7 +534,7 @@ def encode_jumprelu( """ Calculate SAE features from inputs """ - latents_slice = slice(None) if latents is None else latents + latents_slice = slice(None) if latents is None else torch.tensor(latents) # move x to correct dtype x = x.to(self.dtype) @@ -565,7 +565,7 @@ def encode_standard( """ Calculate SAE features from inputs """ - latents_slice = slice(None) if latents is None else latents + latents_slice = slice(None) if latents is None else torch.tensor(latents) x = x.to(self.dtype) x = self.reshape_fn_in(x) diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py index b7925d4e..0be5a652 100644 --- a/sae_lens/training/training_sae.py +++ b/sae_lens/training/training_sae.py @@ -5,7 +5,7 @@ import json import os from dataclasses import dataclass, fields -from typing import Any, Optional +from typing import Any, Iterable, Optional import einops import torch @@ -38,7 +38,6 @@ class TrainStepOutput: @dataclass(kw_only=True) class TrainingSAEConfig(SAEConfig): - # Sparsity Loss Calculations l1_coefficient: float lp_norm: float @@ -55,7 +54,6 @@ class TrainingSAEConfig(SAEConfig): def from_sae_runner_config( cls, cfg: LanguageModelSAERunnerConfig ) -> "TrainingSAEConfig": - return cls( # base config architecture=cfg.architecture, @@ -168,7 +166,6 @@ class TrainingSAE(SAE): device: torch.device def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False): - base_sae_cfg = SAEConfig.from_dict(cfg.get_base_sae_cfg_dict()) super().__init__(base_sae_cfg) self.cfg = cfg # type: ignore @@ -203,18 +200,20 @@ def check_cfg_compatibility(self): assert self.use_error_term is False, "Gated SAEs do not support error terms" def encode_standard( - self, x: Float[torch.Tensor, "... d_in"] + self, x: Float[torch.Tensor, "... d_in"], latents: Iterable[int] | None = None ) -> Float[torch.Tensor, "... d_sae"]: """ Calcuate SAE features from inputs """ + assert ( + latents is None + ), "Function `encode_standard` in training should always return activations for all latents" feature_acts, _ = self.encode_with_hidden_pre_fn(x) return feature_acts def encode_with_hidden_pre( self, x: Float[torch.Tensor, "... d_in"] ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]: - x = x.to(self.dtype) x = self.reshape_fn_in(x) # type: ignore x = self.hook_sae_input(x) @@ -235,7 +234,6 @@ def encode_with_hidden_pre( def encode_with_hidden_pre_gated( self, x: Float[torch.Tensor, "... d_in"] ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]: - x = x.to(self.dtype) x = self.reshape_fn_in(x) # type: ignore x = self.hook_sae_input(x) @@ -267,7 +265,6 @@ def forward( self, x: Float[torch.Tensor, "... d_in"], ) -> Float[torch.Tensor, "... d_in"]: - feature_acts, _ = self.encode_with_hidden_pre_fn(x) sae_out = self.decode(feature_acts) @@ -279,7 +276,6 @@ def training_forward_pass( current_l1_coefficient: float, dead_neuron_mask: Optional[torch.Tensor] = None, ) -> TrainStepOutput: - # do a forward pass to get SAE out, but we also need the # hidden pre. feature_acts, _ = self.encode_with_hidden_pre_fn(sae_in) @@ -291,7 +287,6 @@ def training_forward_pass( # GHOST GRADS if self.cfg.use_ghost_grads and self.training and dead_neuron_mask is not None: - # first half of second forward pass _, hidden_pre = self.encode_with_hidden_pre_fn(sae_in) ghost_grad_loss = self.calculate_ghost_grad_loss( @@ -362,7 +357,6 @@ def calculate_ghost_grad_loss( hidden_pre: torch.Tensor, dead_neuron_mask: torch.Tensor, ) -> torch.Tensor: - # 1. residual = x - sae_out l2_norm_residual = torch.norm(residual, dim=-1) @@ -394,7 +388,6 @@ def calculate_ghost_grad_loss( @torch.no_grad() def _get_mse_loss_fn(self) -> Any: - def standard_mse_loss_fn( preds: torch.Tensor, target: torch.Tensor ) -> torch.Tensor: @@ -421,7 +414,6 @@ def load_from_pretrained( device: str = "cpu", dtype: str | None = None, ) -> "TrainingSAE": - # get the config config_path = os.path.join(path, SAE_CFG_PATH) with open(config_path, "r") as f: diff --git a/tests/unit/training/test_sae_basic.py b/tests/unit/training/test_sae_basic.py index 194fab64..6801c4be 100644 --- a/tests/unit/training/test_sae_basic.py +++ b/tests/unit/training/test_sae_basic.py @@ -72,6 +72,7 @@ def test_sae_init(cfg: LanguageModelSAERunnerConfig): def test_sae_encode_with_slicing(cfg: LanguageModelSAERunnerConfig): sae = SAE.from_dict(cfg.get_base_sae_cfg_dict()) + assert isinstance(cfg.d_sae, int) activations = torch.randn(10, 4, cfg.d_in, device=cfg.device) latents = torch.randint(low=0, high=cfg.d_sae, size=(10,)) From aef2fc18375cf70e287cdc44eb9c423f0f423d1f Mon Sep 17 00:00:00 2001 From: callummcdougall Date: Wed, 16 Oct 2024 19:29:53 +0100 Subject: [PATCH 5/8] fix typing problems --- tests/unit/training/test_sae_basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/training/test_sae_basic.py b/tests/unit/training/test_sae_basic.py index 6801c4be..586671a8 100644 --- a/tests/unit/training/test_sae_basic.py +++ b/tests/unit/training/test_sae_basic.py @@ -75,7 +75,7 @@ def test_sae_encode_with_slicing(cfg: LanguageModelSAERunnerConfig): assert isinstance(cfg.d_sae, int) activations = torch.randn(10, 4, cfg.d_in, device=cfg.device) - latents = torch.randint(low=0, high=cfg.d_sae, size=(10,)) + latents = torch.randint(low=0, high=cfg.d_sae, size=(10,)).tolist() feature_activations = sae.encode(activations) feature_activations_slice = sae.encode(activations, latents=latents) torch.testing.assert_close( From e1d105afaa815956c501a4846073f6d14dc76ae7 Mon Sep 17 00:00:00 2001 From: callummcdougall Date: Thu, 17 Oct 2024 15:11:01 +0100 Subject: [PATCH 6/8] error for topk, restructure encode, extend tests to all architectures --- sae_lens/sae.py | 110 +++++++++++++++++--------- tests/unit/training/test_sae_basic.py | 6 +- 2 files changed, 76 insertions(+), 40 deletions(-) diff --git a/sae_lens/sae.py b/sae_lens/sae.py index 471bce41..0ee0c99e 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -6,22 +6,12 @@ import os import warnings from dataclasses import dataclass, field -from typing import ( - Any, - Callable, - Iterable, - Literal, - Optional, - Tuple, - TypeVar, - Union, - overload, -) +from typing import Any, Callable, Literal, Optional, Tuple, TypeVar, Union, overload T = TypeVar("T", bound="SAE") import einops import torch -from jaxtyping import Float +from jaxtyping import Float, Int from safetensors.torch import save_file from torch import nn from transformer_lens.hook_points import HookedRootModule, HookPoint @@ -164,17 +154,8 @@ def __init__( self.device = torch.device(cfg.device) self.use_error_term = use_error_term - if self.cfg.architecture == "standard": - self.initialize_weights_basic() - self.encode = self.encode_standard - elif self.cfg.architecture == "gated": - self.initialize_weights_gated() - self.encode = self.encode_gated - elif self.cfg.architecture == "jumprelu": - self.initialize_weights_jumprelu() - self.encode = self.encode_jumprelu - else: - raise (ValueError) + if self.cfg.architecture not in ["standard", "gated", "jumprelu"]: + raise ValueError(f"Architecture {self.cfg.architecture} not supported") # handle presence / absence of scaling factor. if self.cfg.finetuning_scaling_factor: @@ -243,6 +224,16 @@ def run_time_activation_ln_out(x: torch.Tensor, eps: float = 1e-5): self.setup() # Required for `HookedRootModule`s + def initialize_weights(self): + if self.cfg.architecture == "standard": + self.initialize_weights_basic() + elif self.cfg.architecture == "gated": + self.initialize_weights_gated() + elif self.cfg.architecture == "jumprelu": + self.initialize_weights_jumprelu() + else: + raise (ValueError) + def initialize_weights_basic(self): # no config changes encoder bias init for now. self.b_enc = nn.Parameter( @@ -497,13 +488,39 @@ def forward( return self.hook_sae_output(sae_out) + def encode( + self, x: torch.Tensor, latents: torch.Tensor | None = None + ) -> torch.Tensor: + """ + Calculate SAE latents from inputs. Includes optional `latents` argument to only calculate a subset. Note that + this won't make sense for topk SAEs, because we need to compute all hidden values to apply the topk masking. + """ + if self.cfg.activation_fn_str == "topk": + assert ( + latents is None + ), "Computing a slice of SAE hidden values doesn't make sense in topk SAEs." + + return { + "standard": self.encode_standard, + "gated": self.encode_gated, + "jumprelu": self.encode_jumprelu, + }[self.cfg.architecture](x, latents) + def encode_gated( - self, x: Float[torch.Tensor, "... d_in"], latents: Iterable[int] | None = None + self, + x: Float[torch.Tensor, "... d_in"], + latents: Int[torch.Tensor, "latents"] | None = None, ) -> Float[torch.Tensor, "... d_sae"]: """ - Calculate SAE features from inputs + Computes the latent values of the Sparse Autoencoder (SAE) using a gated architecture. The activation values are + computed as the product of the masking term & the post-activation function magnitude term: + + 1[(x - b_dec) @ W_gate + b_gate > 0] * activation_fn((x - b_dec) @ W_enc + b_enc) + + The `latents` argument allows for the computation of a specific subset of the hidden values. If `latents` is not + provided, all latent values will be computed. """ - latents_slice = slice(None) if latents is None else torch.tensor(latents) + latents_tensor = torch.arange(self.cfg.d_sae) if latents is None else latents x = x.to(self.dtype) x = self.reshape_fn_in(x) @@ -513,14 +530,14 @@ def encode_gated( # Gating path gating_pre_activation = ( - sae_in @ self.W_enc[:, latents_slice] + self.b_gate[latents_slice] + sae_in @ self.W_enc[:, latents_tensor] + self.b_gate[latents_tensor] ) active_features = (gating_pre_activation > 0).to(self.dtype) # Magnitude path with weight sharing magnitude_pre_activation = self.hook_sae_acts_pre( - sae_in @ (self.W_enc[:, latents_slice] * self.r_mag[latents_slice].exp()) - + self.b_mag[latents_slice] + sae_in @ (self.W_enc[:, latents_tensor] * self.r_mag[latents_tensor].exp()) + + self.b_mag[latents_tensor] ) feature_magnitudes = self.activation_fn(magnitude_pre_activation) @@ -529,12 +546,20 @@ def encode_gated( return feature_acts def encode_jumprelu( - self, x: Float[torch.Tensor, "... d_in"], latents: Iterable[int] | None = None + self, + x: Float[torch.Tensor, "... d_in"], + latents: Int[torch.Tensor, "latents"] | None = None, ) -> Float[torch.Tensor, "... d_sae"]: """ - Calculate SAE features from inputs + Computes the latent values of the Sparse Autoencoder (SAE) using a gated architecture. The activation values are + computed as: + + activation_fn((x - b_dec) @ W_enc + b_enc) * 1[(x - b_dec) @ W_enc + b_enc > threshold] + + The `latents` argument allows for the computation of a specific subset of the hidden values. If `latents` is not + provided, all latent values will be computed. """ - latents_slice = slice(None) if latents is None else torch.tensor(latents) + latents_tensor = torch.arange(self.cfg.d_sae) if latents is None else latents # move x to correct dtype x = x.to(self.dtype) @@ -550,22 +575,31 @@ def encode_jumprelu( # "... d_in, d_in d_sae -> ... d_sae", hidden_pre = self.hook_sae_acts_pre( - sae_in @ self.W_enc[:, latents_slice] + self.b_enc[latents_slice] + sae_in @ self.W_enc[:, latents_tensor] + self.b_enc[latents_tensor] ) feature_acts = self.hook_sae_acts_post( - self.activation_fn(hidden_pre) * (hidden_pre > self.threshold) + self.activation_fn(hidden_pre) + * (hidden_pre > self.threshold[latents_tensor]) ) return feature_acts def encode_standard( - self, x: Float[torch.Tensor, "... d_in"], latents: Iterable[int] | None = None + self, + x: Float[torch.Tensor, "... d_in"], + latents: Int[torch.Tensor, "latents"] | None = None, ) -> Float[torch.Tensor, "... d_sae"]: """ - Calculate SAE features from inputs + Computes the latent values of the Sparse Autoencoder (SAE) using a gated architecture. The activation values are + computed as: + + activation_fn((x - b_dec) @ W_enc + b_enc) + + The `latents` argument allows for the computation of a specific subset of the hidden values. If `latents` is not + provided, all latent values will be computed. """ - latents_slice = slice(None) if latents is None else torch.tensor(latents) + latents_tensor = torch.arange(self.cfg.d_sae) if latents is None else latents x = x.to(self.dtype) x = self.reshape_fn_in(x) @@ -577,7 +611,7 @@ def encode_standard( # "... d_in, d_in d_sae -> ... d_sae", hidden_pre = self.hook_sae_acts_pre( - sae_in @ self.W_enc[:, latents_slice] + self.b_enc[latents_slice] + sae_in @ self.W_enc[:, latents_tensor] + self.b_enc[latents_tensor] ) feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre)) diff --git a/tests/unit/training/test_sae_basic.py b/tests/unit/training/test_sae_basic.py index 586671a8..b514d3fd 100644 --- a/tests/unit/training/test_sae_basic.py +++ b/tests/unit/training/test_sae_basic.py @@ -70,12 +70,14 @@ def test_sae_init(cfg: LanguageModelSAERunnerConfig): assert sae.b_dec.shape == (cfg.d_in,) -def test_sae_encode_with_slicing(cfg: LanguageModelSAERunnerConfig): +@pytest.mark.parametrize("architecture", ["standard", "gated", "jumprelu"]) +def test_sae_encode_with_different_architectures(architecture: str) -> None: + cfg = build_sae_cfg(architecture=architecture) sae = SAE.from_dict(cfg.get_base_sae_cfg_dict()) assert isinstance(cfg.d_sae, int) activations = torch.randn(10, 4, cfg.d_in, device=cfg.device) - latents = torch.randint(low=0, high=cfg.d_sae, size=(10,)).tolist() + latents = torch.randint(low=0, high=cfg.d_sae, size=(10,)) feature_activations = sae.encode(activations) feature_activations_slice = sae.encode(activations, latents=latents) torch.testing.assert_close( From 9fd5f2e0b81114a9410f5921ff24d6bd0f9fac13 Mon Sep 17 00:00:00 2001 From: callummcdougall Date: Thu, 17 Oct 2024 15:24:07 +0100 Subject: [PATCH 7/8] fix typing error --- sae_lens/sae.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sae_lens/sae.py b/sae_lens/sae.py index 0ee0c99e..111148ea 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -11,7 +11,7 @@ T = TypeVar("T", bound="SAE") import einops import torch -from jaxtyping import Float, Int +from jaxtyping import Float from safetensors.torch import save_file from torch import nn from transformer_lens.hook_points import HookedRootModule, HookPoint @@ -509,7 +509,7 @@ def encode( def encode_gated( self, x: Float[torch.Tensor, "... d_in"], - latents: Int[torch.Tensor, "latents"] | None = None, + latents: torch.Tensor | None = None, ) -> Float[torch.Tensor, "... d_sae"]: """ Computes the latent values of the Sparse Autoencoder (SAE) using a gated architecture. The activation values are @@ -548,7 +548,7 @@ def encode_gated( def encode_jumprelu( self, x: Float[torch.Tensor, "... d_in"], - latents: Int[torch.Tensor, "latents"] | None = None, + latents: torch.Tensor | None = None, ) -> Float[torch.Tensor, "... d_sae"]: """ Computes the latent values of the Sparse Autoencoder (SAE) using a gated architecture. The activation values are @@ -588,7 +588,7 @@ def encode_jumprelu( def encode_standard( self, x: Float[torch.Tensor, "... d_in"], - latents: Int[torch.Tensor, "latents"] | None = None, + latents: torch.Tensor | None = None, ) -> Float[torch.Tensor, "... d_sae"]: """ Computes the latent values of the Sparse Autoencoder (SAE) using a gated architecture. The activation values are From 393a43a490017fe6a84dae4a8cc6d8aed1e11fb4 Mon Sep 17 00:00:00 2001 From: callummcdougall Date: Thu, 17 Oct 2024 15:28:18 +0100 Subject: [PATCH 8/8] fix iterable type --- sae_lens/training/training_sae.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py index 0be5a652..94ea550c 100644 --- a/sae_lens/training/training_sae.py +++ b/sae_lens/training/training_sae.py @@ -5,7 +5,7 @@ import json import os from dataclasses import dataclass, fields -from typing import Any, Iterable, Optional +from typing import Any, Optional import einops import torch @@ -200,10 +200,11 @@ def check_cfg_compatibility(self): assert self.use_error_term is False, "Gated SAEs do not support error terms" def encode_standard( - self, x: Float[torch.Tensor, "... d_in"], latents: Iterable[int] | None = None + self, x: Float[torch.Tensor, "... d_in"], latents: torch.Tensor | None = None ) -> Float[torch.Tensor, "... d_sae"]: """ - Calcuate SAE features from inputs + Calcuate SAE features from inputs. The `latents` argument is ignored (this is just so the type signature matches + the parent class, which uses this argument to compute only a subset of the SAE hidden values) """ assert ( latents is None