Skip to content

Commit

Permalink
Add DataParallelWithModelAttributes to the resampler
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney committed Jan 29, 2024
1 parent a89abbc commit 100278a
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pydantic import Field, NonNegativeInt, PositiveInt, validate_call
import torch
from torch import Tensor
from torch.nn.parallel import DataParallel
from torch.utils.data import DataLoader
from torchmetrics import Metric

Expand All @@ -18,6 +17,7 @@
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
from sparse_autoencoder.tensor_types import Axis
from sparse_autoencoder.train.utils.get_model_device import get_model_device
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes


@dataclass
Expand Down Expand Up @@ -207,7 +207,7 @@ def _get_dead_neuron_indices(
def compute_loss_and_get_activations(
self,
store: ActivationStore,
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder],
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
loss_fn: Metric,
train_batch_size: int,
) -> LossInputActivationsTuple:
Expand Down Expand Up @@ -440,7 +440,7 @@ def renormalize_and_scale(
def resample_dead_neurons(
self,
activation_store: ActivationStore,
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder],
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
loss_fn: Metric,
train_batch_size: int,
) -> list[ParameterUpdateResults]:
Expand Down Expand Up @@ -530,7 +530,7 @@ def step_resampler(
self,
batch_neuron_activity: Int64[Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)],
activation_store: ActivationStore,
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder],
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
loss_fn: Metric,
train_batch_size: int,
) -> list[ParameterUpdateResults] | None:
Expand Down

0 comments on commit 100278a

Please sign in to comment.