Skip to content

Commit 60579ed

Browse files
committed
By default we don't use a threshold for custom topk SAEs
1 parent 0888d07 commit 60579ed

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

sae_bench/custom_saes/topk_sae.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def load_dictionary_learning_topk_sae(
7777
dtype: torch.dtype,
7878
layer: int | None = None,
7979
local_dir: str = "downloaded_saes",
80+
use_threshold_at_inference: bool = False,
8081
) -> TopKSAE:
8182
assert "ae.pt" in filename
8283

@@ -122,9 +123,7 @@ def load_dictionary_learning_topk_sae(
122123
"k": "k",
123124
}
124125

125-
use_threshold = "threshold" in pt_params
126-
127-
if use_threshold:
126+
if "threshold" in pt_params:
128127
key_mapping["threshold"] = "threshold"
129128

130129
# Create a new dictionary with renamed keys
@@ -145,7 +144,7 @@ def load_dictionary_learning_topk_sae(
145144
hook_layer=layer, # type: ignore
146145
device=device,
147146
dtype=dtype,
148-
use_threshold=use_threshold,
147+
use_threshold=use_threshold_at_inference,
149148
)
150149

151150
sae.load_state_dict(renamed_params)

0 commit comments

Comments
 (0)