-
Notifications
You must be signed in to change notification settings - Fork 148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: custom loading converters for SAE.from_pretrained() and SAE.load_from_disk() #433
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,6 +1,7 @@ | ||||||
<img width="1308" alt="Screenshot 2024-03-21 at 3 08 28 pm" src="https://github.com/jbloomAus/mats_sae_training/assets/69127271/209012ec-a779-4036-b4be-7b7739ea87f6"> | ||||||
|
||||||
# SAELens | ||||||
|
||||||
[](https://pypi.org/project/sae-lens/) | ||||||
[](https://opensource.org/licenses/MIT) | ||||||
[](https://github.com/jbloomAus/SAELens/actions/workflows/build.yml) | ||||||
|
@@ -25,20 +26,36 @@ pip install sae-lens | |||||
|
||||||
### Loading Sparse Autoencoders from Huggingface | ||||||
|
||||||
To load a pretrained sparse autoencoder, you can use `SAE.from_pretrained()` as below. Note that we return the *original cfg dict* from the huggingface repo so that it's easy to debug older configs that are being handled when we import an SAe. We also return a sparsity tensor if it is present in the repo. For an example repo structure, see [here](https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs). | ||||||
To load a pretrained sparse autoencoder, you can use `SAE.from_pretrained()` as below. Note that we return the _original cfg dict_ from the huggingface repo so that it's easy to debug older configs that are being handled when we import an SAe. We also return a sparsity tensor if it is present in the repo. For an example repo structure, see [here](https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs). | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
```python | ||||||
from sae_lens import SAE | ||||||
|
||||||
sae, cfg_dict, sparsity = SAE.from_pretrained( | ||||||
release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml | ||||||
sae_id = "blocks.8.hook_resid_pre", # won't always be a hook point | ||||||
device = device | ||||||
device = "cuda" | ||||||
) | ||||||
``` | ||||||
|
||||||
You can see other importable SAEs on [this page](https://jbloomaus.github.io/SAELens/sae_table/). | ||||||
|
||||||
Any SAE on Huggingface that's trained using SAELens can also be loaded using `SAE.from_pretrained()`. In this case, `release` is the name of the Huggingface repo, and `sae_id` is the path to the SAE in the repo. You can see a list of SAEs listed on Huggingface with the [saelens tag](https://huggingface.co/models?library=saelens). | ||||||
|
||||||
### Loading Sparse Autoencoders from Disk | ||||||
|
||||||
To load a pretrained sparse autoencoder from disk that you've trained yourself, you can use `SAE.load_from_disk()` as below. | ||||||
|
||||||
```python | ||||||
from sae_lens import SAE | ||||||
|
||||||
sae = SAE.load_from_disk("/path/to/your/sae", device="cuda") | ||||||
``` | ||||||
|
||||||
### Importing SAEs from other libraries | ||||||
|
||||||
You can import an SAE created with another library by writing a custom `PretrainedSaeHuggingfaceLoader` or `PretrainedSaeDiskLoader` for use with `SAE.from_pretrained()` or `SAE.load_from_disk()`, respectively. See the [pretrained_sae_loaders.py](https://github.com/jbloomAus/SAELens/blob/main/sae_lens/toolkit/pretrained_sae_loaders.py) file for more details, or ask on the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-2k0id7mv8-CsIgPLmmHd03RPJmLUcapw). If you write a good custom loader for another library, please consider contributing it back to SAELens! | ||||||
|
||||||
### Background and further Readings | ||||||
|
||||||
We highly recommend this [tutorial](https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab). | ||||||
|
@@ -50,13 +67,12 @@ For recent progress in SAEs, we recommend the LessWrong forum's [Sparse Autoenco | |||||
I wrote a tutorial to show users how to do some basic exploration of their SAE: | ||||||
|
||||||
- Loading and Analysing Pre-Trained Sparse Autoencoders [](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb) | ||||||
- Understanding SAE Features with the Logit Lens [](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb) | ||||||
- Training a Sparse Autoencoder [](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb) | ||||||
|
||||||
- Understanding SAE Features with the Logit Lens [](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb) | ||||||
- Training a Sparse Autoencoder [](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb) | ||||||
|
||||||
## Example WandB Dashboard | ||||||
|
||||||
WandB Dashboards provide lots of useful insights while training SAE's. Here's a screenshot from one training run. | ||||||
WandB Dashboards provide lots of useful insights while training SAE's. Here's a screenshot from one training run. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
 | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -3,7 +3,6 @@ | |||||||||||||||||
""" | ||||||||||||||||||
|
||||||||||||||||||
import json | ||||||||||||||||||
import os | ||||||||||||||||||
import warnings | ||||||||||||||||||
from contextlib import contextmanager | ||||||||||||||||||
from dataclasses import dataclass, field | ||||||||||||||||||
|
@@ -16,17 +15,22 @@ | |||||||||||||||||
from safetensors.torch import save_file | ||||||||||||||||||
from torch import nn | ||||||||||||||||||
from transformer_lens.hook_points import HookedRootModule, HookPoint | ||||||||||||||||||
from typing_extensions import deprecated | ||||||||||||||||||
|
||||||||||||||||||
from sae_lens.config import DTYPE_MAP | ||||||||||||||||||
from sae_lens.toolkit.pretrained_sae_loaders import ( | ||||||||||||||||||
NAMED_PRETRAINED_SAE_LOADERS, | ||||||||||||||||||
PretrainedSaeDiskLoader, | ||||||||||||||||||
PretrainedSaeHuggingfaceLoader, | ||||||||||||||||||
get_conversion_loader_name, | ||||||||||||||||||
handle_config_defaulting, | ||||||||||||||||||
read_sae_from_disk, | ||||||||||||||||||
sae_lens_disk_loader, | ||||||||||||||||||
) | ||||||||||||||||||
from sae_lens.toolkit.pretrained_saes_directory import ( | ||||||||||||||||||
get_config_overrides, | ||||||||||||||||||
get_norm_scaling_factor, | ||||||||||||||||||
get_pretrained_saes_directory, | ||||||||||||||||||
get_repo_id_and_folder_name, | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
SPARSITY_FILENAME = "sparsity.safetensors" | ||||||||||||||||||
|
@@ -532,31 +536,29 @@ def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None: | |||||||||||||||||
pass | ||||||||||||||||||
|
||||||||||||||||||
@classmethod | ||||||||||||||||||
@deprecated("Use load_from_disk instead") | ||||||||||||||||||
def load_from_pretrained( | ||||||||||||||||||
cls, path: str, device: str = "cpu", dtype: str | None = None | ||||||||||||||||||
) -> "SAE": | ||||||||||||||||||
# get the config | ||||||||||||||||||
config_path = os.path.join(path, SAE_CFG_FILENAME) | ||||||||||||||||||
with open(config_path) as f: | ||||||||||||||||||
cfg_dict = json.load(f) | ||||||||||||||||||
cfg_dict = handle_config_defaulting(cfg_dict) | ||||||||||||||||||
cfg_dict["device"] = device | ||||||||||||||||||
sae = cls.load_from_disk(path, device) | ||||||||||||||||||
if dtype is not None: | ||||||||||||||||||
cfg_dict["dtype"] = dtype | ||||||||||||||||||
|
||||||||||||||||||
weight_path = os.path.join(path, SAE_WEIGHTS_FILENAME) | ||||||||||||||||||
cfg_dict, state_dict = read_sae_from_disk( | ||||||||||||||||||
cfg_dict=cfg_dict, | ||||||||||||||||||
weight_path=weight_path, | ||||||||||||||||||
device=device, | ||||||||||||||||||
) | ||||||||||||||||||
sae.cfg.dtype = dtype | ||||||||||||||||||
sae = sae.to(dtype) | ||||||||||||||||||
return sae | ||||||||||||||||||
|
||||||||||||||||||
@classmethod | ||||||||||||||||||
def load_from_disk( | ||||||||||||||||||
cls, | ||||||||||||||||||
path: str, | ||||||||||||||||||
device: str = "cpu", | ||||||||||||||||||
converter: PretrainedSaeDiskLoader = sae_lens_disk_loader, | ||||||||||||||||||
) -> "SAE": | ||||||||||||||||||
cfg_dict, state_dict = converter(path, device, cfg_overrides=None) | ||||||||||||||||||
cfg_dict = handle_config_defaulting(cfg_dict) | ||||||||||||||||||
sae_cfg = SAEConfig.from_dict(cfg_dict) | ||||||||||||||||||
|
||||||||||||||||||
sae = cls(sae_cfg) | ||||||||||||||||||
sae.process_state_dict_for_loading(state_dict) | ||||||||||||||||||
sae.load_state_dict(state_dict) | ||||||||||||||||||
|
||||||||||||||||||
return sae | ||||||||||||||||||
|
||||||||||||||||||
@classmethod | ||||||||||||||||||
|
@@ -565,9 +567,10 @@ def from_pretrained( | |||||||||||||||||
release: str, | ||||||||||||||||||
sae_id: str, | ||||||||||||||||||
device: str = "cpu", | ||||||||||||||||||
force_download: bool = False, | ||||||||||||||||||
converter: PretrainedSaeHuggingfaceLoader | None = None, | ||||||||||||||||||
) -> Tuple["SAE", dict[str, Any], Optional[torch.Tensor]]: | ||||||||||||||||||
""" | ||||||||||||||||||
|
||||||||||||||||||
Load a pretrained SAE from the Hugging Face model hub. | ||||||||||||||||||
|
||||||||||||||||||
Args: | ||||||||||||||||||
|
@@ -616,19 +619,23 @@ def from_pretrained( | |||||||||||||||||
f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}." | ||||||||||||||||||
+ value_suffix | ||||||||||||||||||
) | ||||||||||||||||||
sae_info = sae_directory.get(release, None) | ||||||||||||||||||
config_overrides = sae_info.config_overrides if sae_info is not None else None | ||||||||||||||||||
|
||||||||||||||||||
conversion_loader_name = get_conversion_loader_name(sae_info) | ||||||||||||||||||
conversion_loader = NAMED_PRETRAINED_SAE_LOADERS[conversion_loader_name] | ||||||||||||||||||
conversion_loader_name = get_conversion_loader_name(release) | ||||||||||||||||||
conversion_loader = ( | ||||||||||||||||||
converter or NAMED_PRETRAINED_SAE_LOADERS[conversion_loader_name] | ||||||||||||||||||
) | ||||||||||||||||||
Comment on lines
+623
to
+626
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
We can avoid a call to |
||||||||||||||||||
repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id) | ||||||||||||||||||
config_overrides = get_config_overrides(release, sae_id) | ||||||||||||||||||
config_overrides["device"] = device | ||||||||||||||||||
|
||||||||||||||||||
cfg_dict, state_dict, log_sparsities = conversion_loader( | ||||||||||||||||||
release, | ||||||||||||||||||
sae_id=sae_id, | ||||||||||||||||||
repo_id=repo_id, | ||||||||||||||||||
folder_name=folder_name, | ||||||||||||||||||
device=device, | ||||||||||||||||||
force_download=False, | ||||||||||||||||||
force_download=force_download, | ||||||||||||||||||
cfg_overrides=config_overrides, | ||||||||||||||||||
) | ||||||||||||||||||
cfg_dict = handle_config_defaulting(cfg_dict) | ||||||||||||||||||
|
||||||||||||||||||
sae = cls(SAEConfig.from_dict(cfg_dict)) | ||||||||||||||||||
sae.process_state_dict_for_loading(state_dict) | ||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should rename
test_get_sae_config()
to matchload_sae_config_from_huggingface()
.