Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] fail to reproduce the GPT-2-small SAE with provided hyperparameters #405

Open
cwyoon-99 opened this issue Jan 16, 2025 · 8 comments

Comments

@cwyoon-99
Copy link

Questions

Hi. I'm trying to reproduce the results of GPT-2 SAE blocks.8.hook_resid_pre (https://huggingface.co/jbloom/GPT2-Small-SAEs-Reformatted/tree/main/blocks.8.hook_resid_pre).

However, even though I use exactly the same hyperparameters, I can't achieve a similar performance to the uploaded hf version.
(the main metrics I monitor are L0, explained_variance, and log10_feature_density histogram)

How can I achieve this?

The l1_coefficient is set to 8e-05 in cfg.json, but I’ve read comments suggesting that a value closer to 1.0 might be more effective (see: GitHub comment). Is it possible that the learning rate calculation method has changed since then?
Or any other suggestions?

I tried a lot of combinations of l1_coefficient (8e-5, 8e-3, 8e-3, 0.1, 1.0, 2.0, 4.0) and lr (4e-4, 5e-5, 1e-5),
but can't even get decent SAE (log10_feature density is too sparse or too dense)

hyperparmeters

Mine

cfg_dict = {
        # Data Generating Function (Model + Training Distribution)
        "model_name": args.model_name,
        "hook_name": args.hook_name,
        "hook_layer": args.hook_layer,
        "d_in": 768,
        "dataset_path": "apollo-research/Skylion007-openwebtext-tokenizer-gpt2",
        "is_dataset_tokenized": True,
        "prepend_bos": True,  # should experiment with turning this off.
        # SAE Parameters
        "expansion_factor": 32,  # determines the dimension of the SAE.
        "b_dec_init_method": "geometric_median",  # geometric median is better but slower to get started
        "apply_b_dec_to_input": False,
        # Training Parameters
        "adam_beta1": 0,
        "adam_beta2": 0.999,
        "lr": args.lr,
        "l1_coefficient": args.l1_coefficient,
        "lr_scheduler_name": "constant",
        "train_batch_size_tokens": 4096,
        "context_size": 128,
        "lr_warm_up_steps": 5000,
        # Activation Store Parameters
        "n_batches_in_buffer": 128,
        "training_tokens": 1_000_000 * 300,  # 200M tokens seems doable overnight.
        "finetuning_method": "decoder",
        "finetuning_tokens": 1_000_000 * 100,
        "store_batch_size_prompts": 32,
        # Resampling protocol
        "use_ghost_grads": args.use_ghost_grads,
        "feature_sampling_window": 1000,
        "dead_feature_window": 5000,
        "dead_feature_threshold": 1e-8,
        # WANDB
        "log_to_wandb": True,
        "wandb_project": f"{args.model_name}_{args.hook_name}",
        "wandb_entity": None,
        "wandb_id": f"{args.model_name}_{args.hook_name}_{time.strftime('%y%m%d_%H%M')}_L1-{args.l1_coefficient}_LR-{args.lr}",
        "run_name": f"{time.strftime('%y%m%d_%H%M')}_L1-{args.l1_coefficient}_LR-{args.lr}",
        "wandb_log_frequency": 100,
        # Misc
        "device": device,
        "seed": 42,
        "n_checkpoints": 10,
        "checkpoint_path": "checkpoints",
        "dtype": "float32",
        "autocast": True,
        "autocast_lm": True,
        "compile_llm": True,
        # "compile_sae": True,
    }

jbloom/GPT2-Small-SAEs-Reformatted

{"model_name": "gpt2-small", "hook_point": "blocks.8.hook_resid_pre", "hook_point_layer": 8, "hook_point_head_index": null, "dataset_path": "Skylion007/openwebtext", "is_dataset_tokenized": false, "context_size": 128, "use_cached_activations": false, "cached_activations_path": "activations/Skylion007_openwebtext/gpt2-small/blocks.8.hook_resid_pre", "d_in": 768, "n_batches_in_buffer": 128, "total_training_tokens": 300000000, "store_batch_size": 32, "device": "mps", "seed": 42, "dtype": "torch.float32", "b_dec_init_method": "geometric_median", "expansion_factor": 32, "from_pretrained_path": null, "l1_coefficient": 8e-05, "lr": 0.0004, "lr_scheduler_name": null, "lr_warm_up_steps": 5000, "train_batch_size": 4096, "use_ghost_grads": false, "feature_sampling_window": 1000, "feature_sampling_method": null, "resample_batches": 1028, "feature_reinit_scale": 0.2, "dead_feature_window": 5000, "dead_feature_estimation_method": "no_fire", "dead_feature_threshold": 1e-08, "log_to_wandb": true, "wandb_project": "mats_sae_training_gpt2_small_resid_pre_5", "wandb_entity": null, "wandb_log_frequency": 100, "n_checkpoints": 10, "checkpoint_path": "checkpoints/ut7lhl4q", "d_sae": 24576, "tokens_per_buffer": 67108864, "run_name": "24576-L1-8e-05-LR-0.0004-Tokens-3.000e+08"}

log10 feature_density

mine

image

jbloom/GPT2-Small-SAEs-Reformatted

image

@NatanFreeman
Copy link

Are you running locally or on Google Colab? If you're running locally, what version of Python are you using? You should be using Python 3.10.14. I accidentally used Python 3.10.16 (which is the latest release of Python 3.10) and as a result, the attention on the feature dashboard was wrong. Maybe you're encountering a similar issue

