-
Notifications
You must be signed in to change notification settings - Fork 148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: custom loading converters for SAE.from_pretrained() and SAE.load_from_disk() #433
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The build is failing, BTW.
@@ -40,10 +39,9 @@ | |||
|
|||
def test_get_sae_config(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should rename test_get_sae_config()
to match load_sae_config_from_huggingface()
.
@@ -25,20 +26,36 @@ pip install sae-lens | |||
|
|||
### Loading Sparse Autoencoders from Huggingface | |||
|
|||
To load a pretrained sparse autoencoder, you can use `SAE.from_pretrained()` as below. Note that we return the *original cfg dict* from the huggingface repo so that it's easy to debug older configs that are being handled when we import an SAe. We also return a sparsity tensor if it is present in the repo. For an example repo structure, see [here](https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs). | |||
To load a pretrained sparse autoencoder, you can use `SAE.from_pretrained()` as below. Note that we return the _original cfg dict_ from the huggingface repo so that it's easy to debug older configs that are being handled when we import an SAe. We also return a sparsity tensor if it is present in the repo. For an example repo structure, see [here](https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To load a pretrained sparse autoencoder, you can use `SAE.from_pretrained()` as below. Note that we return the _original cfg dict_ from the huggingface repo so that it's easy to debug older configs that are being handled when we import an SAe. We also return a sparsity tensor if it is present in the repo. For an example repo structure, see [here](https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs). | |
To load a pretrained sparse autoencoder, you can use `SAE.from_pretrained()` as below. Note that we return the _original cfg dict_ from the huggingface repo so that it's easy to debug older configs that are being handled when we import an SAE. We also return a sparsity tensor if it is present in the repo. For an example repo structure, see [here](https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs). |
|
||
## Example WandB Dashboard | ||
|
||
WandB Dashboards provide lots of useful insights while training SAE's. Here's a screenshot from one training run. | ||
WandB Dashboards provide lots of useful insights while training SAE's. Here's a screenshot from one training run. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WandB Dashboards provide lots of useful insights while training SAE's. Here's a screenshot from one training run. | |
WandB Dashboards provide lots of useful insights while training SAEs. Here's a screenshot from one training run. |
conversion_loader_name = get_conversion_loader_name(release) | ||
conversion_loader = ( | ||
converter or NAMED_PRETRAINED_SAE_LOADERS[conversion_loader_name] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
conversion_loader_name = get_conversion_loader_name(release) | |
conversion_loader = ( | |
converter or NAMED_PRETRAINED_SAE_LOADERS[conversion_loader_name] | |
) | |
conversion_loader = ( | |
converter | |
or NAMED_PRETRAINED_SAE_LOADERS[get_conversion_loader_name(release)] | |
) |
We can avoid a call to get_conversion_loader_name()
if converter
is not None
.
cfg_dict, state_dict = read_sae_components_from_disk( | ||
cfg_dict=cfg_dict, | ||
weight_path=weights_path, | ||
device=device, | ||
) | ||
return cfg_dict, state_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cfg_dict, state_dict = read_sae_components_from_disk( | |
cfg_dict=cfg_dict, | |
weight_path=weights_path, | |
device=device, | |
) | |
return cfg_dict, state_dict | |
return read_sae_components_from_disk( | |
cfg_dict=cfg_dict, | |
weight_path=weights_path, | |
device=device, | |
) |
config_path = os.path.join(path, SAE_CFG_PATH) | ||
with open(config_path) as f: | ||
cfg_dict = json.load(f) | ||
overrides = None if dtype is None else {"dtype": dtype} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overrides = None if dtype is None else {"dtype": dtype} | |
overrides = {"dtype": dtype} if dtype is not None else None |
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]: | ||
"""Loads SAEs from disk""" | ||
|
||
weights_path = Path(path) / "sae_weights.safetensors" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can replace "sae_weights.safetensors"
with the constant SAE_WEIGHTS_FILENAME
.
device: str | None = None, | ||
cfg_overrides: dict[str, Any] | None = None, | ||
) -> dict[str, Any]: | ||
cfg_filename = Path(path) / "cfg.json" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can replace "cfg.json"
with the constant SAE_CFG_FILENAME
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comments to #433 (comment), for tests that test load_sae_config_from_huggingface()
.
# Download the SAE weights | ||
sae_path = hf_hub_download( | ||
repo_id=repo_id, | ||
filename="sae_weights.safetensors", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as https://github.com/jbloomAus/SAELens/pull/433/files#r1976512393.
Description
This PR standardizes the interface for
PretrainedSaeHuggingfaceLoader
andPretrainedSaeDiskLoader
to allow directly loading any custom SAE from another library as long as a correctconverter
param is passed toSAE.from_pretrained()
orSAE.load_from_disk()
.This PR also deprecates the
SAE.load_from_pretrained()
method in favor ofSAE.load_from_disk()
to make the difference betweenfrom_pretrained()
, which loads from Huggingface, andload_from_disk()
, which loads a locally saved SAE more obvious.These follow the following protocols:
Essentially, the provided function must take in either a
repo_id
andfolder_name
on huggingface, or a local path, optional config overrides, and return a config dict and model state dict.Our current converters don't feel general enough to work for all instances of other libraries, but we can and should write some converters for the
dictionary_learning
andsparsify
libraries in the future.Type of change
Please delete options that are not relevant.
Checklist:
You have tested formatting, typing and tests
make check-ci
to check format and linting. (you can runmake format
to format code if needed.)