generated from alan-cooney/transformer-lens-starter-template
-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
2,129 additions
and
1,407 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
164 changes: 164 additions & 0 deletions
164
sparse_autoencoder/autoencoder/components/linear_decoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
"""Linear decoder layer.""" | ||
import math | ||
from typing import final | ||
|
||
import einops | ||
from jaxtyping import Float, Int64 | ||
from pydantic import PositiveInt, validate_call | ||
import torch | ||
from torch import Tensor | ||
from torch.nn import Module, Parameter, init | ||
|
||
from sparse_autoencoder.autoencoder.types import ResetOptimizerParameterDetails | ||
from sparse_autoencoder.tensor_types import Axis | ||
from sparse_autoencoder.utils.tensor_shape import shape_with_optional_dimensions | ||
|
||
|
||
@final | ||
class LinearDecoder(Module): | ||
r"""Constrained unit norm linear decoder layer. | ||
Linear layer decoder, where the dictionary vectors (columns of the weight matrix) are NOT | ||
constrained to have unit norm. | ||
$$ \begin{align*} | ||
m &= \text{learned features dimension} \\ | ||
n &= \text{input and output dimension} \\ | ||
b &= \text{batch items dimension} \\ | ||
f \in \mathbb{R}^{b \times m} &= \text{encoder output} \\ | ||
W_d \in \mathbb{R}^{n \times m} &= \text{weight matrix} \\ | ||
z \in \mathbb{R}^{b \times m} &= f W_d^T = \text{UnitNormDecoder output (pre-tied bias)} | ||
\end{align*} $$ | ||
Motivation: | ||
TODO | ||
""" | ||
|
||
_learnt_features: int | ||
"""Number of learnt features (inputs to this layer).""" | ||
|
||
_decoded_features: int | ||
"""Number of decoded features (outputs from this layer).""" | ||
|
||
_n_components: int | None | ||
|
||
weight: Float[ | ||
Parameter, | ||
Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE, Axis.LEARNT_FEATURE), | ||
] | ||
"""Weight parameter. | ||
Each column in the weights matrix acts as a dictionary vector, representing a single basis | ||
element in the learned activation space. | ||
""" | ||
|
||
@property | ||
def reset_optimizer_parameter_details(self) -> list[ResetOptimizerParameterDetails]: | ||
"""Reset optimizer parameter details. | ||
Details of the parameters that should be reset in the optimizer, when resetting | ||
dictionary vectors. | ||
Returns: | ||
List of tuples of the form `(parameter, axis)`, where `parameter` is the parameter to | ||
reset (e.g. encoder.weight), and `axis` is the axis of the parameter to reset. | ||
""" | ||
return [ResetOptimizerParameterDetails(parameter=self.weight, axis=-1)] | ||
|
||
@validate_call | ||
def __init__( | ||
self, | ||
learnt_features: PositiveInt, | ||
decoded_features: PositiveInt, | ||
n_components: PositiveInt | None, | ||
) -> None: | ||
"""Initialize the constrained unit norm linear layer. | ||
Args: | ||
learnt_features: Number of learnt features in the autoencoder. | ||
decoded_features: Number of decoded (output) features in the autoencoder. | ||
n_components: Number of source model components the SAE is trained on. | ||
""" | ||
super().__init__() | ||
|
||
self._learnt_features = learnt_features | ||
self._decoded_features = decoded_features | ||
self._n_components = n_components | ||
|
||
# Create the linear layer as per the standard PyTorch linear layer | ||
self.weight = Parameter( | ||
torch.empty( | ||
shape_with_optional_dimensions(n_components, decoded_features, learnt_features), | ||
) | ||
) | ||
self.reset_parameters() | ||
|
||
def update_dictionary_vectors( | ||
self, | ||
dictionary_vector_indices: Int64[ | ||
Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE_IDX) | ||
], | ||
updated_weights: Float[ | ||
Tensor, | ||
Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE, Axis.LEARNT_FEATURE_IDX), | ||
], | ||
component_idx: int | None = None, | ||
) -> None: | ||
"""Update decoder dictionary vectors. | ||
Updates the dictionary vectors (rows in the weight matrix) with the given values. Typically | ||
this is used when resampling neurons (dictionary vectors) that have died. | ||
Args: | ||
dictionary_vector_indices: Indices of the dictionary vectors to update. | ||
updated_weights: Updated weights for just these dictionary vectors. | ||
component_idx: Component index to update. | ||
Raises: | ||
ValueError: If `component_idx` is not specified when `n_components` is not None. | ||
""" | ||
if dictionary_vector_indices.numel() == 0: | ||
return | ||
|
||
with torch.no_grad(): | ||
if component_idx is None: | ||
if self._n_components is not None: | ||
error_message = "component_idx must be specified when n_components is not None" | ||
raise ValueError(error_message) | ||
|
||
self.weight[:, dictionary_vector_indices] = updated_weights | ||
else: | ||
self.weight[component_idx, :, dictionary_vector_indices] = updated_weights | ||
|
||
def reset_parameters(self) -> None: | ||
"""Initialize or reset the parameters.""" | ||
# Assumes we are using ReLU activation function (for e.g. leaky ReLU, the `a` parameter and | ||
# `nonlinerity` must be changed. | ||
init.kaiming_uniform_(self.weight, nonlinearity="relu") | ||
|
||
def forward( | ||
self, x: Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)] | ||
) -> Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)]: | ||
"""Forward pass. | ||
Args: | ||
x: Input tensor. | ||
Returns: | ||
Output of the forward pass. | ||
""" | ||
return einops.einsum( | ||
x, | ||
self.weight, | ||
f"{Axis.BATCH} ... {Axis.LEARNT_FEATURE}, \ | ||
... {Axis.INPUT_OUTPUT_FEATURE} {Axis.LEARNT_FEATURE} \ | ||
-> {Axis.BATCH} ... {Axis.INPUT_OUTPUT_FEATURE}", | ||
) | ||
|
||
def extra_repr(self) -> str: | ||
"""String extra representation of the module.""" | ||
return ( | ||
f"learnt_features={self._learnt_features}, " | ||
f"decoded_features={self._decoded_features}, " | ||
f"n_components={self._n_components}" | ||
) |
229 changes: 229 additions & 0 deletions
229
sparse_autoencoder/autoencoder/components/tanh_encoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
"""Linear encoder layer with tanh(ReLU()) activation.""" | ||
import math | ||
from typing import final | ||
|
||
import einops | ||
from jaxtyping import Float, Int64 | ||
from pydantic import PositiveInt, validate_call | ||
import torch | ||
from torch import Tensor | ||
from torch.nn import Module, Parameter, ReLU, init, Tanh | ||
|
||
from sparse_autoencoder.autoencoder.types import ResetOptimizerParameterDetails | ||
from sparse_autoencoder.tensor_types import Axis | ||
from sparse_autoencoder.utils.tensor_shape import shape_with_optional_dimensions | ||
|
||
class TanhReLU(Module): | ||
def __init__(self): | ||
super(TanhReLU, self).__init__() | ||
self.tanh = Tanh() | ||
self.relu = ReLU() | ||
|
||
def forward(self, x): | ||
return self.tanh(self.relu(x)) | ||
|
||
@final | ||
class TanhEncoder(Module): | ||
r"""Linear encoder layer. | ||
Linear encoder layer (essentially `nn.Linear`, with a ReLU activation function). Designed to be | ||
used as the encoder in a sparse autoencoder (excluding any outer tied bias). | ||
$$ | ||
\begin{align*} | ||
m &= \text{learned features dimension} \\ | ||
n &= \text{input and output dimension} \\ | ||
b &= \text{batch items dimension} \\ | ||
\overline{\mathbf{x}} \in \mathbb{R}^{b \times n} &= \text{input after tied bias} \\ | ||
W_e \in \mathbb{R}^{m \times n} &= \text{weight matrix} \\ | ||
b_e \in \mathbb{R}^{m} &= \text{bias vector} \\ | ||
f &= \text{ReLU}(\overline{\mathbf{x}} W_e^T + b_e) = \text{LinearEncoder output} | ||
\end{align*} | ||
$$ | ||
""" | ||
|
||
_learnt_features: int | ||
"""Number of learnt features (inputs to this layer).""" | ||
|
||
_input_features: int | ||
"""Number of input features from the source model.""" | ||
|
||
_n_components: int | None | ||
|
||
weight: Float[ | ||
Parameter, | ||
Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE), | ||
] | ||
"""Weight parameter. | ||
Each row in the weights matrix acts as a dictionary vector, representing a single basis | ||
element in the learned activation space. | ||
""" | ||
|
||
bias: Float[Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)] | ||
"""Bias parameter.""" | ||
|
||
@property | ||
def reset_optimizer_parameter_details(self) -> list[ResetOptimizerParameterDetails]: | ||
"""Reset optimizer parameter details. | ||
Details of the parameters that should be reset in the optimizer, when resetting | ||
dictionary vectors. | ||
Returns: | ||
List of tuples of the form `(parameter, axis)`, where `parameter` is the parameter to | ||
reset (e.g. encoder.weight), and `axis` is the axis of the parameter to reset. | ||
""" | ||
return [ | ||
ResetOptimizerParameterDetails(parameter=self.weight, axis=-2), | ||
ResetOptimizerParameterDetails(parameter=self.bias, axis=-1), | ||
] | ||
|
||
activation_function: TanhReLU | ||
"""Activation function.""" | ||
|
||
@validate_call | ||
def __init__( | ||
self, | ||
input_features: PositiveInt, | ||
learnt_features: PositiveInt, | ||
n_components: PositiveInt | None, | ||
): | ||
"""Initialize the linear encoder layer. | ||
Args: | ||
input_features: Number of input features to the autoencoder. | ||
learnt_features: Number of learnt features in the autoencoder. | ||
n_components: Number of source model components the SAE is trained on. | ||
""" | ||
super().__init__() | ||
|
||
self._learnt_features = learnt_features | ||
self._input_features = input_features | ||
self._n_components = n_components | ||
|
||
self.weight = Parameter( | ||
torch.empty( | ||
shape_with_optional_dimensions(n_components, learnt_features, input_features), | ||
) | ||
) | ||
self.bias = Parameter( | ||
torch.zeros(shape_with_optional_dimensions(n_components, learnt_features)) | ||
) | ||
self.activation_function = TanhReLU() | ||
|
||
self.reset_parameters() | ||
|
||
def reset_parameters(self) -> None: | ||
"""Initialize or reset the parameters.""" | ||
# Assumes we are using ReLU activation function (for e.g. leaky ReLU, the `a` parameter and | ||
# `nonlinerity` must be changed. | ||
init.kaiming_uniform_(self.weight, nonlinearity="relu") | ||
|
||
# Bias (approach from nn.Linear) | ||
fan_in = self.weight.size(1) | ||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 | ||
init.uniform_(self.bias, -bound, bound) | ||
|
||
def forward( | ||
self, | ||
x: Float[ | ||
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) | ||
], | ||
) -> Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]: | ||
"""Forward pass. | ||
Args: | ||
x: Input tensor. | ||
Returns: | ||
Output of the forward pass. | ||
""" | ||
z = ( | ||
einops.einsum( | ||
x, | ||
self.weight, | ||
f"{Axis.BATCH} ... {Axis.INPUT_OUTPUT_FEATURE}, \ | ||
... {Axis.LEARNT_FEATURE} {Axis.INPUT_OUTPUT_FEATURE} \ | ||
-> {Axis.BATCH} ... {Axis.LEARNT_FEATURE}", | ||
) | ||
+ self.bias | ||
) | ||
|
||
return self.activation_function(z) | ||
|
||
@final | ||
def update_dictionary_vectors( | ||
self, | ||
dictionary_vector_indices: Int64[Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX)], | ||
updated_dictionary_weights: Float[ | ||
Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX, Axis.INPUT_OUTPUT_FEATURE) | ||
], | ||
component_idx: int | None = None, | ||
) -> None: | ||
"""Update encoder dictionary vectors. | ||
Updates the dictionary vectors (columns in the weight matrix) with the given values. | ||
Args: | ||
dictionary_vector_indices: Indices of the dictionary vectors to update. | ||
updated_dictionary_weights: Updated weights for just these dictionary vectors. | ||
component_idx: Component index to update. | ||
Raises: | ||
ValueError: If there are multiple components and `component_idx` is not specified. | ||
""" | ||
if dictionary_vector_indices.numel() == 0: | ||
return | ||
|
||
with torch.no_grad(): | ||
if component_idx is None: | ||
if self._n_components is not None: | ||
error_message = "component_idx must be specified when n_components is not None" | ||
raise ValueError(error_message) | ||
|
||
self.weight[dictionary_vector_indices] = updated_dictionary_weights | ||
else: | ||
self.weight[component_idx, dictionary_vector_indices] = updated_dictionary_weights | ||
|
||
@final | ||
def update_bias( | ||
self, | ||
update_parameter_indices: Int64[ | ||
Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE_IDX) | ||
], | ||
updated_bias_features: Float[ | ||
Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE_IDX) | ||
], | ||
component_idx: int | None = None, | ||
) -> None: | ||
"""Update encoder bias. | ||
Args: | ||
update_parameter_indices: Indices of the bias features to update. | ||
updated_bias_features: Updated bias features for just these indices. | ||
component_idx: Component index to update. | ||
Raises: | ||
ValueError: If there are multiple components and `component_idx` is not specified. | ||
""" | ||
if update_parameter_indices.numel() == 0: | ||
return | ||
|
||
with torch.no_grad(): | ||
if component_idx is None: | ||
if self._n_components is not None: | ||
error_message = "component_idx must be specified when n_components is not None" | ||
raise ValueError(error_message) | ||
|
||
self.bias[update_parameter_indices] = updated_bias_features | ||
else: | ||
self.bias[component_idx, update_parameter_indices] = updated_bias_features | ||
|
||
def extra_repr(self) -> str: | ||
"""String extra representation of the module.""" | ||
return ( | ||
f"input_features={self._input_features}, " | ||
f"learnt_features={self._learnt_features}, " | ||
f"n_components={self._n_components}" | ||
) |
Oops, something went wrong.