@jbloom-aisi
Copy link

The default hyperparameters have shifted since I trained those. I think the thing that's most likely going wrong is you might not be using the mse loss calculation I used. When I trained them I had some weird variation that does help in some cases. Also, I loaded the models from tlens with pre-processing including residual stream centring. I'm not sure if those factors matter a lot but they might. Sorry for slow responses and very brief response (pretty busy at my new job).

@cwyoon-99
Copy link
Author

@NatanFreeman I'm running locally and Using Python 3.10.14 does not help to solve the issue. but Thanks for your advice!

@cwyoon-99
Copy link
Author

@jbloom-aisi
How about saes for gemma-2b? I'm wondering if you applied the same variations that you used for GPT-2.
For gemma SAEs, I’ve been also following the exact same hyperparameters, but unable to reproduce the trends seen in your WandB logs.

I’ve read your post on LessWrong, but can't find information about certain details like MSE loss calculation, pre-processing steps, or whether residual stream centering was applied.

Since these details aren’t fully covered in the post, Would you be able to share any insights or repositories you used to train SAEs? Any guidance would be greatly appreciated!

Thanks in advance!

@NatanFreeman
Copy link

@cwyoon-99 just as well. Turns out my issue had nothing to do with my Python installation and was simply a misunderstanding of neuronpedia's interface

@hijohnnylin
Copy link
Collaborator

@NatanFreeman curious what the misunderstanding was of Neuronpedia's interface? my guess is that we should be making it clearer somehow - if it was confusing to you, it was likely confusing to other as well.

@NatanFreeman
Copy link

@hijohnnylin the interface for calculating the activations of a specific neuron confused me. There's a feature which allows you to copy and analyzed an example text. When you have it set to only show the snippet with the activations, only the snippet is copied and analyzed. This causes a mismatch between the activations because in the example text the whole text was analyzed, whereas in the copied text only the snippet is analyzed. Here's an example on my phone, I'm not in front of my computer right now
image

@cwyoon-99
Copy link
Author

cwyoon-99 commented Feb 6, 2025

After a lot of attempts to reproduce SAEs, I found that training SAEs with a previous version of sae_lens works. For Gemma2, I used sae_lens version 3.5.0 and achieved good results.
(Some comments on the other issue said that the mse loss calculation is different from previous so suspect that might be the reason though I didn't read that part of the code)
However, I haven't been able to reproduce SAEs for GPT-2. In GPT-2, I struggle to achieve reasonable L0 metrics (<150) while maintaining explained variance.

here's is my cfg for training gemma2-2b SAE. (L1_coeff: 2.0, lr: 8e-05)
I saw that normalize_activations is quite important for reducing L0.
Image

cfg_dict = {
        # Data Generating Function (Model + Training Distribution)
        "model_name": args.model_name,
        "hook_name": args.hook_name,
        "hook_layer": args.hook_layer,
        "d_in": 2048,
        "dataset_path": "HuggingFaceFW/fineweb",
        "is_dataset_tokenized": False,
        "prepend_bos": True,  # should experiment with turning this off.
        # SAE Parameters
        "expansion_factor": 8,  # determines the dimension of the SAE.
        "b_dec_init_method": "zeros",
        "apply_b_dec_to_input": False,
        "activation_fn": "relu",
        "normalize_sae_decoder": False,
        "normalize_activations": "expected_average_only_in", # layer_norm, expected_average_only_in, or constant_norm_rescale
        "noise_scale": 0.0,
        "decoder_orthogonal_init": False,
        "decoder_heuristic_init": True,
        "init_encoder_as_decoder_transpose": True,
        # Training Parameters
        "adam_beta1": 0.9,
        "adam_beta2": 0.999,
        "lr": args.lr,
        "l1_coefficient": args.l1_coefficient,
        "lr_scheduler_name": "constant",
        "train_batch_size_tokens": train_batch_size_tokens,
        "context_size": 1024,
        "l1_warm_up_steps": l1_warm_up_steps,
        "lr_warm_up_steps": 0,
        # Activation Store Parameters
        "n_batches_in_buffer": 16,
        "training_tokens": 1228800000,  # 200M tokens seems doable overnight.
        "finetuning_method": None,
        "finetuning_tokens": 0,
        "store_batch_size_prompts": 8,
        # Resampling protocol
        "use_ghost_grads": args.use_ghost_grads,
        "feature_sampling_window": 5000,
        "dead_feature_window": 5000,
        "dead_feature_threshold": 1e-6,
        # WANDB
        "log_to_wandb": True,
        "wandb_project": f"{args.model_name}_{args.hook_name}",
        "wandb_entity": None,
        "wandb_id": f"{args.model_name}_{args.hook_name}_{run_name}",
        "run_name": f"{run_name}_3.5.0",
        "wandb_log_frequency": 50,
        "eval_every_n_wandb_logs": 10,
        "model_from_pretrained_kwargs": {"n_devices": args.gpu_num},
        # "device": f"cuda:{int(args.gpu_num) - 1}",
        # "act_store_device": f"cuda:{args.gpu_num}",
        "device": f"cuda:{args.gpu_num}",
        "act_store_device": "cpu",
        "seed": 42,
        "n_checkpoints": 10,
        "checkpoint_path": "checkpoints",
        "dtype": "torch.float32",
        "autocast_lm": True,
        "autocast": True,
        "compile_llm": True,
        # "llm_compilation_mode": "max-autotune",
        # "compile_sae": False,
        # "sae_compilation_mode": None,
    }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants