forked from adamkarvonen/SAEBench
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathk_sparse_probing.py
494 lines (454 loc) · 17.4 KB
/
k_sparse_probing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
from collections import defaultdict
from collections.abc import Sequence
from pathlib import Path
from typing import Callable
import numpy as np
import pandas as pd
import torch
from sae_lens import SAE
from sklearn import metrics
from sklearn.linear_model import LogisticRegression
from torch import nn
from tqdm.autonotebook import tqdm
from transformer_lens import HookedTransformer
from sae_bench.evals.absorption.common import (
PROBES_DIR,
RESULTS_DIR,
get_or_make_dir,
load_df_or_run,
load_dfs_or_run,
load_or_train_probe,
load_probe_data_split_or_train,
)
from sae_bench.evals.absorption.probing import LinearProbe, train_multi_probe
from sae_bench.evals.absorption.util import batchify
from sae_bench.evals.absorption.vocab import LETTERS
EPS = 1e-6
SPARSE_PROBING_EXPERIMENT_NAME = "k_sparse_probing"
class KSparseProbe(nn.Module):
weight: torch.Tensor # shape (k)
bias: torch.Tensor # scalar
feature_ids: torch.Tensor # shape (k)
def __init__(
self, weight: torch.Tensor, bias: torch.Tensor, feature_ids: torch.Tensor
):
super().__init__()
self.weight = weight
self.bias = bias
self.feature_ids = feature_ids
@property
def k(self) -> int:
return self.weight.shape[0]
def forward(self, x: torch.Tensor) -> torch.Tensor:
filtered_acts = (
x[:, self.feature_ids] if len(x.shape) == 2 else x[self.feature_ids]
)
return filtered_acts @ self.weight + self.bias
def train_sparse_multi_probe(
x_train: torch.Tensor, # tensor of shape (num_samples, input_dim)
y_train: torch.Tensor, # tensor of shape (num_samples, num_probes), with values in [0, 1]
device: torch.device,
l1_decay: float = 0.01, # l1 regularization strength
num_probes: int | None = None, # inferred from y_train if None
batch_size: int = 4096,
num_epochs: int = 50,
lr: float = 0.01,
end_lr: float = 1e-5,
l2_decay: float = 1e-6,
show_progress: bool = True,
verbose: bool = False,
map_acts: Callable[[torch.Tensor], torch.Tensor] | None = None,
probe_dim: int | None = None,
) -> LinearProbe:
"""
Train a multi-probe with L1 regularization on the weights.
"""
return train_multi_probe(
x_train,
y_train,
num_probes=num_probes,
batch_size=batch_size,
num_epochs=num_epochs,
lr=lr,
end_lr=end_lr,
weight_decay=l2_decay,
show_progress=show_progress,
verbose=verbose,
device=device,
extra_loss_fn=lambda probe, _x, _y: l1_decay
* probe.weights.abs().sum(dim=-1).mean(),
map_acts=map_acts,
probe_dim=probe_dim,
)
def _get_sae_acts(
sae: SAE,
input_activations: torch.Tensor,
batch_size: int = 4096,
sparse_feat_ids: torch.Tensor | None = None,
) -> torch.Tensor:
batch_acts = []
for batch in batchify(input_activations, batch_size):
acts = sae.encode(batch.to(device=sae.device, dtype=sae.dtype))
if sparse_feat_ids is not None:
acts = acts[:, sparse_feat_ids]
batch_acts.append(acts.cpu())
return torch.cat(batch_acts)
def train_k_sparse_probes(
sae: SAE,
train_labels: list[tuple[str, int]], # list of (token, letter number) pairs
train_activations: torch.Tensor, # n_vocab X d_model
ks: Sequence[int],
l1_decay: float = 0.01,
batch_size: int = 4096,
num_epochs: int = 50,
) -> dict[int, dict[int, KSparseProbe]]: # dict[k, dict[letter_id, probe]]
"""
Train k-sparse probes for each k in ks.
Returns a dict of dicts, where the outer dict is indexed by k and the inner dict is the label.
"""
results: dict[int, dict[int, KSparseProbe]] = defaultdict(dict)
with torch.no_grad():
labels = {label for _, label in train_labels}
sparse_train_y = torch.nn.functional.one_hot(
torch.tensor([idx for _, idx in train_labels])
)
train_activations = train_activations.to(sae.device, dtype=sae.dtype)
l1_probe = (
train_sparse_multi_probe(
train_activations,
sparse_train_y,
l1_decay=l1_decay,
num_epochs=num_epochs,
batch_size=batch_size,
device=sae.device,
map_acts=lambda acts: sae.encode(acts.to(sae.device, dtype=sae.dtype)),
probe_dim=sae.cfg.d_sae,
)
.float()
.cpu()
)
with torch.no_grad():
train_k_y = np.array([idx for _, idx in train_labels])
with tqdm(total=len(ks) * len(labels), desc="training k-probes") as pbar:
for k in ks:
for label in labels:
# using topk and not abs() because we only want features that directly predict the label
sparse_feat_ids = l1_probe.weights[label].topk(k).indices
train_k_x = (
_get_sae_acts(
sae,
train_activations,
sparse_feat_ids=sparse_feat_ids,
batch_size=batch_size,
)
.float()
.numpy()
)
# Use SKLearn here because it's much faster than torch if the data is small
sk_probe = LogisticRegression(
max_iter=500, class_weight="balanced"
).fit(train_k_x, (train_k_y == label).astype(np.int64))
probe = KSparseProbe(
weight=torch.tensor(sk_probe.coef_[0]).float(),
bias=torch.tensor(sk_probe.intercept_[0]).float(), # type: ignore
feature_ids=sparse_feat_ids,
)
results[k][label] = probe
pbar.update(1)
return results
@torch.inference_mode()
def sae_k_sparse_metadata(
sae: SAE,
probe: LinearProbe,
k_sparse_probes: dict[int, dict[int, KSparseProbe]],
sae_name: str,
layer: int,
) -> pd.DataFrame:
norm_probe_weights = probe.weights / torch.norm(probe.weights, dim=-1, keepdim=True)
norm_W_enc = sae.W_enc / torch.norm(sae.W_enc, dim=0, keepdim=True)
norm_W_dec = sae.W_dec / torch.norm(sae.W_dec, dim=-1, keepdim=True)
probe_dec_cos = (
(
norm_probe_weights.to(dtype=norm_W_dec.dtype, device=norm_W_dec.device)
@ norm_W_dec.T
)
.cpu()
.float()
)
probe_enc_cos = (
(
norm_probe_weights.to(dtype=norm_W_enc.dtype, device=norm_W_enc.device)
@ norm_W_enc
)
.cpu()
.float()
)
metadata: dict[str, float | str | float | np.ndarray] = {
"layer": layer,
"sae_name": sae_name,
}
rows = []
for letter_i, letter in enumerate(LETTERS):
for k, k_probes in k_sparse_probes.items():
row = {**metadata}
k_probe = k_probes[letter_i]
row["letter"] = letter
row["k"] = k
row["feats"] = k_probe.feature_ids.numpy()
row["cos_probe_sae_enc"] = probe_enc_cos[
letter_i, k_probe.feature_ids
].numpy()
row["cos_probe_sae_dec"] = probe_dec_cos[
letter_i, k_probe.feature_ids
].numpy()
row["weights"] = k_probe.weight.float().numpy()
row["bias"] = k_probe.bias.item()
rows.append(row)
return pd.DataFrame(rows)
@torch.inference_mode()
def eval_probe_and_sae_k_sparse_raw_scores(
sae: SAE,
probe: LinearProbe,
k_sparse_probes: dict[int, dict[int, KSparseProbe]],
eval_labels: list[tuple[str, int]], # list of (token, letter number) pairs
eval_activations: torch.Tensor, # n_vocab X d_model
) -> pd.DataFrame:
probe = probe.to("cpu")
# using a generator to avoid storing all the rows in memory
def row_generator():
for token_act, (token, answer_idx) in tqdm(
zip(eval_activations, eval_labels), total=len(eval_labels)
):
probe_scores = probe(token_act).tolist()
row: dict[str, float | str | int | np.ndarray] = {
"token": token,
"answer_letter": LETTERS[answer_idx],
}
sae_acts = (
_get_sae_acts(sae, token_act.unsqueeze(0).to(sae.device)).float().cpu()
).squeeze()
for letter_i, (letter, probe_score) in enumerate(
zip(LETTERS, probe_scores)
):
row[f"score_probe_{letter}"] = probe_score
for k, k_probes in k_sparse_probes.items():
k_probe = k_probes[letter_i]
k_probe_score = k_probe(sae_acts)
sparse_acts = sae_acts[k_probe.feature_ids]
row[f"score_sparse_sae_{letter}_k_{k}"] = k_probe_score.item()
row[f"sum_sparse_sae_{letter}_k_{k}"] = sparse_acts.sum().item()
row[f"sparse_sae_{letter}_k_{k}_acts"] = sparse_acts.numpy()
yield row
return pd.DataFrame(row_generator())
def load_and_run_eval_probe_and_sae_k_sparse_raw_scores(
sae: SAE,
model: HookedTransformer,
layer: int,
sae_name: str,
max_k_value: int,
prompt_template: str,
prompt_token_pos: int,
probes_dir: Path | str,
device: str,
verbose: bool = True,
k_sparse_probe_l1_decay: float = 0.01,
k_sparse_probe_batch_size: int = 4096,
k_sparse_probe_num_epochs: int = 50,
) -> tuple[pd.DataFrame, pd.DataFrame]:
if verbose:
print("Loading probe and training data", flush=True)
probe = load_or_train_probe(
model=model,
layer=layer,
probes_dir=probes_dir,
base_template=prompt_template,
pos_idx=prompt_token_pos,
device=device,
)
train_activations, train_data = load_probe_data_split_or_train(
model,
base_template=prompt_template,
pos_idx=prompt_token_pos,
probes_dir=probes_dir,
layer=layer,
split="train",
device="cpu",
)
if verbose:
print("Training k-sparse probes", flush=True)
k_sparse_probes = train_k_sparse_probes(
sae,
train_data,
train_activations,
ks=list(range(1, max_k_value + 1)),
l1_decay=k_sparse_probe_l1_decay,
batch_size=k_sparse_probe_batch_size,
num_epochs=k_sparse_probe_num_epochs,
)
with torch.no_grad():
if verbose:
print("Loading validation data", flush=True)
eval_activations, eval_data = load_probe_data_split_or_train(
model,
base_template=prompt_template,
pos_idx=prompt_token_pos,
probes_dir=probes_dir,
layer=layer,
split="test",
device="cpu",
)
if verbose:
print("Evaluating raw k-sparse probing scores", flush=True)
df = eval_probe_and_sae_k_sparse_raw_scores(
sae,
probe,
k_sparse_probes=k_sparse_probes,
eval_labels=eval_data,
eval_activations=eval_activations,
)
if verbose:
print("Building metadata", flush=True)
metadata = sae_k_sparse_metadata(
sae,
probe,
k_sparse_probes,
sae_name=sae_name,
layer=layer,
)
return df, metadata
def build_metrics_df(results_df, metadata_df, max_k_value: int):
aucs = []
for letter in LETTERS:
y = (results_df["answer_letter"] == letter).values
pred_probe = results_df[f"score_probe_{letter}"].values
auc_probe = metrics.roc_auc_score(y, pred_probe)
f1_probe = metrics.f1_score(y, pred_probe > 0.0)
recall_probe = metrics.recall_score(y, pred_probe > 0.0)
precision_probe = metrics.precision_score(y, pred_probe > 0.0)
auc_info = {
"auc_probe": auc_probe,
"f1_probe": f1_probe,
"recall_probe": recall_probe,
"precision_probe": precision_probe,
"letter": letter,
"layer": metadata_df["layer"].iloc[0],
"sae_name": metadata_df["sae_name"].iloc[0],
}
for k in range(1, max_k_value + 1):
pred_sae = results_df[f"score_sparse_sae_{letter}_k_{k}"].values
auc_sae = metrics.roc_auc_score(y, pred_sae)
f1 = metrics.f1_score(y, pred_sae > 0.0)
recall = metrics.recall_score(y, pred_sae > 0.0)
precision = metrics.precision_score(y, pred_sae > 0.0)
auc_info[f"auc_sparse_sae_{k}"] = auc_sae
sum_sae_pred = results_df[f"sum_sparse_sae_{letter}_k_{k}"].values
auc_sum_sae = metrics.roc_auc_score(y, sum_sae_pred)
f1_sum_sae = metrics.f1_score(y, sum_sae_pred > EPS)
recall_sum_sae = metrics.recall_score(y, sum_sae_pred > EPS)
precision_sum_sae = metrics.precision_score(y, sum_sae_pred > EPS)
auc_info[f"f1_sparse_sae_{k}"] = f1
auc_info[f"recall_sparse_sae_{k}"] = recall
auc_info[f"precision_sparse_sae_{k}"] = precision
auc_info[f"auc_sum_sparse_sae_{k}"] = auc_sum_sae
auc_info[f"f1_sum_sparse_sae_{k}"] = f1_sum_sae
auc_info[f"recall_sum_sparse_sae_{k}"] = recall_sum_sae
auc_info[f"precision_sum_sparse_sae_{k}"] = precision_sum_sae
meta_row = metadata_df[
(metadata_df["letter"] == letter) & (metadata_df["k"] == k)
]
auc_info[f"sparse_sae_k_{k}_feats"] = meta_row["feats"].iloc[0]
auc_info[f"cos_probe_sae_enc_k_{k}"] = meta_row["cos_probe_sae_enc"].iloc[0]
auc_info[f"cos_probe_sae_dec_k_{k}"] = meta_row["cos_probe_sae_dec"].iloc[0]
auc_info[f"sparse_sae_k_{k}_weights"] = meta_row["weights"].iloc[0]
auc_info[f"sparse_sae_k_{k}_bias"] = meta_row["bias"].iloc[0]
auc_info["layer"] = meta_row["layer"].iloc[0]
auc_info["sae_name"] = meta_row["sae_name"].iloc[0]
aucs.append(auc_info)
return pd.DataFrame(aucs)
def add_feature_splits_to_metrics_df(
df: pd.DataFrame,
max_k_value: int,
f1_jump_threshold: float = 0.03,
) -> None:
"""
If a k-sparse probe has a F1 score that increases by `f1_jump_threshold` or more from the previous k-1, consider this to be feature splitting.
"""
split_feats_by_letter = {}
for letter in LETTERS:
prev_best = -100
df_letter = df[df["letter"] == letter]
for k in range(1, max_k_value + 1):
k_score = df_letter[f"f1_sparse_sae_{k}"].iloc[0] # type: ignore
k_feats = df_letter[f"sparse_sae_k_{k}_feats"].iloc[0].tolist() # type: ignore
if k_score > prev_best + f1_jump_threshold:
prev_best = k_score
split_feats_by_letter[letter] = k_feats
else:
break
df["split_feats"] = df["letter"].apply(
lambda letter: split_feats_by_letter.get(letter, [])
)
df["num_split_features"] = df["split_feats"].apply(len) - 1
def get_sparse_probing_raw_results_filename(sae_name: str, layer: int) -> str:
return f"layer_{layer}_{sae_name}_raw_results.parquet"
def get_sparse_probing_metadata_filename(sae_name: str, layer: int) -> str:
return f"layer_{layer}_{sae_name}_metadata.parquet"
def get_sparse_probing_metrics_filename(sae_name: str, layer: int) -> str:
return f"layer_{layer}_{sae_name}_metrics.parquet"
def run_k_sparse_probing_experiment(
model: HookedTransformer,
sae: SAE,
layer: int,
sae_name: str,
max_k_value: int,
prompt_template: str,
prompt_token_pos: int,
device: str,
experiment_dir: Path | str = RESULTS_DIR / SPARSE_PROBING_EXPERIMENT_NAME,
probes_dir: Path | str = PROBES_DIR,
force: bool = False,
f1_jump_threshold: float = 0.03, # noqa: ARG001
k_sparse_probe_l1_decay: float = 0.01, # noqa: ARG001
k_sparse_probe_batch_size: int = 4096,
k_sparse_probe_num_epochs: int = 50,
verbose: bool = True,
) -> pd.DataFrame:
task_output_dir = get_or_make_dir(experiment_dir) / sae_name
raw_results_path = task_output_dir / get_sparse_probing_raw_results_filename(
sae_name, layer
)
metadata_results_path = task_output_dir / get_sparse_probing_metadata_filename(
sae_name, layer
)
metrics_results_path = task_output_dir / get_sparse_probing_metrics_filename(
sae_name, layer
)
def get_raw_results_df():
return load_dfs_or_run(
lambda: load_and_run_eval_probe_and_sae_k_sparse_raw_scores(
sae,
model,
probes_dir=probes_dir,
verbose=verbose,
sae_name=sae_name,
layer=layer,
max_k_value=max_k_value,
prompt_template=prompt_template,
prompt_token_pos=prompt_token_pos,
k_sparse_probe_l1_decay=k_sparse_probe_l1_decay,
k_sparse_probe_batch_size=k_sparse_probe_batch_size,
k_sparse_probe_num_epochs=k_sparse_probe_num_epochs,
device=device,
),
(raw_results_path, metadata_results_path),
force=force,
)
metrics_df = load_df_or_run(
lambda: build_metrics_df(*get_raw_results_df(), max_k_value=max_k_value),
metrics_results_path,
force=force,
)
add_feature_splits_to_metrics_df(
metrics_df, max_k_value=max_k_value, f1_jump_threshold=f1_jump_threshold
)
return metrics_df