Skip to content

Commit 11b2e38

Browse files
authored
Merge pull request #57 from chanind/fix-slash-in-model-name-autointerp
fix: gracefully handle slashes in model filename for autointerp
2 parents 60579ed + 5d6464a commit 11b2e38

File tree

3 files changed

+84
-1
lines changed

3 files changed

+84
-1
lines changed

sae_bench/evals/autointerp/main.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def str_bool(b: bool) -> str:
5555
return "Y" if b else ""
5656

5757

58+
def escape_slash(s: str) -> str:
59+
return s.replace("/", "_")
60+
61+
5862
class Example:
5963
"""
6064
Data for a single example sequence.
@@ -523,7 +527,7 @@ def run_eval_single_sae(
523527

524528
os.makedirs(artifacts_folder, exist_ok=True)
525529

526-
tokens_filename = f"{config.model_name}_{config.total_tokens}_tokens_{config.llm_context_size}_ctx.pt"
530+
tokens_filename = f"{escape_slash(config.model_name)}_{config.total_tokens}_tokens_{config.llm_context_size}_ctx.pt"
527531
tokens_path = os.path.join(artifacts_folder, tokens_filename)
528532

529533
if os.path.exists(tokens_path):

tests/conftest.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import torch
23
from sae_lens import SAE
34
from transformer_lens import HookedTransformer
45
from transformers import GPT2LMHeadModel
@@ -21,6 +22,15 @@ def gpt2_l4_sae() -> SAE:
2122
)[0]
2223

2324

25+
@pytest.fixture
26+
def gpt2_l4_sae_sparsity() -> torch.Tensor:
27+
sparsity = SAE.from_pretrained(
28+
"gpt2-small-res-jb", "blocks.4.hook_resid_pre", device="cpu"
29+
)[2]
30+
assert sparsity is not None
31+
return sparsity
32+
33+
2434
@pytest.fixture
2535
def gpt2_l5_sae() -> SAE:
2636
return SAE.from_pretrained(
+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from pathlib import Path
2+
from unittest.mock import patch
3+
4+
import torch
5+
from sae_lens import SAE
6+
from transformer_lens import HookedTransformer
7+
8+
from sae_bench.evals.autointerp.eval_config import AutoInterpEvalConfig
9+
from sae_bench.evals.autointerp.main import run_eval_single_sae
10+
11+
12+
def test_run_eval_single_sae_saves_tokens_to_artifacts_folder(
13+
gpt2_l4_sae: SAE,
14+
gpt2_model: HookedTransformer,
15+
gpt2_l4_sae_sparsity: torch.Tensor,
16+
tmp_path: Path,
17+
):
18+
artifacts_folder = tmp_path / "artifacts"
19+
20+
config = AutoInterpEvalConfig(
21+
model_name="gpt2",
22+
dataset_name="roneneldan/TinyStories",
23+
n_latents=2,
24+
total_tokens=255,
25+
llm_context_size=128,
26+
)
27+
with patch("sae_bench.evals.autointerp.main.AutoInterp.run", return_value={}):
28+
run_eval_single_sae(
29+
config=config,
30+
sae=gpt2_l4_sae,
31+
model=gpt2_model,
32+
device="cpu",
33+
artifacts_folder=str(artifacts_folder),
34+
api_key="fake_api_key",
35+
sae_sparsity=gpt2_l4_sae_sparsity,
36+
)
37+
38+
assert (artifacts_folder / "gpt2_255_tokens_128_ctx.pt").exists()
39+
tokenized_dataset = torch.load(artifacts_folder / "gpt2_255_tokens_128_ctx.pt")
40+
assert tokenized_dataset.shape == (2, 128)
41+
42+
43+
def test_run_eval_single_sae_saves_handles_slash_in_model_name(
44+
gpt2_l4_sae: SAE,
45+
gpt2_model: HookedTransformer,
46+
gpt2_l4_sae_sparsity: torch.Tensor,
47+
tmp_path: Path,
48+
):
49+
artifacts_folder = tmp_path / "artifacts"
50+
51+
config = AutoInterpEvalConfig(
52+
model_name="openai/gpt2",
53+
dataset_name="roneneldan/TinyStories",
54+
n_latents=2,
55+
total_tokens=255,
56+
llm_context_size=128,
57+
)
58+
with patch("sae_bench.evals.autointerp.main.AutoInterp.run", return_value={}):
59+
run_eval_single_sae(
60+
config=config,
61+
sae=gpt2_l4_sae,
62+
model=gpt2_model,
63+
device="cpu",
64+
artifacts_folder=str(artifacts_folder),
65+
api_key="fake_api_key",
66+
sae_sparsity=gpt2_l4_sae_sparsity,
67+
)
68+
69+
assert (artifacts_folder / "openai_gpt2_255_tokens_128_ctx.pt").exists()

0 commit comments

Comments
 (0)