|
7 | 7 | from pydantic import Field, NonNegativeInt, PositiveInt, validate_call
|
8 | 8 | import torch
|
9 | 9 | from torch import Tensor
|
10 |
| -from torch.nn.parallel import DataParallel |
11 | 10 | from torch.utils.data import DataLoader
|
12 | 11 | from torchmetrics import Metric
|
13 | 12 |
|
|
18 | 17 | from sparse_autoencoder.autoencoder.model import SparseAutoencoder
|
19 | 18 | from sparse_autoencoder.tensor_types import Axis
|
20 | 19 | from sparse_autoencoder.train.utils.get_model_device import get_model_device
|
| 20 | +from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes |
21 | 21 |
|
22 | 22 |
|
23 | 23 | @dataclass
|
@@ -207,7 +207,7 @@ def _get_dead_neuron_indices(
|
207 | 207 | def compute_loss_and_get_activations(
|
208 | 208 | self,
|
209 | 209 | store: ActivationStore,
|
210 |
| - autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder], |
| 210 | + autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder], |
211 | 211 | loss_fn: Metric,
|
212 | 212 | train_batch_size: int,
|
213 | 213 | ) -> LossInputActivationsTuple:
|
@@ -440,7 +440,7 @@ def renormalize_and_scale(
|
440 | 440 | def resample_dead_neurons(
|
441 | 441 | self,
|
442 | 442 | activation_store: ActivationStore,
|
443 |
| - autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder], |
| 443 | + autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder], |
444 | 444 | loss_fn: Metric,
|
445 | 445 | train_batch_size: int,
|
446 | 446 | ) -> list[ParameterUpdateResults]:
|
@@ -530,7 +530,7 @@ def step_resampler(
|
530 | 530 | self,
|
531 | 531 | batch_neuron_activity: Int64[Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)],
|
532 | 532 | activation_store: ActivationStore,
|
533 |
| - autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder], |
| 533 | + autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder], |
534 | 534 | loss_fn: Metric,
|
535 | 535 | train_batch_size: int,
|
536 | 536 | ) -> list[ParameterUpdateResults] | None:
|
|
0 commit comments