Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix typing and updating mdl for saelens >=5.4.0 #56

Merged
merged 1 commit into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ classifiers = ["Topic :: Scientific/Engineering :: Artificial Intelligence"]

[tool.poetry.dependencies]
python = "^3.10,"
sae_lens = ">=4.4.2"
sae_lens = ">=5.4.0"
transformer-lens = ">=2.0.0"
torch = ">=2.1.0"
einops = ">=0.8.0"
Expand Down
8 changes: 4 additions & 4 deletions sae_bench/custom_saes/pca_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def save_state_dict(self, file_path: str):
"""Save the encoder and decoder to a file."""
torch.save(
{
"W_enc": self.W_enc.data,
"W_enc": self.W_enc.data, # type: ignore
"W_dec": self.W_dec.data,
"mean": self.mean.data,
},
Expand All @@ -58,7 +58,7 @@ def save_state_dict(self, file_path: str):
def load_from_file(self, file_path: str):
"""Load the encoder and decoder from a file."""
state_dict = torch.load(file_path, map_location=self.device)
self.W_enc.data = state_dict["W_enc"]
self.W_enc.data = state_dict["W_enc"] # type: ignore
self.W_dec.data = state_dict["W_dec"]
self.mean.data = state_dict["mean"]
self.normalize_decoder()
Expand Down Expand Up @@ -137,7 +137,7 @@ def fit_PCA(

# Set the learned components
pca.mean.data = torch.tensor(ipca.mean_, dtype=torch.float32, device="cpu")
pca.W_enc.data = torch.tensor(ipca.components_, dtype=torch.float32, device="cpu")
pca.W_enc.data = torch.tensor(ipca.components_, dtype=torch.float32, device="cpu") # type: ignore
pca.W_dec.data = torch.tensor(ipca.components_.T, dtype=torch.float32, device="cpu") # type: ignore

pca.save_state_dict(f"pca_{pca.cfg.model_name}_{pca.cfg.hook_name}.pt")
Expand Down Expand Up @@ -215,7 +215,7 @@ def fit_PCA_gpu(

# Set the learned components
pca.mean.data = pca_mean.to(dtype=torch.float32, device="cpu")
pca.W_enc.data = components.float().to(dtype=torch.float32, device="cpu")
pca.W_enc.data = components.float().to(dtype=torch.float32, device="cpu") # type: ignore
pca.W_dec.data = components.T.float().to(dtype=torch.float32, device="cpu")

pca.save_state_dict(f"pca_{pca.cfg.model_name}_{pca.cfg.hook_name}.pt")
Expand Down
2 changes: 1 addition & 1 deletion sae_bench/evals/absorption/k_sparse_probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def train_k_sparse_probes(
).fit(train_k_x, (train_k_y == label).astype(np.int64))
probe = KSparseProbe(
weight=torch.tensor(sk_probe.coef_[0]).float(),
bias=torch.tensor(sk_probe.intercept_[0]).float(),
bias=torch.tensor(sk_probe.intercept_[0]).float(), # type: ignore
feature_ids=sparse_feat_ids,
)
results[k][label] = probe
Expand Down
8 changes: 5 additions & 3 deletions sae_bench/evals/mdl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def calculate_dl(
float_entropy_F = torch.zeros(num_features, device=device, dtype=torch.float32)
bool_entropy_F = torch.zeros(num_features, device=device, dtype=torch.float32)

x_BSN = activations_store.get_buffer(config.sae_batch_size)
x_BSN = activations_store.get_buffer(config.sae_batch_size)[0]
feature_activations_BsF = sae.encode(x_BSN).squeeze()

if feature_activations_BsF.ndim == 2:
Expand Down Expand Up @@ -235,7 +235,7 @@ def check_quantised_features_reach_mse_threshold(
mse_losses: list[torch.Tensor] = []

for i in range(1):
x_BSN = activations_store.get_buffer(config.sae_batch_size)
x_BSN = activations_store.get_buffer(config.sae_batch_size)[0]
feature_activations_BSF = sae.encode(x_BSN).squeeze()

if k is not None:
Expand Down Expand Up @@ -337,7 +337,9 @@ def get_min_max_activations() -> tuple[torch.Tensor, torch.Tensor]:
max_activations_1F = torch.zeros(1, num_features, device=device) + 100

for _ in range(10):
neuron_activations_BSN = activations_store.get_buffer(config.sae_batch_size)
neuron_activations_BSN = activations_store.get_buffer(
config.sae_batch_size
)[0]

feature_activations_BsF = sae.encode(neuron_activations_BSN).squeeze()

Expand Down
2 changes: 1 addition & 1 deletion sae_bench/evals/sparse_probing/probe_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def train_probe_gpu(
print(f"Training probe with dim: {dim}, device: {device}, dtype: {model_dtype}")

probe = Probe(dim, model_dtype).to(device)
optimizer = torch.optim.AdamW(probe.parameters(), lr=lr)
optimizer = torch.optim.AdamW(probe.parameters(), lr=lr) # type: ignore
criterion = nn.BCEWithLogitsLoss()

best_test_accuracy = 0.0
Expand Down
Loading