Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encode with slice #334

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,17 @@
import os
import warnings
from dataclasses import dataclass, field
from typing import Any, Callable, Literal, Optional, Tuple, TypeVar, Union, overload
from typing import (
Any,
Callable,
Iterable,
Literal,
Optional,
Tuple,
TypeVar,
Union,
overload,
)

T = TypeVar("T", bound="SAE")
import einops
Expand Down Expand Up @@ -67,7 +77,6 @@ class SAEConfig:

@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig":

# rename dict:
rename_dict = { # old : new
"hook_point": "hook_name",
Expand Down Expand Up @@ -196,7 +205,6 @@ def __init__(

# handle run time activation normalization if needed:
if self.cfg.normalize_activations == "constant_norm_rescale":

# we need to scale the norm of the input and store the scaling factor
def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
self.x_norm_coeff = (self.cfg.d_in**0.5) / x.norm(dim=-1, keepdim=True)
Expand All @@ -212,7 +220,6 @@ def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor: #
self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out

elif self.cfg.normalize_activations == "layer_norm":

# we need to scale the norm of the input and store the scaling factor
def run_time_activation_ln_in(
x: torch.Tensor, eps: float = 1e-5
Expand All @@ -237,7 +244,6 @@ def run_time_activation_ln_out(x: torch.Tensor, eps: float = 1e-5):
self.setup() # Required for `HookedRootModule`s

def initialize_weights_basic(self):

# no config changes encoder bias init for now.
self.b_enc = nn.Parameter(
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
Expand Down Expand Up @@ -492,8 +498,12 @@ def forward(
return self.hook_sae_output(sae_out)

def encode_gated(
self, x: Float[torch.Tensor, "... d_in"]
self, x: Float[torch.Tensor, "... d_in"], latents: Iterable[int] | None = None
) -> Float[torch.Tensor, "... d_sae"]:
"""
Calculate SAE features from inputs
"""
latents_slice = slice(None) if latents is None else torch.tensor(latents)

x = x.to(self.dtype)
x = self.reshape_fn_in(x)
Expand All @@ -502,12 +512,15 @@ def encode_gated(
sae_in = x - self.b_dec * self.cfg.apply_b_dec_to_input

# Gating path
gating_pre_activation = sae_in @ self.W_enc + self.b_gate
gating_pre_activation = (
sae_in @ self.W_enc[:, latents_slice] + self.b_gate[latents_slice]
)
active_features = (gating_pre_activation > 0).to(self.dtype)

# Magnitude path with weight sharing
magnitude_pre_activation = self.hook_sae_acts_pre(
sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
sae_in @ (self.W_enc[:, latents_slice] * self.r_mag[latents_slice].exp())
+ self.b_mag[latents_slice]
)
feature_magnitudes = self.activation_fn(magnitude_pre_activation)

Expand All @@ -516,11 +529,12 @@ def encode_gated(
return feature_acts

def encode_jumprelu(
self, x: Float[torch.Tensor, "... d_in"]
self, x: Float[torch.Tensor, "... d_in"], latents: Iterable[int] | None = None
) -> Float[torch.Tensor, "... d_sae"]:
"""
Calculate SAE features from inputs
"""
latents_slice = slice(None) if latents is None else torch.tensor(latents)

# move x to correct dtype
x = x.to(self.dtype)
Expand All @@ -535,7 +549,9 @@ def encode_jumprelu(
sae_in = self.hook_sae_input(x - (self.b_dec * self.cfg.apply_b_dec_to_input))

# "... d_in, d_in d_sae -> ... d_sae",
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
hidden_pre = self.hook_sae_acts_pre(
sae_in @ self.W_enc[:, latents_slice] + self.b_enc[latents_slice]
)

feature_acts = self.hook_sae_acts_post(
self.activation_fn(hidden_pre) * (hidden_pre > self.threshold)
Expand All @@ -544,11 +560,12 @@ def encode_jumprelu(
return feature_acts

def encode_standard(
self, x: Float[torch.Tensor, "... d_in"]
self, x: Float[torch.Tensor, "... d_in"], latents: Iterable[int] | None = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should this also take a tensor as well? It seems like the code will work with a tensor of ints as well

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also looks like topk SAEs go through this codepath and will silently break if anything is passed for latents. We should make it non-silent.

) -> Float[torch.Tensor, "... d_sae"]:
"""
Calculate SAE features from inputs
"""
latents_slice = slice(None) if latents is None else torch.tensor(latents)

x = x.to(self.dtype)
x = self.reshape_fn_in(x)
Expand All @@ -559,7 +576,9 @@ def encode_standard(
sae_in = x - (self.b_dec * self.cfg.apply_b_dec_to_input)

# "... d_in, d_in d_sae -> ... d_sae",
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
hidden_pre = self.hook_sae_acts_pre(
sae_in @ self.W_enc[:, latents_slice] + self.b_enc[latents_slice]
)
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))

return feature_acts
Expand Down Expand Up @@ -606,7 +625,6 @@ def fold_activation_norm_scaling_factor(
self.cfg.normalize_activations = "none"

def save_model(self, path: str, sparsity: Optional[torch.Tensor] = None):

if not os.path.exists(path):
os.mkdir(path)

Expand All @@ -627,7 +645,6 @@ def save_model(self, path: str, sparsity: Optional[torch.Tensor] = None):
def load_from_pretrained(
cls, path: str, device: str = "cpu", dtype: str | None = None
) -> "SAE":

# get the config
config_path = os.path.join(path, SAE_CFG_PATH)
with open(config_path, "r") as f:
Expand Down Expand Up @@ -752,7 +769,6 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "SAE":
return cls(SAEConfig.from_dict(config_dict))

def turn_on_forward_pass_hook_z_reshaping(self):

assert self.cfg.hook_name.endswith(
"_z"
), "This method should only be called for hook_z SAEs."
Expand Down
18 changes: 5 additions & 13 deletions sae_lens/training/training_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import os
from dataclasses import dataclass, fields
from typing import Any, Optional
from typing import Any, Iterable, Optional

import einops
import torch
Expand Down Expand Up @@ -38,7 +38,6 @@ class TrainStepOutput:

@dataclass(kw_only=True)
class TrainingSAEConfig(SAEConfig):

# Sparsity Loss Calculations
l1_coefficient: float
lp_norm: float
Expand All @@ -55,7 +54,6 @@ class TrainingSAEConfig(SAEConfig):
def from_sae_runner_config(
cls, cfg: LanguageModelSAERunnerConfig
) -> "TrainingSAEConfig":

return cls(
# base config
architecture=cfg.architecture,
Expand Down Expand Up @@ -168,7 +166,6 @@ class TrainingSAE(SAE):
device: torch.device

def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):

base_sae_cfg = SAEConfig.from_dict(cfg.get_base_sae_cfg_dict())
super().__init__(base_sae_cfg)
self.cfg = cfg # type: ignore
Expand Down Expand Up @@ -203,18 +200,20 @@ def check_cfg_compatibility(self):
assert self.use_error_term is False, "Gated SAEs do not support error terms"

def encode_standard(
self, x: Float[torch.Tensor, "... d_in"]
self, x: Float[torch.Tensor, "... d_in"], latents: Iterable[int] | None = None
) -> Float[torch.Tensor, "... d_sae"]:
"""
Calcuate SAE features from inputs
"""
assert (
latents is None
), "Function `encode_standard` in training should always return activations for all latents"
feature_acts, _ = self.encode_with_hidden_pre_fn(x)
return feature_acts

def encode_with_hidden_pre(
self, x: Float[torch.Tensor, "... d_in"]
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:

x = x.to(self.dtype)
x = self.reshape_fn_in(x) # type: ignore
x = self.hook_sae_input(x)
Expand All @@ -235,7 +234,6 @@ def encode_with_hidden_pre(
def encode_with_hidden_pre_gated(
self, x: Float[torch.Tensor, "... d_in"]
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:

x = x.to(self.dtype)
x = self.reshape_fn_in(x) # type: ignore
x = self.hook_sae_input(x)
Expand Down Expand Up @@ -267,7 +265,6 @@ def forward(
self,
x: Float[torch.Tensor, "... d_in"],
) -> Float[torch.Tensor, "... d_in"]:

feature_acts, _ = self.encode_with_hidden_pre_fn(x)
sae_out = self.decode(feature_acts)

Expand All @@ -279,7 +276,6 @@ def training_forward_pass(
current_l1_coefficient: float,
dead_neuron_mask: Optional[torch.Tensor] = None,
) -> TrainStepOutput:

# do a forward pass to get SAE out, but we also need the
# hidden pre.
feature_acts, _ = self.encode_with_hidden_pre_fn(sae_in)
Expand All @@ -291,7 +287,6 @@ def training_forward_pass(

# GHOST GRADS
if self.cfg.use_ghost_grads and self.training and dead_neuron_mask is not None:

# first half of second forward pass
_, hidden_pre = self.encode_with_hidden_pre_fn(sae_in)
ghost_grad_loss = self.calculate_ghost_grad_loss(
Expand Down Expand Up @@ -362,7 +357,6 @@ def calculate_ghost_grad_loss(
hidden_pre: torch.Tensor,
dead_neuron_mask: torch.Tensor,
) -> torch.Tensor:

# 1.
residual = x - sae_out
l2_norm_residual = torch.norm(residual, dim=-1)
Expand Down Expand Up @@ -394,7 +388,6 @@ def calculate_ghost_grad_loss(

@torch.no_grad()
def _get_mse_loss_fn(self) -> Any:

def standard_mse_loss_fn(
preds: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
Expand All @@ -421,7 +414,6 @@ def load_from_pretrained(
device: str = "cpu",
dtype: str | None = None,
) -> "TrainingSAE":

# get the config
config_path = os.path.join(path, SAE_CFG_PATH)
with open(config_path, "r") as f:
Expand Down
14 changes: 13 additions & 1 deletion tests/unit/training/test_sae_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,19 @@ def test_sae_init(cfg: LanguageModelSAERunnerConfig):
assert sae.b_dec.shape == (cfg.d_in,)


def test_sae_encode_with_slicing(cfg: LanguageModelSAERunnerConfig):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It would be good to test with pytest.mark.parametrize with all architectures, since it seems like it's easy to accidentally mess up the implementation in one architecture and have it slip through the cracks. It could make sense to add a second test calling build_sae_cfg() directly so we can ensure we're hitting every architecture variant explicitly rather than just the pre-defined SAEs in the cfg fixture.

sae = SAE.from_dict(cfg.get_base_sae_cfg_dict())
assert isinstance(cfg.d_sae, int)

activations = torch.randn(10, 4, cfg.d_in, device=cfg.device)
latents = torch.randint(low=0, high=cfg.d_sae, size=(10,)).tolist()
feature_activations = sae.encode(activations)
feature_activations_slice = sae.encode(activations, latents=latents)
torch.testing.assert_close(
feature_activations[..., latents], feature_activations_slice
)


def test_sae_fold_w_dec_norm(cfg: LanguageModelSAERunnerConfig):
sae = SAE.from_dict(cfg.get_base_sae_cfg_dict())
sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here.
Expand Down Expand Up @@ -106,7 +119,6 @@ def test_sae_fold_w_dec_norm(cfg: LanguageModelSAERunnerConfig):


def test_sae_fold_norm_scaling_factor(cfg: LanguageModelSAERunnerConfig):

norm_scaling_factor = 3.0

sae = SAE.from_dict(cfg.get_base_sae_cfg_dict())
Expand Down
Loading