-
Notifications
You must be signed in to change notification settings - Fork 148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Encode with slice #334
base: main
Are you sure you want to change the base?
Encode with slice #334
Changes from 1 commit
5f17c0e
27933f7
64c6fc4
cbda3a0
aef2fc1
e1d105a
9fd5f2e
393a43a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,22 +6,12 @@ | |
import os | ||
import warnings | ||
from dataclasses import dataclass, field | ||
from typing import ( | ||
Any, | ||
Callable, | ||
Iterable, | ||
Literal, | ||
Optional, | ||
Tuple, | ||
TypeVar, | ||
Union, | ||
overload, | ||
) | ||
from typing import Any, Callable, Literal, Optional, Tuple, TypeVar, Union, overload | ||
|
||
T = TypeVar("T", bound="SAE") | ||
import einops | ||
import torch | ||
from jaxtyping import Float | ||
from jaxtyping import Float, Int | ||
from safetensors.torch import save_file | ||
from torch import nn | ||
from transformer_lens.hook_points import HookedRootModule, HookPoint | ||
|
@@ -164,17 +154,8 @@ def __init__( | |
self.device = torch.device(cfg.device) | ||
self.use_error_term = use_error_term | ||
|
||
if self.cfg.architecture == "standard": | ||
self.initialize_weights_basic() | ||
self.encode = self.encode_standard | ||
elif self.cfg.architecture == "gated": | ||
self.initialize_weights_gated() | ||
self.encode = self.encode_gated | ||
elif self.cfg.architecture == "jumprelu": | ||
self.initialize_weights_jumprelu() | ||
self.encode = self.encode_jumprelu | ||
else: | ||
raise (ValueError) | ||
if self.cfg.architecture not in ["standard", "gated", "jumprelu"]: | ||
raise ValueError(f"Architecture {self.cfg.architecture} not supported") | ||
|
||
# handle presence / absence of scaling factor. | ||
if self.cfg.finetuning_scaling_factor: | ||
|
@@ -243,6 +224,16 @@ def run_time_activation_ln_out(x: torch.Tensor, eps: float = 1e-5): | |
|
||
self.setup() # Required for `HookedRootModule`s | ||
|
||
def initialize_weights(self): | ||
if self.cfg.architecture == "standard": | ||
self.initialize_weights_basic() | ||
elif self.cfg.architecture == "gated": | ||
self.initialize_weights_gated() | ||
elif self.cfg.architecture == "jumprelu": | ||
self.initialize_weights_jumprelu() | ||
else: | ||
raise (ValueError) | ||
|
||
def initialize_weights_basic(self): | ||
# no config changes encoder bias init for now. | ||
self.b_enc = nn.Parameter( | ||
|
@@ -497,13 +488,39 @@ def forward( | |
|
||
return self.hook_sae_output(sae_out) | ||
|
||
def encode( | ||
self, x: torch.Tensor, latents: torch.Tensor | None = None | ||
) -> torch.Tensor: | ||
""" | ||
Calculate SAE latents from inputs. Includes optional `latents` argument to only calculate a subset. Note that | ||
this won't make sense for topk SAEs, because we need to compute all hidden values to apply the topk masking. | ||
""" | ||
if self.cfg.activation_fn_str == "topk": | ||
assert ( | ||
latents is None | ||
), "Computing a slice of SAE hidden values doesn't make sense in topk SAEs." | ||
|
||
return { | ||
"standard": self.encode_standard, | ||
"gated": self.encode_gated, | ||
"jumprelu": self.encode_jumprelu, | ||
}[self.cfg.architecture](x, latents) | ||
|
||
def encode_gated( | ||
self, x: Float[torch.Tensor, "... d_in"], latents: Iterable[int] | None = None | ||
self, | ||
x: Float[torch.Tensor, "... d_in"], | ||
latents: Int[torch.Tensor, "latents"] | None = None, | ||
) -> Float[torch.Tensor, "... d_sae"]: | ||
""" | ||
Calculate SAE features from inputs | ||
Computes the latent values of the Sparse Autoencoder (SAE) using a gated architecture. The activation values are | ||
computed as the product of the masking term & the post-activation function magnitude term: | ||
|
||
1[(x - b_dec) @ W_gate + b_gate > 0] * activation_fn((x - b_dec) @ W_enc + b_enc) | ||
|
||
The `latents` argument allows for the computation of a specific subset of the hidden values. If `latents` is not | ||
provided, all latent values will be computed. | ||
""" | ||
latents_slice = slice(None) if latents is None else torch.tensor(latents) | ||
latents_tensor = torch.arange(self.cfg.d_sae) if latents is None else latents | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The previous implementation you had, |
||
|
||
x = x.to(self.dtype) | ||
x = self.reshape_fn_in(x) | ||
|
@@ -513,14 +530,14 @@ def encode_gated( | |
|
||
# Gating path | ||
gating_pre_activation = ( | ||
sae_in @ self.W_enc[:, latents_slice] + self.b_gate[latents_slice] | ||
sae_in @ self.W_enc[:, latents_tensor] + self.b_gate[latents_tensor] | ||
) | ||
active_features = (gating_pre_activation > 0).to(self.dtype) | ||
|
||
# Magnitude path with weight sharing | ||
magnitude_pre_activation = self.hook_sae_acts_pre( | ||
sae_in @ (self.W_enc[:, latents_slice] * self.r_mag[latents_slice].exp()) | ||
+ self.b_mag[latents_slice] | ||
sae_in @ (self.W_enc[:, latents_tensor] * self.r_mag[latents_tensor].exp()) | ||
+ self.b_mag[latents_tensor] | ||
) | ||
feature_magnitudes = self.activation_fn(magnitude_pre_activation) | ||
|
||
|
@@ -529,12 +546,20 @@ def encode_gated( | |
return feature_acts | ||
|
||
def encode_jumprelu( | ||
self, x: Float[torch.Tensor, "... d_in"], latents: Iterable[int] | None = None | ||
self, | ||
x: Float[torch.Tensor, "... d_in"], | ||
latents: Int[torch.Tensor, "latents"] | None = None, | ||
) -> Float[torch.Tensor, "... d_sae"]: | ||
""" | ||
Calculate SAE features from inputs | ||
Computes the latent values of the Sparse Autoencoder (SAE) using a gated architecture. The activation values are | ||
computed as: | ||
|
||
activation_fn((x - b_dec) @ W_enc + b_enc) * 1[(x - b_dec) @ W_enc + b_enc > threshold] | ||
|
||
The `latents` argument allows for the computation of a specific subset of the hidden values. If `latents` is not | ||
provided, all latent values will be computed. | ||
""" | ||
latents_slice = slice(None) if latents is None else torch.tensor(latents) | ||
latents_tensor = torch.arange(self.cfg.d_sae) if latents is None else latents | ||
|
||
# move x to correct dtype | ||
x = x.to(self.dtype) | ||
|
@@ -550,22 +575,31 @@ def encode_jumprelu( | |
|
||
# "... d_in, d_in d_sae -> ... d_sae", | ||
hidden_pre = self.hook_sae_acts_pre( | ||
sae_in @ self.W_enc[:, latents_slice] + self.b_enc[latents_slice] | ||
sae_in @ self.W_enc[:, latents_tensor] + self.b_enc[latents_tensor] | ||
) | ||
|
||
feature_acts = self.hook_sae_acts_post( | ||
self.activation_fn(hidden_pre) * (hidden_pre > self.threshold) | ||
self.activation_fn(hidden_pre) | ||
* (hidden_pre > self.threshold[latents_tensor]) | ||
) | ||
|
||
return feature_acts | ||
|
||
def encode_standard( | ||
self, x: Float[torch.Tensor, "... d_in"], latents: Iterable[int] | None = None | ||
self, | ||
x: Float[torch.Tensor, "... d_in"], | ||
latents: Int[torch.Tensor, "latents"] | None = None, | ||
) -> Float[torch.Tensor, "... d_sae"]: | ||
""" | ||
Calculate SAE features from inputs | ||
Computes the latent values of the Sparse Autoencoder (SAE) using a gated architecture. The activation values are | ||
computed as: | ||
|
||
activation_fn((x - b_dec) @ W_enc + b_enc) | ||
|
||
The `latents` argument allows for the computation of a specific subset of the hidden values. If `latents` is not | ||
provided, all latent values will be computed. | ||
""" | ||
latents_slice = slice(None) if latents is None else torch.tensor(latents) | ||
latents_tensor = torch.arange(self.cfg.d_sae) if latents is None else latents | ||
|
||
x = x.to(self.dtype) | ||
x = self.reshape_fn_in(x) | ||
|
@@ -577,7 +611,7 @@ def encode_standard( | |
|
||
# "... d_in, d_in d_sae -> ... d_sae", | ||
hidden_pre = self.hook_sae_acts_pre( | ||
sae_in @ self.W_enc[:, latents_slice] + self.b_enc[latents_slice] | ||
sae_in @ self.W_enc[:, latents_tensor] + self.b_enc[latents_tensor] | ||
) | ||
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre)) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like tests are failing because
initialize_weights()
have been moved into their own function, but that function is never called now. IMO this error can be raised in the newinitialize_weights()
method instead.