Skip to content

Commit fa282f1

Browse files
committed
Fix resampler
1 parent b40732c commit fa282f1

File tree

4 files changed

+137
-187
lines changed

4 files changed

+137
-187
lines changed

sparse_autoencoder/activation_resampler/activation_resampler.py

Lines changed: 54 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
from jaxtyping import Bool, Float, Int
77
from pydantic import Field, NonNegativeInt, PositiveInt, validate_call
88
import torch
9-
from torch import Tensor
9+
from torch import Tensor, distributed
1010
from torch.distributed import get_world_size, group
11+
from torch.nn import Parameter
1112
from torchmetrics import Metric
13+
from torchmetrics.utilities import dim_zero_cat
1214

1315
from sparse_autoencoder.activation_resampler.utils.component_slice_tensor import (
1416
get_component_slice_tensor,
@@ -83,26 +85,21 @@ class ActivationResampler(Metric):
8385

8486
# Tracking
8587
_n_activations_seen_process: int
86-
_n_times_resampled_process: int
88+
_n_times_resampled: int
8789

8890
# Settings
8991
_n_components: int
9092
_threshold_is_dead_portion_fires: float
9193
_max_n_resamples: int
94+
resample_interval: int
9295
resample_interval_process: int
9396
start_collecting_neuron_activity_process: int
9497
start_collecting_loss_process: int
9598

96-
# Encoder weight reference
97-
_encoder_weight: Float[Tensor, Axis.names(Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE)]
98-
9999
@validate_call
100100
def __init__(
101101
self,
102102
n_learned_features: PositiveInt,
103-
encoder_weight_reference: Float[
104-
Tensor, Axis.names(Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE)
105-
],
106103
n_components: NonNegativeInt = 1,
107104
resample_interval: PositiveInt = 200_000_000,
108105
max_n_resamples: NonNegativeInt = 4,
@@ -116,7 +113,6 @@ def __init__(
116113
117114
Args:
118115
n_learned_features: Number of learned features
119-
encoder_weight_reference: Reference to the encoder weight tensor.
120116
n_components: Number of components that the SAE is being trained on.
121117
resample_interval: Interval in number of autoencoder input activation vectors trained
122118
on, before resampling.
@@ -133,34 +129,41 @@ def __init__(
133129
Raises:
134130
ValueError: If any of the arguments are invalid (e.g. negative integers).
135131
"""
136-
super().__init__(sync_on_compute=False)
132+
super().__init__(
133+
sync_on_compute=False # Manually sync instead in compute, where needed
134+
)
137135

138136
# Error handling
139137
if n_activations_activity_collate > resample_interval:
140138
error_message = "Must collate less activation activity than the resample interval."
141139
raise ValueError(error_message)
142140

143141
# Number of processes
144-
world_size = get_world_size(group.WORLD)
142+
world_size = (
143+
get_world_size(group.WORLD)
144+
if distributed.is_available() and distributed.is_initialized()
145+
else 1
146+
)
145147
process_resample_dataset_size = resample_dataset_size // world_size
146148

147149
# State setup (note half precision is used as it's sufficient for resampling purposes)
148150
self.add_state(
149151
"_neuron_fired_count",
150-
torch.zeros((n_components, n_learned_features), dtype=torch.int),
152+
torch.zeros((n_components, n_learned_features), dtype=torch.bfloat16),
151153
"sum",
152154
)
153155
self.add_state("_loss", [], "cat")
154156
self.add_state("_input_activations", [], "cat")
155157

156158
# Tracking
157159
self._n_activations_seen_process = 0
158-
self._n_times_resampled_process = 0
160+
self._n_times_resampled = 0
159161

160162
# Settings
161163
self._n_components = n_components
162164
self._threshold_is_dead_portion_fires = threshold_is_dead_portion_fires
163165
self._max_n_resamples = max_n_resamples
166+
self.resample_interval = resample_interval
164167
self.resample_interval_process = resample_interval // world_size
165168
self.start_collecting_neuron_activity_process = (
166169
self.resample_interval_process - n_activations_activity_collate // world_size
@@ -169,25 +172,24 @@ def __init__(
169172
self.resample_interval_process - process_resample_dataset_size
170173
)
171174

172-
# Encoder weight reference
173-
self._encoder_weight = encoder_weight_reference
174-
175175
def update(
176176
self,
177177
input_activations: Float[
178178
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
179179
],
180180
learned_activations: Float[
181-
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT, Axis.LEARNT_FEATURE)
181+
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
182182
],
183183
loss: Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)],
184+
encoder_weight_reference: Parameter,
184185
) -> None:
185186
"""Update the collated data from forward passes.
186187
187188
Args:
188189
input_activations: Input activations to the SAE.
189190
learned_activations: Learned activations from the SAE.
190191
loss: Loss per input activation.
192+
encoder_weight_reference: Reference to the SAE encoder weight tensor.
191193
192194
Raises:
193195
TypeError: If the loss or input activations are not lists (e.g. from unsync having not
@@ -209,6 +211,7 @@ def update(
209211
self._input_activations.append(input_activations.to(dtype=torch.bfloat16))
210212

211213
self._n_activations_seen_process += len(learned_activations)
214+
self._encoder_weight = encoder_weight_reference
212215

213216
def _get_dead_neuron_indices(
214217
self,
@@ -218,34 +221,26 @@ def _get_dead_neuron_indices(
218221
Identifies any neurons that have fired less than the threshold portion of the collated
219222
sample size.
220223
221-
Example:
222-
>>> resampler = ActivationResampler(n_learned_features=6, n_components=2)
223-
>>> resampler._collated_neuron_activity = torch.tensor(
224-
... [[1, 1, 0, 0, 1, 1], [1, 1, 1, 1, 1, 0]]
225-
... )
226-
>>> resampler._get_dead_neuron_indices()
227-
[tensor([2, 3]), tensor([5])]
228-
229224
Returns:
230225
List of dead neuron indices for each component.
231226
232227
Raises:
233228
ValueError: If no neuron activity has been collated yet.
234229
"""
235230
# Check we have already collated some neuron activity
236-
if torch.all(self._collated_neuron_activity == 0):
231+
if torch.all(self._neuron_fired_count == 0):
237232
error_message = "Cannot get dead neuron indices without neuron activity."
238233
raise ValueError(error_message)
239234

240235
# Find any neurons that fire less than the threshold portion of times
241236
threshold_is_dead_n_fires: int = int(
242-
self._n_activations_collated_since_last_resample * self._threshold_is_dead_portion_fires
237+
self.resample_interval * self._threshold_is_dead_portion_fires
243238
)
244239

245240
return [
246-
torch.where(self._collated_neuron_activity[component_idx] <= threshold_is_dead_n_fires)[
247-
0
248-
].to(dtype=torch.int64)
241+
torch.where(self._neuron_fired_count[component_idx] <= threshold_is_dead_n_fires)[0].to(
242+
dtype=torch.int
243+
)
249244
for component_idx in range(self._n_components)
250245
]
251246

@@ -359,7 +354,7 @@ def sample_input(
359354
@staticmethod
360355
def renormalize_and_scale(
361356
sampled_input: Float[Tensor, Axis.names(Axis.DEAD_FEATURE, Axis.INPUT_OUTPUT_FEATURE)],
362-
neuron_activity: Int[Tensor, Axis.names(Axis.LEARNT_FEATURE)],
357+
neuron_activity: Float[Tensor, Axis.names(Axis.LEARNT_FEATURE)],
363358
encoder_weight: Float[Tensor, Axis.names(Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE)],
364359
) -> Float[Tensor, Axis.names(Axis.DEAD_FEATURE, Axis.INPUT_OUTPUT_FEATURE)]:
365360
"""Renormalize and scale the resampled dictionary vectors.
@@ -371,7 +366,7 @@ def renormalize_and_scale(
371366
>>> from torch.nn import Parameter
372367
>>> _seed = torch.manual_seed(0) # For reproducibility in example
373368
>>> sampled_input = torch.tensor([[3.0, 4.0]])
374-
>>> neuron_activity = torch.tensor([3, 0, 5, 0, 1, 3])
369+
>>> neuron_activity = torch.tensor([3.0, 0, 5, 0, 1, 3])
375370
>>> encoder_weight = Parameter(torch.ones((6, 2)))
376371
>>> rescaled_input = ActivationResampler.renormalize_and_scale(
377372
... sampled_input,
@@ -428,10 +423,6 @@ def compute(self) -> list[ParameterUpdateResults] | None:
428423
Returns:
429424
A list of parameter update results (for each component that the SAE is being trained
430425
on), if an update is needed.
431-
432-
Raises:
433-
TypeError: If the loss or input activations are not lists (e.g. from unsync having not
434-
been called).
435426
"""
436427
# Resample if needed
437428
if self._n_activations_seen_process >= self.resample_interval_process:
@@ -441,10 +432,8 @@ def compute(self) -> list[ParameterUpdateResults] | None:
441432

442433
# Sync & typecast
443434
self.sync()
444-
if not isinstance(self._loss, Tensor) or not isinstance(
445-
self._input_activations, Tensor
446-
):
447-
raise TypeError
435+
loss = dim_zero_cat(self._loss)
436+
input_activations = dim_zero_cat(self._input_activations)
448437

449438
dead_neuron_indices: list[
450439
Int[Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX)]
@@ -454,14 +443,14 @@ def compute(self) -> list[ParameterUpdateResults] | None:
454443
# square of the autoencoder's loss on that input.
455444
sample_probabilities: Float[
456445
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)
457-
] = self.assign_sampling_probabilities(self._loss)
446+
] = self.assign_sampling_probabilities(loss)
458447

459448
# For each dead neuron sample an input according to these probabilities.
460449
sampled_input: list[
461450
Float[Tensor, Axis.names(Axis.DEAD_FEATURE, Axis.INPUT_OUTPUT_FEATURE)]
462451
] = self.sample_input(
463452
sample_probabilities,
464-
self._input_activations,
453+
input_activations,
465454
[len(dead) for dead in dead_neuron_indices],
466455
)
467456

@@ -505,7 +494,7 @@ def compute(self) -> list[ParameterUpdateResults] | None:
505494
)
506495

507496
# Reset
508-
self.unsync()
497+
self.unsync(should_unsync=self._is_synced)
509498
self.reset()
510499

511500
return parameter_update_results
@@ -518,16 +507,18 @@ def forward( # type: ignore[override]
518507
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
519508
],
520509
learned_activations: Float[
521-
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT, Axis.LEARNT_FEATURE)
510+
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
522511
],
523512
loss: Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)],
513+
encoder_weight_reference: Parameter,
524514
) -> list[ParameterUpdateResults] | None:
525515
"""Step the resampler, collating neuron activity and resampling if necessary.
526516
527517
Args:
528518
input_activations: Input activations to the SAE.
529519
learned_activations: Learned activations from the SAE.
530520
loss: Loss per input activation.
521+
encoder_weight_reference: Reference to the SAE encoder weight tensor.
531522
532523
Returns:
533524
Parameter update results (for each component that the SAE is being trained on) if
@@ -537,18 +528,24 @@ def forward( # type: ignore[override]
537528
if self._n_times_resampled >= self._max_n_resamples:
538529
return None
539530

540-
super().forward(
541-
input_activations=input_activations, learned_activations=learned_activations, loss=loss
531+
self.update(
532+
input_activations=input_activations,
533+
learned_activations=learned_activations,
534+
loss=loss,
535+
encoder_weight_reference=encoder_weight_reference,
542536
)
543537

544-
def __str__(self) -> str:
545-
"""Return a string representation of the activation resampler."""
546-
return (
547-
f"ActivationResampler("
548-
f"n_components={self._n_components}, "
549-
f"threshold_is_dead_portion_fires={self._threshold_is_dead_portion_fires}, "
550-
f"max_n_resamples={self._max_n_resamples}, "
551-
f"resample_interval={self.resample_interval_process}, "
552-
f"start_collecting_neuron_activity={self.start_collecting_neuron_activity_process}, "
553-
f"start_collecting_loss={self.start_collecting_loss_process}"
554-
)
538+
return self.compute()
539+
540+
def reset(self) -> None:
541+
"""Reset the activation resampler.
542+
543+
Warning:
544+
This is only called when forward/compute has returned parameters to update (i.e.
545+
resampling is due).
546+
"""
547+
self._n_activations_seen_process = 0
548+
self._neuron_fired_count = torch.zeros_like(self._neuron_fired_count)
549+
self._loss = []
550+
self._input_activations = []
551+
self._n_times_resampled += 1

0 commit comments

Comments
 (0)