Skip to content

Commit

Permalink
Cleaned up CacheActionsRunnerConfig
Browse files Browse the repository at this point in the history
Before `CacheActivationConfig` had a inconsistent config file for some
interopability with `LanguageModelSAERunnerConfig`. It was kind of
unclear which parameters were necessary vs redundant, and just was
fairly unclear.

Simplified to the required arguments:

- `hf_dataset_path`: Tokenized or untokenized dataset
- `total_training_tokens`
- `model_name`
- `model_batch_size`
- `hook_name`
- `final_hook_layer`
- `d_in`

I think this scheme captures everything you need when attempting to
cache activations and makes it a lot easier to reason about.

Optional:

```
activation_save_path # defaults to "activations/{dataset}/{model}/{hook_name}
shuffle=True
prepend_bos=True
streaming=True
seqpos_slice
buffer_size_gb=2 # Size of each buffer. Affects memory usage and saving freq
device="cuda" or "cpu"
dtype="float32"
autocast_lm=False
compile_llm=True
hf_repo_id # Push to hf
model_kwargs # `run_with_cache`
model_from_pretrained_kwargs
```
  • Loading branch information
tom-pollak committed Nov 25, 2024
1 parent 85c90e5 commit dd09264
Show file tree
Hide file tree
Showing 8 changed files with 413 additions and 481 deletions.
110 changes: 56 additions & 54 deletions sae_lens/cache_activations_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import io
import json
import math
import shutil
from dataclasses import asdict
from pathlib import Path
Expand All @@ -12,6 +11,7 @@
from huggingface_hub import HfApi
from jaxtyping import Float
from tqdm import tqdm
from transformer_lens.HookedTransformer import HookedRootModule

