From 802d1c34019b505227a460ae71febc5424ce36c1 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Fri, 14 Feb 2025 15:10:50 -0800 Subject: [PATCH] fix: fix typing and updating mdl for saelens >=5.4.0 --- pyproject.toml | 2 +- sae_bench/custom_saes/pca_sae.py | 8 ++++---- sae_bench/evals/absorption/k_sparse_probing.py | 2 +- sae_bench/evals/mdl/main.py | 8 +++++--- sae_bench/evals/sparse_probing/probe_training.py | 2 +- 5 files changed, 12 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f25b0d3..ab319d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ classifiers = ["Topic :: Scientific/Engineering :: Artificial Intelligence"] [tool.poetry.dependencies] python = "^3.10," -sae_lens = ">=4.4.2" +sae_lens = ">=5.4.0" transformer-lens = ">=2.0.0" torch = ">=2.1.0" einops = ">=0.8.0" diff --git a/sae_bench/custom_saes/pca_sae.py b/sae_bench/custom_saes/pca_sae.py index 79e66f2..0cd2fa9 100644 --- a/sae_bench/custom_saes/pca_sae.py +++ b/sae_bench/custom_saes/pca_sae.py @@ -48,7 +48,7 @@ def save_state_dict(self, file_path: str): """Save the encoder and decoder to a file.""" torch.save( { - "W_enc": self.W_enc.data, + "W_enc": self.W_enc.data, # type: ignore "W_dec": self.W_dec.data, "mean": self.mean.data, }, @@ -58,7 +58,7 @@ def save_state_dict(self, file_path: str): def load_from_file(self, file_path: str): """Load the encoder and decoder from a file.""" state_dict = torch.load(file_path, map_location=self.device) - self.W_enc.data = state_dict["W_enc"] + self.W_enc.data = state_dict["W_enc"] # type: ignore self.W_dec.data = state_dict["W_dec"] self.mean.data = state_dict["mean"] self.normalize_decoder() @@ -137,7 +137,7 @@ def fit_PCA( # Set the learned components pca.mean.data = torch.tensor(ipca.mean_, dtype=torch.float32, device="cpu") - pca.W_enc.data = torch.tensor(ipca.components_, dtype=torch.float32, device="cpu") + pca.W_enc.data = torch.tensor(ipca.components_, dtype=torch.float32, device="cpu") # type: ignore pca.W_dec.data = torch.tensor(ipca.components_.T, dtype=torch.float32, device="cpu") # type: ignore pca.save_state_dict(f"pca_{pca.cfg.model_name}_{pca.cfg.hook_name}.pt") @@ -215,7 +215,7 @@ def fit_PCA_gpu( # Set the learned components pca.mean.data = pca_mean.to(dtype=torch.float32, device="cpu") - pca.W_enc.data = components.float().to(dtype=torch.float32, device="cpu") + pca.W_enc.data = components.float().to(dtype=torch.float32, device="cpu") # type: ignore pca.W_dec.data = components.T.float().to(dtype=torch.float32, device="cpu") pca.save_state_dict(f"pca_{pca.cfg.model_name}_{pca.cfg.hook_name}.pt") diff --git a/sae_bench/evals/absorption/k_sparse_probing.py b/sae_bench/evals/absorption/k_sparse_probing.py index ec231f5..89fdb39 100644 --- a/sae_bench/evals/absorption/k_sparse_probing.py +++ b/sae_bench/evals/absorption/k_sparse_probing.py @@ -164,7 +164,7 @@ def train_k_sparse_probes( ).fit(train_k_x, (train_k_y == label).astype(np.int64)) probe = KSparseProbe( weight=torch.tensor(sk_probe.coef_[0]).float(), - bias=torch.tensor(sk_probe.intercept_[0]).float(), + bias=torch.tensor(sk_probe.intercept_[0]).float(), # type: ignore feature_ids=sparse_feat_ids, ) results[k][label] = probe diff --git a/sae_bench/evals/mdl/main.py b/sae_bench/evals/mdl/main.py index 7b782c5..4b05d1f 100644 --- a/sae_bench/evals/mdl/main.py +++ b/sae_bench/evals/mdl/main.py @@ -104,7 +104,7 @@ def calculate_dl( float_entropy_F = torch.zeros(num_features, device=device, dtype=torch.float32) bool_entropy_F = torch.zeros(num_features, device=device, dtype=torch.float32) - x_BSN = activations_store.get_buffer(config.sae_batch_size) + x_BSN = activations_store.get_buffer(config.sae_batch_size)[0] feature_activations_BsF = sae.encode(x_BSN).squeeze() if feature_activations_BsF.ndim == 2: @@ -235,7 +235,7 @@ def check_quantised_features_reach_mse_threshold( mse_losses: list[torch.Tensor] = [] for i in range(1): - x_BSN = activations_store.get_buffer(config.sae_batch_size) + x_BSN = activations_store.get_buffer(config.sae_batch_size)[0] feature_activations_BSF = sae.encode(x_BSN).squeeze() if k is not None: @@ -337,7 +337,9 @@ def get_min_max_activations() -> tuple[torch.Tensor, torch.Tensor]: max_activations_1F = torch.zeros(1, num_features, device=device) + 100 for _ in range(10): - neuron_activations_BSN = activations_store.get_buffer(config.sae_batch_size) + neuron_activations_BSN = activations_store.get_buffer( + config.sae_batch_size + )[0] feature_activations_BsF = sae.encode(neuron_activations_BSN).squeeze() diff --git a/sae_bench/evals/sparse_probing/probe_training.py b/sae_bench/evals/sparse_probing/probe_training.py index af57032..ac75e85 100644 --- a/sae_bench/evals/sparse_probing/probe_training.py +++ b/sae_bench/evals/sparse_probing/probe_training.py @@ -253,7 +253,7 @@ def train_probe_gpu( print(f"Training probe with dim: {dim}, device: {device}, dtype: {model_dtype}") probe = Probe(dim, model_dtype).to(device) - optimizer = torch.optim.AdamW(probe.parameters(), lr=lr) + optimizer = torch.optim.AdamW(probe.parameters(), lr=lr) # type: ignore criterion = nn.BCEWithLogitsLoss() best_test_accuracy = 0.0