Skip to content

Commit f005cc1

Browse files
authored
Use lightning for DDP training & mixed precision (#199)
Supports DDP by default on multi-gpu machines
1 parent e3b6102 commit f005cc1

File tree

14 files changed

+735
-767
lines changed

14 files changed

+735
-767
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ docs/content/reference
143143
wandb/
144144
artifacts/
145145

146+
# Lightning
147+
lightning_logs
148+
146149
# Scratch files
147150
scratch.py
148151
scratch.ipynb

.vscode/cspell.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
"uncopyrighted",
118118
"ungraphed",
119119
"unsqueeze",
120+
"unsync",
120121
"venv",
121122
"virtualenv",
122123
"virtualenvs",

sparse_autoencoder/activation_resampler/activation_resampler.py

Lines changed: 220 additions & 264 deletions
Large diffs are not rendered by default.

sparse_autoencoder/activation_resampler/tests/test_activation_resampler.py

Lines changed: 92 additions & 129 deletions
Large diffs are not rendered by default.
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
"""PyTorch Lightning module for training a sparse autoencoder."""
2+
from functools import partial
3+
from typing import Any
4+
5+
from jaxtyping import Float
6+
from lightning.pytorch import LightningModule
7+
from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt
8+
from torch import Tensor
9+
from torch.optim.optimizer import Optimizer
10+
from torchmetrics import MetricCollection
11+
import wandb
12+
13+
from sparse_autoencoder.activation_resampler.activation_resampler import (
14+
ActivationResampler,
15+
ParameterUpdateResults,
16+
)
17+
from sparse_autoencoder.autoencoder.model import (
18+
ForwardPassResult,
19+
SparseAutoencoder,
20+
SparseAutoencoderConfig,
21+
)
22+
from sparse_autoencoder.autoencoder.types import ResetOptimizerParameterDetails
23+
from sparse_autoencoder.metrics.loss.l1_absolute_loss import L1AbsoluteLoss
24+
from sparse_autoencoder.metrics.loss.l2_reconstruction_loss import L2ReconstructionLoss
25+
from sparse_autoencoder.metrics.loss.sae_loss import SparseAutoencoderLoss
26+
from sparse_autoencoder.metrics.train.l0_norm import L0NormMetric
27+
from sparse_autoencoder.metrics.train.neuron_activity import NeuronActivityMetric
28+
from sparse_autoencoder.metrics.wrappers.classwise import ClasswiseWrapperWithMean
29+
from sparse_autoencoder.optimizer.adam_with_reset import AdamWithReset
30+
from sparse_autoencoder.tensor_types import Axis
31+
32+
33+
class LitSparseAutoencoderConfig(SparseAutoencoderConfig):
34+
"""PyTorch Lightning Sparse Autoencoder config."""
35+
36+
component_names: list[str]
37+
38+
l1_coefficient: float = 0.001
39+
40+
resample_interval: PositiveInt = 200000000
41+
42+
max_n_resamples: NonNegativeInt = 4
43+
44+
resample_dead_neurons_dataset_size: PositiveInt = 100000000
45+
46+
resample_loss_dataset_size: PositiveInt = 819200
47+
48+
resample_threshold_is_dead_portion_fires: NonNegativeFloat = 0.0
49+
50+
def model_post_init(self, __context: Any) -> None: # noqa: ANN401
51+
"""Model post init validation.
52+
53+
Args:
54+
__context: Pydantic context.
55+
56+
Raises:
57+
ValueError: If the number of component names does not match the number of components.
58+
"""
59+
if self.n_components and len(self.component_names) != self.n_components:
60+
error_message = (
61+
f"Number of component names ({len(self.component_names)}) must match the number of "
62+
f"components ({self.n_components})"
63+
)
64+
raise ValueError(error_message)
65+
66+
67+
class LitSparseAutoencoder(LightningModule):
68+
"""Lightning Sparse Autoencoder."""
69+
70+
sparse_autoencoder: SparseAutoencoder
71+
72+
config: LitSparseAutoencoderConfig
73+
74+
loss_fn: SparseAutoencoderLoss
75+
76+
train_metrics: MetricCollection
77+
78+
def __init__(
79+
self,
80+
config: LitSparseAutoencoderConfig,
81+
):
82+
"""Initialise the module."""
83+
super().__init__()
84+
self.sparse_autoencoder = SparseAutoencoder(config)
85+
self.config = config
86+
87+
num_components = config.n_components or 1
88+
add_component_names = partial(
89+
ClasswiseWrapperWithMean, component_names=config.component_names
90+
)
91+
92+
# Create the loss & metrics
93+
self.loss_fn = SparseAutoencoderLoss(
94+
num_components, config.l1_coefficient, keep_batch_dim=True
95+
)
96+
97+
self.train_metrics = MetricCollection(
98+
{
99+
"l0": add_component_names(L0NormMetric(num_components), prefix="train/l0_norm"),
100+
"activity": add_component_names(
101+
NeuronActivityMetric(config.n_learned_features, num_components),
102+
prefix="train/neuron_activity",
103+
),
104+
"l1": add_component_names(
105+
L1AbsoluteLoss(num_components), prefix="loss/l1_learned_activations"
106+
),
107+
"l2": add_component_names(
108+
L2ReconstructionLoss(num_components), prefix="loss/l2_reconstruction"
109+
),
110+
"loss": add_component_names(
111+
SparseAutoencoderLoss(num_components, config.l1_coefficient),
112+
prefix="loss/total",
113+
),
114+
},
115+
# Share state & updates across groups (to avoid e.g. computing l1 twice for both the
116+
# loss and l1 metrics). Note the metric that goes first must calculate all the states
117+
# needed by the rest of the group.
118+
compute_groups=[
119+
["loss", "l1", "l2"],
120+
["activity"],
121+
["l0"],
122+
],
123+
)
124+
125+
self.activation_resampler = ActivationResampler(
126+
n_learned_features=config.n_learned_features,
127+
n_components=num_components,
128+
resample_interval=config.resample_interval,
129+
max_n_resamples=config.max_n_resamples,
130+
n_activations_activity_collate=config.resample_dead_neurons_dataset_size,
131+
resample_dataset_size=config.resample_loss_dataset_size,
132+
threshold_is_dead_portion_fires=config.resample_threshold_is_dead_portion_fires,
133+
)
134+
135+
def forward( # type: ignore[override]
136+
self,
137+
inputs: Float[
138+
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
139+
],
140+
) -> ForwardPassResult:
141+
"""Forward pass."""
142+
return self.sparse_autoencoder.forward(inputs)
143+
144+
def update_parameters(self, parameter_updates: list[ParameterUpdateResults]) -> None:
145+
"""Update the parameters of the model from the results of the resampler.
146+
147+
Args:
148+
parameter_updates: Parameter updates from the resampler.
149+
150+
Raises:
151+
TypeError: If the optimizer is not an AdamWithReset.
152+
"""
153+
for component_idx, component_parameter_update in enumerate(parameter_updates):
154+
# Update the weights and biases
155+
self.sparse_autoencoder.encoder.update_dictionary_vectors(
156+
component_parameter_update.dead_neuron_indices,
157+
component_parameter_update.dead_encoder_weight_updates,
158+
component_idx=component_idx,
159+
)
160+
self.sparse_autoencoder.encoder.update_bias(
161+
component_parameter_update.dead_neuron_indices,
162+
component_parameter_update.dead_encoder_bias_updates,
163+
component_idx=component_idx,
164+
)
165+
self.sparse_autoencoder.decoder.update_dictionary_vectors(
166+
component_parameter_update.dead_neuron_indices,
167+
component_parameter_update.dead_decoder_weight_updates,
168+
component_idx=component_idx,
169+
)
170+
171+
# Reset the optimizer
172+
for (
173+
parameter,
174+
axis,
175+
) in self.reset_optimizer_parameter_details:
176+
optimizer = self.optimizers(use_pl_optimizer=False)
177+
if not isinstance(optimizer, AdamWithReset):
178+
error_message = "Cannot reset the optimizer. "
179+
raise TypeError(error_message)
180+
181+
optimizer.reset_neurons_state(
182+
parameter=parameter,
183+
neuron_indices=component_parameter_update.dead_neuron_indices,
184+
axis=axis,
185+
component_idx=component_idx,
186+
)
187+
188+
def training_step( # type: ignore[override]
189+
self,
190+
batch: Float[
191+
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
192+
],
193+
batch_idx: int | None = None, # noqa: ARG002
194+
) -> Float[Tensor, Axis.SINGLE_ITEM]:
195+
"""Training step."""
196+
# Forward pass
197+
output: ForwardPassResult = self.forward(batch)
198+
199+
# Metrics & loss
200+
train_metrics = self.train_metrics.forward(
201+
source_activations=batch,
202+
learned_activations=output.learned_activations,
203+
decoded_activations=output.decoded_activations,
204+
)
205+
206+
loss = self.loss_fn.forward(
207+
source_activations=batch,
208+
learned_activations=output.learned_activations,
209+
decoded_activations=output.decoded_activations,
210+
)
211+
212+
if wandb.run is not None:
213+
self.log_dict(train_metrics)
214+
215+
# Resample dead neurons
216+
parameter_updates = self.activation_resampler.forward(
217+
input_activations=batch,
218+
learned_activations=output.learned_activations,
219+
loss=loss,
220+
encoder_weight_reference=self.sparse_autoencoder.encoder.weight,
221+
)
222+
if parameter_updates is not None:
223+
self.update_parameters(parameter_updates)
224+
225+
# Return the mean loss
226+
return loss.mean()
227+
228+
def on_after_backward(self) -> None:
229+
"""After-backward pass hook."""
230+
self.sparse_autoencoder.post_backwards_hook()
231+
232+
def configure_optimizers(self) -> Optimizer:
233+
"""Configure the optimizer."""
234+
return AdamWithReset(
235+
self.sparse_autoencoder.parameters(),
236+
named_parameters=self.sparse_autoencoder.named_parameters(),
237+
has_components_dim=True,
238+
)
239+
240+
@property
241+
def reset_optimizer_parameter_details(self) -> list[ResetOptimizerParameterDetails]:
242+
"""Reset optimizer parameter details."""
243+
return self.sparse_autoencoder.reset_optimizer_parameter_details

sparse_autoencoder/autoencoder/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from sparse_autoencoder.utils.tensor_shape import shape_with_optional_dimensions
2727

2828

29-
class SparseAutoencoderConfig(BaseModel, frozen=True):
29+
class SparseAutoencoderConfig(BaseModel):
3030
"""SAE model config."""
3131

3232
n_input_features: PositiveInt

sparse_autoencoder/metrics/loss/sae_loss.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,22 @@ def compute(self) -> Tensor:
142142
)
143143

