Skip to content

Commit

Permalink
Move from manual sharding to HF dataset builder.
Browse files Browse the repository at this point in the history
Depends on #389.

Inspired by:
https://opensourcemechanistic.slack.com/archives/C07EHMK3XC7/p1732413633220709

Instead of manually writing the single arrow shards, we can create a
dataset builder that can do this more efficiently. This speeds up saving
quite a lot, old method spent a some time calculating the fingerprint of
the shard, which was unecessary and would require a hack to get around.

> Along with this change, I also switched to a 1D activation scheme.

- Previously the dataset was stored as a `(seq_len d_in)` array.
- Now stored as a flat `d_in`

Primary reason for this change is shuffling activations. I found that by
using activations sequence, the activations are not properly shuffled.
This is a problem with `ActivationCache` too but there's not a great
solution for it there.

You can observe this in the loss of the SAE by using small buffer sizes
with either using cache or `ActivationStore`.
  • Loading branch information
tom-pollak committed Nov 25, 2024
1 parent dd09264 commit a1da04c
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 277 deletions.
313 changes: 106 additions & 207 deletions sae_lens/cache_activations_runner.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,112 @@
import io
import json
import shutil
from dataclasses import asdict
from pathlib import Path
from typing import Generator

import datasets
import einops
import numpy as np
import pyarrow as pa
import torch
from datasets import Array2D, Dataset, Features
from datasets.fingerprint import generate_fingerprint
from datasets import Dataset, Features, Sequence, Value
from huggingface_hub import HfApi
from jaxtyping import Float
from tqdm import tqdm
from transformer_lens.HookedTransformer import HookedRootModule

from sae_lens.config import DTYPE_MAP, CacheActivationsRunnerConfig
from sae_lens.load_model import load_model
from sae_lens.training.activations_store import ActivationsStore
from transformer_lens.HookedTransformer import HookedRootModule


class CacheActivationDataset(datasets.ArrowBasedBuilder):
cfg: CacheActivationsRunnerConfig
activation_store: ActivationsStore
# info: datasets.DatasetInfo # By DatasetBuilder

pa_dtype: pa.DataType
schema: pa.Schema

hook_names: list[str] # while we can only use one hook

def __init__(
self,
cfg: CacheActivationsRunnerConfig,
activation_store: ActivationsStore,
):
self.cfg = cfg
self.activation_store = activation_store
self.hook_names = [cfg.hook_name]

if cfg.dtype == "float32":
self.pa_dtype = pa.float32()
elif cfg.dtype == "float16":
self.pa_dtype = pa.float16()
else:
raise ValueError(f"dtype {cfg.dtype} not supported")

self.schema = pa.schema(
[
pa.field(hook_name, pa.list_(self.pa_dtype, list_size=cfg.d_in))
for hook_name in self.hook_names
]
)

features = Features(
{
hook_name: Sequence(Value(dtype=cfg.dtype), length=cfg.d_in)
for hook_name in [cfg.hook_name]
}
)
cfg.activation_save_path.mkdir(parents=True, exist_ok=True)
assert cfg.activation_save_path.is_dir()
if any(cfg.activation_save_path.iterdir()):
raise ValueError(
f"Activation save path {cfg.activation_save_path} is not empty. Please delete it or specify a different path"
)
cache_dir = cfg.activation_save_path.parent
dataset_name = cfg.activation_save_path.name
super().__init__(
cache_dir=str(cache_dir),
dataset_name=dataset_name,
info=datasets.DatasetInfo(features=features),
)

def _split_generators(
self, dl_manager: datasets.DownloadManager | datasets.StreamingDownloadManager
) -> list[datasets.SplitGenerator]:
return [
datasets.SplitGenerator(name=str(datasets.Split.TRAIN)),
]

def _generate_tables(self) -> Generator[tuple[int, pa.Table], None, None]: # type: ignore
for i in range(self.cfg.n_buffers):
buffer = self.activation_store.get_buffer(
self.cfg.batches_in_buffer, shuffle=False
)
assert buffer.device.type == "cpu"
buffer = einops.rearrange(
buffer, "batch hook d_in -> hook batch d_in"
).numpy()
table = pa.Table.from_pydict(
{
hn: self.np2pa_2d(buf, d_in=self.cfg.d_in)
for hn, buf in zip(self.hook_names, buffer)
},
schema=self.schema,
)
yield i, table

