You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I’ve noticed an issue while using the released Gemma SAE model: the MSE loss is extremely high on the last few layers. I wanted to check if I am using the model correctly.
Here is the code I am using:
fromsae_lensimportSAEMODEL_PATHS= {
"gemma2": "google/gemma-2-9b/",
}
model_path=MODEL_PATHS["gemma2"]
tokenizer=AutoTokenizer.from_pretrained(model_path, padding_side="left")
model_config=AutoConfig.from_pretrained(model_path)
model=AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float32,
config=model_config,
device_map={"": "cuda:0"}
)
hook_model=HookedTransformer.from_pretrained_no_processing(
"google/gemma-2-9b",
device="cuda:0",
hf_model=model,
tokenizer=tokenizer,
dtype=torch.float32,
).eval()
text="Would you be able to travel through time using a wormhole?"tokens=hook_model.to_tokens(text)
_, cache=hook_model.run_with_cache(tokens)
release="gemma-scope-9b-pt-res"sae_id="layer_41/width_131k/average_l0_45"sae, cfg_dict, sparsity=SAE.from_pretrained(release, sae_id)
print(
(
sae.decode(sae.encode(cache[cfg_dict["hook_name"]][:, 1:]))
-cache[cfg_dict["hook_name"]][:, 1:]
).pow(2).mean()
)
For result i got tensor(20.1427, device='cuda:0', grad_fn=<MeanBackward0>)
The text was updated successfully, but these errors were encountered:
Hi,
I’ve noticed an issue while using the released Gemma SAE model: the MSE loss is extremely high on the last few layers. I wanted to check if I am using the model correctly.
Here is the code I am using:
For result i got
tensor(20.1427, device='cuda:0', grad_fn=<MeanBackward0>)
The text was updated successfully, but these errors were encountered: