Skip to content

Commit d91a218

Browse files
committed
Remove threshold from state dict if we aren't using it
1 parent 53901a2 commit d91a218

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

sae_bench/custom_saes/topk_sae.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,10 @@ def load_dictionary_learning_topk_sae(
127127
}
128128

129129
if "threshold" in pt_params:
130-
key_mapping["threshold"] = "threshold"
130+
if use_threshold_at_inference:
131+
key_mapping["threshold"] = "threshold"
132+
else:
133+
del pt_params["threshold"]
131134

132135
# Create a new dictionary with renamed keys
133136
renamed_params = {key_mapping.get(k, k): v for k, v in pt_params.items()}

0 commit comments

Comments
 (0)