Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: custom loading converters for SAE.from_pretrained() and SAE.load_from_disk() #433

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions benchmark/test_eval_all_loadable_saes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
run_evaluations,
)
from sae_lens.toolkit.pretrained_sae_loaders import (
SAEConfigLoadOptions,
get_sae_config_from_hf,
load_sae_config_from_huggingface,
)
from tests.helpers import load_model_cached

Expand All @@ -40,10 +39,9 @@

def test_get_sae_config():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should rename test_get_sae_config() to match load_sae_config_from_huggingface().

repo_id = "jbloom/GPT2-Small-SAEs-Reformatted"
cfg = get_sae_config_from_hf(
repo_id=repo_id,
folder_name="blocks.0.hook_resid_pre",
options=SAEConfigLoadOptions(),
cfg = load_sae_config_from_huggingface(
release=repo_id,
sae_id="blocks.0.hook_resid_pre",
)
assert cfg is not None

Expand Down
6 changes: 2 additions & 4 deletions docs/generate_sae_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@

from sae_lens import SAEConfig
from sae_lens.toolkit.pretrained_sae_loaders import (
SAEConfigLoadOptions,
get_sae_config,
handle_config_defaulting,
load_sae_config_from_huggingface,
)

INCLUDED_CFG = [
Expand Down Expand Up @@ -61,10 +60,9 @@ def generate_sae_table():
for info in tqdm(model_info["saes"]):
# can remove this by explicitly overriding config in yaml. Do this later.
sae_id = info["id"]
cfg = get_sae_config(
cfg = load_sae_config_from_huggingface(
release,
sae_id=sae_id,
options=SAEConfigLoadOptions(),
)
cfg = handle_config_defaulting(cfg)
cfg = SAEConfig.from_dict(cfg).to_dict()
Expand Down
28 changes: 22 additions & 6 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
<img width="1308" alt="Screenshot 2024-03-21 at 3 08 28 pm" src="https://github.com/jbloomAus/mats_sae_training/assets/69127271/209012ec-a779-4036-b4be-7b7739ea87f6">

# SAELens

[![PyPI](https://img.shields.io/pypi/v/sae-lens?color=blue)](https://pypi.org/project/sae-lens/)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![build](https://github.com/jbloomAus/SAELens/actions/workflows/build.yml/badge.svg)](https://github.com/jbloomAus/SAELens/actions/workflows/build.yml)
Expand All @@ -25,20 +26,36 @@ pip install sae-lens

### Loading Sparse Autoencoders from Huggingface

To load a pretrained sparse autoencoder, you can use `SAE.from_pretrained()` as below. Note that we return the *original cfg dict* from the huggingface repo so that it's easy to debug older configs that are being handled when we import an SAe. We also return a sparsity tensor if it is present in the repo. For an example repo structure, see [here](https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs).
To load a pretrained sparse autoencoder, you can use `SAE.from_pretrained()` as below. Note that we return the _original cfg dict_ from the huggingface repo so that it's easy to debug older configs that are being handled when we import an SAe. We also return a sparsity tensor if it is present in the repo. For an example repo structure, see [here](https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
To load a pretrained sparse autoencoder, you can use `SAE.from_pretrained()` as below. Note that we return the _original cfg dict_ from the huggingface repo so that it's easy to debug older configs that are being handled when we import an SAe. We also return a sparsity tensor if it is present in the repo. For an example repo structure, see [here](https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs).
To load a pretrained sparse autoencoder, you can use `SAE.from_pretrained()` as below. Note that we return the _original cfg dict_ from the huggingface repo so that it's easy to debug older configs that are being handled when we import an SAE. We also return a sparsity tensor if it is present in the repo. For an example repo structure, see [here](https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs).


```python
from sae_lens import SAE

sae, cfg_dict, sparsity = SAE.from_pretrained(
release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
sae_id = "blocks.8.hook_resid_pre", # won't always be a hook point
device = device
device = "cuda"
)
```

You can see other importable SAEs on [this page](https://jbloomaus.github.io/SAELens/sae_table/).

Any SAE on Huggingface that's trained using SAELens can also be loaded using `SAE.from_pretrained()`. In this case, `release` is the name of the Huggingface repo, and `sae_id` is the path to the SAE in the repo. You can see a list of SAEs listed on Huggingface with the [saelens tag](https://huggingface.co/models?library=saelens).

### Loading Sparse Autoencoders from Disk

To load a pretrained sparse autoencoder from disk that you've trained yourself, you can use `SAE.load_from_disk()` as below.

```python
from sae_lens import SAE

sae = SAE.load_from_disk("/path/to/your/sae", device="cuda")
```

### Importing SAEs from other libraries

You can import an SAE created with another library by writing a custom `PretrainedSaeHuggingfaceLoader` or `PretrainedSaeDiskLoader` for use with `SAE.from_pretrained()` or `SAE.load_from_disk()`, respectively. See the [pretrained_sae_loaders.py](https://github.com/jbloomAus/SAELens/blob/main/sae_lens/toolkit/pretrained_sae_loaders.py) file for more details, or ask on the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-2k0id7mv8-CsIgPLmmHd03RPJmLUcapw). If you write a good custom loader for another library, please consider contributing it back to SAELens!

### Background and further Readings

We highly recommend this [tutorial](https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab).
Expand All @@ -50,13 +67,12 @@ For recent progress in SAEs, we recommend the LessWrong forum's [Sparse Autoenco
I wrote a tutorial to show users how to do some basic exploration of their SAE:

- Loading and Analysing Pre-Trained Sparse Autoencoders [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
- Understanding SAE Features with the Logit Lens [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
- Training a Sparse Autoencoder [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)

- Understanding SAE Features with the Logit Lens [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
- Training a Sparse Autoencoder [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)

## Example WandB Dashboard

WandB Dashboards provide lots of useful insights while training SAE's. Here's a screenshot from one training run.
WandB Dashboards provide lots of useful insights while training SAE's. Here's a screenshot from one training run.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
WandB Dashboards provide lots of useful insights while training SAE's. Here's a screenshot from one training run.
WandB Dashboards provide lots of useful insights while training SAEs. Here's a screenshot from one training run.


![screenshot](dashboard_screenshot.png)

Expand Down
6 changes: 6 additions & 0 deletions sae_lens/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
from .pretokenize_runner import PretokenizeRunner, pretokenize_runner
from .sae import SAE, SAEConfig
from .sae_training_runner import SAETrainingRunner
from .toolkit.pretrained_sae_loaders import (
PretrainedSaeDiskLoader,
PretrainedSaeHuggingfaceLoader,
)
from .training.activations_store import ActivationsStore
from .training.training_sae import TrainingSAE, TrainingSAEConfig
from .training.upload_saes_to_huggingface import upload_saes_to_huggingface
Expand All @@ -36,4 +40,6 @@
"pretokenize_runner",
"run_evals",
"upload_saes_to_huggingface",
"PretrainedSaeHuggingfaceLoader",
"PretrainedSaeDiskLoader",
]
59 changes: 33 additions & 26 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

import json
import os
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
Expand All @@ -16,17 +15,22 @@
from safetensors.torch import save_file
from torch import nn
from transformer_lens.hook_points import HookedRootModule, HookPoint
from typing_extensions import deprecated

from sae_lens.config import DTYPE_MAP
from sae_lens.toolkit.pretrained_sae_loaders import (
NAMED_PRETRAINED_SAE_LOADERS,
PretrainedSaeDiskLoader,
PretrainedSaeHuggingfaceLoader,
get_conversion_loader_name,
handle_config_defaulting,
read_sae_from_disk,
sae_lens_disk_loader,
)
from sae_lens.toolkit.pretrained_saes_directory import (
get_config_overrides,
get_norm_scaling_factor,
get_pretrained_saes_directory,
get_repo_id_and_folder_name,
)

SPARSITY_FILENAME = "sparsity.safetensors"
Expand Down Expand Up @@ -532,31 +536,29 @@ def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
pass

@classmethod
@deprecated("Use load_from_disk instead")
def load_from_pretrained(
cls, path: str, device: str = "cpu", dtype: str | None = None
) -> "SAE":
# get the config
config_path = os.path.join(path, SAE_CFG_FILENAME)
with open(config_path) as f:
cfg_dict = json.load(f)
cfg_dict = handle_config_defaulting(cfg_dict)
cfg_dict["device"] = device
sae = cls.load_from_disk(path, device)
if dtype is not None:
cfg_dict["dtype"] = dtype

weight_path = os.path.join(path, SAE_WEIGHTS_FILENAME)
cfg_dict, state_dict = read_sae_from_disk(
cfg_dict=cfg_dict,
weight_path=weight_path,
device=device,
)
sae.cfg.dtype = dtype
sae = sae.to(dtype)
return sae

@classmethod
def load_from_disk(
cls,
path: str,
device: str = "cpu",
converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
) -> "SAE":
cfg_dict, state_dict = converter(path, device, cfg_overrides=None)
cfg_dict = handle_config_defaulting(cfg_dict)
sae_cfg = SAEConfig.from_dict(cfg_dict)

sae = cls(sae_cfg)
sae.process_state_dict_for_loading(state_dict)
sae.load_state_dict(state_dict)

return sae

@classmethod
Expand All @@ -565,9 +567,10 @@ def from_pretrained(
release: str,
sae_id: str,
device: str = "cpu",
force_download: bool = False,
converter: PretrainedSaeHuggingfaceLoader | None = None,
) -> Tuple["SAE", dict[str, Any], Optional[torch.Tensor]]:
"""

Load a pretrained SAE from the Hugging Face model hub.

Args:
Expand Down Expand Up @@ -616,19 +619,23 @@ def from_pretrained(
f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
+ value_suffix
)
sae_info = sae_directory.get(release, None)
config_overrides = sae_info.config_overrides if sae_info is not None else None

conversion_loader_name = get_conversion_loader_name(sae_info)
conversion_loader = NAMED_PRETRAINED_SAE_LOADERS[conversion_loader_name]
conversion_loader_name = get_conversion_loader_name(release)
conversion_loader = (
converter or NAMED_PRETRAINED_SAE_LOADERS[conversion_loader_name]
)
Comment on lines +623 to +626
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
conversion_loader_name = get_conversion_loader_name(release)
conversion_loader = (
converter or NAMED_PRETRAINED_SAE_LOADERS[conversion_loader_name]
)
conversion_loader = (
converter
or NAMED_PRETRAINED_SAE_LOADERS[get_conversion_loader_name(release)]
)

We can avoid a call to get_conversion_loader_name() if converter is not None.

repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id)
config_overrides = get_config_overrides(release, sae_id)
config_overrides["device"] = device

cfg_dict, state_dict, log_sparsities = conversion_loader(
release,
sae_id=sae_id,
repo_id=repo_id,
folder_name=folder_name,
device=device,
force_download=False,
force_download=force_download,
cfg_overrides=config_overrides,
)
cfg_dict = handle_config_defaulting(cfg_dict)

sae = cls(SAEConfig.from_dict(cfg_dict))
sae.process_state_dict_for_loading(state_dict)
Expand Down
Loading
Loading