Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: custom loading converters for SAE.from_pretrained() and SAE.load_from_disk() #433

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

chanind
Copy link
Collaborator

@chanind chanind commented Feb 23, 2025

Description

This PR standardizes the interface for PretrainedSaeHuggingfaceLoader and PretrainedSaeDiskLoader to allow directly loading any custom SAE from another library as long as a correct converter param is passed to SAE.from_pretrained() or SAE.load_from_disk().

This PR also deprecates the SAE.load_from_pretrained() method in favor of SAE.load_from_disk() to make the difference between from_pretrained(), which loads from Huggingface, and load_from_disk(), which loads a locally saved SAE more obvious.

These follow the following protocols:

class PretrainedSaeHuggingfaceLoader(Protocol):
    def __call__(
        self,
        repo_id: str,
        folder_name: str,
        device: str,
        force_download: bool,
        cfg_overrides: dict[str, Any] | None,
    ) -> tuple[dict[str, Any], dict[str, torch.Tensor], Optional[torch.Tensor]]: ...


class PretrainedSaeDiskLoader(Protocol):
    def __call__(
        self,
        path: str | Path,
        device: str,
        cfg_overrides: dict[str, Any] | None,
    ) -> tuple[dict[str, Any], dict[str, torch.Tensor]]: ...

Essentially, the provided function must take in either a repo_id and folder_name on huggingface, or a local path, optional config overrides, and return a config dict and model state dict.

Our current converters don't feel general enough to work for all instances of other libraries, but we can and should write some converters for the dictionary_learning and sparsify libraries in the future.

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

You have tested formatting, typing and tests

  • I have run make check-ci to check format and linting. (you can run make format to format code if needed.)

@anthonyduong9 anthonyduong9 changed the title feat: custom loading converters for SAE.from_pretrained() and SAE.load_from_dink() feat: custom loading converters for SAE.from_pretrained() and SAE.load_from_disk() Feb 23, 2025
Copy link
Collaborator

@anthonyduong9 anthonyduong9 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The build is failing, BTW.

@@ -40,10 +39,9 @@

def test_get_sae_config():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should rename test_get_sae_config() to match load_sae_config_from_huggingface().

@@ -25,20 +26,36 @@ pip install sae-lens

### Loading Sparse Autoencoders from Huggingface

To load a pretrained sparse autoencoder, you can use `SAE.from_pretrained()` as below. Note that we return the *original cfg dict* from the huggingface repo so that it's easy to debug older configs that are being handled when we import an SAe. We also return a sparsity tensor if it is present in the repo. For an example repo structure, see [here](https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs).
To load a pretrained sparse autoencoder, you can use `SAE.from_pretrained()` as below. Note that we return the _original cfg dict_ from the huggingface repo so that it's easy to debug older configs that are being handled when we import an SAe. We also return a sparsity tensor if it is present in the repo. For an example repo structure, see [here](https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
To load a pretrained sparse autoencoder, you can use `SAE.from_pretrained()` as below. Note that we return the _original cfg dict_ from the huggingface repo so that it's easy to debug older configs that are being handled when we import an SAe. We also return a sparsity tensor if it is present in the repo. For an example repo structure, see [here](https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs).
To load a pretrained sparse autoencoder, you can use `SAE.from_pretrained()` as below. Note that we return the _original cfg dict_ from the huggingface repo so that it's easy to debug older configs that are being handled when we import an SAE. We also return a sparsity tensor if it is present in the repo. For an example repo structure, see [here](https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs).


## Example WandB Dashboard

WandB Dashboards provide lots of useful insights while training SAE's. Here's a screenshot from one training run.
WandB Dashboards provide lots of useful insights while training SAE's. Here's a screenshot from one training run.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
WandB Dashboards provide lots of useful insights while training SAE's. Here's a screenshot from one training run.
WandB Dashboards provide lots of useful insights while training SAEs. Here's a screenshot from one training run.

Comment on lines +623 to +626
conversion_loader_name = get_conversion_loader_name(release)
conversion_loader = (
converter or NAMED_PRETRAINED_SAE_LOADERS[conversion_loader_name]
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
conversion_loader_name = get_conversion_loader_name(release)
conversion_loader = (
converter or NAMED_PRETRAINED_SAE_LOADERS[conversion_loader_name]
)
conversion_loader = (
converter
or NAMED_PRETRAINED_SAE_LOADERS[get_conversion_loader_name(release)]
)

We can avoid a call to get_conversion_loader_name() if converter is not None.

Comment on lines +117 to +122
cfg_dict, state_dict = read_sae_components_from_disk(
cfg_dict=cfg_dict,
weight_path=weights_path,
device=device,
)
return cfg_dict, state_dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cfg_dict, state_dict = read_sae_components_from_disk(
cfg_dict=cfg_dict,
weight_path=weights_path,
device=device,
)
return cfg_dict, state_dict
return read_sae_components_from_disk(
cfg_dict=cfg_dict,
weight_path=weights_path,
device=device,
)

config_path = os.path.join(path, SAE_CFG_PATH)
with open(config_path) as f:
cfg_dict = json.load(f)
overrides = None if dtype is None else {"dtype": dtype}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
overrides = None if dtype is None else {"dtype": dtype}
overrides = {"dtype": dtype} if dtype is not None else None

) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
"""Loads SAEs from disk"""

weights_path = Path(path) / "sae_weights.safetensors"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can replace "sae_weights.safetensors" with the constant SAE_WEIGHTS_FILENAME.

device: str | None = None,
cfg_overrides: dict[str, Any] | None = None,
) -> dict[str, Any]:
cfg_filename = Path(path) / "cfg.json"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can replace "cfg.json" with the constant SAE_CFG_FILENAME.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comments to #433 (comment), for tests that test load_sae_config_from_huggingface().

# Download the SAE weights
sae_path = hf_hub_download(
repo_id=repo_id,
filename="sae_weights.safetensors",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants