Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Report] Unusual L0 Behavior with Gemma-2b #382

Open
1 task done
muyo8692 opened this issue Nov 22, 2024 · 1 comment
Open
1 task done

[Bug Report] Unusual L0 Behavior with Gemma-2b #382

muyo8692 opened this issue Nov 22, 2024 · 1 comment

Comments

@muyo8692
Copy link

muyo8692 commented Nov 22, 2024

Describe the bug
When training SAE on Gemma-2b layer 12, the L0 values remain unusually high (>1500) and are largely unresponsive to L0 lambda parameter adjustments. This behavior persists across various hyperparameter configurations:

  • L0 lambda variations (50, 100, 500) show minimal impact on L0 values
  • MSE losses remain around 110 at 600k steps (≈2.5B tokens)
  • Even with significantly reduced learning rate (7e-8), L0 values still exceed 1500 after 2.5B tokens
  • Bandwidth and initial threshold adjustments showed limited effect on controlling L0

Code example
Training config as follow, simply sweep l1_coefficient, I tried 1e-3, 1, 5, 5k, 500k:
50, 500 and 50k is also tried but I accidentally delete these runs log

from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner


cfg = LanguageModelSAERunnerConfig(
    architecture="jumprelu",
    model_name="gemma-2-2b",
    model_class_name="HookedTransformer",
    hook_name="blocks.12.hook_resid_post",
    hook_eval="NOT_IN_USE",
    hook_layer=12,
    hook_head_index=None,
    dataset_path="monology/pile-uncopyrighted",
    dataset_trust_remote_code=True,
    streaming=True,
    is_dataset_tokenized=False,
    context_size=1024,
    use_cached_activations=False,
    cached_activations_path=None,
    d_in=2304,
    d_sae=16384,
    b_dec_init_method="zeros", 
    activation_fn="relu",
    normalize_sae_decoder=True, 
    noise_scale=0.0,
    from_pretrained_path=None,
    apply_b_dec_to_input=True, 
    decoder_orthogonal_init=False,
    decoder_heuristic_init=False, 
    init_encoder_as_decoder_transpose=True, 
    n_batches_in_buffer=64,
    training_tokens=4000000000,
    finetuning_tokens=0,
    store_batch_size_prompts=64,
    train_batch_size_tokens=4096,
    normalize_activations="expected_average_only_in", 
    device="cuda",
    act_store_device="cpu",
    seed=42,
    dtype="torch.float32",
    prepend_bos=True,
    jumprelu_init_threshold=0.001,
    jumprelu_bandwidth=0.001,
    autocast=False,
    autocast_lm=False,
    compile_llm=False,
    llm_compilation_mode="default",
    compile_sae=False,
    sae_compilation_mode=None,
    adam_beta1=0,
    adam_beta2=0.999,
    mse_loss_normalization=None,
    l1_coefficient=5,
    lp_norm=1.0, # not used for jumprelu
    scale_sparsity_penalty_by_decoder_norm=False,
    l1_warm_up_steps=10000,
    lr=7e-08,
    lr_scheduler_name="cosineannealing",
    lr_warm_up_steps=1000,
    lr_end=7e-09,
    lr_decay_steps=0,
    finetuning_method=None,
    use_ghost_grads=False,
    feature_sampling_window=2000,
    dead_feature_window=1000,
    dead_feature_threshold=1e-08,
    n_eval_batches=10,
    eval_batch_size_prompts=1,
    log_to_wandb=True,
    log_activations_store_to_wandb=False,
    log_optimizer_state_to_wandb=False,
    wandb_project="<PROJECT>",
    wandb_id=None,
    run_name="<NAME>",
    wandb_entity=None,
    wandb_log_frequency=300,
    eval_every_n_wandb_logs=10,
    resume=False,
    n_checkpoints=0,
    checkpoint_path="<PATH>",
    verbose=True,
    model_kwargs={},
    model_from_pretrained_kwargs={},
)

sparse_autoencoder = SAETrainingRunner(cfg).run()

System Info
Describe the characteristic of your environment:
All library was installed using uv, pyproject.toml is shown as below:
Note: I also tried previous version of sae-lens (4.3.4, 4.3.5)

requires-python = ">=3.12"
dependencies = [
    "circuitsvis>=1.41.0",
    "gpustat>=1.1.1",
    "huggingface-hub>=0.26.2",
    "notebook>=7.2.2",
    "sae-lens==4.4.1",
    "torch>=2.5.1",
    "transformer-lens>=2.8.1",
]

Run log
Screenshot 2024-11-22 at 10 11 58
Screenshot 2024-11-22 at 10 12 50

Checklist

  • I have checked that there is no similar issue in the repo (required)
@cwyoon-99
Copy link

@muyo8692 In this moment, Have you found any solutions related to this issue?
I have a similar problem that L0 does not converge into reasonable level (<= 150).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants