Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix parameter naming: Separate l0_lambda from l1_coefficient #376

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"python.testing.pytestEnabled": true,

"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
Expand Down
9 changes: 9 additions & 0 deletions sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,11 @@ class LanguageModelSAERunnerConfig:
adam_beta2 (float): The beta2 parameter for Adam.
mse_loss_normalization (str): The normalization to use for the MSE loss.
l1_coefficient (float): The L1 coefficient.
l0_lambda (float): The L0 lambda.
lp_norm (float): The Lp norm.
scale_sparsity_penalty_by_decoder_norm (bool): Whether to scale the sparsity penalty by the decoder norm.
l1_warm_up_steps (int): The number of warm-up steps for the L1 loss.
l0_warm_up_steps (int): The number of warm-up steps for the L0 loss.
lr (float): The learning rate.
lr_scheduler_name (str): The name of the learning rate scheduler to use.
lr_warm_up_steps (int): The number of warm-up steps for the learning rate.
Expand Down Expand Up @@ -187,9 +189,11 @@ class LanguageModelSAERunnerConfig:
## Loss Function
mse_loss_normalization: Optional[str] = None
l1_coefficient: float = 1e-3
l0_lambda: float = 0.0
lp_norm: float = 1
scale_sparsity_penalty_by_decoder_norm: bool = False
l1_warm_up_steps: int = 0
l0_warm_up_steps: int = 0

## Learning Rate Schedule
lr: float = 3e-4
Expand Down Expand Up @@ -369,6 +373,9 @@ def __post_init__(self):
f"The provided context_size is {self.context_size} is negative. Expecting positive context_size."
)

if self.architecture == "jumprelu" and self.l0_lambda == 0.0:
raise ValueError("For JumpReLU SAEs, you must specify l0_lambda.")

_validate_seqpos(seqpos=self.seqpos_slice, context_size=self.context_size)

