Skip to content

Commit cf79a57

Browse files
committed
Added new architecture.
1 parent 0d93614 commit cf79a57

File tree

7 files changed

+2129
-1407
lines changed

7 files changed

+2129
-1407
lines changed

docs/content/demo.ipynb

Lines changed: 224 additions & 7 deletions
Large diffs are not rendered by default.

poetry.lock

Lines changed: 1465 additions & 1386 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""Linear decoder layer."""
2+
import math
3+
from typing import final
4+
5+
import einops
6+
from jaxtyping import Float, Int64
7+
from pydantic import PositiveInt, validate_call
8+
import torch
9+
from torch import Tensor
10+
from torch.nn import Module, Parameter, init
11+
12+
from sparse_autoencoder.autoencoder.types import ResetOptimizerParameterDetails
13+
from sparse_autoencoder.tensor_types import Axis
14+
from sparse_autoencoder.utils.tensor_shape import shape_with_optional_dimensions
15+
16+
17+
@final
18+
class LinearDecoder(Module):
19+
r"""Constrained unit norm linear decoder layer.
20+
21+
Linear layer decoder, where the dictionary vectors (columns of the weight matrix) are NOT
22+
constrained to have unit norm.
23+
24+
$$ \begin{align*}
25+
m &= \text{learned features dimension} \\
26+
n &= \text{input and output dimension} \\
27+
b &= \text{batch items dimension} \\
28+
f \in \mathbb{R}^{b \times m} &= \text{encoder output} \\
29+
W_d \in \mathbb{R}^{n \times m} &= \text{weight matrix} \\
30+
z \in \mathbb{R}^{b \times m} &= f W_d^T = \text{UnitNormDecoder output (pre-tied bias)}
31+
\end{align*} $$
32+
33+
Motivation:
34+
TODO
35+
"""
36+
37+
_learnt_features: int
38+
"""Number of learnt features (inputs to this layer)."""
39+
40+
_decoded_features: int
41+
"""Number of decoded features (outputs from this layer)."""
42+
43+
_n_components: int | None
44+
45+
weight: Float[
46+
Parameter,
47+
Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE, Axis.LEARNT_FEATURE),
48+
]
49+
"""Weight parameter.
50+
51+
Each column in the weights matrix acts as a dictionary vector, representing a single basis
52+
element in the learned activation space.
53+
"""
54+
55+
@property
56+
def reset_optimizer_parameter_details(self) -> list[ResetOptimizerParameterDetails]:
57+
"""Reset optimizer parameter details.
58+
59+
Details of the parameters that should be reset in the optimizer, when resetting
60+
dictionary vectors.
61+
62+
Returns:
63+
List of tuples of the form `(parameter, axis)`, where `parameter` is the parameter to
64+
reset (e.g. encoder.weight), and `axis` is the axis of the parameter to reset.
65+
"""
66+
return [ResetOptimizerParameterDetails(parameter=self.weight, axis=-1)]
67+
68+
@validate_call
69+
def __init__(
70+
self,
71+
learnt_features: PositiveInt,
72+
decoded_features: PositiveInt,
73+
n_components: PositiveInt | None,
74+
) -> None:
75+
"""Initialize the constrained unit norm linear layer.
76+
77+
Args:
78+
learnt_features: Number of learnt features in the autoencoder.
79+
decoded_features: Number of decoded (output) features in the autoencoder.
80+
n_components: Number of source model components the SAE is trained on.
81+
"""
82+
super().__init__()
83+
84+
self._learnt_features = learnt_features
85+
self._decoded_features = decoded_features
86+
self._n_components = n_components
87+
88+
# Create the linear layer as per the standard PyTorch linear layer
89+
self.weight = Parameter(
90+
torch.empty(
91+
shape_with_optional_dimensions(n_components, decoded_features, learnt_features),
92+
)
93+
)
94+
self.reset_parameters()
95+
96+
def update_dictionary_vectors(
97+
self,
98+
dictionary_vector_indices: Int64[
99+
Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE_IDX)
100+
],
101+
updated_weights: Float[
102+
Tensor,
103+
Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE, Axis.LEARNT_FEATURE_IDX),
104+
],
105+
component_idx: int | None = None,
106+
) -> None:
107+
"""Update decoder dictionary vectors.
108+
109+
Updates the dictionary vectors (rows in the weight matrix) with the given values. Typically
110+
this is used when resampling neurons (dictionary vectors) that have died.
111+
112+
Args:
113+
dictionary_vector_indices: Indices of the dictionary vectors to update.
114+
updated_weights: Updated weights for just these dictionary vectors.
115+
component_idx: Component index to update.
116+
117+
Raises:
118+
ValueError: If `component_idx` is not specified when `n_components` is not None.
119+
"""
120+
if dictionary_vector_indices.numel() == 0:
121+
return
122+
123+
with torch.no_grad():
124+
if component_idx is None:
125+
if self._n_components is not None:
126+
error_message = "component_idx must be specified when n_components is not None"
127+
raise ValueError(error_message)
128+
129+
self.weight[:, dictionary_vector_indices] = updated_weights
130+
else:
131+
self.weight[component_idx, :, dictionary_vector_indices] = updated_weights
132+
133+
def reset_parameters(self) -> None:
134+
"""Initialize or reset the parameters."""
135+
# Assumes we are using ReLU activation function (for e.g. leaky ReLU, the `a` parameter and
136+
# `nonlinerity` must be changed.
137+
init.kaiming_uniform_(self.weight, nonlinearity="relu")
138+
139+
def forward(
140+
self, x: Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]
141+
) -> Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)]:
142+
"""Forward pass.
143+
144+
Args:
145+
x: Input tensor.
146+
147+
Returns:
148+
Output of the forward pass.
149+
"""
150+
return einops.einsum(
151+
x,
152+
self.weight,
153+
f"{Axis.BATCH} ... {Axis.LEARNT_FEATURE}, \
154+
... {Axis.INPUT_OUTPUT_FEATURE} {Axis.LEARNT_FEATURE} \
155+
-> {Axis.BATCH} ... {Axis.INPUT_OUTPUT_FEATURE}",
156+
)
157+
158+
def extra_repr(self) -> str:
159+
"""String extra representation of the module."""
160+
return (
161+
f"learnt_features={self._learnt_features}, "
162+
f"decoded_features={self._decoded_features}, "
163+
f"n_components={self._n_components}"
164+
)
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
"""Linear encoder layer with tanh(ReLU()) activation."""
2+
import math
3+
from typing import final
4+
5+
import einops
6+
from jaxtyping import Float, Int64
7+
from pydantic import PositiveInt, validate_call
8+
import torch
9+
from torch import Tensor
10+
from torch.nn import Module, Parameter, ReLU, init, Tanh
11+
12+
from sparse_autoencoder.autoencoder.types import ResetOptimizerParameterDetails
13+
from sparse_autoencoder.tensor_types import Axis
14+
from sparse_autoencoder.utils.tensor_shape import shape_with_optional_dimensions
15+
16+
class TanhReLU(Module):
17+
def __init__(self):
18+
super(TanhReLU, self).__init__()
19+
self.tanh = Tanh()
20+
self.relu = ReLU()
21+
22+
def forward(self, x):
23+
return self.tanh(self.relu(x))
24+
25+
@final
26+
class TanhEncoder(Module):
27+
r"""Linear encoder layer.
28+
29+
Linear encoder layer (essentially `nn.Linear`, with a ReLU activation function). Designed to be
30+
used as the encoder in a sparse autoencoder (excluding any outer tied bias).
31+
32+
$$
33+
\begin{align*}
34+
m &= \text{learned features dimension} \\
35+
n &= \text{input and output dimension} \\
36+
b &= \text{batch items dimension} \\
37+
\overline{\mathbf{x}} \in \mathbb{R}^{b \times n} &= \text{input after tied bias} \\
38+
W_e \in \mathbb{R}^{m \times n} &= \text{weight matrix} \\
39+
b_e \in \mathbb{R}^{m} &= \text{bias vector} \\
40+
f &= \text{ReLU}(\overline{\mathbf{x}} W_e^T + b_e) = \text{LinearEncoder output}
41+
\end{align*}
42+
$$
43+
"""
44+
45+
_learnt_features: int
46+
"""Number of learnt features (inputs to this layer)."""
47+
48+
_input_features: int
49+
"""Number of input features from the source model."""
50+
51+
_n_components: int | None
52+
53+
weight: Float[
54+
Parameter,
55+
Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE),
56+
]
57+
"""Weight parameter.
58+
59+
Each row in the weights matrix acts as a dictionary vector, representing a single basis
60+
element in the learned activation space.
61+
"""
62+
63+
bias: Float[Parameter, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]
64+
"""Bias parameter."""
65+
66+
@property
67+
def reset_optimizer_parameter_details(self) -> list[ResetOptimizerParameterDetails]:
68+
"""Reset optimizer parameter details.
69+
70+
Details of the parameters that should be reset in the optimizer, when resetting
71+
dictionary vectors.
72+
73+
Returns:
74+
List of tuples of the form `(parameter, axis)`, where `parameter` is the parameter to
75+
reset (e.g. encoder.weight), and `axis` is the axis of the parameter to reset.
76+
"""
77+
return [
78+
ResetOptimizerParameterDetails(parameter=self.weight, axis=-2),
79+
ResetOptimizerParameterDetails(parameter=self.bias, axis=-1),
80+
]
81+
82+
activation_function: TanhReLU
83+
"""Activation function."""
84+
85+
@validate_call
86+
def __init__(
87+
self,
88+
input_features: PositiveInt,
89+
learnt_features: PositiveInt,
90+
n_components: PositiveInt | None,
91+
):
92+
"""Initialize the linear encoder layer.
93+
94+
Args:
95+
input_features: Number of input features to the autoencoder.
96+
learnt_features: Number of learnt features in the autoencoder.
97+
n_components: Number of source model components the SAE is trained on.
98+
"""
99+
super().__init__()
100+
101+
self._learnt_features = learnt_features
102+
self._input_features = input_features
103+
self._n_components = n_components
104+
105+
self.weight = Parameter(
106+
torch.empty(
107+
shape_with_optional_dimensions(n_components, learnt_features, input_features),
108+
)
109+
)
110+
self.bias = Parameter(
111+
torch.zeros(shape_with_optional_dimensions(n_components, learnt_features))
112+
)
113+
self.activation_function = TanhReLU()
114+
115+
self.reset_parameters()
116+
117+
def reset_parameters(self) -> None:
118+
"""Initialize or reset the parameters."""
119+
# Assumes we are using ReLU activation function (for e.g. leaky ReLU, the `a` parameter and
120+
# `nonlinerity` must be changed.
121+
init.kaiming_uniform_(self.weight, nonlinearity="relu")
122+
123+
# Bias (approach from nn.Linear)
124+
fan_in = self.weight.size(1)
125+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
126+
init.uniform_(self.bias, -bound, bound)
127+
128+
def forward(
129+
self,
130+
x: Float[
131+
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
132+
],
133+
) -> Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)]:
134+
"""Forward pass.
135+
136+
Args:
137+
x: Input tensor.
138+
139+
Returns:
140+
Output of the forward pass.
141+
"""
142+
z = (
143+
einops.einsum(
144+
x,
145+
self.weight,
146+
f"{Axis.BATCH} ... {Axis.INPUT_OUTPUT_FEATURE}, \
147+
... {Axis.LEARNT_FEATURE} {Axis.INPUT_OUTPUT_FEATURE} \
148+
-> {Axis.BATCH} ... {Axis.LEARNT_FEATURE}",
149+
)
150+
+ self.bias
151+
)
152+
153+
return self.activation_function(z)
154+
155+
@final
156+
def update_dictionary_vectors(
157+
self,
158+
dictionary_vector_indices: Int64[Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX)],
159+
updated_dictionary_weights: Float[
160+
Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX, Axis.INPUT_OUTPUT_FEATURE)
161+
],
162+
component_idx: int | None = None,
163+
) -> None:
164+
"""Update encoder dictionary vectors.
165+
166+
Updates the dictionary vectors (columns in the weight matrix) with the given values.
167+
168+
Args:
169+
dictionary_vector_indices: Indices of the dictionary vectors to update.
170+
updated_dictionary_weights: Updated weights for just these dictionary vectors.
171+
component_idx: Component index to update.
172+
173+
Raises:
174+
ValueError: If there are multiple components and `component_idx` is not specified.
175+
"""
176+
if dictionary_vector_indices.numel() == 0:
177+
return
178+
179+
with torch.no_grad():
180+
if component_idx is None:
181+
if self._n_components is not None:
182+
error_message = "component_idx must be specified when n_components is not None"
183+
raise ValueError(error_message)
184+
185+
self.weight[dictionary_vector_indices] = updated_dictionary_weights
186+
else:
187+
self.weight[component_idx, dictionary_vector_indices] = updated_dictionary_weights
188+
189+
@final
190+
def update_bias(
191+
self,
192+
update_parameter_indices: Int64[
193+
Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE_IDX)
194+
],
195+
updated_bias_features: Float[
196+
Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE_IDX)
197+
],
198+
component_idx: int | None = None,
199+
) -> None:
200+
"""Update encoder bias.
201+
202+
Args:
203+
update_parameter_indices: Indices of the bias features to update.
204+
updated_bias_features: Updated bias features for just these indices.
205+
component_idx: Component index to update.
206+
207+
Raises:
208+
ValueError: If there are multiple components and `component_idx` is not specified.
209+
"""
210+
if update_parameter_indices.numel() == 0:
211+
return
212+
213+
with torch.no_grad():
214+
if component_idx is None:
215+
if self._n_components is not None:
216+
error_message = "component_idx must be specified when n_components is not None"
217+
raise ValueError(error_message)
218+
219+
self.bias[update_parameter_indices] = updated_bias_features
220+
else:
221+
self.bias[component_idx, update_parameter_indices] = updated_bias_features
222+
223+
def extra_repr(self) -> str:
224+
"""String extra representation of the module."""
225+
return (
226+
f"input_features={self._input_features}, "
227+
f"learnt_features={self._learnt_features}, "
228+
f"n_components={self._n_components}"
229+
)

0 commit comments

Comments
 (0)