From 0a57e473b63f79043d69172592ec3ee00b852339 Mon Sep 17 00:00:00 2001 From: callummcdougall Date: Wed, 18 Sep 2024 09:13:51 +0100 Subject: [PATCH 1/5] support seqpos slicing --- sae_lens/config.py | 3 +++ sae_lens/training/activations_store.py | 27 +++++++++++++------------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/sae_lens/config.py b/sae_lens/config.py index a74b7565..94bf1517 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -60,6 +60,7 @@ class LanguageModelSAERunnerConfig: store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations. train_batch_size_tokens (int): The batch size for training. This controls the batch size of the SAE Training loop. normalize_activations (str): Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output). + seqpos_slice (tuple): Determines slicing of (batch, seq, d_in) activations when constructing batches, during training. Example: for Othello we sometimes use (5, -5). device (str): The device to use. Usually cuda. act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram. seed (int): The seed to use. @@ -151,6 +152,7 @@ class LanguageModelSAERunnerConfig: normalize_activations: str = ( "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update) ) + seqpos_slice: tuple[int | None, ...] = (None,) # Misc device: str = "cpu" @@ -453,6 +455,7 @@ class CacheActivationsRunnerConfig: store_batch_size_prompts: int = 32 train_batch_size_tokens: int = 4096 normalize_activations: str = "none" # should always be none for activation caching + seqpos_slice: tuple[int | None, ...] = (None,) # Misc device: str = "cpu" diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 3a068cd2..d68aadf3 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -87,6 +87,7 @@ def from_config( model_kwargs=cfg.model_kwargs, autocast_lm=cfg.autocast_lm, dataset_trust_remote_code=cfg.dataset_trust_remote_code, + seqpos_slice=cfg.seqpos_slice, ) @classmethod @@ -146,6 +147,7 @@ def __init__( model_kwargs: dict[str, Any] | None = None, autocast_lm: bool = False, dataset_trust_remote_code: bool | None = None, + seqpos_slice: tuple[int | None, ...] = (None,) ): self.model = model if model_kwargs is None: @@ -187,6 +189,7 @@ def __init__( self.dtype = DTYPE_MAP[dtype] self.cached_activations_path = cached_activations_path self.autocast_lm = autocast_lm + self.seqpos_slice = seqpos_slice self.n_dataset_processed = 0 @@ -428,7 +431,7 @@ def get_activations(self, batch_tokens: torch.Tensor): autocast_if_enabled = contextlib.nullcontext() with autocast_if_enabled: - layerwise_activations = self.model.run_with_cache( + layerwise_activations_cache = self.model.run_with_cache( batch_tokens, names_filter=[self.hook_name], stop_at_layer=self.hook_layer + 1, @@ -436,29 +439,26 @@ def get_activations(self, batch_tokens: torch.Tensor): **self.model_kwargs, )[1] - n_batches, n_context = batch_tokens.shape + layerwise_activations = layerwise_activations_cache[self.hook_name][:, slice(*self.seqpos_slice)] + n_batches, n_context = layerwise_activations.shape[:2] stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in)) if self.hook_head_index is not None: - stacked_activations[:, :, 0] = layerwise_activations[self.hook_name][ + stacked_activations[:, :, 0] = layerwise_activations[ :, :, self.hook_head_index ] elif ( - layerwise_activations[self.hook_name].ndim > 3 + layerwise_activations.ndim > 3 ): # if we have a head dimension try: - stacked_activations[:, :, 0] = layerwise_activations[ - self.hook_name - ].view(n_batches, n_context, -1) + stacked_activations[:, :, 0] = layerwise_activations.view(n_batches, n_context, -1) except RuntimeError as e: print(f"Error during view operation: {e}") print("Attempting to use reshape instead...") - stacked_activations[:, :, 0] = layerwise_activations[ - self.hook_name - ].reshape(n_batches, n_context, -1) + stacked_activations[:, :, 0] = layerwise_activations.reshape(n_batches, n_context, -1) else: - stacked_activations[:, :, 0] = layerwise_activations[self.hook_name] + stacked_activations[:, :, 0] = layerwise_activations return stacked_activations @@ -474,6 +474,7 @@ def get_buffer( If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react. """ context_size = self.context_size + training_context_size = len(range(context_size)[slice(*self.seqpos_slice)]) batch_size = self.store_batch_size_prompts d_in = self.d_in total_size = batch_size * n_batches_in_buffer @@ -481,7 +482,7 @@ def get_buffer( if self.cached_activations_path is not None: # Load the activations from disk - buffer_size = total_size * context_size + buffer_size = total_size * training_context_size # Initialize an empty tensor with an additional dimension for layers new_buffer = torch.zeros( (buffer_size, num_layers, d_in), @@ -535,7 +536,7 @@ def get_buffer( refill_iterator = range(0, batch_size * n_batches_in_buffer, batch_size) # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers new_buffer = torch.zeros( - (total_size, context_size, num_layers, d_in), + (total_size, training_context_size, num_layers, d_in), dtype=self.dtype, # type: ignore device=self.device, ) From 05543fb9eb05f2d3b74ecf00ad30903fcd0f3603 Mon Sep 17 00:00:00 2001 From: jbloomAus Date: Fri, 20 Sep 2024 10:47:03 +0100 Subject: [PATCH 2/5] add basic tests, ensure it's in the SAE config --- sae_lens/config.py | 1 + sae_lens/sae.py | 6 +++++ sae_lens/training/activations_store.py | 3 ++- tests/unit/training/test_activations_store.py | 23 +++++++++++++++++++ tests/unit/training/test_sae_basic.py | 17 ++++++++++++++ 5 files changed, 49 insertions(+), 1 deletion(-) diff --git a/sae_lens/config.py b/sae_lens/config.py index 94bf1517..94f772ce 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -380,6 +380,7 @@ def get_base_sae_cfg_dict(self) -> dict[str, Any]: "normalize_activations": self.normalize_activations, "activation_fn_kwargs": self.activation_fn_kwargs, "model_from_pretrained_kwargs": self.model_from_pretrained_kwargs, + "seqpos_slice": self.seqpos_slice, } def get_training_sae_cfg_dict(self) -> dict[str, Any]: diff --git a/sae_lens/sae.py b/sae_lens/sae.py index b347990a..55cd03ed 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -62,6 +62,7 @@ class SAEConfig: activation_fn_kwargs: dict[str, Any] = field(default_factory=dict) neuronpedia_id: Optional[str] = None model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict) + seqpos_slice: tuple[int | None, ...] = (None,) @classmethod def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig": @@ -81,6 +82,10 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig": for k, v in config_dict.items() if k in cls.__dataclass_fields__ # pylint: disable=no-member } + + if "seqpos_slice" in config_dict: + config_dict["seqpos_slice"] = tuple(config_dict["seqpos_slice"]) + return cls(**config_dict) # def __post_init__(self): @@ -108,6 +113,7 @@ def to_dict(self) -> dict[str, Any]: "normalize_activations": self.normalize_activations, "neuronpedia_id": self.neuronpedia_id, "model_from_pretrained_kwargs": self.model_from_pretrained_kwargs, + "seqpos_slice": self.seqpos_slice, } diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index d68aadf3..605615ea 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -123,6 +123,7 @@ def from_sae( dataset_trust_remote_code=sae.cfg.dataset_trust_remote_code, dtype=sae.cfg.dtype, device=torch.device(device), + seqpos_slice=sae.cfg.seqpos_slice, ) def __init__( @@ -147,7 +148,7 @@ def __init__( model_kwargs: dict[str, Any] | None = None, autocast_lm: bool = False, dataset_trust_remote_code: bool | None = None, - seqpos_slice: tuple[int | None, ...] = (None,) + seqpos_slice: tuple[int | None, ...] = (None,), ): self.model = model if model_kwargs is None: diff --git a/tests/unit/training/test_activations_store.py b/tests/unit/training/test_activations_store.py index 20598d31..d4cf6871 100644 --- a/tests/unit/training/test_activations_store.py +++ b/tests/unit/training/test_activations_store.py @@ -462,3 +462,26 @@ def test_validate_pretokenized_dataset_tokenizer_does_nothing_if_the_dataset_pat model_tokenizer = ts_model.tokenizer assert model_tokenizer is not None validate_pretokenized_dataset_tokenizer(ds_path, model_tokenizer) + + +def test_activations_store_respects_seqpos_slice(ts_model: HookedTransformer): + cfg = build_sae_cfg( + context_size=10, + seqpos_slice=(2, 8), # Only consider positions 2 to 7 (inclusive) + ) + dataset = Dataset.from_list( + [ + {"text": "This is a test sentence for slicing."}, + ] + * 100 + ) + + activation_store = ActivationsStore.from_config( + ts_model, cfg, override_dataset=dataset + ) + + batch = activation_store.get_batch_tokens(1) + activations = activation_store.get_activations(batch) + + assert batch.shape == (1, 10) # Full context size + assert activations.shape == (1, 6, 1, cfg.d_in) # Only 6 positions (2 to 7) diff --git a/tests/unit/training/test_sae_basic.py b/tests/unit/training/test_sae_basic.py index 60dfaddb..da76b754 100644 --- a/tests/unit/training/test_sae_basic.py +++ b/tests/unit/training/test_sae_basic.py @@ -228,6 +228,23 @@ def test_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None: assert torch.allclose(sae_out_1, sae_out_2) +def test_sae_seqpos(tmp_path: Path) -> None: + cfg = build_sae_cfg( + seqpos_slice=(1, 3), + device="cpu", + ) + model_path = str(tmp_path) + sae = SAE.from_dict(cfg.get_base_sae_cfg_dict()) + + assert sae.cfg.seqpos_slice == (1, 3) + + sae.save_model(model_path) + + sae_loaded = SAE.load_from_pretrained(model_path, device="cpu") + + assert sae_loaded.cfg.seqpos_slice == (1, 3) + + # TODO: Handle scaling factor in saeBase # def test_sae_save_and_load_from_pretrained_lacks_scaling_factor( # tmp_path: Path, From c20830a9f7585ff55cde2bfd995ebecbf8b2002f Mon Sep 17 00:00:00 2001 From: jbloomAus Date: Fri, 20 Sep 2024 10:47:15 +0100 Subject: [PATCH 3/5] format --- sae_lens/training/activations_store.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 605615ea..980034db 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -440,7 +440,9 @@ def get_activations(self, batch_tokens: torch.Tensor): **self.model_kwargs, )[1] - layerwise_activations = layerwise_activations_cache[self.hook_name][:, slice(*self.seqpos_slice)] + layerwise_activations = layerwise_activations_cache[self.hook_name][ + :, slice(*self.seqpos_slice) + ] n_batches, n_context = layerwise_activations.shape[:2] stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in)) @@ -449,15 +451,17 @@ def get_activations(self, batch_tokens: torch.Tensor): stacked_activations[:, :, 0] = layerwise_activations[ :, :, self.hook_head_index ] - elif ( - layerwise_activations.ndim > 3 - ): # if we have a head dimension + elif layerwise_activations.ndim > 3: # if we have a head dimension try: - stacked_activations[:, :, 0] = layerwise_activations.view(n_batches, n_context, -1) + stacked_activations[:, :, 0] = layerwise_activations.view( + n_batches, n_context, -1 + ) except RuntimeError as e: print(f"Error during view operation: {e}") print("Attempting to use reshape instead...") - stacked_activations[:, :, 0] = layerwise_activations.reshape(n_batches, n_context, -1) + stacked_activations[:, :, 0] = layerwise_activations.reshape( + n_batches, n_context, -1 + ) else: stacked_activations[:, :, 0] = layerwise_activations From 493edf3d834205d28e20832231bf7eb03631c559 Mon Sep 17 00:00:00 2001 From: jbloomAus Date: Fri, 20 Sep 2024 10:57:46 +0100 Subject: [PATCH 4/5] fix tests --- sae_lens/training/training_sae.py | 1 + tests/unit/training/test_config.py | 1 + 2 files changed, 2 insertions(+) diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py index 217e7252..66637716 100644 --- a/sae_lens/training/training_sae.py +++ b/sae_lens/training/training_sae.py @@ -75,6 +75,7 @@ def from_sae_runner_config( context_size=cfg.context_size, dataset_path=cfg.dataset_path, prepend_bos=cfg.prepend_bos, + seqpos_slice=cfg.seqpos_slice, # Training cfg l1_coefficient=cfg.l1_coefficient, lp_norm=cfg.lp_norm, diff --git a/tests/unit/training/test_config.py b/tests/unit/training/test_config.py index e4cc461a..688fb1f8 100644 --- a/tests/unit/training/test_config.py +++ b/tests/unit/training/test_config.py @@ -67,6 +67,7 @@ def test_sae_training_runner_config_get_sae_base_parameters(): "model_from_pretrained_kwargs": { "center_writing_weights": False, }, + "seqpos_slice": (None,), } assert expected_config == cfg.get_base_sae_cfg_dict() From 9d21daf70a4fe27b715c3050f5a354bec4e0f046 Mon Sep 17 00:00:00 2001 From: jbloomAus Date: Fri, 20 Sep 2024 11:19:14 +0100 Subject: [PATCH 5/5] fix tests 2 --- sae_lens/config.py | 9 +++++++++ sae_lens/training/training_sae.py | 12 ++++++++++++ 2 files changed, 21 insertions(+) diff --git a/sae_lens/config.py b/sae_lens/config.py index 94f772ce..ff066a5b 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -422,6 +422,15 @@ def to_json(self, path: str) -> None: def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig": with open(path + "cfg.json", "r") as f: cfg = json.load(f) + + # ensure that seqpos slices is a tuple + # Ensure seqpos_slice is a tuple + if "seqpos_slice" in cfg: + if isinstance(cfg["seqpos_slice"], list): + cfg["seqpos_slice"] = tuple(cfg["seqpos_slice"]) + elif not isinstance(cfg["seqpos_slice"], tuple): + cfg["seqpos_slice"] = (cfg["seqpos_slice"],) + return cls(**cfg) diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py index 66637716..b7925d4e 100644 --- a/sae_lens/training/training_sae.py +++ b/sae_lens/training/training_sae.py @@ -100,6 +100,18 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAEConfig": valid_config_dict = { key: val for key, val in config_dict.items() if key in valid_field_names } + + # ensure seqpos slice is tuple + # ensure that seqpos slices is a tuple + # Ensure seqpos_slice is a tuple + if "seqpos_slice" in valid_config_dict: + if isinstance(valid_config_dict["seqpos_slice"], list): + valid_config_dict["seqpos_slice"] = tuple( + valid_config_dict["seqpos_slice"] + ) + elif not isinstance(valid_config_dict["seqpos_slice"], tuple): + valid_config_dict["seqpos_slice"] = (valid_config_dict["seqpos_slice"],) + return TrainingSAEConfig(**valid_config_dict) def to_dict(self) -> dict[str, Any]: