Skip to content

Commit 100278a

Browse files
committed
Add DataParallelWithModelAttributes to the resampler
1 parent a89abbc commit 100278a

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

sparse_autoencoder/activation_resampler/activation_resampler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pydantic import Field, NonNegativeInt, PositiveInt, validate_call
88
import torch
99
from torch import Tensor
10-
from torch.nn.parallel import DataParallel
1110
from torch.utils.data import DataLoader
1211
from torchmetrics import Metric
1312

@@ -18,6 +17,7 @@
1817
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
1918
from sparse_autoencoder.tensor_types import Axis
2019
from sparse_autoencoder.train.utils.get_model_device import get_model_device
20+
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes
2121

2222

2323
@dataclass
@@ -207,7 +207,7 @@ def _get_dead_neuron_indices(
207207
def compute_loss_and_get_activations(
208208
self,
209209
store: ActivationStore,
210-
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder],
210+
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
211211
loss_fn: Metric,
212212
train_batch_size: int,
213213
) -> LossInputActivationsTuple:
@@ -440,7 +440,7 @@ def renormalize_and_scale(
440440
def resample_dead_neurons(
441441
self,
442442
activation_store: ActivationStore,
443-
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder],
443+
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
444444
loss_fn: Metric,
445445
train_batch_size: int,
446446
) -> list[ParameterUpdateResults]:
@@ -530,7 +530,7 @@ def step_resampler(
530530
self,
531531
batch_neuron_activity: Int64[Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)],
532532
activation_store: ActivationStore,
533-
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder],
533+
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
534534
loss_fn: Metric,
535535
train_batch_size: int,
536536
) -> list[ParameterUpdateResults] | None:

0 commit comments

Comments
 (0)