|
| 1 | +"""PyTorch Lightning module for training a sparse autoencoder.""" |
| 2 | +from functools import partial |
| 3 | +from typing import Any |
| 4 | + |
| 5 | +from jaxtyping import Float |
| 6 | +from lightning.pytorch import LightningModule |
| 7 | +from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt |
| 8 | +from torch import Tensor |
| 9 | +from torch.optim.optimizer import Optimizer |
| 10 | +from torchmetrics import MetricCollection |
| 11 | +import wandb |
| 12 | + |
| 13 | +from sparse_autoencoder.activation_resampler.activation_resampler import ( |
| 14 | + ActivationResampler, |
| 15 | + ParameterUpdateResults, |
| 16 | +) |
| 17 | +from sparse_autoencoder.autoencoder.model import ( |
| 18 | + ForwardPassResult, |
| 19 | + SparseAutoencoder, |
| 20 | + SparseAutoencoderConfig, |
| 21 | +) |
| 22 | +from sparse_autoencoder.autoencoder.types import ResetOptimizerParameterDetails |
| 23 | +from sparse_autoencoder.metrics.loss.l1_absolute_loss import L1AbsoluteLoss |
| 24 | +from sparse_autoencoder.metrics.loss.l2_reconstruction_loss import L2ReconstructionLoss |
| 25 | +from sparse_autoencoder.metrics.loss.sae_loss import SparseAutoencoderLoss |
| 26 | +from sparse_autoencoder.metrics.train.l0_norm import L0NormMetric |
| 27 | +from sparse_autoencoder.metrics.train.neuron_activity import NeuronActivityMetric |
| 28 | +from sparse_autoencoder.metrics.wrappers.classwise import ClasswiseWrapperWithMean |
| 29 | +from sparse_autoencoder.optimizer.adam_with_reset import AdamWithReset |
| 30 | +from sparse_autoencoder.tensor_types import Axis |
| 31 | + |
| 32 | + |
| 33 | +class LitSparseAutoencoderConfig(SparseAutoencoderConfig): |
| 34 | + """PyTorch Lightning Sparse Autoencoder config.""" |
| 35 | + |
| 36 | + component_names: list[str] |
| 37 | + |
| 38 | + l1_coefficient: float = 0.001 |
| 39 | + |
| 40 | + resample_interval: PositiveInt = 200000000 |
| 41 | + |
| 42 | + max_n_resamples: NonNegativeInt = 4 |
| 43 | + |
| 44 | + resample_dead_neurons_dataset_size: PositiveInt = 100000000 |
| 45 | + |
| 46 | + resample_loss_dataset_size: PositiveInt = 819200 |
| 47 | + |
| 48 | + resample_threshold_is_dead_portion_fires: NonNegativeFloat = 0.0 |
| 49 | + |
| 50 | + def model_post_init(self, __context: Any) -> None: # noqa: ANN401 |
| 51 | + """Model post init validation. |
| 52 | +
|
| 53 | + Args: |
| 54 | + __context: Pydantic context. |
| 55 | +
|
| 56 | + Raises: |
| 57 | + ValueError: If the number of component names does not match the number of components. |
| 58 | + """ |
| 59 | + if self.n_components and len(self.component_names) != self.n_components: |
| 60 | + error_message = ( |
| 61 | + f"Number of component names ({len(self.component_names)}) must match the number of " |
| 62 | + f"components ({self.n_components})" |
| 63 | + ) |
| 64 | + raise ValueError(error_message) |
| 65 | + |
| 66 | + |
| 67 | +class LitSparseAutoencoder(LightningModule): |
| 68 | + """Lightning Sparse Autoencoder.""" |
| 69 | + |
| 70 | + sparse_autoencoder: SparseAutoencoder |
| 71 | + |
| 72 | + config: LitSparseAutoencoderConfig |
| 73 | + |
| 74 | + loss_fn: SparseAutoencoderLoss |
| 75 | + |
| 76 | + train_metrics: MetricCollection |
| 77 | + |
| 78 | + def __init__( |
| 79 | + self, |
| 80 | + config: LitSparseAutoencoderConfig, |
| 81 | + ): |
| 82 | + """Initialise the module.""" |
| 83 | + super().__init__() |
| 84 | + self.sparse_autoencoder = SparseAutoencoder(config) |
| 85 | + self.config = config |
| 86 | + |
| 87 | + num_components = config.n_components or 1 |
| 88 | + add_component_names = partial( |
| 89 | + ClasswiseWrapperWithMean, component_names=config.component_names |
| 90 | + ) |
| 91 | + |
| 92 | + # Create the loss & metrics |
| 93 | + self.loss_fn = SparseAutoencoderLoss( |
| 94 | + num_components, config.l1_coefficient, keep_batch_dim=True |
| 95 | + ) |
| 96 | + |
| 97 | + self.train_metrics = MetricCollection( |
| 98 | + { |
| 99 | + "l0": add_component_names(L0NormMetric(num_components), prefix="train/l0_norm"), |
| 100 | + "activity": add_component_names( |
| 101 | + NeuronActivityMetric(config.n_learned_features, num_components), |
| 102 | + prefix="train/neuron_activity", |
| 103 | + ), |
| 104 | + "l1": add_component_names( |
| 105 | + L1AbsoluteLoss(num_components), prefix="loss/l1_learned_activations" |
| 106 | + ), |
| 107 | + "l2": add_component_names( |
| 108 | + L2ReconstructionLoss(num_components), prefix="loss/l2_reconstruction" |
| 109 | + ), |
| 110 | + "loss": add_component_names( |
| 111 | + SparseAutoencoderLoss(num_components, config.l1_coefficient), |
| 112 | + prefix="loss/total", |
| 113 | + ), |
| 114 | + }, |
| 115 | + # Share state & updates across groups (to avoid e.g. computing l1 twice for both the |
| 116 | + # loss and l1 metrics). Note the metric that goes first must calculate all the states |
| 117 | + # needed by the rest of the group. |
| 118 | + compute_groups=[ |
| 119 | + ["loss", "l1", "l2"], |
| 120 | + ["activity"], |
| 121 | + ["l0"], |
| 122 | + ], |
| 123 | + ) |
| 124 | + |
| 125 | + self.activation_resampler = ActivationResampler( |
| 126 | + n_learned_features=config.n_learned_features, |
| 127 | + n_components=num_components, |
| 128 | + resample_interval=config.resample_interval, |
| 129 | + max_n_resamples=config.max_n_resamples, |
| 130 | + n_activations_activity_collate=config.resample_dead_neurons_dataset_size, |
| 131 | + resample_dataset_size=config.resample_loss_dataset_size, |
| 132 | + threshold_is_dead_portion_fires=config.resample_threshold_is_dead_portion_fires, |
| 133 | + ) |
| 134 | + |
| 135 | + def forward( # type: ignore[override] |
| 136 | + self, |
| 137 | + inputs: Float[ |
| 138 | + Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) |
| 139 | + ], |
| 140 | + ) -> ForwardPassResult: |
| 141 | + """Forward pass.""" |
| 142 | + return self.sparse_autoencoder.forward(inputs) |
| 143 | + |
| 144 | + def update_parameters(self, parameter_updates: list[ParameterUpdateResults]) -> None: |
| 145 | + """Update the parameters of the model from the results of the resampler. |
| 146 | +
|
| 147 | + Args: |
| 148 | + parameter_updates: Parameter updates from the resampler. |
| 149 | +
|
| 150 | + Raises: |
| 151 | + TypeError: If the optimizer is not an AdamWithReset. |
| 152 | + """ |
| 153 | + for component_idx, component_parameter_update in enumerate(parameter_updates): |
| 154 | + # Update the weights and biases |
| 155 | + self.sparse_autoencoder.encoder.update_dictionary_vectors( |
| 156 | + component_parameter_update.dead_neuron_indices, |
| 157 | + component_parameter_update.dead_encoder_weight_updates, |
| 158 | + component_idx=component_idx, |
| 159 | + ) |
| 160 | + self.sparse_autoencoder.encoder.update_bias( |
| 161 | + component_parameter_update.dead_neuron_indices, |
| 162 | + component_parameter_update.dead_encoder_bias_updates, |
| 163 | + component_idx=component_idx, |
| 164 | + ) |
| 165 | + self.sparse_autoencoder.decoder.update_dictionary_vectors( |
| 166 | + component_parameter_update.dead_neuron_indices, |
| 167 | + component_parameter_update.dead_decoder_weight_updates, |
| 168 | + component_idx=component_idx, |
| 169 | + ) |
| 170 | + |
| 171 | + # Reset the optimizer |
| 172 | + for ( |
| 173 | + parameter, |
| 174 | + axis, |
| 175 | + ) in self.reset_optimizer_parameter_details: |
| 176 | + optimizer = self.optimizers(use_pl_optimizer=False) |
| 177 | + if not isinstance(optimizer, AdamWithReset): |
| 178 | + error_message = "Cannot reset the optimizer. " |
| 179 | + raise TypeError(error_message) |
| 180 | + |
| 181 | + optimizer.reset_neurons_state( |
| 182 | + parameter=parameter, |
| 183 | + neuron_indices=component_parameter_update.dead_neuron_indices, |
| 184 | + axis=axis, |
| 185 | + component_idx=component_idx, |
| 186 | + ) |
| 187 | + |
| 188 | + def training_step( # type: ignore[override] |
| 189 | + self, |
| 190 | + batch: Float[ |
| 191 | + Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) |
| 192 | + ], |
| 193 | + batch_idx: int | None = None, # noqa: ARG002 |
| 194 | + ) -> Float[Tensor, Axis.SINGLE_ITEM]: |
| 195 | + """Training step.""" |
| 196 | + # Forward pass |
| 197 | + output: ForwardPassResult = self.forward(batch) |
| 198 | + |
| 199 | + # Metrics & loss |
| 200 | + train_metrics = self.train_metrics.forward( |
| 201 | + source_activations=batch, |
| 202 | + learned_activations=output.learned_activations, |
| 203 | + decoded_activations=output.decoded_activations, |
| 204 | + ) |
| 205 | + |
| 206 | + loss = self.loss_fn.forward( |
| 207 | + source_activations=batch, |
| 208 | + learned_activations=output.learned_activations, |
| 209 | + decoded_activations=output.decoded_activations, |
| 210 | + ) |
| 211 | + |
| 212 | + if wandb.run is not None: |
| 213 | + self.log_dict(train_metrics) |
| 214 | + |
| 215 | + # Resample dead neurons |
| 216 | + parameter_updates = self.activation_resampler.forward( |
| 217 | + input_activations=batch, |
| 218 | + learned_activations=output.learned_activations, |
| 219 | + loss=loss, |
| 220 | + encoder_weight_reference=self.sparse_autoencoder.encoder.weight, |
| 221 | + ) |
| 222 | + if parameter_updates is not None: |
| 223 | + self.update_parameters(parameter_updates) |
| 224 | + |
| 225 | + # Return the mean loss |
| 226 | + return loss.mean() |
| 227 | + |
| 228 | + def on_after_backward(self) -> None: |
| 229 | + """After-backward pass hook.""" |
| 230 | + self.sparse_autoencoder.post_backwards_hook() |
| 231 | + |
| 232 | + def configure_optimizers(self) -> Optimizer: |
| 233 | + """Configure the optimizer.""" |
| 234 | + return AdamWithReset( |
| 235 | + self.sparse_autoencoder.parameters(), |
| 236 | + named_parameters=self.sparse_autoencoder.named_parameters(), |
| 237 | + has_components_dim=True, |
| 238 | + ) |
| 239 | + |
| 240 | + @property |
| 241 | + def reset_optimizer_parameter_details(self) -> list[ResetOptimizerParameterDetails]: |
| 242 | + """Reset optimizer parameter details.""" |
| 243 | + return self.sparse_autoencoder.reset_optimizer_parameter_details |
0 commit comments