From dd09264723d6d145f030db1992acae0b2049649b Mon Sep 17 00:00:00 2001 From: Tom Pollak Date: Mon, 11 Nov 2024 12:01:57 +0000 Subject: [PATCH] Cleaned up CacheActionsRunnerConfig Before `CacheActivationConfig` had a inconsistent config file for some interopability with `LanguageModelSAERunnerConfig`. It was kind of unclear which parameters were necessary vs redundant, and just was fairly unclear. Simplified to the required arguments: - `hf_dataset_path`: Tokenized or untokenized dataset - `total_training_tokens` - `model_name` - `model_batch_size` - `hook_name` - `final_hook_layer` - `d_in` I think this scheme captures everything you need when attempting to cache activations and makes it a lot easier to reason about. Optional: ``` activation_save_path # defaults to "activations/{dataset}/{model}/{hook_name} shuffle=True prepend_bos=True streaming=True seqpos_slice buffer_size_gb=2 # Size of each buffer. Affects memory usage and saving freq device="cuda" or "cpu" dtype="float32" autocast_lm=False compile_llm=True hf_repo_id # Push to hf model_kwargs # `run_with_cache` model_from_pretrained_kwargs ``` --- sae_lens/cache_activations_runner.py | 110 ++--- sae_lens/config.py | 174 +++++--- sae_lens/training/activations_store.py | 72 +++- scripts/ansible/util/cache_acts.py | 6 +- scripts/caching_replication_how_train_saes.py | 50 +-- .../test_cache_activations_runner.py | 86 ++-- .../training/test_cache_activations_runner.py | 377 +++++------------- tests/unit/training/test_config.py | 19 +- 8 files changed, 413 insertions(+), 481 deletions(-) diff --git a/sae_lens/cache_activations_runner.py b/sae_lens/cache_activations_runner.py index 6c41455b..b13af7a0 100644 --- a/sae_lens/cache_activations_runner.py +++ b/sae_lens/cache_activations_runner.py @@ -1,6 +1,5 @@ import io import json -import math import shutil from dataclasses import asdict from pathlib import Path @@ -12,6 +11,7 @@ from huggingface_hub import HfApi from jaxtyping import Float from tqdm import tqdm +from transformer_lens.HookedTransformer import HookedRootModule from sae_lens.config import DTYPE_MAP, CacheActivationsRunnerConfig from sae_lens.load_model import load_model @@ -21,28 +21,29 @@ class CacheActivationsRunner: def __init__(self, cfg: CacheActivationsRunnerConfig): self.cfg = cfg - self.model = load_model( - model_class_name=cfg.model_class_name, - model_name=cfg.model_name, - device=cfg.device, - model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs, + self.model: HookedRootModule = load_model( + model_class_name=self.cfg.model_class_name, + model_name=self.cfg.model_name, + device=self.cfg.device, + model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs, ) - self.activations_store = ActivationsStore.from_config( + if self.cfg.compile_llm: + self.model = torch.compile(self.model, mode=self.cfg.llm_compilation_mode) # type: ignore + self.activations_store = ActivationsStore._from_save_activations( self.model, - cfg, + self.cfg, + ) + self.context_size = self._get_sliced_context_size( + self.cfg.context_size, self.cfg.seqpos_slice ) - ctx_size = _get_sliced_context_size(self.cfg) self.features = Features( { - f"{self.cfg.hook_name}": Array2D( - shape=(ctx_size, self.cfg.d_in), dtype=self.cfg.dtype + hook_name: Array2D( + shape=(self.context_size, self.cfg.d_in), dtype=self.cfg.dtype ) + for hook_name in [self.cfg.hook_name] } ) - self.tokens_in_buffer = ( - self.cfg.n_batches_in_buffer * self.cfg.store_batch_size_prompts * ctx_size - ) - self.n_buffers = math.ceil(self.cfg.training_tokens / self.tokens_in_buffer) def __str__(self): """ @@ -57,14 +58,14 @@ def __str__(self): if isinstance(self.cfg.dtype, torch.dtype) else DTYPE_MAP[self.cfg.dtype].itemsize ) - total_training_tokens = self.cfg.training_tokens + total_training_tokens = self.cfg.dataset_num_rows * self.context_size total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9 return ( f"Activation Cache Runner:\n" f"Total training tokens: {total_training_tokens}\n" - f"Number of buffers: {self.n_buffers}\n" - f"Tokens per buffer: {self.tokens_in_buffer}\n" + f"Number of buffers: {self.cfg.n_buffers}\n" + f"Tokens per buffer: {self.cfg.tokens_in_buffer}\n" f"Disk space required: {total_disk_space_gb:.2f} GB\n" f"Configuration:\n" f"{self.cfg}" @@ -189,32 +190,13 @@ def _consolidate_shards( return Dataset.load_from_disk(output_dir) - @torch.no_grad() - def _create_shard( - self, - buffer: Float[torch.Tensor, "(bs context_size) num_layers d_in"], - ) -> Dataset: - hook_names = [self.cfg.hook_name] # allow multiple hooks in future - - buffer = einops.rearrange( - buffer, - "(bs context_size) num_layers d_in -> num_layers bs context_size d_in", - bs=self.cfg.n_batches_in_buffer * self.cfg.store_batch_size_prompts, - context_size=_get_sliced_context_size(self.cfg), - d_in=self.cfg.d_in, - num_layers=len(hook_names), - ) - shard = Dataset.from_dict( - {hook_name: act for hook_name, act in zip(hook_names, buffer)}, - features=self.features, - ) - return shard - @torch.no_grad() def run(self) -> Dataset: + activation_save_path = self.cfg.activation_save_path + assert activation_save_path is not None + ### Paths setup - assert self.cfg.new_cached_activations_path is not None - final_cached_activation_path = Path(self.cfg.new_cached_activations_path) + final_cached_activation_path = Path(activation_save_path) final_cached_activation_path.mkdir(exist_ok=True, parents=True) if any(final_cached_activation_path.iterdir()): raise Exception( @@ -226,13 +208,12 @@ def run(self) -> Dataset: ### Create temporary sharded datasets - print(f"Started caching {self.cfg.training_tokens} activations") + print(f"Started caching activations for {self.cfg.hf_dataset_path}") - for i in tqdm(range(self.n_buffers), desc="Caching activations"): + for i in tqdm(range(self.cfg.n_buffers), desc="Caching activations"): try: - # num activations in a single shard: n_batches_in_buffer * store_batch_size_prompts buffer = self.activations_store.get_buffer( - self.cfg.n_batches_in_buffer, shuffle=self.cfg.shuffle + self.cfg.batches_in_buffer, shuffle=False ) shard = self._create_shard(buffer) shard.save_to_disk( @@ -242,11 +223,11 @@ def run(self) -> Dataset: except StopIteration: print( - f"Warning: Ran out of samples while filling the buffer at batch {i} before reaching {self.n_buffers} batches. No more caching will occur." + f"Warning: Ran out of samples while filling the buffer at batch {i} before reaching {self.cfg.n_buffers} batches." ) break - ### Concat sharded datasets together, shuffle and push to hub + ### Concatenate shards and push to Huggingface Hub dataset = self._consolidate_shards( tmp_cached_activation_path, final_cached_activation_path, copy_files=False @@ -257,10 +238,10 @@ def run(self) -> Dataset: dataset = dataset.shuffle(seed=self.cfg.seed) if self.cfg.hf_repo_id: - print("Pushing to hub...") + print("Pushing to Huggingface Hub...") dataset.push_to_hub( repo_id=self.cfg.hf_repo_id, - num_shards=self.cfg.hf_num_shards or self.n_buffers, + num_shards=self.cfg.hf_num_shards or self.cfg.n_buffers, private=self.cfg.hf_is_private_repo, revision=self.cfg.hf_revision, ) @@ -283,9 +264,30 @@ def run(self) -> Dataset: return dataset + def _create_shard( + self, + buffer: Float[torch.Tensor, "(bs context_size) num_layers d_in"], + ) -> Dataset: + hook_names = [self.cfg.hook_name] -def _get_sliced_context_size(cfg: CacheActivationsRunnerConfig) -> int: - context_size = cfg.context_size - if cfg.seqpos_slice: - context_size = len(range(context_size)[slice(*cfg.seqpos_slice)]) - return context_size + buffer = einops.rearrange( + buffer, + "(bs context_size) num_layers d_in -> num_layers bs context_size d_in", + bs=self.cfg.rows_in_buffer, + context_size=self.context_size, + d_in=self.cfg.d_in, + num_layers=len(hook_names), + ) + shard = Dataset.from_dict( + {hook_name: act for hook_name, act in zip(hook_names, buffer)}, + features=self.features, + ) + return shard + + @staticmethod + def _get_sliced_context_size( + context_size: int, seqpos_slice: tuple[int | None, ...] | None + ) -> int: + if seqpos_slice is not None: + context_size = len(range(context_size)[slice(*seqpos_slice)]) + return context_size diff --git a/sae_lens/config.py b/sae_lens/config.py index 8488a3e2..d416bc55 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -1,11 +1,19 @@ import json +import math import os from dataclasses import dataclass, field +from pathlib import Path from typing import Any, Literal, Optional, cast import torch import wandb -from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict +from datasets import ( + Dataset, + DatasetDict, + IterableDataset, + IterableDatasetDict, + load_dataset, +) from sae_lens import __version__ @@ -478,76 +486,136 @@ def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig": @dataclass class CacheActivationsRunnerConfig: """ - Configuration for caching activations of an LLM. + Configuration for creating and caching activations of an LLM. + + Args: + hf_dataset_path (str): The path to the Hugging Face dataset. This may be tokenized or not. + model_name (str): The name of the model to use. + model_batch_size (int): How many prompts are in the batch of the language model when generating activations. + hook_name (str): The name of the hook to use. + final_hook_layer (int): The layer of the final hook. Currently only support a single hook, so this should be the same as hook_name. + d_in (int): Dimension of the model. + total_training_tokens (int): Total number of tokens to process. + context_size (int): Context size to process. Can be left as -1 if the dataset is tokenized. + model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`. + activation_save_path (str, optional): The path to save the activations. + shuffle (bool): Whether to shuffle the dataset. + seed (int): The seed to use for shuffling. + dtype (str): Datatype of activations to be stored. + device (str): The device for the model. + buffer_size_gb (float): The buffer size in GB. This should be < 2GB. + hf_repo_id (str, optional): The Hugging Face repository id to save the activations to. + hf_num_shards (int, optional): The number of shards to save the activations to. + hf_revision (str): The revision to save the activations to. + hf_is_private_repo (bool): Whether the Hugging Face repository is private. + model_kwargs (dict): Keyword arguments for `model.run_with_cache`. + model_from_pretrained_kwargs (dict): Keyword arguments for the `from_pretrained` method of the model. + compile_llm (bool): Whether to compile the LLM. + llm_compilation_mode (str): The torch.compile mode to use. + prepend_bos (bool): Whether to prepend the beginning of sequence token. You should use whatever the model was trained with. + seqpos_slice (tuple): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0. + streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical. + autocast_lm (bool): Whether to use autocast during activation fetching. + dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface. """ - # Data Generating Function (Model + Training Distibuion) - model_name: str = "gelu-2l" - model_class_name: str = "HookedTransformer" - hook_name: str = "blocks.{layer}.hook_mlp_out" - hook_layer: int = 0 - hook_head_index: Optional[int] = None - dataset_path: str = "" - dataset_trust_remote_code: bool | None = None - streaming: bool = True - is_dataset_tokenized: bool = True - context_size: int = 128 - new_cached_activations_path: Optional[str] = ( - None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_head_index}" - ) + hf_dataset_path: str - # if saving to huggingface, set hf_repo_id - hf_repo_id: Optional[str] = None - hf_num_shards: int | None = None - hf_revision: str = "main" - hf_is_private_repo: bool = False + model_name: str + model_batch_size: int - # dont' specify this since you don't want to load from disk with the cache runner. - cached_activations_path: Optional[str] = None - # SAE Parameters - d_in: int = 512 + hook_name: str + final_hook_layer: int # Layer of final hook. Same layer as hook_name + d_in: int - # Activation Store Parameters - n_batches_in_buffer: int = 20 - training_tokens: int = 2_000_000 - store_batch_size_prompts: int = 32 - train_batch_size_tokens: int = 4096 - normalize_activations: str = "none" # should always be none for activation caching - seqpos_slice: tuple[int | None, ...] = (None,) + total_training_tokens: int - # Misc - device: str = "cpu" - act_store_device: str = "with_model" # will be set by post init if with_model + context_size: int = -1 # Required if dataset is not tokenized + model_class_name: str = "HookedTransformer" + # defaults to "activations/{dataset}/{model}/{hook_name} + activation_save_path: str | None = None # type: ignore + + shuffle: bool = True seed: int = 42 dtype: str = "float32" - prepend_bos: bool = True - autocast_lm: bool = False # autocast lm during activation fetching + device: str = "cuda" if torch.cuda.is_available() else "cpu" + buffer_size_gb: float = 2.0 # HF datasets writer have problems with shards > 2GB - # Shuffle activations - shuffle: bool = True + # Huggingface Integration + hf_repo_id: str | None = None + hf_num_shards: int | None = None + hf_revision: str = "main" + hf_is_private_repo: bool = False + # Model model_kwargs: dict[str, Any] = field(default_factory=dict) model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict) + compile_llm: bool = False + llm_compilation_mode: str | None = None # which torch.compile mode to use + + # Activation Store + prepend_bos: bool = True + seqpos_slice: tuple[int | None, ...] = (None,) + streaming: bool = True + autocast_lm: bool = False + dataset_trust_remote_code: bool | None = None + + # set in __post_init__ + tokens_in_buffer: int = -1 + rows_in_buffer: int = -1 + n_buffers: int = -1 + dataset_num_rows: int = -1 + batches_in_buffer: int = -1 def __post_init__(self): - # Autofill cached_activations_path unless the user overrode it - if self.new_cached_activations_path is None: - self.new_cached_activations_path = _default_cached_activations_path( - self.dataset_path, - self.model_name, - self.hook_name, - self.hook_head_index, + # Automatically determine context_size if dataset is tokenized + if self.context_size == -1: + ds = load_dataset(self.hf_dataset_path, split="train", streaming=True) + assert isinstance(ds, IterableDataset) + first_sample = next(iter(ds)) + toks = first_sample.get("tokens") or first_sample.get("input_ids") or None + if toks is None: + raise ValueError( + "Dataset is not tokenized. Please specify context_size." + ) + token_length = len(toks) + self.context_size = token_length + assert self.context_size != -1 + + if self.seqpos_slice is not None: + _validate_seqpos( + seqpos=self.seqpos_slice, + context_size=self.context_size, ) - if self.act_store_device == "with_model": - self.act_store_device = self.device - - if self.context_size < 0: - raise ValueError( - f"The provided context_size is {self.context_size} is negative. Expecting positive context_size." + if self.activation_save_path is None: + self.activation_save_path = _default_cached_activations_path( # type: ignore + self.hf_dataset_path, self.model_name, self.hook_name, None ) - _validate_seqpos(seqpos=self.seqpos_slice, context_size=self.context_size) + self.activation_save_path: Path + self.activation_save_path = Path(self.activation_save_path) + + bytes_per_token = self.d_in * DTYPE_MAP[self.dtype].itemsize + # Calculate raw tokens per buffer based on memory constraints + _tokens_per_buffer = int(self.buffer_size_gb * 1e9) // bytes_per_token + _batch_token_size = self.model_batch_size * self.sliced_context_size + # Round down to nearest multiple of batch_token_size + self.tokens_in_buffer = _tokens_per_buffer - ( + _tokens_per_buffer % _batch_token_size + ) + self.rows_in_buffer = self.tokens_in_buffer // self.sliced_context_size + + self.dataset_num_rows = self.total_training_tokens // self.sliced_context_size + self.batches_in_buffer = self.tokens_in_buffer // _batch_token_size + + self.n_buffers = math.ceil(self.total_training_tokens / self.tokens_in_buffer) + + @property + def sliced_context_size(self) -> int: + if self.seqpos_slice is not None: + return len(range(self.context_size)[slice(*self.seqpos_slice)]) + return self.context_size @dataclass diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 972d528e..f5e6c545 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -49,6 +49,67 @@ class ActivationsStore: _storage_buffer: torch.Tensor | None = None device: torch.device + @classmethod + def _from_save_activations( + cls, + model: HookedRootModule, + cfg: CacheActivationsRunnerConfig, + ) -> "ActivationsStore": + return cls( + model=model, + dataset=cfg.hf_dataset_path, + streaming=cfg.streaming, + hook_name=cfg.hook_name, + hook_layer=cfg.final_hook_layer, + hook_head_index=None, + context_size=cfg.context_size, + d_in=cfg.d_in, + n_batches_in_buffer=cfg.batches_in_buffer, + total_training_tokens=cfg.total_training_tokens, + store_batch_size_prompts=cfg.model_batch_size, + train_batch_size_tokens=-1, + prepend_bos=cfg.prepend_bos, + normalize_activations="none", + device=torch.device("cpu"), # since we're saving to disk + dtype=cfg.dtype, + cached_activations_path=None, + model_kwargs=cfg.model_kwargs, + autocast_lm=cfg.autocast_lm, + dataset_trust_remote_code=cfg.dataset_trust_remote_code, + seqpos_slice=cfg.seqpos_slice, + ) + + @classmethod + def from_cache_activations( + cls, + model: HookedRootModule, + cfg: CacheActivationsRunnerConfig, + ) -> "ActivationsStore": + return cls( + cached_activations_path=str(cfg.activation_save_path), + dtype=cfg.dtype, + hook_name=cfg.hook_name, + hook_layer=cfg.final_hook_layer, + context_size=cfg.context_size, + d_in=cfg.d_in, + n_batches_in_buffer=cfg.batches_in_buffer, + total_training_tokens=cfg.total_training_tokens, + store_batch_size_prompts=cfg.model_batch_size, # get_buffer + train_batch_size_tokens=cfg.model_batch_size, # dataloader + seqpos_slice=(None,), + device=torch.device(cfg.device), # since we're sending these to SAE + # NOOP + prepend_bos=False, + hook_head_index=None, + dataset=cfg.hf_dataset_path, + streaming=False, + model=model, + normalize_activations="none", + model_kwargs=None, + autocast_lm=False, + dataset_trust_remote_code=None, + ) + @classmethod def from_config( cls, @@ -56,6 +117,9 @@ def from_config( cfg: LanguageModelSAERunnerConfig | CacheActivationsRunnerConfig, override_dataset: HfDataset | None = None, ) -> "ActivationsStore": + if isinstance(cfg, CacheActivationsRunnerConfig): + return cls.from_cache_activations(model, cfg) + cached_activations_path = cfg.cached_activations_path # set cached_activations_path to None if we're not using cached activations if ( @@ -563,7 +627,7 @@ def get_buffer( total_size, context_size, num_layers, d_in, raise_on_epoch_end ) - refill_iterator = range(0, batch_size * n_batches_in_buffer, batch_size) + refill_iterator = range(0, total_size, batch_size) # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers new_buffer = torch.zeros( (total_size, training_context_size, num_layers, d_in), @@ -571,7 +635,9 @@ def get_buffer( device=self.device, ) - for refill_batch_idx_start in refill_iterator: + for refill_batch_idx_start in tqdm( + refill_iterator, leave=False, desc="Refilling buffer" + ): # move batch toks to gpu for model refill_batch_tokens = self.get_batch_tokens( raise_at_epoch_end=raise_on_epoch_end @@ -583,8 +649,6 @@ def get_buffer( refill_batch_idx_start : refill_batch_idx_start + batch_size, ... ] = refill_activations - # pbar.update(1) - new_buffer = new_buffer.reshape(-1, num_layers, d_in) if shuffle: new_buffer = new_buffer[torch.randperm(new_buffer.shape[0])] diff --git a/scripts/ansible/util/cache_acts.py b/scripts/ansible/util/cache_acts.py index 376f95b1..d6cabe1e 100644 --- a/scripts/ansible/util/cache_acts.py +++ b/scripts/ansible/util/cache_acts.py @@ -44,10 +44,10 @@ if config is None: raise ValueError("Error: The config is not loaded.") -print(f"Total Training Tokens: {config.training_tokens}") +print(f"Total Training Tokens: {config.total_training_tokens}") # This is set by Ansible -new_cached_activations_path = config.new_cached_activations_path +new_cached_activations_path = config.activation_save_path if new_cached_activations_path is None: raise ValueError("Error: The new_cached_activations_path is not set.") @@ -70,5 +70,5 @@ end_time = time.time() print(f"Total time taken: {end_time - start_time:.2f} seconds") print( - f"{config.training_tokens / ((end_time - start_time)*10**6):.2f} Million Tokens / Second" + f"{config.total_training_tokens / ((end_time - start_time)*10**6):.2f} Million Tokens / Second" ) diff --git a/scripts/caching_replication_how_train_saes.py b/scripts/caching_replication_how_train_saes.py index 279a6472..4be317bd 100755 --- a/scripts/caching_replication_how_train_saes.py +++ b/scripts/caching_replication_how_train_saes.py @@ -1,12 +1,9 @@ import os -import shutil import time import torch from sae_lens.cache_activations_runner import CacheActivationsRunner - -# from pathlib import Path from sae_lens.config import CacheActivationsRunnerConfig if torch.cuda.is_available(): @@ -19,35 +16,12 @@ print("Using device:", device) os.environ["TOKENIZERS_PARALLELISM"] = "false" - -total_training_steps = 20_000 -batch_size = 4096 -total_training_tokens = total_training_steps * batch_size -print(f"Total Training Tokens: {total_training_tokens}") - # change these configs model_name = "gelu-1l" -dataset_path = "NeelNanda/c4-tokenized-2b" -new_cached_activations_path = ( - f"./cached_activations/{model_name}/{dataset_path}/{total_training_steps}" -) - -# check how much data is in the directory -if os.path.exists(new_cached_activations_path): - print("Directory exists. Checking how much data is in the directory.") - total_files = sum( - os.path.getsize(os.path.join(new_cached_activations_path, f)) - for f in os.listdir(new_cached_activations_path) - if os.path.isfile(os.path.join(new_cached_activations_path, f)) - ) - print(f"Total size of directory: {total_files / 1e9:.2f} GB") +model_batch_size = 16 -# If the directory exists, delete it. -if input("Delete the directory? (y/n): ") == "y" and os.path.exists( - new_cached_activations_path -): - if os.path.exists(new_cached_activations_path): - shutil.rmtree(new_cached_activations_path) +dataset_path = "NeelNanda/c4-tokenized-2b" +total_training_tokens = 100_000 if device == "cuda": torch.cuda.empty_cache() @@ -55,24 +29,16 @@ torch.mps.empty_cache() cfg = CacheActivationsRunnerConfig( - new_cached_activations_path=new_cached_activations_path, # Pick a tiny model to make this easier. model_name=model_name, + hf_dataset_path=dataset_path, ## MLP Layer 0 ## hook_name="blocks.0.hook_mlp_out", - hook_layer=0, + final_hook_layer=0, d_in=512, - dataset_path=dataset_path, - context_size=1024, - is_dataset_tokenized=True, prepend_bos=True, - training_tokens=total_training_tokens, # For initial testing I think this is a good number. - train_batch_size_tokens=4096, - # buffer details - n_batches_in_buffer=4, - store_batch_size_prompts=128, - normalize_activations="none", - # + total_training_tokens=total_training_tokens, + model_batch_size=model_batch_size, # Misc device=device, seed=42, @@ -94,5 +60,5 @@ end_time = time.time() print(f"Total time taken: {end_time - start_time:.2f} seconds") print( - f"{total_training_tokens / ((end_time - start_time)*10**6):.2f} Million Tokens / Second" + f"{cfg.total_training_tokens / ((end_time - start_time)*10**6):.2f} Million Tokens / Second" ) diff --git a/tests/benchmark/test_cache_activations_runner.py b/tests/benchmark/test_cache_activations_runner.py index e24a94e1..342e5007 100644 --- a/tests/benchmark/test_cache_activations_runner.py +++ b/tests/benchmark/test_cache_activations_runner.py @@ -1,3 +1,4 @@ +import math import os import shutil import time @@ -8,7 +9,7 @@ from tqdm import trange from sae_lens.cache_activations_runner import CacheActivationsRunner -from sae_lens.config import CacheActivationsRunnerConfig +from sae_lens.config import DTYPE_MAP, CacheActivationsRunnerConfig os.environ["WANDB_MODE"] = "offline" # turn this off if you want to see the output @@ -26,45 +27,31 @@ def test_cache_activations_runner(): print("Using device:", device) os.environ["TOKENIZERS_PARALLELISM"] = "false" - total_training_steps = 500 - batch_size = 4096 - total_training_tokens = total_training_steps * batch_size - print(f"Total Training Tokens: {total_training_tokens}") - - new_cached_activations_path = ( + activations_save_path = ( os.path.dirname(os.path.realpath(__file__)) + "/fixtures/test_activations/gelu_1l" ) # If the directory exists, delete it. - if os.path.exists(new_cached_activations_path): - shutil.rmtree(new_cached_activations_path) + if os.path.exists(activations_save_path): + shutil.rmtree(activations_save_path) torch.mps.empty_cache() cfg = CacheActivationsRunnerConfig( - new_cached_activations_path=new_cached_activations_path, + activation_save_path=activations_save_path, + total_training_tokens=16_000, # Pick a tiny model to make this easier. model_name="gelu-1l", - # model_name="gpt2-xl", + model_batch_size=16, ## MLP Layer 0 ## hook_name="blocks.0.hook_mlp_out", - hook_layer=0, + final_hook_layer=0, d_in=512, - # d_in=1600, - dataset_path="NeelNanda/c4-tokenized-2b", - streaming=False, + ## Dataset ## + hf_dataset_path="NeelNanda/c4-tokenized-2b", context_size=1024, - is_dataset_tokenized=True, - prepend_bos=True, - training_tokens=total_training_tokens, # For initial testing I think this is a good number. - train_batch_size_tokens=4096, - # buffer details - n_batches_in_buffer=32, - store_batch_size_prompts=16, - normalize_activations="none", - # - # Misc + ## Misc ## device=device, seed=42, dtype="float32", @@ -79,30 +66,41 @@ def test_cache_activations_runner(): def test_hf_dataset_save_vs_safetensors(tmp_path: Path): niters = 10 + context_size = 32 + dataset_num_rows = 10_000 + total_training_tokens = dataset_num_rows * context_size + model_batch_size = 8 + num_buffers = 4 * niters ### d_in = 512 - context_size = 32 - n_batches_in_buffer = 32 - batch_size = 8 - num_buffers = 4 * niters - num_tokens = batch_size * context_size * n_batches_in_buffer * num_buffers + dtype = "float32" + device = ( + "cuda" + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) + + bytes_per_token = d_in * DTYPE_MAP[dtype].itemsize + tokens_per_buffer = math.ceil(dataset_num_rows * context_size / num_buffers) + buffer_size_gb = min((tokens_per_buffer * bytes_per_token) / 1_000_000_000, 2.0) cfg = CacheActivationsRunnerConfig( - new_cached_activations_path=str(tmp_path), - d_in=d_in, - context_size=context_size, + activation_save_path=str(tmp_path), + hf_dataset_path="NeelNanda/c4-tokenized-2b", model_name="gelu-1l", hook_name="blocks.0.hook_mlp_out", - dataset_path="NeelNanda/c4-tokenized-2b", - training_tokens=num_tokens, - n_batches_in_buffer=n_batches_in_buffer, - store_batch_size_prompts=batch_size, - normalize_activations="none", - device="cpu", + final_hook_layer=0, + d_in=d_in, + context_size=context_size, + total_training_tokens=total_training_tokens, + model_batch_size=model_batch_size, + buffer_size_gb=buffer_size_gb, + prepend_bos=False, + device=device, seed=42, - dtype="float32", + dtype=dtype, ) runner = CacheActivationsRunner(cfg) store = runner.activations_store @@ -117,18 +115,18 @@ def test_hf_dataset_save_vs_safetensors(tmp_path: Path): print("Warmup") for i in trange(niters // 2, leave=False): - buffer = store.get_buffer(n_batches_in_buffer) + buffer = store.get_buffer(cfg.batches_in_buffer) start_time = time.perf_counter() for i in trange(niters, leave=False): - buffer = store.get_buffer(n_batches_in_buffer) + buffer = store.get_buffer(cfg.batches_in_buffer) end_time = time.perf_counter() print(f"No saving took: {end_time - start_time:.4f}") start_time = time.perf_counter() for i in trange(niters, leave=False): - buffer = store.get_buffer(n_batches_in_buffer) + buffer = store.get_buffer(cfg.batches_in_buffer) shard = runner._create_shard(buffer) shard.save_to_disk(hf_path / str(i), num_shards=1) end_time = time.perf_counter() @@ -140,7 +138,7 @@ def test_hf_dataset_save_vs_safetensors(tmp_path: Path): start_time = time.perf_counter() for i in trange(niters, leave=False): - buffer = store.get_buffer(n_batches_in_buffer) + buffer = store.get_buffer(cfg.batches_in_buffer) save_file({"activations": buffer}, safetensors_path / f"{i}.safetensors") end_time = time.perf_counter() diff --git a/tests/unit/training/test_cache_activations_runner.py b/tests/unit/training/test_cache_activations_runner.py index b702a5a3..06e9749f 100644 --- a/tests/unit/training/test_cache_activations_runner.py +++ b/tests/unit/training/test_cache_activations_runner.py @@ -1,4 +1,5 @@ import dataclasses +import math import os from pathlib import Path from typing import Any, Tuple @@ -11,169 +12,109 @@ from transformer_lens import HookedTransformer from sae_lens.cache_activations_runner import CacheActivationsRunner -from sae_lens.config import CacheActivationsRunnerConfig, LanguageModelSAERunnerConfig +from sae_lens.config import ( + DTYPE_MAP, + CacheActivationsRunnerConfig, + LanguageModelSAERunnerConfig, +) from sae_lens.load_model import load_model from sae_lens.training.activations_store import ActivationsStore -def _create_dataset(tmp_path: Path) -> Dataset: - torch.manual_seed(42) - - model_name = "gelu-1l" - hook_name = "blocks.0.hook_mlp_out" - dataset_path = "chanind/c4-10k-mini-tokenized-16-ctx-gelu-1l-tests" - batch_size = 1 - batches_in_buffer = 2 - context_size = 8 - num_buffers = 4 +def _default_cfg( + tmp_path: Path, + batch_size: int = 16, + context_size: int = 8, + dataset_num_rows: int = 128, + n_buffers: int = 4, + **kwargs: Any, +) -> CacheActivationsRunnerConfig: + d_in = 512 + dtype = "float32" + device = ( + "cuda" + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) - train_batch_size_tokens = 32 + sliced_context_size = kwargs.get("seqpos_slice", None) + if sliced_context_size is not None: + sliced_context_size = len(range(context_size)[slice(*sliced_context_size)]) + else: + sliced_context_size = context_size - tokens_in_buffer = batches_in_buffer * batch_size * context_size - num_tokens = tokens_in_buffer * num_buffers + # Calculate buffer_size_gb to achieve desired n_buffers + bytes_per_token = d_in * DTYPE_MAP[dtype].itemsize + tokens_per_buffer = math.ceil(dataset_num_rows * sliced_context_size / n_buffers) + buffer_size_gb = (tokens_per_buffer * bytes_per_token) / 1_000_000_000 + total_training_tokens = dataset_num_rows * sliced_context_size cfg = CacheActivationsRunnerConfig( - new_cached_activations_path=str(tmp_path), - model_name=model_name, - hook_name=hook_name, - dataset_path=dataset_path, - training_tokens=num_tokens, - shuffle=False, - store_batch_size_prompts=batch_size, - train_batch_size_tokens=train_batch_size_tokens, - n_batches_in_buffer=batches_in_buffer, - ### - hook_layer=0, - d_in=512, + activation_save_path=str(tmp_path), + hf_dataset_path="chanind/c4-10k-mini-tokenized-16-ctx-gelu-1l-tests", + model_name="gelu-1l", + hook_name="blocks.0.hook_mlp_out", + final_hook_layer=0, + ### Parameters + total_training_tokens=total_training_tokens, + model_batch_size=batch_size, + buffer_size_gb=buffer_size_gb, context_size=context_size, - is_dataset_tokenized=True, + ### + d_in=d_in, + shuffle=False, prepend_bos=False, - normalize_activations="none", - device="cpu", + device=device, seed=42, - dtype="float32", + dtype=dtype, + **kwargs, ) - - runner = CacheActivationsRunner(cfg) - return runner.run() + assert cfg.n_buffers == n_buffers + assert cfg.dataset_num_rows == dataset_num_rows + assert ( + cfg.tokens_in_buffer == cfg.batches_in_buffer * batch_size * sliced_context_size + ) + return cfg # The way to run this with this command: # poetry run py.test tests/unit/test_cache_activations_runner.py --profile-svg -s def test_cache_activations_runner(tmp_path: Path): + cfg = _default_cfg(tmp_path) + runner = CacheActivationsRunner(cfg) + dataset = runner.run() - # total_training_steps = 20_000 - context_size = 8 - n_batches_in_buffer = 32 - store_batch_size = 1 - n_buffers = 3 - - tokens_in_buffer = n_batches_in_buffer * store_batch_size * context_size - total_training_tokens = n_buffers * tokens_in_buffer - total_rows = store_batch_size * n_batches_in_buffer * n_buffers - - # better if we can look at the files (change tmp_path to a real path to look at the files) - # tmp_path = os.path.join(os.path.dirname(__file__), "tmp") - # tmp_path = Path("/Volumes/T7 Shield/activations/gelu_1l") - # if os.path.exists(tmp_path): - # shutil.rmtree(tmp_path) - - cfg = CacheActivationsRunnerConfig( - new_cached_activations_path=str(tmp_path), - # Pick a tiny model to make this easier. - model_name="gelu-1l", - ## MLP Layer 0 ## - hook_name="blocks.0.hook_mlp_out", - hook_layer=0, - d_in=512, - dataset_path="chanind/c4-10k-mini-tokenized-16-ctx-gelu-1l-tests", - context_size=context_size, # Speed things up. - is_dataset_tokenized=True, - prepend_bos=True, # I used to train GPT2 SAEs with a prepended-bos but no longer think we should do this. - training_tokens=total_training_tokens, # For initial testing I think this is a good number. - train_batch_size_tokens=32, - # Loss Function - ## Reconstruction Coefficient. - # Buffer details won't matter in we cache / shuffle our activations ahead of time. - n_batches_in_buffer=n_batches_in_buffer, - store_batch_size_prompts=store_batch_size, - normalize_activations="none", - # Misc - device="cpu", - seed=42, - dtype="float16", - ) - - # look at the next cell to see some instruction for what to do while this is running. - dataset = CacheActivationsRunner(cfg).run() - assert len(dataset) == total_rows + assert len(dataset) == cfg.n_buffers * (cfg.tokens_in_buffer // cfg.context_size) + assert cfg.dataset_num_rows == len(dataset) assert dataset.num_columns == 1 and dataset.column_names == [cfg.hook_name] features = dataset.features - assert isinstance(features[cfg.hook_name], datasets.Array2D) - assert features[cfg.hook_name].shape == (context_size, cfg.d_in) + for hook_name in [cfg.hook_name]: + assert isinstance(features[hook_name], datasets.Array2D) + assert features[hook_name].shape == (cfg.context_size, cfg.d_in) def test_load_cached_activations(tmp_path: Path): - - # total_training_steps = 20_000 - context_size = 8 - n_batches_in_buffer = 4 - store_batch_size = 1 - n_buffers = 4 - - tokens_in_buffer = n_batches_in_buffer * store_batch_size * context_size - total_training_tokens = n_buffers * tokens_in_buffer - - _create_dataset(tmp_path) - - cfg = LanguageModelSAERunnerConfig( - cached_activations_path=str(tmp_path), - use_cached_activations=True, - # Pick a tiny model to make this easier. - model_name="gelu-1l", - ## MLP Layer 0 ## - hook_name="blocks.0.hook_mlp_out", - hook_layer=0, - d_in=512, - dataset_path="chanind/c4-10k-mini-tokenized-16-ctx-gelu-1l-tests", - context_size=context_size, - is_dataset_tokenized=True, - prepend_bos=True, # I used to train GPT2 SAEs with a prepended-bos but no longer think we should do this. - training_tokens=total_training_tokens, # For initial testing I think this is a good number. - train_batch_size_tokens=total_training_tokens // 2, - # Loss Function - ## Reconstruction Coefficient. - # Buffer details won't matter in we cache / shuffle our activations ahead of time. - n_batches_in_buffer=n_batches_in_buffer, - store_batch_size_prompts=store_batch_size, - normalize_activations="none", - # shuffle_every_n_buffers=2, - # n_shuffles_with_last_section=1, - # n_shuffles_in_entire_dir=1, - # n_shuffles_final=1, - # Misc - device="cpu", - seed=42, - dtype="float16", - ) + cfg = _default_cfg(tmp_path) + runner = CacheActivationsRunner(cfg) + runner.run() model = HookedTransformer.from_pretrained(cfg.model_name) + activations_store = ActivationsStore.from_config(model, cfg) - for _ in range(n_buffers): + for _ in range(cfg.n_buffers): buffer = activations_store.get_buffer( - cfg.n_batches_in_buffer + cfg.batches_in_buffer ) # Adjusted to use n_batches_in_buffer assert buffer.shape == ( - cfg.n_batches_in_buffer * cfg.store_batch_size_prompts * cfg.context_size, + cfg.rows_in_buffer * cfg.context_size, 1, cfg.d_in, ) def test_activations_store_refreshes_dataset_when_it_runs_out(tmp_path: Path): - context_size = 8 n_batches_in_buffer = 4 store_batch_size = 1 @@ -181,7 +122,9 @@ def test_activations_store_refreshes_dataset_when_it_runs_out(tmp_path: Path): batch_size = 4 total_training_tokens = total_training_steps * batch_size - _create_dataset(tmp_path) + cache_cfg = _default_cfg(tmp_path) + runner = CacheActivationsRunner(cache_cfg) + runner.run() cfg = LanguageModelSAERunnerConfig( cached_activations_path=str(tmp_path), @@ -253,78 +196,36 @@ def test_compare_cached_activations_end_to_end_with_ground_truth(tmp_path: Path) """ torch.manual_seed(42) - - model_name = "gelu-1l" - hook_name = "blocks.0.hook_mlp_out" - dataset_path = "chanind/c4-10k-mini-tokenized-16-ctx-gelu-1l-tests" - batch_size = 8 - batches_in_buffer = 4 - context_size = 8 - num_buffers = 4 - - train_batch_size_tokens = 8 - - tokens_in_buffer = batches_in_buffer * batch_size * context_size - num_tokens = tokens_in_buffer * num_buffers - - cfg = CacheActivationsRunnerConfig( - new_cached_activations_path=str(tmp_path), - model_name=model_name, - hook_name=hook_name, - dataset_path=dataset_path, - training_tokens=num_tokens, - shuffle=False, - store_batch_size_prompts=batch_size, - train_batch_size_tokens=train_batch_size_tokens, - n_batches_in_buffer=batches_in_buffer, - ### - hook_layer=0, - d_in=512, - context_size=context_size, - is_dataset_tokenized=True, - prepend_bos=False, - normalize_activations="none", - device="cpu", - seed=42, - dtype="float32", - ) - + cfg = _default_cfg(tmp_path) runner = CacheActivationsRunner(cfg) activation_dataset = runner.run() - activation_dataset.set_format("torch", device=cfg.device) + activation_dataset.set_format("torch") dataset_acts: torch.Tensor = activation_dataset[cfg.hook_name] # type: ignore - model = HookedTransformer.from_pretrained(model_name, device=cfg.device) - token_dataset: Dataset = load_dataset(dataset_path, split=f"train[:{num_tokens}]") # type: ignore + model = HookedTransformer.from_pretrained(cfg.model_name, device=cfg.device) + token_dataset: Dataset = load_dataset(cfg.hf_dataset_path, split=f"train[:{cfg.dataset_num_rows}]") # type: ignore token_dataset.set_format("torch", device=cfg.device) - total_rows = batch_size * batches_in_buffer * num_buffers - ground_truth_acts = [] - for i in trange(0, total_rows, batch_size): - tokens = token_dataset[i : i + batch_size]["input_ids"][:, :context_size] + for i in trange(0, cfg.dataset_num_rows, cfg.model_batch_size): + tokens = token_dataset[i : i + cfg.model_batch_size]["input_ids"][ + :, : cfg.context_size + ] _, layerwise_activations = model.run_with_cache( tokens, names_filter=[cfg.hook_name], - stop_at_layer=cfg.hook_layer + 1, - **cfg.model_kwargs, + stop_at_layer=cfg.final_hook_layer + 1, ) acts = layerwise_activations[cfg.hook_name] ground_truth_acts.append(acts) - ground_truth_acts = torch.cat(ground_truth_acts, dim=0) + ground_truth_acts = torch.cat(ground_truth_acts, dim=0).cpu() assert torch.allclose(ground_truth_acts, dataset_acts, rtol=1e-3, atol=5e-2) def test_load_activations_store_with_nonexistent_dataset(tmp_path: Path): - cfg = CacheActivationsRunnerConfig( - model_name="gelu-1l", - hook_name="blocks.0.hook_mlp_out", - dataset_path="chanind/c4-10k-mini-tokenized-16-ctx-gelu-1l-tests", - cached_activations_path=str(tmp_path), - context_size=16, - ) + cfg = _default_cfg(tmp_path) model = load_model( model_class_name=cfg.model_class_name, @@ -342,15 +243,6 @@ def test_load_activations_store_with_nonexistent_dataset(tmp_path: Path): def test_cache_activations_runner_with_nonempty_directory(tmp_path: Path): - cfg = CacheActivationsRunnerConfig( - new_cached_activations_path=str(tmp_path), - model_name="gelu-1l", - hook_name="blocks.0.hook_mlp_out", - dataset_path="chanind/c4-10k-mini-tokenized-16-ctx-gelu-1l-tests", - context_size=16, - ) - runner = CacheActivationsRunner(cfg) - # Create a file to make the directory non-empty with open(tmp_path / "some_file.txt", "w") as f: f.write("test") @@ -358,35 +250,13 @@ def test_cache_activations_runner_with_nonempty_directory(tmp_path: Path): with pytest.raises( Exception, match="is not empty. Please delete it or specify a different path." ): + cfg = _default_cfg(tmp_path) + runner = CacheActivationsRunner(cfg) runner.run() - # Clean up - os.remove(tmp_path / "some_file.txt") - def test_cache_activations_runner_with_incorrect_d_in(tmp_path: Path): - d_in = 512 - context_size = 8 - n_batches_in_buffer = 4 - batch_size = 8 - num_buffers = 4 - num_tokens = batch_size * context_size * n_batches_in_buffer * num_buffers - - correct_cfg = CacheActivationsRunnerConfig( - new_cached_activations_path=str(tmp_path), - d_in=d_in, - context_size=context_size, - model_name="gelu-1l", - hook_name="blocks.0.hook_mlp_out", - dataset_path="chanind/c4-10k-mini-tokenized-16-ctx-gelu-1l-tests", - training_tokens=num_tokens, - n_batches_in_buffer=n_batches_in_buffer, - store_batch_size_prompts=batch_size, - normalize_activations="none", - device="cpu", - seed=42, - dtype="float32", - ) + correct_cfg = _default_cfg(tmp_path) # d_in different from hook wrong_d_in_cfg = CacheActivationsRunnerConfig( @@ -403,105 +273,57 @@ def test_cache_activations_runner_with_incorrect_d_in(tmp_path: Path): def test_cache_activations_runner_load_dataset_with_incorrect_config(tmp_path: Path): - d_in = 512 - context_size = 16 - n_batches_in_buffer = 4 - batch_size = 8 - num_buffers = 2 - num_tokens = batch_size * context_size * n_batches_in_buffer * num_buffers - - correct_cfg = CacheActivationsRunnerConfig( - new_cached_activations_path=str(tmp_path), - d_in=d_in, - context_size=context_size, - model_name="gelu-1l", - hook_name="blocks.0.hook_mlp_out", - dataset_path="chanind/c4-10k-mini-tokenized-16-ctx-gelu-1l-tests", - training_tokens=num_tokens, - n_batches_in_buffer=n_batches_in_buffer, - store_batch_size_prompts=batch_size, - normalize_activations="none", - device="cpu", - seed=42, - dtype="float32", - ) - - # Run with correct configuration first - CacheActivationsRunner(correct_cfg).run() - - ### + correct_cfg = _default_cfg(tmp_path, context_size=16) + runner = CacheActivationsRunner(correct_cfg) + runner.run() + model = runner.model # Context size different from dataset wrong_context_size_cfg = CacheActivationsRunnerConfig( **dataclasses.asdict(correct_cfg), ) wrong_context_size_cfg.context_size = 13 - wrong_context_size_cfg.new_cached_activations_path = None - wrong_context_size_cfg.cached_activations_path = str(tmp_path) with pytest.raises( ValueError, match=r"Given dataset of shape \(16, 512\) does not match context_size \(13\) and d_in \(512\)", ): - CacheActivationsRunner(wrong_context_size_cfg).run() + ActivationsStore.from_config(model, wrong_context_size_cfg) # d_in different from dataset wrong_d_in_cfg = CacheActivationsRunnerConfig( **dataclasses.asdict(correct_cfg), ) wrong_d_in_cfg.d_in = 513 - wrong_d_in_cfg.new_cached_activations_path = None - wrong_d_in_cfg.cached_activations_path = str(tmp_path) with pytest.raises( ValueError, match=r"Given dataset of shape \(16, 512\) does not match context_size \(16\) and d_in \(513\)", ): - CacheActivationsRunner(wrong_d_in_cfg).run() + ActivationsStore.from_config(model, wrong_d_in_cfg) # Incorrect hook_name wrong_hook_cfg = CacheActivationsRunnerConfig( **dataclasses.asdict(correct_cfg), ) wrong_hook_cfg.hook_name = "blocks.1.hook_mlp_out" - wrong_hook_cfg.new_cached_activations_path = None - wrong_hook_cfg.cached_activations_path = str(tmp_path) with pytest.raises( ValueError, match=r"Columns \['blocks.1.hook_mlp_out'\] not in the dataset. Current columns in the dataset: \['blocks.0.hook_mlp_out'\]", ): - CacheActivationsRunner(wrong_hook_cfg).run() + ActivationsStore.from_config(model, wrong_hook_cfg) def test_cache_activations_runner_with_valid_seqpos(tmp_path: Path): - context_size = 16 - seqpos_slice = (3, -3) - training_context_size = len(range(context_size)[slice(*seqpos_slice)]) - n_batches_in_buffer = 4 - store_batch_size = 1 - n_buffers = 3 - - tokens_in_buffer = n_batches_in_buffer * store_batch_size * training_context_size - total_training_tokens = n_buffers * tokens_in_buffer - - cfg = CacheActivationsRunnerConfig( - new_cached_activations_path=str(tmp_path), - d_in=512, - context_size=context_size, - model_name="gelu-1l", - hook_name="blocks.0.hook_mlp_out", - dataset_path="chanind/c4-10k-mini-tokenized-16-ctx-gelu-1l-tests", - training_tokens=total_training_tokens, - n_batches_in_buffer=n_batches_in_buffer, - store_batch_size_prompts=store_batch_size, - normalize_activations="none", - device="cpu", - seed=42, - dtype="float32", - seqpos_slice=seqpos_slice, + cfg = _default_cfg( + tmp_path, + batch_size=1, + context_size=16, + n_buffers=3, + dataset_num_rows=12, + seqpos_slice=(3, -3), ) - runner = CacheActivationsRunner(cfg) activation_dataset = runner.run() @@ -516,9 +338,8 @@ def test_cache_activations_runner_with_valid_seqpos(tmp_path: Path): for f in os.listdir(tmp_path) if f.startswith("data-") and f.endswith(".arrow") ] - assert len(buffer_files) == n_buffers + assert len(buffer_files) == cfg.n_buffers - assert len(dataset_acts) == n_buffers * n_batches_in_buffer for act in dataset_acts: # should be 16 - 3 - 3 = 10 assert act.shape == (10, cfg.d_in) diff --git a/tests/unit/training/test_config.py b/tests/unit/training/test_config.py index 154546b3..916151f9 100644 --- a/tests/unit/training/test_config.py +++ b/tests/unit/training/test_config.py @@ -160,17 +160,30 @@ def test_sae_training_runner_config_seqpos( def test_cache_activations_runner_config_seqpos( seqpos_slice: tuple[int, int], expected_error: Optional[AssertionError] ): - context_size = 10 if expected_error is AssertionError: with pytest.raises(expected_error): CacheActivationsRunnerConfig( + hf_dataset_path="", + model_name="", + model_batch_size=1, + hook_name="", + final_hook_layer=0, + d_in=1, + total_training_tokens=100, + context_size=10, seqpos_slice=seqpos_slice, - context_size=context_size, ) else: CacheActivationsRunnerConfig( + hf_dataset_path="", + model_name="", + model_batch_size=1, + hook_name="", + final_hook_layer=0, + d_in=1, + total_training_tokens=100, + context_size=10, seqpos_slice=seqpos_slice, - context_size=context_size, )