Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute activations other than MLP neurons #6

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 59 additions & 16 deletions lib/activations/activations/activations_computation.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,70 @@
from typing import Literal
from enum import Enum
from typing import Callable

import torch
from nnsight.envoy import Envoy # type: ignore
from nnsight.intervention import InterventionProxy # type: ignore
from util.subject import Subject


def get_activations_computing_func(subject: Subject, activation_type: Literal["MLP"], layer: int):
class ActivationType(str, Enum):
RESID = "resid"
MLP_IN = "mlp_in"
MLP_OUT = "mlp_out"
ATTN_OUT = "attn_out"
NEURONS = "neurons"


def _get_activations_funcs(
subject: Subject, activation_type: ActivationType, layer: int
) -> tuple[Callable[[], Envoy], Callable[[Envoy], InterventionProxy]]:
if activation_type == ActivationType.RESID:
return (
lambda: subject.layers[layer],
lambda component: component.output[0],
)
if activation_type == ActivationType.MLP_IN:
return (
lambda: subject.mlps[layer],
lambda component: component.input,
)
if activation_type == ActivationType.MLP_OUT:
return (
lambda: subject.mlps[layer],
lambda component: component.output,
)
if activation_type == ActivationType.ATTN_OUT:
return (
lambda: subject.attns[layer],
lambda component: component.output[0],
)
if activation_type == ActivationType.NEURONS:
return (
lambda: subject.w_outs[layer],
lambda component: component.input,
)
raise ValueError(f"Unknown activation type: {activation_type}")


def get_activations_computing_func(
subject: Subject, activation_type: ActivationType, layer: int
) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Returns a function that computes activations for a given input:
input_ids: torch.Tensor
attn_mask: torch.Tensor

"""
if activation_type == "MLP":
mlp_acts_for_layer = subject.w_outs[layer]

def get_mlp_activations(input_ids: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
with subject.model.trace(
{"input_ids": input_ids, "attention_mask": attn_mask} # type: ignore
):
acts = mlp_acts_for_layer.input.save()
return acts

return get_mlp_activations
else:
raise ValueError(f"Unknown activation type: {activation_type}")
get_component, get_activations = _get_activations_funcs(subject, activation_type, layer)

def activations_computing_func(
input_ids: torch.Tensor, attn_mask: torch.Tensor
) -> torch.Tensor:
with torch.no_grad():
with subject.model.trace(
{"input_ids": input_ids, "attention_mask": attn_mask} # type: ignore
):
acts: torch.Tensor = get_activations(get_component()).save() # type: ignore
return acts

return activations_computing_func
28 changes: 19 additions & 9 deletions lib/activations/activations/exemplars_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import matplotlib.pyplot as plt
import numpy as np
from activations.activations import ActivationRecord
from activations.activations_computation import ActivationType
from activations.dataset import (
ChatDataset,
HFDatasetWrapper,
Expand Down Expand Up @@ -321,7 +322,7 @@ class ExemplarConfig(BaseModel):
batch_size: int = 512
rand_seqs: int = 10
seed: int = 64
activation_type: Literal["MLP"] = "MLP"
activation_type: ActivationType = ActivationType.NEURONS


class ExemplarsWrapper:
Expand All @@ -347,6 +348,8 @@ def __init__(
if subject.is_chat_model:
folder_name_components.append("chat")
folder_name_components.append(f"{config.seq_len}seqlen")
if config.activation_type != "neurons":
folder_name_components.append(config.activation_type)
assert subject.tokenizer.padding_side == "left"

folder_name = "_".join(folder_name_components)
Expand Down Expand Up @@ -430,10 +433,7 @@ def load_layer_checkpoint(self, layer: int, split: ExemplarSplit) -> (
ExemplarSplit.RANDOM_TEST,
)

if self.config.activation_type == "MLP":
num_features = self.subject.I
else:
raise ValueError(f"Invalid activation type: {self.config.activation_type}")
num_features = self.num_features
num_top_feats_to_save = self.config.num_top_acts_to_save
k, seq_len = self.config.k, self.config.seq_len

Expand Down Expand Up @@ -496,10 +496,7 @@ def save_layer_checkpoint(
layer_dir = self.get_layer_dir(layer, split)
os.makedirs(layer_dir, exist_ok=True)

if self.config.activation_type == "MLP":
num_features = self.subject.I
else:
raise ValueError(f"Invalid activation type: {self.config.activation_type}")
num_features = self.num_features
num_top_feats_to_save = self.config.num_top_acts_to_save
k, seq_len = self.config.k, self.config.seq_len

Expand Down Expand Up @@ -883,6 +880,19 @@ def visualize_neuron_exemplars(
)
display(HTML(html_content)) # type: ignore

@property
def num_features(self) -> int:
if self.config.activation_type == ActivationType.NEURONS:
return self.subject.I
if self.config.activation_type in (
ActivationType.RESID,
ActivationType.MLP_IN,
ActivationType.MLP_OUT,
ActivationType.ATTN_OUT,
):
return self.subject.D
raise ValueError(f"Invalid activation type: {self.config.activation_type}")


###################
# Example Configs #
Expand Down
1 change: 1 addition & 0 deletions project/expgen/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
data/
24 changes: 20 additions & 4 deletions project/expgen/scripts/compute_exemplars.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
"""

import argparse
from typing import Any

from activations.dataset import fineweb_dset_config, lmsys_dset_config
from activations.activations_computation import ActivationType
from activations.dataset import HFDatasetWrapperConfig, fineweb_dset_config, lmsys_dset_config
from activations.exemplars import ExemplarSplit
from activations.exemplars_computation import (
compute_exemplars_for_layer,
Expand All @@ -15,12 +17,25 @@
from util.subject import Subject, get_subject_config

parser = argparse.ArgumentParser()
parser.add_argument(
"--activation_type",
type=str,
choices=[
ActivationType.RESID,
ActivationType.MLP_IN,
ActivationType.MLP_OUT,
ActivationType.ATTN_OUT,
ActivationType.NEURONS,
],
default="neurons",
help="Type of activations from which we pick indices to compute exemplars for.",
)
parser.add_argument(
"--layer_indices",
type=int,
nargs="+",
default=None,
help="Layers from which we pick neurons to compute exemplars for.",
help="Layers from which we pick indices to compute exemplars for.",
)
parser.add_argument(
"--subject_hf_model_id",
Expand Down Expand Up @@ -87,7 +102,7 @@
subject_config = get_subject_config(args.subject_hf_model_id)
subject = Subject(subject_config, nnsight_lm_kwargs={"dispatch": True})

hf_dataset_configs = []
hf_dataset_configs: list[HFDatasetWrapperConfig] = []
for hf_dataset in args.hf_datasets:
if hf_dataset == "fineweb":
hf_dataset_configs.append(fineweb_dset_config)
Expand All @@ -106,13 +121,14 @@
num_top_acts_to_save=args.num_top_acts_to_save,
batch_size=args.batch_size,
seed=args.seed,
activation_type=args.activation_type,
)
exemplars_wrapper = ExemplarsWrapper(args.data_dir, exemplar_config, subject)

layer_indices = args.layer_indices if args.layer_indices else range(subject.L)
for layer in layer_indices:
print(f"============ Layer {layer} ============")
kwargs = {
kwargs: dict[str, Any] = {
"exemplars_wrapper": exemplars_wrapper,
"layer": layer,
"split": ExemplarSplit(args.split),
Expand Down