Skip to content

Commit 0888d07

Browse files
authored
Merge pull request #56 from chanind/type-fixes
fix: fix typing and updating mdl for saelens >=5.4.0
2 parents 7ac7ced + 802d1c3 commit 0888d07

File tree

5 files changed

+12
-10
lines changed

5 files changed

+12
-10
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ classifiers = ["Topic :: Scientific/Engineering :: Artificial Intelligence"]
1414

1515
[tool.poetry.dependencies]
1616
python = "^3.10,"
17-
sae_lens = ">=4.4.2"
17+
sae_lens = ">=5.4.0"
1818
transformer-lens = ">=2.0.0"
1919
torch = ">=2.1.0"
2020
einops = ">=0.8.0"

sae_bench/custom_saes/pca_sae.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def save_state_dict(self, file_path: str):
4848
"""Save the encoder and decoder to a file."""
4949
torch.save(
5050
{
51-
"W_enc": self.W_enc.data,
51+
"W_enc": self.W_enc.data, # type: ignore
5252
"W_dec": self.W_dec.data,
5353
"mean": self.mean.data,
5454
},
@@ -58,7 +58,7 @@ def save_state_dict(self, file_path: str):
5858
def load_from_file(self, file_path: str):
5959
"""Load the encoder and decoder from a file."""
6060
state_dict = torch.load(file_path, map_location=self.device)
61-
self.W_enc.data = state_dict["W_enc"]
61+
self.W_enc.data = state_dict["W_enc"] # type: ignore
6262
self.W_dec.data = state_dict["W_dec"]
6363
self.mean.data = state_dict["mean"]
6464
self.normalize_decoder()
@@ -137,7 +137,7 @@ def fit_PCA(
137137

138138
# Set the learned components
139139
pca.mean.data = torch.tensor(ipca.mean_, dtype=torch.float32, device="cpu")
140-
pca.W_enc.data = torch.tensor(ipca.components_, dtype=torch.float32, device="cpu")
140+
pca.W_enc.data = torch.tensor(ipca.components_, dtype=torch.float32, device="cpu") # type: ignore
141141
pca.W_dec.data = torch.tensor(ipca.components_.T, dtype=torch.float32, device="cpu") # type: ignore
142142

143143
pca.save_state_dict(f"pca_{pca.cfg.model_name}_{pca.cfg.hook_name}.pt")
@@ -215,7 +215,7 @@ def fit_PCA_gpu(
215215

216216
# Set the learned components
217217
pca.mean.data = pca_mean.to(dtype=torch.float32, device="cpu")
218-
pca.W_enc.data = components.float().to(dtype=torch.float32, device="cpu")
218+
pca.W_enc.data = components.float().to(dtype=torch.float32, device="cpu") # type: ignore
219219
pca.W_dec.data = components.T.float().to(dtype=torch.float32, device="cpu")
220220

221221
pca.save_state_dict(f"pca_{pca.cfg.model_name}_{pca.cfg.hook_name}.pt")

sae_bench/evals/absorption/k_sparse_probing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def train_k_sparse_probes(
164164
).fit(train_k_x, (train_k_y == label).astype(np.int64))
165165
probe = KSparseProbe(
166166
weight=torch.tensor(sk_probe.coef_[0]).float(),
167-
bias=torch.tensor(sk_probe.intercept_[0]).float(),
167+
bias=torch.tensor(sk_probe.intercept_[0]).float(), # type: ignore
168168
feature_ids=sparse_feat_ids,
169169
)
170170
results[k][label] = probe

sae_bench/evals/mdl/main.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def calculate_dl(
104104
float_entropy_F = torch.zeros(num_features, device=device, dtype=torch.float32)
105105
bool_entropy_F = torch.zeros(num_features, device=device, dtype=torch.float32)
106106

107-
x_BSN = activations_store.get_buffer(config.sae_batch_size)
107+
x_BSN = activations_store.get_buffer(config.sae_batch_size)[0]
108108
feature_activations_BsF = sae.encode(x_BSN).squeeze()
109109

110110
if feature_activations_BsF.ndim == 2:
@@ -235,7 +235,7 @@ def check_quantised_features_reach_mse_threshold(
235235
mse_losses: list[torch.Tensor] = []
236236

237237
for i in range(1):
238-
x_BSN = activations_store.get_buffer(config.sae_batch_size)
238+
x_BSN = activations_store.get_buffer(config.sae_batch_size)[0]
239239
feature_activations_BSF = sae.encode(x_BSN).squeeze()
240240

241241
if k is not None:
@@ -337,7 +337,9 @@ def get_min_max_activations() -> tuple[torch.Tensor, torch.Tensor]:
337337
max_activations_1F = torch.zeros(1, num_features, device=device) + 100
338338

339339
for _ in range(10):
340-
neuron_activations_BSN = activations_store.get_buffer(config.sae_batch_size)
340+
neuron_activations_BSN = activations_store.get_buffer(
341+
config.sae_batch_size
342+
)[0]
341343

342344
feature_activations_BsF = sae.encode(neuron_activations_BSN).squeeze()
343345

sae_bench/evals/sparse_probing/probe_training.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def train_probe_gpu(
253253
print(f"Training probe with dim: {dim}, device: {device}, dtype: {model_dtype}")
254254

255255
probe = Probe(dim, model_dtype).to(device)
256-
optimizer = torch.optim.AdamW(probe.parameters(), lr=lr)
256+
optimizer = torch.optim.AdamW(probe.parameters(), lr=lr) # type: ignore
257257
criterion = nn.BCEWithLogitsLoss()
258258

259259
best_test_accuracy = 0.0

0 commit comments

Comments
 (0)