Skip to content

Commit

Permalink
Fix resampler
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney committed Feb 5, 2024
1 parent b40732c commit fa282f1
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 187 deletions.
111 changes: 54 additions & 57 deletions sparse_autoencoder/activation_resampler/activation_resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from jaxtyping import Bool, Float, Int
from pydantic import Field, NonNegativeInt, PositiveInt, validate_call
import torch
from torch import Tensor
from torch import Tensor, distributed
from torch.distributed import get_world_size, group
from torch.nn import Parameter
from torchmetrics import Metric
from torchmetrics.utilities import dim_zero_cat

from sparse_autoencoder.activation_resampler.utils.component_slice_tensor import (
get_component_slice_tensor,
Expand Down Expand Up @@ -83,26 +85,21 @@ class ActivationResampler(Metric):

# Tracking
_n_activations_seen_process: int
_n_times_resampled_process: int
_n_times_resampled: int

# Settings
_n_components: int
_threshold_is_dead_portion_fires: float
_max_n_resamples: int
resample_interval: int
resample_interval_process: int
start_collecting_neuron_activity_process: int
start_collecting_loss_process: int

# Encoder weight reference
_encoder_weight: Float[Tensor, Axis.names(Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE)]

@validate_call
def __init__(
self,
n_learned_features: PositiveInt,
encoder_weight_reference: Float[
Tensor, Axis.names(Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE)
],
n_components: NonNegativeInt = 1,
resample_interval: PositiveInt = 200_000_000,
max_n_resamples: NonNegativeInt = 4,
Expand All @@ -116,7 +113,6 @@ def __init__(
Args:
n_learned_features: Number of learned features
encoder_weight_reference: Reference to the encoder weight tensor.
n_components: Number of components that the SAE is being trained on.
resample_interval: Interval in number of autoencoder input activation vectors trained
on, before resampling.
Expand All @@ -133,34 +129,41 @@ def __init__(
Raises:
ValueError: If any of the arguments are invalid (e.g. negative integers).
"""
super().__init__(sync_on_compute=False)
super().__init__(
sync_on_compute=False # Manually sync instead in compute, where needed
)

# Error handling
if n_activations_activity_collate > resample_interval:
error_message = "Must collate less activation activity than the resample interval."
raise ValueError(error_message)

# Number of processes
world_size = get_world_size(group.WORLD)
world_size = (
get_world_size(group.WORLD)
if distributed.is_available() and distributed.is_initialized()
else 1
)
process_resample_dataset_size = resample_dataset_size // world_size

# State setup (note half precision is used as it's sufficient for resampling purposes)
self.add_state(
"_neuron_fired_count",
torch.zeros((n_components, n_learned_features), dtype=torch.int),
torch.zeros((n_components, n_learned_features), dtype=torch.bfloat16),
"sum",
)
self.add_state("_loss", [], "cat")
self.add_state("_input_activations", [], "cat")

# Tracking
self._n_activations_seen_process = 0
self._n_times_resampled_process = 0
self._n_times_resampled = 0

# Settings
self._n_components = n_components
self._threshold_is_dead_portion_fires = threshold_is_dead_portion_fires
self._max_n_resamples = max_n_resamples
self.resample_interval = resample_interval
self.resample_interval_process = resample_interval // world_size
self.start_collecting_neuron_activity_process = (
self.resample_interval_process - n_activations_activity_collate // world_size
Expand All @@ -169,25 +172,24 @@ def __init__(
self.resample_interval_process - process_resample_dataset_size
)

# Encoder weight reference
self._encoder_weight = encoder_weight_reference

def update(
self,
input_activations: Float[
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
],
learned_activations: Float[
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT, Axis.LEARNT_FEATURE)
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
],
loss: Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)],
encoder_weight_reference: Parameter,
) -> None:
"""Update the collated data from forward passes.
Args:
input_activations: Input activations to the SAE.
learned_activations: Learned activations from the SAE.
loss: Loss per input activation.
encoder_weight_reference: Reference to the SAE encoder weight tensor.
Raises:
TypeError: If the loss or input activations are not lists (e.g. from unsync having not
Expand All @@ -209,6 +211,7 @@ def update(
self._input_activations.append(input_activations.to(dtype=torch.bfloat16))

self._n_activations_seen_process += len(learned_activations)
self._encoder_weight = encoder_weight_reference

def _get_dead_neuron_indices(
self,
Expand All @@ -218,34 +221,26 @@ def _get_dead_neuron_indices(
Identifies any neurons that have fired less than the threshold portion of the collated
sample size.
Example:
>>> resampler = ActivationResampler(n_learned_features=6, n_components=2)
>>> resampler._collated_neuron_activity = torch.tensor(
... [[1, 1, 0, 0, 1, 1], [1, 1, 1, 1, 1, 0]]
... )
>>> resampler._get_dead_neuron_indices()
[tensor([2, 3]), tensor([5])]
Returns:
List of dead neuron indices for each component.
Raises:
ValueError: If no neuron activity has been collated yet.
"""
# Check we have already collated some neuron activity
if torch.all(self._collated_neuron_activity == 0):
if torch.all(self._neuron_fired_count == 0):
error_message = "Cannot get dead neuron indices without neuron activity."
raise ValueError(error_message)

# Find any neurons that fire less than the threshold portion of times
threshold_is_dead_n_fires: int = int(
self._n_activations_collated_since_last_resample * self._threshold_is_dead_portion_fires
self.resample_interval * self._threshold_is_dead_portion_fires
)

return [
torch.where(self._collated_neuron_activity[component_idx] <= threshold_is_dead_n_fires)[
0
].to(dtype=torch.int64)
torch.where(self._neuron_fired_count[component_idx] <= threshold_is_dead_n_fires)[0].to(
dtype=torch.int
)
for component_idx in range(self._n_components)
]

Expand Down Expand Up @@ -359,7 +354,7 @@ def sample_input(
@staticmethod
def renormalize_and_scale(
sampled_input: Float[Tensor, Axis.names(Axis.DEAD_FEATURE, Axis.INPUT_OUTPUT_FEATURE)],
neuron_activity: Int[Tensor, Axis.names(Axis.LEARNT_FEATURE)],
neuron_activity: Float[Tensor, Axis.names(Axis.LEARNT_FEATURE)],
encoder_weight: Float[Tensor, Axis.names(Axis.LEARNT_FEATURE, Axis.INPUT_OUTPUT_FEATURE)],
) -> Float[Tensor, Axis.names(Axis.DEAD_FEATURE, Axis.INPUT_OUTPUT_FEATURE)]:
"""Renormalize and scale the resampled dictionary vectors.
Expand All @@ -371,7 +366,7 @@ def renormalize_and_scale(
>>> from torch.nn import Parameter
>>> _seed = torch.manual_seed(0) # For reproducibility in example
>>> sampled_input = torch.tensor([[3.0, 4.0]])
>>> neuron_activity = torch.tensor([3, 0, 5, 0, 1, 3])
>>> neuron_activity = torch.tensor([3.0, 0, 5, 0, 1, 3])
>>> encoder_weight = Parameter(torch.ones((6, 2)))
>>> rescaled_input = ActivationResampler.renormalize_and_scale(
... sampled_input,
Expand Down Expand Up @@ -428,10 +423,6 @@ def compute(self) -> list[ParameterUpdateResults] | None:
Returns:
A list of parameter update results (for each component that the SAE is being trained
on), if an update is needed.
Raises:
TypeError: If the loss or input activations are not lists (e.g. from unsync having not
been called).
"""
# Resample if needed
if self._n_activations_seen_process >= self.resample_interval_process:
Expand All @@ -441,10 +432,8 @@ def compute(self) -> list[ParameterUpdateResults] | None:

# Sync & typecast
self.sync()
if not isinstance(self._loss, Tensor) or not isinstance(
self._input_activations, Tensor
):
raise TypeError
loss = dim_zero_cat(self._loss)
input_activations = dim_zero_cat(self._input_activations)

dead_neuron_indices: list[
Int[Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX)]
Expand All @@ -454,14 +443,14 @@ def compute(self) -> list[ParameterUpdateResults] | None:
# square of the autoencoder's loss on that input.
sample_probabilities: Float[
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)
] = self.assign_sampling_probabilities(self._loss)
] = self.assign_sampling_probabilities(loss)

# For each dead neuron sample an input according to these probabilities.
sampled_input: list[
Float[Tensor, Axis.names(Axis.DEAD_FEATURE, Axis.INPUT_OUTPUT_FEATURE)]
] = self.sample_input(
sample_probabilities,
self._input_activations,
input_activations,
[len(dead) for dead in dead_neuron_indices],
)

Expand Down Expand Up @@ -505,7 +494,7 @@ def compute(self) -> list[ParameterUpdateResults] | None:
)

# Reset
self.unsync()
self.unsync(should_unsync=self._is_synced)
self.reset()

return parameter_update_results
Expand All @@ -518,16 +507,18 @@ def forward( # type: ignore[override]
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
],
learned_activations: Float[
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT, Axis.LEARNT_FEATURE)
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
],
loss: Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)],
encoder_weight_reference: Parameter,
) -> list[ParameterUpdateResults] | None:
"""Step the resampler, collating neuron activity and resampling if necessary.
Args:
input_activations: Input activations to the SAE.
learned_activations: Learned activations from the SAE.
loss: Loss per input activation.
encoder_weight_reference: Reference to the SAE encoder weight tensor.
Returns:
Parameter update results (for each component that the SAE is being trained on) if
Expand All @@ -537,18 +528,24 @@ def forward( # type: ignore[override]
if self._n_times_resampled >= self._max_n_resamples:
return None

super().forward(
input_activations=input_activations, learned_activations=learned_activations, loss=loss
self.update(
input_activations=input_activations,
learned_activations=learned_activations,
loss=loss,
encoder_weight_reference=encoder_weight_reference,
)

def __str__(self) -> str:
"""Return a string representation of the activation resampler."""
return (
f"ActivationResampler("
f"n_components={self._n_components}, "
f"threshold_is_dead_portion_fires={self._threshold_is_dead_portion_fires}, "
f"max_n_resamples={self._max_n_resamples}, "
f"resample_interval={self.resample_interval_process}, "
f"start_collecting_neuron_activity={self.start_collecting_neuron_activity_process}, "
f"start_collecting_loss={self.start_collecting_loss_process}"
)
return self.compute()

def reset(self) -> None:
"""Reset the activation resampler.
Warning:
This is only called when forward/compute has returned parameters to update (i.e.
resampling is due).
"""
self._n_activations_seen_process = 0
self._neuron_fired_count = torch.zeros_like(self._neuron_fired_count)
self._loss = []
self._input_activations = []
self._n_times_resampled += 1
Loading

0 comments on commit fa282f1

Please sign in to comment.