@property
Expand Down Expand Up @@ -409,6 +416,7 @@ def get_training_sae_cfg_dict(self) -> dict[str, Any]:
return {
**self.get_base_sae_cfg_dict(),
"l1_coefficient": self.l1_coefficient,
"l0_lambda": self.l0_lambda,
"lp_norm": self.lp_norm,
"use_ghost_grads": self.use_ghost_grads,
"normalize_sae_decoder": self.normalize_sae_decoder,
Expand Down Expand Up @@ -548,6 +556,7 @@ class ToyModelSAERunnerConfig:

# Training Parameters
l1_coefficient: float = 1e-3
l0_lambda: float | None = None
lr: float = 3e-4
train_batch_size: int = 1024
b_dec_init_method: str = "geometric_median"
Expand Down
64 changes: 61 additions & 3 deletions sae_lens/training/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,21 +94,21 @@ def _get_main_lr_scheduler(
return lr_scheduler.CosineAnnealingLR(optimizer, T_max=steps, eta_min=lr_end) # type: ignore
elif scheduler_name == "cosineannealingwarmrestarts":
return lr_scheduler.CosineAnnealingWarmRestarts(
optimizer, T_0=steps // num_cycles, eta_min=lr_end # type: ignore
optimizer,
T_0=steps // num_cycles,
eta_min=lr_end, # type: ignore
)
else:
raise ValueError(f"Unsupported scheduler: {scheduler_name}")


class L1Scheduler:

def __init__(
self,
l1_warm_up_steps: float,
total_steps: int,
final_l1_coefficient: float,
):

self.l1_warmup_steps = l1_warm_up_steps
# assume using warm-up
if self.l1_warmup_steps != 0:
Expand Down Expand Up @@ -157,3 +157,61 @@ def load_state_dict(self, state_dict: dict[str, Any]):
"""Loads all state apart from attached SAE."""
for k in state_dict:
setattr(self, k, state_dict[k])


class L0Scheduler:
def __init__(
self,
l0_warm_up_steps: float,
total_steps: int,
final_l0_lambda: float,
):
assert isinstance(final_l0_lambda, float | int)
self.l0_warmup_steps = l0_warm_up_steps
# assume using warm-up
if self.l0_warmup_steps != 0:
self.current_l0_lambda = 0.0
else:
self.current_l0_lambda = final_l0_lambda

self.final_l0_lambda = final_l0_lambda

self.current_step = 0
self.total_steps = total_steps
assert isinstance(self.final_l0_lambda, float | int)

def __repr__(self) -> str:
return (
f"L0Scheduler(final_l0_value={self.final_l0_lambda}, "
f"l0_warmup_steps={self.l0_warmup_steps}, "
f"total_steps={self.total_steps})"
)

def step(self):
"""
Updates the l0 lambda of the sparse autoencoder.
"""
step = self.current_step
if step < self.l0_warmup_steps:
self.current_l0_lambda = self.final_l0_lambda * (
(1 + step) / self.l0_warmup_steps
) # type: ignore
else:
self.current_l0_lambda = self.final_l0_lambda # type: ignore

self.current_step += 1

def state_dict(self):
"""State dict for serializing as part of an SAETrainContext."""
return {
"l0_warmup_steps": self.l0_warmup_steps,
"total_steps": self.total_steps,
"current_l0_lambda": self.current_l0_lambda,
"final_l0_lambda": self.final_l0_lambda,
"current_step": self.current_step,
}

def load_state_dict(self, state_dict: dict[str, Any]):
"""Loads all state apart from attached SAE."""
for k in state_dict:
setattr(self, k, state_dict[k])
28 changes: 19 additions & 9 deletions sae_lens/training/sae_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sae_lens.config import LanguageModelSAERunnerConfig
from sae_lens.evals import EvalConfig, run_evals
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.optim import L1Scheduler, get_lr_scheduler
from sae_lens.training.optim import L0Scheduler, L1Scheduler, get_lr_scheduler
from sae_lens.training.training_sae import TrainingSAE, TrainStepOutput

# used to map between parameters which are updated during finetuning and the config str.
Expand Down Expand Up @@ -56,7 +56,6 @@ def __init__(
save_checkpoint_fn, # type: ignore
cfg: LanguageModelSAERunnerConfig,
) -> None:

self.model = model
self.sae = sae
self.activation_store = activation_store
Expand Down Expand Up @@ -118,6 +117,11 @@ def __init__(
total_steps=cfg.total_training_steps,
final_l1_coefficient=cfg.l1_coefficient,
)
self.l0_scheduler = L0Scheduler(
l0_warm_up_steps=cfg.l0_warm_up_steps,
total_steps=cfg.total_training_steps,
final_l0_lambda=cfg.l0_lambda,
)

# Setup autocast if using
self.scaler = torch.cuda.amp.GradScaler(enabled=self.cfg.autocast)
Expand Down Expand Up @@ -157,12 +161,15 @@ def log_feature_sparsity(self) -> torch.Tensor:
def current_l1_coefficient(self) -> float:
return self.l1_scheduler.current_l1_coefficient

@property
def current_l0_lambda(self) -> float:
return self.l0_scheduler.current_l0_lambda

@property
def dead_neurons(self) -> torch.Tensor:
return (self.n_forward_passes_since_fired > self.cfg.dead_feature_window).bool()

def fit(self) -> TrainingSAE:

pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training SAE")

self._estimate_norm_scaling_factor_if_needed()
Expand Down Expand Up @@ -216,7 +223,6 @@ def _train_step(
sae: TrainingSAE,
sae_in: torch.Tensor,
) -> TrainStepOutput:

sae.train()
# Make sure the W_dec is still zero-norm
if self.cfg.normalize_sae_decoder:
Expand All @@ -232,11 +238,11 @@ def _train_step(
# for documentation on autocasting see:
# https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html
with self.autocast_if_enabled:

train_step_output = self.sae.training_forward_pass(
sae_in=sae_in,
dead_neuron_mask=self.dead_neurons,
current_l1_coefficient=self.current_l1_coefficient,
current_l0_lambda=self.current_l0_lambda,
)

with torch.no_grad():
Expand Down Expand Up @@ -264,6 +270,7 @@ def _train_step(
self.optimizer.zero_grad()
self.lr_scheduler.step()
self.l1_scheduler.step()
self.l0_scheduler.step()

return train_step_output

Expand Down Expand Up @@ -309,6 +316,7 @@ def _build_train_step_log_dict(
"sparsity/dead_features": self.dead_neurons.sum().item(),
"details/current_learning_rate": current_learning_rate,
"details/current_l1_coefficient": self.current_l1_coefficient,
"details/current_l0_lambda": self.current_l0_lambda,
"details/n_training_tokens": n_training_tokens,
}
for loss_name, loss_value in output.losses.items():
Expand Down Expand Up @@ -369,7 +377,6 @@ def _run_and_log_evals(self):

@torch.no_grad()
def _build_sparsity_log_dict(self) -> dict[str, Any]:

log_feature_sparsity = _log_feature_sparsity(self.feature_sparsity)
wandb_histogram = wandb.Histogram(log_feature_sparsity.numpy()) # type: ignore
return {
Expand All @@ -381,7 +388,6 @@ def _build_sparsity_log_dict(self) -> dict[str, Any]:

@torch.no_grad()
def _reset_running_sparsity_stats(self) -> None:

self.act_freq_scores = torch.zeros(
self.cfg.d_sae, # type: ignore
device=self.cfg.device,
Expand All @@ -401,8 +407,12 @@ def _checkpoint_if_needed(self):
self.checkpoint_thresholds.pop(0)

@torch.no_grad()
def _update_pbar(self, step_output: TrainStepOutput, pbar: tqdm, update_interval: int = 100): # type: ignore

def _update_pbar(
self,
step_output: TrainStepOutput,
pbar: tqdm, # type: ignore
update_interval: int = 100, # type: ignore
): # type: ignore
if self.n_training_steps % update_interval == 0:
loss_strs = " | ".join(
f"{loss_name}: {_unwrap_item(loss_value):.5f}"
Expand Down
23 changes: 12 additions & 11 deletions sae_lens/training/training_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def setup_context(
ctx.bandwidth = bandwidth

@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[None, torch.Tensor, None]: # type: ignore[override]
def backward( # type: ignore
ctx: Any, grad_output: torch.Tensor
) -> tuple[None, torch.Tensor, None]: # type: ignore[override]
x, threshold = ctx.saved_tensors
bandwidth = ctx.bandwidth
threshold_grad = torch.sum(
Expand All @@ -73,7 +75,9 @@ def setup_context(
ctx.bandwidth = bandwidth

@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None]: # type: ignore[override]
def backward( # type: ignore
ctx: Any, grad_output: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, None]: # type: ignore[override]
x, threshold = ctx.saved_tensors
bandwidth = ctx.bandwidth
x_grad = (x > threshold) * grad_output # We don't apply STE to x input
Expand All @@ -97,10 +101,10 @@ class TrainStepOutput:

@dataclass(kw_only=True)
class TrainingSAEConfig(SAEConfig):

# Sparsity Loss Calculations
l1_coefficient: float
lp_norm: float
l0_lambda: float
lp_norm: Optional[float]
use_ghost_grads: bool
normalize_sae_decoder: bool
noise_scale: float
Expand All @@ -116,7 +120,6 @@ class TrainingSAEConfig(SAEConfig):
def from_sae_runner_config(
cls, cfg: LanguageModelSAERunnerConfig
) -> "TrainingSAEConfig":

return cls(
# base config
architecture=cfg.architecture,
Expand All @@ -139,6 +142,7 @@ def from_sae_runner_config(
seqpos_slice=cfg.seqpos_slice,
# Training cfg
l1_coefficient=cfg.l1_coefficient,
l0_lambda=cfg.l0_lambda,
lp_norm=cfg.lp_norm,
use_ghost_grads=cfg.use_ghost_grads,
normalize_sae_decoder=cfg.normalize_sae_decoder,
Expand Down Expand Up @@ -181,6 +185,7 @@ def to_dict(self) -> dict[str, Any]:
return {
**super().to_dict(),
"l1_coefficient": self.l1_coefficient,
"l0_lambda": self.l0_lambda,
"lp_norm": self.lp_norm,
"use_ghost_grads": self.use_ghost_grads,
"normalize_sae_decoder": self.normalize_sae_decoder,
Expand Down Expand Up @@ -359,7 +364,6 @@ def forward(
self,
x: Float[torch.Tensor, "... d_in"],
) -> Float[torch.Tensor, "... d_in"]:

feature_acts, _ = self.encode_with_hidden_pre_fn(x)
sae_out = self.decode(feature_acts)

Expand All @@ -370,8 +374,8 @@ def training_forward_pass(
sae_in: torch.Tensor,
current_l1_coefficient: float,
dead_neuron_mask: Optional[torch.Tensor] = None,
current_l0_lambda: Optional[float] = None,
) -> TrainStepOutput:

# do a forward pass to get SAE out, but we also need the
# hidden pre.
feature_acts, hidden_pre = self.encode_with_hidden_pre_fn(sae_in)
Expand Down Expand Up @@ -410,7 +414,7 @@ def training_forward_pass(
elif self.cfg.architecture == "jumprelu":
threshold = torch.exp(self.log_threshold)
l0 = torch.sum(Step.apply(hidden_pre, threshold, self.bandwidth), dim=-1) # type: ignore
l0_loss = (current_l1_coefficient * l0).mean()
l0_loss = (current_l0_lambda * l0).mean()
loss = mse_loss + l0_loss
losses["l0_loss"] = l0_loss
else:
Expand Down Expand Up @@ -458,7 +462,6 @@ def calculate_ghost_grad_loss(
hidden_pre: torch.Tensor,
dead_neuron_mask: torch.Tensor,
) -> torch.Tensor:

# 1.
residual = x - sae_out
l2_norm_residual = torch.norm(residual, dim=-1)
Expand Down Expand Up @@ -490,7 +493,6 @@ def calculate_ghost_grad_loss(

@torch.no_grad()
def _get_mse_loss_fn(self) -> Any:

def standard_mse_loss_fn(
preds: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
Expand Down Expand Up @@ -529,7 +531,6 @@ def load_from_pretrained(
device: str = "cpu",
dtype: str | None = None,
) -> "TrainingSAE":

# get the config
config_path = os.path.join(path, SAE_CFG_PATH)
with open(config_path, "r") as f:
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class LanguageModelSAERunnerConfigDict(TypedDict, total=False):
use_cached_activations: bool
d_in: int
l1_coefficient: float
l0_lambda: float
lp_norm: float
lr: float
train_batch_size_tokens: int
Expand Down Expand Up @@ -56,6 +57,7 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig:
"use_cached_activations": False,
"d_in": 64,
"l1_coefficient": 2e-3,
"l0_lambda": 6e-4,
"lp_norm": 1,
"lr": 2e-4,
"train_batch_size_tokens": 4,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/training/test_jumprelu_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_jumprelu_sae_training_forward_pass():
train_step_output = sae.training_forward_pass(
sae_in=x,
current_l1_coefficient=sae.cfg.l1_coefficient,
current_l0_lambda=sae.cfg.l0_lambda,
)

assert train_step_output.sae_out.shape == (batch_size, d_in)
Expand Down
Loading