@staticmethod
def np2pa_2d(data: Float[np.ndarray, "batch d_in"], d_in: int) -> pa.Array: # type: ignore
"""
Convert a 2D numpy array to a PyArrow FixedSizeListArray.
"""
assert data.ndim == 2, "Input array must be 2-dimensional."
_, d_in_found = data.shape
if d_in_found != d_in:
raise RuntimeError(f"d_in {d_in_found} does not match expected d_in {d_in}")
flat = data.ravel() # no copy if possible
pa_data = pa.array(flat)
return pa.FixedSizeListArray.from_arrays(pa_data, d_in)


class CacheActivationsRunner:
Expand All @@ -33,19 +124,8 @@ def __init__(self, cfg: CacheActivationsRunnerConfig):
self.model,
self.cfg,
)
self.context_size = self._get_sliced_context_size(
self.cfg.context_size, self.cfg.seqpos_slice
)
self.features = Features(
{
hook_name: Array2D(
shape=(self.context_size, self.cfg.d_in), dtype=self.cfg.dtype
)
for hook_name in [self.cfg.hook_name]
}
)

def __str__(self):
def summary(self):
"""
Print the number of tokens to be cached.
Print the number of buffers, and the number of tokens per buffer.
Expand All @@ -58,10 +138,10 @@ def __str__(self):
if isinstance(self.cfg.dtype, torch.dtype)
else DTYPE_MAP[self.cfg.dtype].itemsize
)
total_training_tokens = self.cfg.dataset_num_rows * self.context_size
total_training_tokens = self.cfg.dataset_num_rows * self.cfg.sliced_context_size
total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9

return (
print(
f"Activation Cache Runner:\n"
f"Total training tokens: {total_training_tokens}\n"
f"Number of buffers: {self.cfg.n_buffers}\n"
Expand All @@ -71,168 +151,15 @@ def __str__(self):
f"{self.cfg}"
)

@staticmethod
def _consolidate_shards(
source_dir: Path, output_dir: Path, copy_files: bool = True
) -> Dataset:
"""Consolidate sharded datasets into a single directory without rewriting data.
Each of the shards must be of the same format, aka the full dataset must be able to
be recreated like so:
```
ds = concatenate_datasets(
[Dataset.load_from_disk(str(shard_dir)) for shard_dir in sorted(source_dir.iterdir())]
)
```
Sharded dataset format:
```
source_dir/
shard_00000/
dataset_info.json
state.json
data-00000-of-00002.arrow
data-00001-of-00002.arrow
shard_00001/
dataset_info.json
state.json
data-00000-of-00001.arrow
```
And flattens them into the format:
```
output_dir/
dataset_info.json
state.json
data-00000-of-00003.arrow
data-00001-of-00003.arrow
data-00002-of-00003.arrow
```
allowing the dataset to be loaded like so:
```
ds = datasets.load_from_disk(output_dir)
```
Args:
source_dir: Directory containing the sharded datasets
output_dir: Directory to consolidate the shards into
copy_files: If True, copy files; if False, move them and delete source_dir
"""
first_shard_dir_name = "shard_00000" # shard_{i:05d}

assert source_dir.exists() and source_dir.is_dir()
assert (
output_dir.exists()
and output_dir.is_dir()
and not any(p for p in output_dir.iterdir() if not p.name == ".tmp_shards")
)
if not (source_dir / first_shard_dir_name).exists():
raise Exception(f"No shards in {source_dir} exist!")

transfer_fn = shutil.copy2 if copy_files else shutil.move

# Move dataset_info.json from any shard (all the same)
transfer_fn(
source_dir / first_shard_dir_name / "dataset_info.json",
output_dir / "dataset_info.json",
)

arrow_files = []
file_count = 0

for shard_dir in sorted(source_dir.iterdir()):
if not shard_dir.name.startswith("shard_"):
continue

# state.json contains arrow filenames
state = json.loads((shard_dir / "state.json").read_text())

for data_file in state["_data_files"]:
src = shard_dir / data_file["filename"]
new_name = f"data-{file_count:05d}-of-{len(list(source_dir.iterdir())):05d}.arrow"
dst = output_dir / new_name
transfer_fn(src, dst)
arrow_files.append({"filename": new_name})
file_count += 1

new_state = {
"_data_files": arrow_files,
"_fingerprint": None, # temporary
"_format_columns": None,
"_format_kwargs": {},
"_format_type": None,
"_output_all_columns": False,
"_split": None,
}

# fingerprint is generated from dataset.__getstate__ (not includeing _fingerprint)
with open(output_dir / "state.json", "w") as f:
json.dump(new_state, f, indent=2)

ds = Dataset.load_from_disk(str(output_dir))
fingerprint = generate_fingerprint(ds)
del ds

with open(output_dir / "state.json", "r+") as f:
state = json.loads(f.read())
state["_fingerprint"] = fingerprint
f.seek(0)
json.dump(state, f, indent=2)
f.truncate()

if not copy_files: # cleanup source dir
shutil.rmtree(source_dir)

return Dataset.load_from_disk(output_dir)

@torch.no_grad()
def run(self) -> Dataset:
activation_save_path = self.cfg.activation_save_path
assert activation_save_path is not None

### Paths setup
final_cached_activation_path = Path(activation_save_path)
final_cached_activation_path.mkdir(exist_ok=True, parents=True)
if any(final_cached_activation_path.iterdir()):
raise Exception(
f"Activations directory ({final_cached_activation_path}) is not empty. Please delete it or specify a different path. Exiting the script to prevent accidental deletion of files."
)

tmp_cached_activation_path = final_cached_activation_path / ".tmp_shards/"
tmp_cached_activation_path.mkdir(exist_ok=False, parents=False)

### Create temporary sharded datasets

print(f"Started caching activations for {self.cfg.hf_dataset_path}")

for i in tqdm(range(self.cfg.n_buffers), desc="Caching activations"):
try:
buffer = self.activations_store.get_buffer(
self.cfg.batches_in_buffer, shuffle=False
)
shard = self._create_shard(buffer)
shard.save_to_disk(
f"{tmp_cached_activation_path}/shard_{i:05d}", num_shards=1
)
del buffer, shard

except StopIteration:
print(
f"Warning: Ran out of samples while filling the buffer at batch {i} before reaching {self.cfg.n_buffers} batches."
)
break
builder = CacheActivationDataset(self.cfg, self.activations_store)
builder.download_and_prepare()
dataset = builder.as_dataset(split="train") # type: ignore
assert isinstance(dataset, Dataset)

### Concatenate shards and push to Huggingface Hub

dataset = self._consolidate_shards(
tmp_cached_activation_path, final_cached_activation_path, copy_files=False
)

if self.cfg.shuffle:
print("Shuffling...")
dataset = dataset.shuffle(seed=self.cfg.seed)
Expand All @@ -241,7 +168,7 @@ def run(self) -> Dataset:
print("Pushing to Huggingface Hub...")
dataset.push_to_hub(
repo_id=self.cfg.hf_repo_id,
num_shards=self.cfg.hf_num_shards or self.cfg.n_buffers,
num_shards=self.cfg.hf_num_shards,
private=self.cfg.hf_is_private_repo,
revision=self.cfg.hf_revision,
)
Expand All @@ -263,31 +190,3 @@ def run(self) -> Dataset:
)

return dataset

def _create_shard(
self,
buffer: Float[torch.Tensor, "(bs context_size) num_layers d_in"],
) -> Dataset:
hook_names = [self.cfg.hook_name]

buffer = einops.rearrange(
buffer,
"(bs context_size) num_layers d_in -> num_layers bs context_size d_in",
bs=self.cfg.rows_in_buffer,
context_size=self.context_size,
d_in=self.cfg.d_in,
num_layers=len(hook_names),
)
shard = Dataset.from_dict(
{hook_name: act for hook_name, act in zip(hook_names, buffer)},
features=self.features,
)
return shard

@staticmethod
def _get_sliced_context_size(
context_size: int, seqpos_slice: tuple[int | None, ...] | None
) -> int:
if seqpos_slice is not None:
context_size = len(range(context_size)[slice(*seqpos_slice)])
return context_size
Loading

0 comments on commit a1da04c

Please sign in to comment.