6
6
from jaxtyping import Bool , Float , Int
7
7
from pydantic import Field , NonNegativeInt , PositiveInt , validate_call
8
8
import torch
9
- from torch import Tensor
9
+ from torch import Tensor , distributed
10
10
from torch .distributed import get_world_size , group
11
+ from torch .nn import Parameter
11
12
from torchmetrics import Metric
13
+ from torchmetrics .utilities import dim_zero_cat
12
14
13
15
from sparse_autoencoder .activation_resampler .utils .component_slice_tensor import (
14
16
get_component_slice_tensor ,
@@ -83,26 +85,21 @@ class ActivationResampler(Metric):
83
85
84
86
# Tracking
85
87
_n_activations_seen_process : int
86
- _n_times_resampled_process : int
88
+ _n_times_resampled : int
87
89
88
90
# Settings
89
91
_n_components : int
90
92
_threshold_is_dead_portion_fires : float
91
93
_max_n_resamples : int
94
+ resample_interval : int
92
95
resample_interval_process : int
93
96
start_collecting_neuron_activity_process : int
94
97
start_collecting_loss_process : int
95
98
96
- # Encoder weight reference
97
- _encoder_weight : Float [Tensor , Axis .names (Axis .LEARNT_FEATURE , Axis .INPUT_OUTPUT_FEATURE )]
98
-
99
99
@validate_call
100
100
def __init__ (
101
101
self ,
102
102
n_learned_features : PositiveInt ,
103
- encoder_weight_reference : Float [
104
- Tensor , Axis .names (Axis .LEARNT_FEATURE , Axis .INPUT_OUTPUT_FEATURE )
105
- ],
106
103
n_components : NonNegativeInt = 1 ,
107
104
resample_interval : PositiveInt = 200_000_000 ,
108
105
max_n_resamples : NonNegativeInt = 4 ,
@@ -116,7 +113,6 @@ def __init__(
116
113
117
114
Args:
118
115
n_learned_features: Number of learned features
119
- encoder_weight_reference: Reference to the encoder weight tensor.
120
116
n_components: Number of components that the SAE is being trained on.
121
117
resample_interval: Interval in number of autoencoder input activation vectors trained
122
118
on, before resampling.
@@ -133,34 +129,41 @@ def __init__(
133
129
Raises:
134
130
ValueError: If any of the arguments are invalid (e.g. negative integers).
135
131
"""
136
- super ().__init__ (sync_on_compute = False )
132
+ super ().__init__ (
133
+ sync_on_compute = False # Manually sync instead in compute, where needed
134
+ )
137
135
138
136
# Error handling
139
137
if n_activations_activity_collate > resample_interval :
140
138
error_message = "Must collate less activation activity than the resample interval."
141
139
raise ValueError (error_message )
142
140
143
141
# Number of processes
144
- world_size = get_world_size (group .WORLD )
142
+ world_size = (
143
+ get_world_size (group .WORLD )
144
+ if distributed .is_available () and distributed .is_initialized ()
145
+ else 1
146
+ )
145
147
process_resample_dataset_size = resample_dataset_size // world_size
146
148
147
149
# State setup (note half precision is used as it's sufficient for resampling purposes)
148
150
self .add_state (
149
151
"_neuron_fired_count" ,
150
- torch .zeros ((n_components , n_learned_features ), dtype = torch .int ),
152
+ torch .zeros ((n_components , n_learned_features ), dtype = torch .bfloat16 ),
151
153
"sum" ,
152
154
)
153
155
self .add_state ("_loss" , [], "cat" )
154
156
self .add_state ("_input_activations" , [], "cat" )
155
157
156
158
# Tracking
157
159
self ._n_activations_seen_process = 0
158
- self ._n_times_resampled_process = 0
160
+ self ._n_times_resampled = 0
159
161
160
162
# Settings
161
163
self ._n_components = n_components
162
164
self ._threshold_is_dead_portion_fires = threshold_is_dead_portion_fires
163
165
self ._max_n_resamples = max_n_resamples
166
+ self .resample_interval = resample_interval
164
167
self .resample_interval_process = resample_interval // world_size
165
168
self .start_collecting_neuron_activity_process = (
166
169
self .resample_interval_process - n_activations_activity_collate // world_size
@@ -169,25 +172,24 @@ def __init__(
169
172
self .resample_interval_process - process_resample_dataset_size
170
173
)
171
174
172
- # Encoder weight reference
173
- self ._encoder_weight = encoder_weight_reference
174
-
175
175
def update (
176
176
self ,
177
177
input_activations : Float [
178
178
Tensor , Axis .names (Axis .BATCH , Axis .COMPONENT_OPTIONAL , Axis .INPUT_OUTPUT_FEATURE )
179
179
],
180
180
learned_activations : Float [
181
- Tensor , Axis .names (Axis .BATCH , Axis .COMPONENT , Axis .LEARNT_FEATURE )
181
+ Tensor , Axis .names (Axis .BATCH , Axis .COMPONENT_OPTIONAL , Axis .LEARNT_FEATURE )
182
182
],
183
183
loss : Float [Tensor , Axis .names (Axis .BATCH , Axis .COMPONENT_OPTIONAL )],
184
+ encoder_weight_reference : Parameter ,
184
185
) -> None :
185
186
"""Update the collated data from forward passes.
186
187
187
188
Args:
188
189
input_activations: Input activations to the SAE.
189
190
learned_activations: Learned activations from the SAE.
190
191
loss: Loss per input activation.
192
+ encoder_weight_reference: Reference to the SAE encoder weight tensor.
191
193
192
194
Raises:
193
195
TypeError: If the loss or input activations are not lists (e.g. from unsync having not
@@ -209,6 +211,7 @@ def update(
209
211
self ._input_activations .append (input_activations .to (dtype = torch .bfloat16 ))
210
212
211
213
self ._n_activations_seen_process += len (learned_activations )
214
+ self ._encoder_weight = encoder_weight_reference
212
215
213
216
def _get_dead_neuron_indices (
214
217
self ,
@@ -218,34 +221,26 @@ def _get_dead_neuron_indices(
218
221
Identifies any neurons that have fired less than the threshold portion of the collated
219
222
sample size.
220
223
221
- Example:
222
- >>> resampler = ActivationResampler(n_learned_features=6, n_components=2)
223
- >>> resampler._collated_neuron_activity = torch.tensor(
224
- ... [[1, 1, 0, 0, 1, 1], [1, 1, 1, 1, 1, 0]]
225
- ... )
226
- >>> resampler._get_dead_neuron_indices()
227
- [tensor([2, 3]), tensor([5])]
228
-
229
224
Returns:
230
225
List of dead neuron indices for each component.
231
226
232
227
Raises:
233
228
ValueError: If no neuron activity has been collated yet.
234
229
"""
235
230
# Check we have already collated some neuron activity
236
- if torch .all (self ._collated_neuron_activity == 0 ):
231
+ if torch .all (self ._neuron_fired_count == 0 ):
237
232
error_message = "Cannot get dead neuron indices without neuron activity."
238
233
raise ValueError (error_message )
239
234
240
235
# Find any neurons that fire less than the threshold portion of times
241
236
threshold_is_dead_n_fires : int = int (
242
- self ._n_activations_collated_since_last_resample * self ._threshold_is_dead_portion_fires
237
+ self .resample_interval * self ._threshold_is_dead_portion_fires
243
238
)
244
239
245
240
return [
246
- torch .where (self ._collated_neuron_activity [component_idx ] <= threshold_is_dead_n_fires )[
247
- 0
248
- ]. to ( dtype = torch . int64 )
241
+ torch .where (self ._neuron_fired_count [component_idx ] <= threshold_is_dead_n_fires )[0 ]. to (
242
+ dtype = torch . int
243
+ )
249
244
for component_idx in range (self ._n_components )
250
245
]
251
246
@@ -359,7 +354,7 @@ def sample_input(
359
354
@staticmethod
360
355
def renormalize_and_scale (
361
356
sampled_input : Float [Tensor , Axis .names (Axis .DEAD_FEATURE , Axis .INPUT_OUTPUT_FEATURE )],
362
- neuron_activity : Int [Tensor , Axis .names (Axis .LEARNT_FEATURE )],
357
+ neuron_activity : Float [Tensor , Axis .names (Axis .LEARNT_FEATURE )],
363
358
encoder_weight : Float [Tensor , Axis .names (Axis .LEARNT_FEATURE , Axis .INPUT_OUTPUT_FEATURE )],
364
359
) -> Float [Tensor , Axis .names (Axis .DEAD_FEATURE , Axis .INPUT_OUTPUT_FEATURE )]:
365
360
"""Renormalize and scale the resampled dictionary vectors.
@@ -371,7 +366,7 @@ def renormalize_and_scale(
371
366
>>> from torch.nn import Parameter
372
367
>>> _seed = torch.manual_seed(0) # For reproducibility in example
373
368
>>> sampled_input = torch.tensor([[3.0, 4.0]])
374
- >>> neuron_activity = torch.tensor([3, 0, 5, 0, 1, 3])
369
+ >>> neuron_activity = torch.tensor([3.0 , 0, 5, 0, 1, 3])
375
370
>>> encoder_weight = Parameter(torch.ones((6, 2)))
376
371
>>> rescaled_input = ActivationResampler.renormalize_and_scale(
377
372
... sampled_input,
@@ -428,10 +423,6 @@ def compute(self) -> list[ParameterUpdateResults] | None:
428
423
Returns:
429
424
A list of parameter update results (for each component that the SAE is being trained
430
425
on), if an update is needed.
431
-
432
- Raises:
433
- TypeError: If the loss or input activations are not lists (e.g. from unsync having not
434
- been called).
435
426
"""
436
427
# Resample if needed
437
428
if self ._n_activations_seen_process >= self .resample_interval_process :
@@ -441,10 +432,8 @@ def compute(self) -> list[ParameterUpdateResults] | None:
441
432
442
433
# Sync & typecast
443
434
self .sync ()
444
- if not isinstance (self ._loss , Tensor ) or not isinstance (
445
- self ._input_activations , Tensor
446
- ):
447
- raise TypeError
435
+ loss = dim_zero_cat (self ._loss )
436
+ input_activations = dim_zero_cat (self ._input_activations )
448
437
449
438
dead_neuron_indices : list [
450
439
Int [Tensor , Axis .names (Axis .LEARNT_FEATURE_IDX )]
@@ -454,14 +443,14 @@ def compute(self) -> list[ParameterUpdateResults] | None:
454
443
# square of the autoencoder's loss on that input.
455
444
sample_probabilities : Float [
456
445
Tensor , Axis .names (Axis .BATCH , Axis .COMPONENT_OPTIONAL )
457
- ] = self .assign_sampling_probabilities (self . _loss )
446
+ ] = self .assign_sampling_probabilities (loss )
458
447
459
448
# For each dead neuron sample an input according to these probabilities.
460
449
sampled_input : list [
461
450
Float [Tensor , Axis .names (Axis .DEAD_FEATURE , Axis .INPUT_OUTPUT_FEATURE )]
462
451
] = self .sample_input (
463
452
sample_probabilities ,
464
- self . _input_activations ,
453
+ input_activations ,
465
454
[len (dead ) for dead in dead_neuron_indices ],
466
455
)
467
456
@@ -505,7 +494,7 @@ def compute(self) -> list[ParameterUpdateResults] | None:
505
494
)
506
495
507
496
# Reset
508
- self .unsync ()
497
+ self .unsync (should_unsync = self . _is_synced )
509
498
self .reset ()
510
499
511
500
return parameter_update_results
@@ -518,16 +507,18 @@ def forward( # type: ignore[override]
518
507
Tensor , Axis .names (Axis .BATCH , Axis .COMPONENT_OPTIONAL , Axis .INPUT_OUTPUT_FEATURE )
519
508
],
520
509
learned_activations : Float [
521
- Tensor , Axis .names (Axis .BATCH , Axis .COMPONENT , Axis .LEARNT_FEATURE )
510
+ Tensor , Axis .names (Axis .BATCH , Axis .COMPONENT_OPTIONAL , Axis .LEARNT_FEATURE )
522
511
],
523
512
loss : Float [Tensor , Axis .names (Axis .BATCH , Axis .COMPONENT_OPTIONAL )],
513
+ encoder_weight_reference : Parameter ,
524
514
) -> list [ParameterUpdateResults ] | None :
525
515
"""Step the resampler, collating neuron activity and resampling if necessary.
526
516
527
517
Args:
528
518
input_activations: Input activations to the SAE.
529
519
learned_activations: Learned activations from the SAE.
530
520
loss: Loss per input activation.
521
+ encoder_weight_reference: Reference to the SAE encoder weight tensor.
531
522
532
523
Returns:
533
524
Parameter update results (for each component that the SAE is being trained on) if
@@ -537,18 +528,24 @@ def forward( # type: ignore[override]
537
528
if self ._n_times_resampled >= self ._max_n_resamples :
538
529
return None
539
530
540
- super ().forward (
541
- input_activations = input_activations , learned_activations = learned_activations , loss = loss
531
+ self .update (
532
+ input_activations = input_activations ,
533
+ learned_activations = learned_activations ,
534
+ loss = loss ,
535
+ encoder_weight_reference = encoder_weight_reference ,
542
536
)
543
537
544
- def __str__ (self ) -> str :
545
- """Return a string representation of the activation resampler."""
546
- return (
547
- f"ActivationResampler("
548
- f"n_components={ self ._n_components } , "
549
- f"threshold_is_dead_portion_fires={ self ._threshold_is_dead_portion_fires } , "
550
- f"max_n_resamples={ self ._max_n_resamples } , "
551
- f"resample_interval={ self .resample_interval_process } , "
552
- f"start_collecting_neuron_activity={ self .start_collecting_neuron_activity_process } , "
553
- f"start_collecting_loss={ self .start_collecting_loss_process } "
554
- )
538
+ return self .compute ()
539
+
540
+ def reset (self ) -> None :
541
+ """Reset the activation resampler.
542
+
543
+ Warning:
544
+ This is only called when forward/compute has returned parameters to update (i.e.
545
+ resampling is due).
546
+ """
547
+ self ._n_activations_seen_process = 0
548
+ self ._neuron_fired_count = torch .zeros_like (self ._neuron_fired_count )
549
+ self ._loss = []
550
+ self ._input_activations = []
551
+ self ._n_times_resampled += 1
0 commit comments