144144
return l1 * self._l1_coefficient + l2
145+
146+
def forward( # type: ignore[override] (narrowing)
147+
self,
148+
source_activations: Float[
149+
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
150+
],
151+
learned_activations: Float[
152+
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
153+
],
154+
decoded_activations: Float[
155+
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
156+
],
157+
) -> Tensor:
158+
"""Forward pass."""
159+
return super().forward(
160+
source_activations=source_activations,
161+
learned_activations=learned_activations,
162+
decoded_activations=decoded_activations,
163+
)

sparse_autoencoder/source_model/replace_activations_hook.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch.nn.parallel import DataParallel
88
from transformer_lens.hook_points import HookPoint
99

10+
from sparse_autoencoder.autoencoder.lightning import LitSparseAutoencoder
1011
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
1112

1213

@@ -17,7 +18,10 @@
1718
def replace_activations_hook(
1819
value: Tensor,
1920
hook: HookPoint, # noqa: ARG001
20-
sparse_autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder] | Module,
21+
sparse_autoencoder: SparseAutoencoder
22+
| DataParallel[SparseAutoencoder]
23+
| LitSparseAutoencoder
24+
| Module,
2125
component_idx: int | None = None,
2226
n_components: int | None = None,
2327
) -> Tensor:

0 commit comments

Comments
 (0)