Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encode with slice #334

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 72 additions & 38 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,12 @@
import os
import warnings
from dataclasses import dataclass, field
from typing import (
Any,
Callable,
Iterable,
Literal,
Optional,
Tuple,
TypeVar,
Union,
overload,
)
from typing import Any, Callable, Literal, Optional, Tuple, TypeVar, Union, overload

T = TypeVar("T", bound="SAE")
import einops
import torch
from jaxtyping import Float
from jaxtyping import Float, Int
from safetensors.torch import save_file
from torch import nn
from transformer_lens.hook_points import HookedRootModule, HookPoint
Expand Down Expand Up @@ -164,17 +154,8 @@ def __init__(
self.device = torch.device(cfg.device)
self.use_error_term = use_error_term

if self.cfg.architecture == "standard":
self.initialize_weights_basic()
self.encode = self.encode_standard
elif self.cfg.architecture == "gated":
self.initialize_weights_gated()
self.encode = self.encode_gated
elif self.cfg.architecture == "jumprelu":
self.initialize_weights_jumprelu()
self.encode = self.encode_jumprelu
else:
raise (ValueError)
if self.cfg.architecture not in ["standard", "gated", "jumprelu"]:
Copy link
Collaborator

@chanind chanind Oct 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like tests are failing because initialize_weights() have been moved into their own function, but that function is never called now. IMO this error can be raised in the new initialize_weights() method instead.

raise ValueError(f"Architecture {self.cfg.architecture} not supported")

# handle presence / absence of scaling factor.
if self.cfg.finetuning_scaling_factor:
Expand Down Expand Up @@ -243,6 +224,16 @@ def run_time_activation_ln_out(x: torch.Tensor, eps: float = 1e-5):

self.setup() # Required for `HookedRootModule`s

def initialize_weights(self):
if self.cfg.architecture == "standard":
self.initialize_weights_basic()
elif self.cfg.architecture == "gated":
self.initialize_weights_gated()
elif self.cfg.architecture == "jumprelu":
self.initialize_weights_jumprelu()
else:
raise (ValueError)

def initialize_weights_basic(self):
# no config changes encoder bias init for now.
self.b_enc = nn.Parameter(
Expand Down Expand Up @@ -497,13 +488,39 @@ def forward(

return self.hook_sae_output(sae_out)

def encode(
self, x: torch.Tensor, latents: torch.Tensor | None = None
) -> torch.Tensor:
"""
Calculate SAE latents from inputs. Includes optional `latents` argument to only calculate a subset. Note that
this won't make sense for topk SAEs, because we need to compute all hidden values to apply the topk masking.
"""
if self.cfg.activation_fn_str == "topk":
assert (
latents is None
), "Computing a slice of SAE hidden values doesn't make sense in topk SAEs."

return {
"standard": self.encode_standard,
"gated": self.encode_gated,
"jumprelu": self.encode_jumprelu,
}[self.cfg.architecture](x, latents)

def encode_gated(
self, x: Float[torch.Tensor, "... d_in"], latents: Iterable[int] | None = None
self,
x: Float[torch.Tensor, "... d_in"],
latents: Int[torch.Tensor, "latents"] | None = None,
) -> Float[torch.Tensor, "... d_sae"]:
"""
Calculate SAE features from inputs
Computes the latent values of the Sparse Autoencoder (SAE) using a gated architecture. The activation values are
computed as the product of the masking term & the post-activation function magnitude term:

1[(x - b_dec) @ W_gate + b_gate > 0] * activation_fn((x - b_dec) @ W_enc + b_enc)

The `latents` argument allows for the computation of a specific subset of the hidden values. If `latents` is not
provided, all latent values will be computed.
"""
latents_slice = slice(None) if latents is None else torch.tensor(latents)
latents_tensor = torch.arange(self.cfg.d_sae) if latents is None else latents
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous implementation you had, latents_slice = slice(None) if latents is None else torch.tensor(latents), seems better to me since this new version will create a new tensor of size d_sae on every SAE forward pass when not selecting specific latents. This would likely reduce performance for most users, which seems counter-productive since this PR is just meant to be a performance improvement if I understand the goal correctly. Wouldn't the old implementation have worked fine just adding torch.Tensor to the type of latents? e.g. latents: Iterable[int] | torch.Tensor | None = None


x = x.to(self.dtype)
x = self.reshape_fn_in(x)
Expand All @@ -513,14 +530,14 @@ def encode_gated(

# Gating path
gating_pre_activation = (
sae_in @ self.W_enc[:, latents_slice] + self.b_gate[latents_slice]
sae_in @ self.W_enc[:, latents_tensor] + self.b_gate[latents_tensor]
)
active_features = (gating_pre_activation > 0).to(self.dtype)

# Magnitude path with weight sharing
magnitude_pre_activation = self.hook_sae_acts_pre(
sae_in @ (self.W_enc[:, latents_slice] * self.r_mag[latents_slice].exp())
+ self.b_mag[latents_slice]
sae_in @ (self.W_enc[:, latents_tensor] * self.r_mag[latents_tensor].exp())
+ self.b_mag[latents_tensor]
)
feature_magnitudes = self.activation_fn(magnitude_pre_activation)

Expand All @@ -529,12 +546,20 @@ def encode_gated(
return feature_acts

def encode_jumprelu(
self, x: Float[torch.Tensor, "... d_in"], latents: Iterable[int] | None = None
self,
x: Float[torch.Tensor, "... d_in"],
latents: Int[torch.Tensor, "latents"] | None = None,
) -> Float[torch.Tensor, "... d_sae"]:
"""
Calculate SAE features from inputs
Computes the latent values of the Sparse Autoencoder (SAE) using a gated architecture. The activation values are
computed as:

activation_fn((x - b_dec) @ W_enc + b_enc) * 1[(x - b_dec) @ W_enc + b_enc > threshold]

The `latents` argument allows for the computation of a specific subset of the hidden values. If `latents` is not
provided, all latent values will be computed.
"""
latents_slice = slice(None) if latents is None else torch.tensor(latents)
latents_tensor = torch.arange(self.cfg.d_sae) if latents is None else latents

# move x to correct dtype
x = x.to(self.dtype)
Expand All @@ -550,22 +575,31 @@ def encode_jumprelu(

# "... d_in, d_in d_sae -> ... d_sae",
hidden_pre = self.hook_sae_acts_pre(
sae_in @ self.W_enc[:, latents_slice] + self.b_enc[latents_slice]
sae_in @ self.W_enc[:, latents_tensor] + self.b_enc[latents_tensor]
)

feature_acts = self.hook_sae_acts_post(
self.activation_fn(hidden_pre) * (hidden_pre > self.threshold)
self.activation_fn(hidden_pre)
* (hidden_pre > self.threshold[latents_tensor])
)

return feature_acts

def encode_standard(
self, x: Float[torch.Tensor, "... d_in"], latents: Iterable[int] | None = None
self,
x: Float[torch.Tensor, "... d_in"],
latents: Int[torch.Tensor, "latents"] | None = None,
) -> Float[torch.Tensor, "... d_sae"]:
"""
Calculate SAE features from inputs
Computes the latent values of the Sparse Autoencoder (SAE) using a gated architecture. The activation values are
computed as:

activation_fn((x - b_dec) @ W_enc + b_enc)

The `latents` argument allows for the computation of a specific subset of the hidden values. If `latents` is not
provided, all latent values will be computed.
"""
latents_slice = slice(None) if latents is None else torch.tensor(latents)
latents_tensor = torch.arange(self.cfg.d_sae) if latents is None else latents

x = x.to(self.dtype)
x = self.reshape_fn_in(x)
Expand All @@ -577,7 +611,7 @@ def encode_standard(

# "... d_in, d_in d_sae -> ... d_sae",
hidden_pre = self.hook_sae_acts_pre(
sae_in @ self.W_enc[:, latents_slice] + self.b_enc[latents_slice]
sae_in @ self.W_enc[:, latents_tensor] + self.b_enc[latents_tensor]
)
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))

Expand Down
6 changes: 4 additions & 2 deletions tests/unit/training/test_sae_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@ def test_sae_init(cfg: LanguageModelSAERunnerConfig):
assert sae.b_dec.shape == (cfg.d_in,)


def test_sae_encode_with_slicing(cfg: LanguageModelSAERunnerConfig):
@pytest.mark.parametrize("architecture", ["standard", "gated", "jumprelu"])
def test_sae_encode_with_different_architectures(architecture: str) -> None:
cfg = build_sae_cfg(architecture=architecture)
sae = SAE.from_dict(cfg.get_base_sae_cfg_dict())
assert isinstance(cfg.d_sae, int)

activations = torch.randn(10, 4, cfg.d_in, device=cfg.device)
latents = torch.randint(low=0, high=cfg.d_sae, size=(10,)).tolist()
latents = torch.randint(low=0, high=cfg.d_sae, size=(10,))
feature_activations = sae.encode(activations)
feature_activations_slice = sae.encode(activations, latents=latents)
torch.testing.assert_close(
Expand Down
Loading