from sae_lens.config import DTYPE_MAP, CacheActivationsRunnerConfig
from sae_lens.load_model import load_model
Expand All @@ -21,28 +21,29 @@
class CacheActivationsRunner:
def __init__(self, cfg: CacheActivationsRunnerConfig):
self.cfg = cfg
self.model = load_model(
model_class_name=cfg.model_class_name,
model_name=cfg.model_name,
device=cfg.device,
model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs,
self.model: HookedRootModule = load_model(
model_class_name=self.cfg.model_class_name,
model_name=self.cfg.model_name,
device=self.cfg.device,
model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs,
)
self.activations_store = ActivationsStore.from_config(
if self.cfg.compile_llm:
self.model = torch.compile(self.model, mode=self.cfg.llm_compilation_mode) # type: ignore
self.activations_store = ActivationsStore._from_save_activations(
self.model,
cfg,
self.cfg,
)
self.context_size = self._get_sliced_context_size(
self.cfg.context_size, self.cfg.seqpos_slice
)
ctx_size = _get_sliced_context_size(self.cfg)
self.features = Features(
{
f"{self.cfg.hook_name}": Array2D(
shape=(ctx_size, self.cfg.d_in), dtype=self.cfg.dtype
hook_name: Array2D(
shape=(self.context_size, self.cfg.d_in), dtype=self.cfg.dtype
)
for hook_name in [self.cfg.hook_name]
}
)
self.tokens_in_buffer = (
self.cfg.n_batches_in_buffer * self.cfg.store_batch_size_prompts * ctx_size
)
self.n_buffers = math.ceil(self.cfg.training_tokens / self.tokens_in_buffer)

def __str__(self):
"""
Expand All @@ -57,14 +58,14 @@ def __str__(self):
if isinstance(self.cfg.dtype, torch.dtype)
else DTYPE_MAP[self.cfg.dtype].itemsize
)
total_training_tokens = self.cfg.training_tokens
total_training_tokens = self.cfg.dataset_num_rows * self.context_size
total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9

return (
f"Activation Cache Runner:\n"
f"Total training tokens: {total_training_tokens}\n"
f"Number of buffers: {self.n_buffers}\n"
f"Tokens per buffer: {self.tokens_in_buffer}\n"
f"Number of buffers: {self.cfg.n_buffers}\n"
f"Tokens per buffer: {self.cfg.tokens_in_buffer}\n"
f"Disk space required: {total_disk_space_gb:.2f} GB\n"
f"Configuration:\n"
f"{self.cfg}"
Expand Down Expand Up @@ -189,32 +190,13 @@ def _consolidate_shards(

return Dataset.load_from_disk(output_dir)

@torch.no_grad()
def _create_shard(
self,
buffer: Float[torch.Tensor, "(bs context_size) num_layers d_in"],
) -> Dataset:
hook_names = [self.cfg.hook_name] # allow multiple hooks in future

buffer = einops.rearrange(
buffer,
"(bs context_size) num_layers d_in -> num_layers bs context_size d_in",
bs=self.cfg.n_batches_in_buffer * self.cfg.store_batch_size_prompts,
context_size=_get_sliced_context_size(self.cfg),
d_in=self.cfg.d_in,
num_layers=len(hook_names),
)
shard = Dataset.from_dict(
{hook_name: act for hook_name, act in zip(hook_names, buffer)},
features=self.features,
)
return shard

@torch.no_grad()
def run(self) -> Dataset:
activation_save_path = self.cfg.activation_save_path
assert activation_save_path is not None

### Paths setup
assert self.cfg.new_cached_activations_path is not None
final_cached_activation_path = Path(self.cfg.new_cached_activations_path)
final_cached_activation_path = Path(activation_save_path)
final_cached_activation_path.mkdir(exist_ok=True, parents=True)
if any(final_cached_activation_path.iterdir()):
raise Exception(
Expand All @@ -226,13 +208,12 @@ def run(self) -> Dataset:

### Create temporary sharded datasets

print(f"Started caching {self.cfg.training_tokens} activations")
print(f"Started caching activations for {self.cfg.hf_dataset_path}")

for i in tqdm(range(self.n_buffers), desc="Caching activations"):
for i in tqdm(range(self.cfg.n_buffers), desc="Caching activations"):
try:
# num activations in a single shard: n_batches_in_buffer * store_batch_size_prompts
buffer = self.activations_store.get_buffer(
self.cfg.n_batches_in_buffer, shuffle=self.cfg.shuffle
self.cfg.batches_in_buffer, shuffle=False
)
shard = self._create_shard(buffer)
shard.save_to_disk(
Expand All @@ -242,11 +223,11 @@ def run(self) -> Dataset:

except StopIteration:
print(
f"Warning: Ran out of samples while filling the buffer at batch {i} before reaching {self.n_buffers} batches. No more caching will occur."
f"Warning: Ran out of samples while filling the buffer at batch {i} before reaching {self.cfg.n_buffers} batches."
)
break

### Concat sharded datasets together, shuffle and push to hub
### Concatenate shards and push to Huggingface Hub

dataset = self._consolidate_shards(
tmp_cached_activation_path, final_cached_activation_path, copy_files=False
Expand All @@ -257,10 +238,10 @@ def run(self) -> Dataset:
dataset = dataset.shuffle(seed=self.cfg.seed)

if self.cfg.hf_repo_id:
print("Pushing to hub...")
print("Pushing to Huggingface Hub...")
dataset.push_to_hub(
repo_id=self.cfg.hf_repo_id,
num_shards=self.cfg.hf_num_shards or self.n_buffers,
num_shards=self.cfg.hf_num_shards or self.cfg.n_buffers,
private=self.cfg.hf_is_private_repo,
revision=self.cfg.hf_revision,
)
Expand All @@ -283,9 +264,30 @@ def run(self) -> Dataset:

return dataset

def _create_shard(
self,
buffer: Float[torch.Tensor, "(bs context_size) num_layers d_in"],
) -> Dataset:
hook_names = [self.cfg.hook_name]

def _get_sliced_context_size(cfg: CacheActivationsRunnerConfig) -> int:
context_size = cfg.context_size
if cfg.seqpos_slice:
context_size = len(range(context_size)[slice(*cfg.seqpos_slice)])
return context_size
buffer = einops.rearrange(
buffer,
"(bs context_size) num_layers d_in -> num_layers bs context_size d_in",
bs=self.cfg.rows_in_buffer,
context_size=self.context_size,
d_in=self.cfg.d_in,
num_layers=len(hook_names),
)
shard = Dataset.from_dict(
{hook_name: act for hook_name, act in zip(hook_names, buffer)},
features=self.features,
)
return shard

@staticmethod
def _get_sliced_context_size(
context_size: int, seqpos_slice: tuple[int | None, ...] | None
) -> int:
if seqpos_slice is not None:
context_size = len(range(context_size)[slice(*seqpos_slice)])
return context_size
Loading

0 comments on commit dd09264

Please sign in to comment.