diff --git a/sae_lens/config.py b/sae_lens/config.py index f3478533..207eccd8 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -4,9 +4,9 @@ from typing import Any, Literal, Optional, cast import torch -import wandb from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict +import wandb from sae_lens import __version__ DTYPE_MAP = { @@ -78,10 +78,10 @@ class LanguageModelSAERunnerConfig: adam_beta1 (float): The beta1 parameter for Adam. adam_beta2 (float): The beta2 parameter for Adam. mse_loss_normalization (str): The normalization to use for the MSE loss. - l1_coefficient (float): The L1 coefficient. + sparsity_coefficient: The sparsity coefficient for either L1 or L0. lp_norm (float): The Lp norm. scale_sparsity_penalty_by_decoder_norm (bool): Whether to scale the sparsity penalty by the decoder norm. - l1_warm_up_steps (int): The number of warm-up steps for the L1 loss. + coefficient_warm_up_steps (int): The number of warm-up steps for the sparsity loss. lr (float): The learning rate. lr_scheduler_name (str): The name of the learning rate scheduler to use. lr_warm_up_steps (int): The number of warm-up steps for the learning rate. @@ -153,9 +153,7 @@ class LanguageModelSAERunnerConfig: finetuning_tokens: int = 0 store_batch_size_prompts: int = 32 train_batch_size_tokens: int = 4096 - normalize_activations: str = ( - "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update) - ) + normalize_activations: str = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update) seqpos_slice: tuple[int | None, ...] = (None,) # Misc @@ -186,10 +184,10 @@ class LanguageModelSAERunnerConfig: ## Loss Function mse_loss_normalization: Optional[str] = None - l1_coefficient: float = 1e-3 + sparsity_coefficient: float = 1.0 # changed the init value to 1.0 lp_norm: float = 1 scale_sparsity_penalty_by_decoder_norm: bool = False - l1_warm_up_steps: int = 0 + coefficient_warm_up_steps: int = 0 ## Learning Rate Schedule lr: float = 3e-4 @@ -264,7 +262,7 @@ def __post_init__(self): ) if self.run_name is None: - self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}" + self.run_name = f"{self.d_sae}-L1-{self.sparsity_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}" if self.model_from_pretrained_kwargs is None: if self.model_class_name == "HookedTransformer": @@ -318,7 +316,7 @@ def __post_init__(self): if self.verbose: print( - f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}" + f"Run name: {self.d_sae}-L1-{self.sparsity_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}" ) # Print out some useful info: n_tokens_per_buffer = ( @@ -408,7 +406,7 @@ def get_base_sae_cfg_dict(self) -> dict[str, Any]: def get_training_sae_cfg_dict(self) -> dict[str, Any]: return { **self.get_base_sae_cfg_dict(), - "l1_coefficient": self.l1_coefficient, + "sparsity_coefficient": self.sparsity_coefficient, "lp_norm": self.lp_norm, "use_ghost_grads": self.use_ghost_grads, "normalize_sae_decoder": self.normalize_sae_decoder, @@ -547,7 +545,7 @@ class ToyModelSAERunnerConfig: d_sae: int = 5 # Training Parameters - l1_coefficient: float = 1e-3 + sparsity_coefficient: float = 1.0 lr: float = 3e-4 train_batch_size: int = 1024 b_dec_init_method: str = "geometric_median" diff --git a/sae_lens/training/optim.py b/sae_lens/training/optim.py index 6772126d..e2a12d89 100644 --- a/sae_lens/training/optim.py +++ b/sae_lens/training/optim.py @@ -94,62 +94,62 @@ def _get_main_lr_scheduler( return lr_scheduler.CosineAnnealingLR(optimizer, T_max=steps, eta_min=lr_end) # type: ignore elif scheduler_name == "cosineannealingwarmrestarts": return lr_scheduler.CosineAnnealingWarmRestarts( - optimizer, T_0=steps // num_cycles, eta_min=lr_end # type: ignore + optimizer, + T_0=steps // num_cycles, + eta_min=lr_end, # type: ignore ) else: raise ValueError(f"Unsupported scheduler: {scheduler_name}") -class L1Scheduler: - +class CoefficientScheduler: def __init__( self, - l1_warm_up_steps: float, + coefficient_warm_up_steps: float, total_steps: int, - final_l1_coefficient: float, + final_sparsity_coefficient: float, ): - - self.l1_warmup_steps = l1_warm_up_steps + self.coefficient_warm_up_steps = coefficient_warm_up_steps # assume using warm-up - if self.l1_warmup_steps != 0: - self.current_l1_coefficient = 0.0 + if self.coefficient_warm_up_steps != 0: + self.current_sparsity_coefficient = 0.0 else: - self.current_l1_coefficient = final_l1_coefficient + self.current_sparsity_coefficient = final_sparsity_coefficient - self.final_l1_coefficient = final_l1_coefficient + self.final_sparsity_coefficient = final_sparsity_coefficient self.current_step = 0 self.total_steps = total_steps - assert isinstance(self.final_l1_coefficient, float | int) + assert isinstance(self.final_sparsity_coefficient, float | int) def __repr__(self) -> str: return ( - f"L1Scheduler(final_l1_value={self.final_l1_coefficient}, " - f"l1_warmup_steps={self.l1_warmup_steps}, " + f"CoefficientScheduler(final_coefficient_value={self.final_sparsity_coefficient}, " + f"warm_up_steps={self.coefficient_warm_up_steps}, " f"total_steps={self.total_steps})" ) def step(self): """ - Updates the l1 coefficient of the sparse autoencoder. + Updates the coefficient of the sparse autoencoder. """ step = self.current_step - if step < self.l1_warmup_steps: - self.current_l1_coefficient = self.final_l1_coefficient * ( - (1 + step) / self.l1_warmup_steps + if step < self.coefficient_warm_up_steps: + self.current_sparsity_coefficient = self.final_sparsity_coefficient * ( + (1 + step) / self.coefficient_warm_up_steps ) # type: ignore else: - self.current_l1_coefficient = self.final_l1_coefficient # type: ignore + self.current_sparsity_coefficient = self.final_sparsity_coefficient # type: ignore self.current_step += 1 def state_dict(self): """State dict for serializing as part of an SAETrainContext.""" return { - "l1_warmup_steps": self.l1_warmup_steps, + "coefficient_warm_up_steps": self.coefficient_warm_up_steps, "total_steps": self.total_steps, - "current_l1_coefficient": self.current_l1_coefficient, - "final_l1_coefficient": self.final_l1_coefficient, + "current_sparsity_coefficient": self.current_sparsity_coefficient, + "final_sparsity_coefficient": self.final_sparsity_coefficient, "current_step": self.current_step, } diff --git a/sae_lens/training/sae_trainer.py b/sae_lens/training/sae_trainer.py index fc59fb56..80db33b9 100644 --- a/sae_lens/training/sae_trainer.py +++ b/sae_lens/training/sae_trainer.py @@ -3,16 +3,16 @@ from typing import Any, cast import torch -import wandb from torch.optim import Adam from tqdm import tqdm from transformer_lens.hook_points import HookedRootModule +import wandb from sae_lens import __version__ from sae_lens.config import LanguageModelSAERunnerConfig from sae_lens.evals import EvalConfig, run_evals from sae_lens.training.activations_store import ActivationsStore -from sae_lens.training.optim import L1Scheduler, get_lr_scheduler +from sae_lens.training.optim import CoefficientScheduler, get_lr_scheduler from sae_lens.training.training_sae import TrainingSAE, TrainStepOutput # used to map between parameters which are updated during finetuning and the config str. @@ -56,7 +56,6 @@ def __init__( save_checkpoint_fn, # type: ignore cfg: LanguageModelSAERunnerConfig, ) -> None: - self.model = model self.sae = sae self.activation_store = activation_store @@ -113,10 +112,11 @@ def __init__( lr_end=cfg.lr_end, num_cycles=cfg.n_restart_cycles, ) - self.l1_scheduler = L1Scheduler( - l1_warm_up_steps=cfg.l1_warm_up_steps, # type: ignore + + self.coefficient_scheduler = CoefficientScheduler( + coefficient_warm_up_steps=cfg.coefficient_warm_up_steps, total_steps=cfg.total_training_steps, - final_l1_coefficient=cfg.l1_coefficient, + final_sparsity_coefficient=cfg.sparsity_coefficient, ) # Setup autocast if using @@ -154,15 +154,14 @@ def log_feature_sparsity(self) -> torch.Tensor: return _log_feature_sparsity(self.feature_sparsity) @property - def current_l1_coefficient(self) -> float: - return self.l1_scheduler.current_l1_coefficient + def current_sparsity_coefficient(self) -> float: + return self.coefficient_scheduler.current_sparsity_coefficient @property def dead_neurons(self) -> torch.Tensor: return (self.n_forward_passes_since_fired > self.cfg.dead_feature_window).bool() def fit(self) -> TrainingSAE: - pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training SAE") self._estimate_norm_scaling_factor_if_needed() @@ -216,7 +215,6 @@ def _train_step( sae: TrainingSAE, sae_in: torch.Tensor, ) -> TrainStepOutput: - sae.train() # Make sure the W_dec is still zero-norm if self.cfg.normalize_sae_decoder: @@ -232,11 +230,10 @@ def _train_step( # for documentation on autocasting see: # https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html with self.autocast_if_enabled: - train_step_output = self.sae.training_forward_pass( sae_in=sae_in, dead_neuron_mask=self.dead_neurons, - current_l1_coefficient=self.current_l1_coefficient, + current_sparsity_coefficient=self.current_sparsity_coefficient, ) with torch.no_grad(): @@ -262,8 +259,7 @@ def _train_step( sae.remove_gradient_parallel_to_decoder_directions() self.optimizer.zero_grad() - self.lr_scheduler.step() - self.l1_scheduler.step() + self.coefficient_scheduler.step() return train_step_output @@ -308,7 +304,7 @@ def _build_train_step_log_dict( "sparsity/mean_passes_since_fired": self.n_forward_passes_since_fired.mean().item(), "sparsity/dead_features": self.dead_neurons.sum().item(), "details/current_learning_rate": current_learning_rate, - "details/current_l1_coefficient": self.current_l1_coefficient, + "details/current_sparsity_coefficient": self.current_sparsity_coefficient, "details/n_training_tokens": n_training_tokens, } for loss_name, loss_value in output.losses.items(): @@ -316,7 +312,8 @@ def _build_train_step_log_dict( # special case for l1 loss, which we normalize by the l1 coefficient if loss_name == "l1_loss": log_dict[f"losses/{loss_name}"] = ( - loss_item / self.current_l1_coefficient + loss_item / self.current_sparsity_coefficient + # loss_item / self.current_l1_coefficient ) log_dict[f"losses/raw_{loss_name}"] = loss_item else: @@ -369,7 +366,6 @@ def _run_and_log_evals(self): @torch.no_grad() def _build_sparsity_log_dict(self) -> dict[str, Any]: - log_feature_sparsity = _log_feature_sparsity(self.feature_sparsity) wandb_histogram = wandb.Histogram(log_feature_sparsity.numpy()) # type: ignore return { @@ -381,7 +377,6 @@ def _build_sparsity_log_dict(self) -> dict[str, Any]: @torch.no_grad() def _reset_running_sparsity_stats(self) -> None: - self.act_freq_scores = torch.zeros( self.cfg.d_sae, # type: ignore device=self.cfg.device, @@ -401,8 +396,12 @@ def _checkpoint_if_needed(self): self.checkpoint_thresholds.pop(0) @torch.no_grad() - def _update_pbar(self, step_output: TrainStepOutput, pbar: tqdm, update_interval: int = 100): # type: ignore - + def _update_pbar( + self, + step_output: TrainStepOutput, + pbar: tqdm, # type: ignore + update_interval: int = 100, # type: ignore + ): # type: ignore if self.n_training_steps % update_interval == 0: loss_strs = " | ".join( f"{loss_name}: {_unwrap_item(loss_value):.5f}" diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py index a2e7a7a3..ba259b96 100644 --- a/sae_lens/training/training_sae.py +++ b/sae_lens/training/training_sae.py @@ -46,7 +46,9 @@ def setup_context( ctx.bandwidth = bandwidth @staticmethod - def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[None, torch.Tensor, None]: # type: ignore[override] + def backward( # type: ignore + ctx: Any, grad_output: torch.Tensor + ) -> tuple[None, torch.Tensor, None]: # type: ignore[override] x, threshold = ctx.saved_tensors bandwidth = ctx.bandwidth threshold_grad = torch.sum( @@ -73,7 +75,9 @@ def setup_context( ctx.bandwidth = bandwidth @staticmethod - def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None]: # type: ignore[override] + def backward( # type: ignore + ctx: Any, grad_output: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, None]: # type: ignore[override] x, threshold = ctx.saved_tensors bandwidth = ctx.bandwidth x_grad = (x > threshold) * grad_output # We don't apply STE to x input @@ -97,10 +101,9 @@ class TrainStepOutput: @dataclass(kw_only=True) class TrainingSAEConfig(SAEConfig): - # Sparsity Loss Calculations - l1_coefficient: float - lp_norm: float + sparsity_coefficient: float + lp_norm: Optional[float] use_ghost_grads: bool normalize_sae_decoder: bool noise_scale: float @@ -116,7 +119,6 @@ class TrainingSAEConfig(SAEConfig): def from_sae_runner_config( cls, cfg: LanguageModelSAERunnerConfig ) -> "TrainingSAEConfig": - return cls( # base config architecture=cfg.architecture, @@ -138,7 +140,7 @@ def from_sae_runner_config( prepend_bos=cfg.prepend_bos, seqpos_slice=cfg.seqpos_slice, # Training cfg - l1_coefficient=cfg.l1_coefficient, + sparsity_coefficient=cfg.sparsity_coefficient, lp_norm=cfg.lp_norm, use_ghost_grads=cfg.use_ghost_grads, normalize_sae_decoder=cfg.normalize_sae_decoder, @@ -180,7 +182,7 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAEConfig": def to_dict(self) -> dict[str, Any]: return { **super().to_dict(), - "l1_coefficient": self.l1_coefficient, + "sparsity_coefficient": self.sparsity_coefficient, "lp_norm": self.lp_norm, "use_ghost_grads": self.use_ghost_grads, "normalize_sae_decoder": self.normalize_sae_decoder, @@ -359,7 +361,6 @@ def forward( self, x: Float[torch.Tensor, "... d_in"], ) -> Float[torch.Tensor, "... d_in"]: - feature_acts, _ = self.encode_with_hidden_pre_fn(x) sae_out = self.decode(feature_acts) @@ -368,10 +369,9 @@ def forward( def training_forward_pass( self, sae_in: torch.Tensor, - current_l1_coefficient: float, + current_sparsity_coefficient: float, dead_neuron_mask: Optional[torch.Tensor] = None, ) -> TrainStepOutput: - # do a forward pass to get SAE out, but we also need the # hidden pre. feature_acts, hidden_pre = self.encode_with_hidden_pre_fn(sae_in) @@ -395,7 +395,7 @@ def training_forward_pass( # SFN sparsity loss - summed over the feature dimension and averaged over the batch l1_loss = ( - current_l1_coefficient + current_sparsity_coefficient * torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean() ) @@ -410,7 +410,8 @@ def training_forward_pass( elif self.cfg.architecture == "jumprelu": threshold = torch.exp(self.log_threshold) l0 = torch.sum(Step.apply(hidden_pre, threshold, self.bandwidth), dim=-1) # type: ignore - l0_loss = (current_l1_coefficient * l0).mean() + l0_loss = (current_sparsity_coefficient * l0).mean() + # l0_loss = (current_l0_lambda * l0).mean() loss = mse_loss + l0_loss losses["l0_loss"] = l0_loss else: @@ -422,7 +423,7 @@ def training_forward_pass( p=self.cfg.lp_norm, dim=-1 ) # sum over the feature dimension - l1_loss = (current_l1_coefficient * sparsity).mean() + l1_loss = (current_sparsity_coefficient * sparsity).mean() loss = mse_loss + l1_loss if ( self.cfg.use_ghost_grads @@ -458,7 +459,6 @@ def calculate_ghost_grad_loss( hidden_pre: torch.Tensor, dead_neuron_mask: torch.Tensor, ) -> torch.Tensor: - # 1. residual = x - sae_out l2_norm_residual = torch.norm(residual, dim=-1) @@ -490,7 +490,6 @@ def calculate_ghost_grad_loss( @torch.no_grad() def _get_mse_loss_fn(self) -> Any: - def standard_mse_loss_fn( preds: torch.Tensor, target: torch.Tensor ) -> torch.Tensor: @@ -529,7 +528,6 @@ def load_from_pretrained( device: str = "cpu", dtype: str | None = None, ) -> "TrainingSAE": - # get the config config_path = os.path.join(path, SAE_CFG_PATH) with open(config_path, "r") as f: diff --git a/tests/unit/helpers.py b/tests/unit/helpers.py index 78d04054..4ac00773 100644 --- a/tests/unit/helpers.py +++ b/tests/unit/helpers.py @@ -19,7 +19,7 @@ class LanguageModelSAERunnerConfigDict(TypedDict, total=False): is_dataset_tokenized: bool use_cached_activations: bool d_in: int - l1_coefficient: float + sparsity_coefficient: float lp_norm: float lr: float train_batch_size_tokens: int @@ -55,7 +55,7 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig: "is_dataset_tokenized": False, "use_cached_activations": False, "d_in": 64, - "l1_coefficient": 2e-3, + "sparsity_coefficient": 1.0, "lp_norm": 1, "lr": 2e-4, "train_batch_size_tokens": 4, diff --git a/tests/unit/training/test_coefficient_scheduler.py b/tests/unit/training/test_coefficient_scheduler.py new file mode 100644 index 00000000..828c37d3 --- /dev/null +++ b/tests/unit/training/test_coefficient_scheduler.py @@ -0,0 +1,50 @@ +from sae_lens.training.optim import CoefficientScheduler +from tests.unit.helpers import build_sae_cfg + + +def test_coefficient_scheduler_initialization(): + cfg = build_sae_cfg( + sparsity_coefficient=5, + training_tokens=100 * 4, # train batch size (so 100 steps) + coefficient_warm_up_steps=10, + ) + + coefficient_scheduler = CoefficientScheduler( + coefficient_warm_up_steps=cfg.coefficient_warm_up_steps, # type: ignore + total_steps=cfg.training_tokens // cfg.train_batch_size_tokens, + final_sparsity_coefficient=cfg.sparsity_coefficient, + ) + + assert cfg.sparsity_coefficient == 5 + assert ( + coefficient_scheduler.current_sparsity_coefficient == 0 + ) # the l1 coefficient is set to 0, to begin warm up. + + # over 10 steps, we should get to the final value of 5 + for i in range(10): + coefficient_scheduler.step() + assert coefficient_scheduler.current_sparsity_coefficient == 5 * (1 + i) / 10 + + +def test_coefficient_scheduler_initialization_no_warmup(): + cfg = build_sae_cfg( + sparsity_coefficient=5, + training_tokens=100 * 4, # train batch size (so 100 steps) + coefficient_warm_up_steps=0, + ) + + coefficient_scheduler = CoefficientScheduler( + coefficient_warm_up_steps=cfg.coefficient_warm_up_steps, # type: ignore + total_steps=cfg.training_tokens // cfg.train_batch_size_tokens, + final_sparsity_coefficient=cfg.sparsity_coefficient, + ) + + assert cfg.sparsity_coefficient == 5 + assert ( + coefficient_scheduler.current_sparsity_coefficient == 5 + ) # the l1 coefficient is set to 0, to begin warm up. + + # over 10 steps, we should get to the final value of 5 + for _ in range(10): + coefficient_scheduler.step() + assert coefficient_scheduler.current_sparsity_coefficient == coefficient_scheduler.final_sparsity_coefficient diff --git a/tests/unit/training/test_gated_sae.py b/tests/unit/training/test_gated_sae.py index 03dd7183..36543ef9 100644 --- a/tests/unit/training/test_gated_sae.py +++ b/tests/unit/training/test_gated_sae.py @@ -68,7 +68,7 @@ def test_gated_sae_loss(): train_step_output = sae.training_forward_pass( sae_in=x, - current_l1_coefficient=sae.cfg.l1_coefficient, + current_sparsity_coefficient=sae.cfg.sparsity_coefficient, ) assert train_step_output.sae_out.shape == (batch_size, d_in) @@ -77,7 +77,7 @@ def test_gated_sae_loss(): sae_in_centered = x - sae.b_dec via_gate_feature_magnitudes = torch.relu(sae_in_centered @ sae.W_enc + sae.b_gate) preactivation_l1_loss = ( - sae.cfg.l1_coefficient * torch.sum(via_gate_feature_magnitudes, dim=-1).mean() + sae.cfg.sparsity_coefficient * torch.sum(via_gate_feature_magnitudes, dim=-1).mean() ) via_gate_reconstruction = ( @@ -122,7 +122,7 @@ def test_gated_sae_training_forward_pass(): x = torch.randn(batch_size, d_in) train_step_output = sae.training_forward_pass( sae_in=x, - current_l1_coefficient=sae.cfg.l1_coefficient, + current_sparsity_coefficient=sae.cfg.sparsity_coefficient, ) assert train_step_output.sae_out.shape == (batch_size, d_in) diff --git a/tests/unit/training/test_jumprelu_sae.py b/tests/unit/training/test_jumprelu_sae.py index a295af48..d8fc6f59 100644 --- a/tests/unit/training/test_jumprelu_sae.py +++ b/tests/unit/training/test_jumprelu_sae.py @@ -40,7 +40,7 @@ def test_jumprelu_sae_training_forward_pass(): x = torch.randn(batch_size, d_in) train_step_output = sae.training_forward_pass( sae_in=x, - current_l1_coefficient=sae.cfg.l1_coefficient, + current_sparsity_coefficient=sae.cfg.sparsity_coefficient, ) assert train_step_output.sae_out.shape == (batch_size, d_in) diff --git a/tests/unit/training/test_l1_scheduler.py b/tests/unit/training/test_l1_scheduler.py deleted file mode 100644 index bc68c31d..00000000 --- a/tests/unit/training/test_l1_scheduler.py +++ /dev/null @@ -1,50 +0,0 @@ -from sae_lens.training.optim import L1Scheduler -from tests.unit.helpers import build_sae_cfg - - -def test_l1_scheduler_initialization(): - cfg = build_sae_cfg( - l1_coefficient=5, - training_tokens=100 * 4, # train batch size (so 100 steps) - l1_warm_up_steps=10, - ) - - l1_scheduler = L1Scheduler( - l1_warm_up_steps=cfg.l1_warm_up_steps, # type: ignore - total_steps=cfg.training_tokens // cfg.train_batch_size_tokens, - final_l1_coefficient=cfg.l1_coefficient, - ) - - assert cfg.l1_coefficient == 5 - assert ( - l1_scheduler.current_l1_coefficient == 0 - ) # the l1 coefficient is set to 0, to begin warm up. - - # over 10 steps, we should get to the final value of 5 - for i in range(10): - l1_scheduler.step() - assert l1_scheduler.current_l1_coefficient == 5 * (1 + i) / 10 - - -def test_l1_scheduler_initialization_no_warmup(): - cfg = build_sae_cfg( - l1_coefficient=5, - training_tokens=100 * 4, # train batch size (so 100 steps) - l1_warm_up_steps=0, - ) - - l1_scheduler = L1Scheduler( - l1_warm_up_steps=cfg.l1_warm_up_steps, # type: ignore - total_steps=cfg.training_tokens // cfg.train_batch_size_tokens, - final_l1_coefficient=cfg.l1_coefficient, - ) - - assert cfg.l1_coefficient == 5 - assert ( - l1_scheduler.current_l1_coefficient == 5 - ) # the l1 coefficient is set to 0, to begin warm up. - - # over 10 steps, we should get to the final value of 5 - for _ in range(10): - l1_scheduler.step() - assert l1_scheduler.current_l1_coefficient == l1_scheduler.final_l1_coefficient diff --git a/tests/unit/training/test_sae_trainer.py b/tests/unit/training/test_sae_trainer.py index 94bc906f..9d5525f4 100644 --- a/tests/unit/training/test_sae_trainer.py +++ b/tests/unit/training/test_sae_trainer.py @@ -49,7 +49,6 @@ def trainer( model: HookedTransformer, activation_store: ActivationsStore, ): - trainer = SAETrainer( model=model, sae=training_sae, @@ -77,7 +76,6 @@ def modified_forward(*args: Any, **kwargs: Any) -> torch.Tensor: def test_train_step__reduces_loss_when_called_repeatedly_on_same_acts( trainer: SAETrainer, ) -> None: - layer_acts = trainer.activation_store.next_batch() # intentionally train on the same activations 5 times to ensure loss decreases @@ -98,7 +96,6 @@ def test_train_step__reduces_loss_when_called_repeatedly_on_same_acts( def test_train_step__output_looks_reasonable(trainer: SAETrainer) -> None: - layer_acts = trainer.activation_store.next_batch() output = trainer._train_step( @@ -123,7 +120,6 @@ def test_train_step__output_looks_reasonable(trainer: SAETrainer) -> None: def test_train_step__sparsity_updates_based_on_feature_act_sparsity( trainer: SAETrainer, ) -> None: - trainer._reset_running_sparsity_stats() layer_acts = trainer.activation_store.next_batch() @@ -163,7 +159,6 @@ def test_log_feature_sparsity__handles_zeroes_by_default_fp16() -> None: def test_build_train_step_log_dict(trainer: SAETrainer) -> None: - train_output = TrainStepOutput( sae_in=torch.tensor([[-1, 0], [0, 2], [1, 1]]).float(), sae_out=torch.tensor([[0, 0], [0, 2], [0.5, 1]]).float(), @@ -185,7 +180,7 @@ def test_build_train_step_log_dict(trainer: SAETrainer) -> None: assert log_dict == { "losses/mse_loss": 0.25, # l1 loss is scaled by l1_coefficient - "losses/l1_loss": train_output.losses["l1_loss"] / trainer.cfg.l1_coefficient, + "losses/l1_loss": train_output.losses["l1_loss"] / trainer.cfg.sparsity_coefficient, "losses/raw_l1_loss": train_output.losses["l1_loss"], "losses/overall_loss": 0.5, "losses/ghost_grad_loss": 0.15, @@ -195,7 +190,7 @@ def test_build_train_step_log_dict(trainer: SAETrainer) -> None: "sparsity/mean_passes_since_fired": trainer.n_forward_passes_since_fired.mean().item(), "sparsity/dead_features": trainer.dead_neurons.sum().item(), "details/current_learning_rate": 2e-4, - "details/current_l1_coefficient": trainer.cfg.l1_coefficient, + "details/current_sparsity_coefficient": trainer.cfg.sparsity_coefficient, "details/n_training_tokens": 123, } diff --git a/tests/unit/training/test_sae_training.py b/tests/unit/training/test_sae_training.py index 70bef679..bade2675 100644 --- a/tests/unit/training/test_sae_training.py +++ b/tests/unit/training/test_sae_training.py @@ -149,7 +149,7 @@ def test_sae_forward(training_sae: TrainingSAE): x = torch.randn(batch_size, d_in) train_step_output = training_sae.training_forward_pass( sae_in=x, - current_l1_coefficient=training_sae.cfg.l1_coefficient, + current_sparsity_coefficient=training_sae.cfg.sparsity_coefficient, ) assert train_step_output.sae_out.shape == (batch_size, d_in) @@ -188,7 +188,7 @@ def test_sae_forward(training_sae: TrainingSAE): ) assert ( pytest.approx(train_step_output.losses["l1_loss"].item(), rel=1e-3) # type: ignore - == training_sae.cfg.l1_coefficient * expected_l1_loss.detach().float() + == training_sae.cfg.sparsity_coefficient * expected_l1_loss.detach().float() ) @@ -206,7 +206,7 @@ def test_sae_forward_with_mse_loss_norm( x = torch.randn(batch_size, d_in) train_step_output = training_sae.training_forward_pass( sae_in=x, - current_l1_coefficient=training_sae.cfg.l1_coefficient, + current_sparsity_coefficient=training_sae.cfg.sparsity_coefficient, ) assert train_step_output.sae_out.shape == (batch_size, d_in) @@ -248,7 +248,7 @@ def test_sae_forward_with_mse_loss_norm( ) assert ( pytest.approx(train_step_output.losses["l1_loss"].item(), rel=1e-3) # type: ignore - == training_sae.cfg.l1_coefficient * expected_l1_loss.detach().float() + == training_sae.cfg.sparsity_coefficient * expected_l1_loss.detach().float() ) @@ -262,7 +262,7 @@ def test_SparseAutoencoder_forward_ghost_grad_loss_non_zero( x = torch.randn(batch_size, d_in) train_step_output = training_sae.training_forward_pass( sae_in=x, - current_l1_coefficient=training_sae.cfg.l1_coefficient, + current_sparsity_coefficient=training_sae.cfg.sparsity_coefficient, dead_neuron_mask=torch.ones_like( training_sae.b_enc ).bool(), # all neurons are dead. diff --git a/tests/unit/training/test_training_sae.py b/tests/unit/training/test_training_sae.py index 16a84594..a61ebd50 100644 --- a/tests/unit/training/test_training_sae.py +++ b/tests/unit/training/test_training_sae.py @@ -22,7 +22,7 @@ def test_TrainingSAE_training_forward_pass_can_scale_sparsity_penalty_by_decoder x = torch.randn(32, 3) train_step_output = training_sae.training_forward_pass( sae_in=x, - current_l1_coefficient=2.0, + current_sparsity_coefficient=2.0, ) feature_acts = train_step_output.feature_acts decoder_norm = training_sae.W_dec.norm(dim=1)