Commit 60579ed 1 parent 0888d07 commit 60579ed Copy full SHA for 60579ed
File tree 1 file changed +3
-4
lines changed
1 file changed +3
-4
lines changed Original file line number Diff line number Diff line change @@ -77,6 +77,7 @@ def load_dictionary_learning_topk_sae(
77
77
dtype : torch .dtype ,
78
78
layer : int | None = None ,
79
79
local_dir : str = "downloaded_saes" ,
80
+ use_threshold_at_inference : bool = False ,
80
81
) -> TopKSAE :
81
82
assert "ae.pt" in filename
82
83
@@ -122,9 +123,7 @@ def load_dictionary_learning_topk_sae(
122
123
"k" : "k" ,
123
124
}
124
125
125
- use_threshold = "threshold" in pt_params
126
-
127
- if use_threshold :
126
+ if "threshold" in pt_params :
128
127
key_mapping ["threshold" ] = "threshold"
129
128
130
129
# Create a new dictionary with renamed keys
@@ -145,7 +144,7 @@ def load_dictionary_learning_topk_sae(
145
144
hook_layer = layer , # type: ignore
146
145
device = device ,
147
146
dtype = dtype ,
148
- use_threshold = use_threshold ,
147
+ use_threshold = use_threshold_at_inference ,
149
148
)
150
149
151
150
sae .load_state_dict (renamed_params )
You can’t perform that action at this time.
0 commit comments