Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] High MSE Loss in Released Gemma SAE Model – Am I Using It Correctly? #410

Open
yingjiahao14 opened this issue Jan 28, 2025 · 1 comment

Comments

@yingjiahao14
Copy link

Hi,

I’ve noticed an issue while using the released Gemma SAE model: the MSE loss is extremely high on the last few layers. I wanted to check if I am using the model correctly.

Here is the code I am using:

from sae_lens import SAE

MODEL_PATHS = {
    "gemma2": "google/gemma-2-9b/",
}
model_path = MODEL_PATHS["gemma2"]
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
model_config = AutoConfig.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
        model_path, 
        torch_dtype=torch.float32,
        config=model_config,
        device_map={"": "cuda:0"}
        )

hook_model = HookedTransformer.from_pretrained_no_processing(
    "google/gemma-2-9b",
    device="cuda:0",
    hf_model=model,
    tokenizer=tokenizer,
    dtype=torch.float32,
).eval()

text = "Would you be able to travel through time using a wormhole?"
tokens = hook_model.to_tokens(text)

_, cache = hook_model.run_with_cache(tokens)

release = "gemma-scope-9b-pt-res"
sae_id = "layer_41/width_131k/average_l0_45"
sae, cfg_dict, sparsity  = SAE.from_pretrained(release, sae_id)

print(
    (
    sae.decode(sae.encode(cache[cfg_dict["hook_name"]][:, 1:]))
    - cache[cfg_dict["hook_name"]][:, 1:]
).pow(2).mean()
)

For result i got tensor(20.1427, device='cuda:0', grad_fn=<MeanBackward0>)

@jbloom-aisi
Copy link

You might need to be masking tokens on which gemma-scope was not trained.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants