-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathgenerate_acts.py
125 lines (109 loc) · 5.47 KB
/
generate_acts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch as t
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer, LlamaForCausalLM
import argparse
import pandas as pd
from tqdm import tqdm
import os
import configparser
import glob
config = configparser.ConfigParser()
config.read('config.ini')
class Hook:
def __init__(self):
self.out = None
def __call__(self, module, module_inputs, module_outputs):
self.out, _ = module_outputs
def load_model(model_family: str, model_size: str, model_type: str, device: str):
model_path = os.path.join(config[model_family]['weights_directory'],
config[model_family][f'{model_size}_{model_type}_subdir'])
try:
if model_family == 'Llama2':
tokenizer = LlamaTokenizer.from_pretrained(str(model_path))
model = LlamaForCausalLM.from_pretrained(str(model_path))
tokenizer.bos_token = '<s>'
else:
tokenizer = AutoTokenizer.from_pretrained(str(model_path))
model = AutoModelForCausalLM.from_pretrained(str(model_path))
if model_family == "Gemma2": # Gemma2 requires bfloat16 precision which is only available on new GPUs
model = model.to(t.bfloat16) # Convert the model to bfloat16 precision
else:
model = model.half() # storing model in float32 precision -> conversion to float16
return tokenizer, model.to(device)
except Exception as e:
print(f"Error loading model: {e}")
raise
def load_statements(dataset_name):
"""
Load statements from csv file, return list of strings.
"""
dataset = pd.read_csv(f"datasets/{dataset_name}.csv")
statements = dataset['statement'].tolist()
return statements
def get_acts(statements, tokenizer, model, layers, device):
"""
Get given layer activations for the statements.
Return dictionary of stacked activations.
"""
# attach hooks
hooks, handles = [], []
for layer in layers:
hook = Hook()
handle = model.model.layers[layer].register_forward_hook(hook)
hooks.append(hook), handles.append(handle)
# get activations
acts = {layer : [] for layer in layers}
for statement in tqdm(statements):
input_ids = tokenizer.encode(statement, return_tensors="pt").to(device)
model(input_ids)
for layer, hook in zip(layers, hooks):
acts[layer].append(hook.out[0, -1])
for layer, act in acts.items():
acts[layer] = t.stack(act).float()
# remove hooks
for handle in handles:
handle.remove()
return acts
if __name__ == "__main__":
"""
read statements from dataset, record activations in given layers, and save to specified files
"""
parser = argparse.ArgumentParser(description="Generate activations for statements in a dataset")
parser.add_argument("--model_family", default="Llama3", help="Model family to use. Options are Llama2, Llama3, Gemma, Gemma2 or Mistral.")
parser.add_argument("--model_size", default="8B",
help="Size of the model to use. Options for Llama3 are 8B or 70B")
parser.add_argument("--model_type", default="base", help="Whether to choose base or chat model. Options are base or chat.")
parser.add_argument("--layers", nargs='+',
help="Layers to save embeddings from.")
parser.add_argument("--datasets", nargs='+',
help="Names of datasets, without .csv extension")
parser.add_argument("--output_dir", default="acts",
help="Directory to save activations to.")
parser.add_argument("--device", default="cpu")
args = parser.parse_args()
datasets = args.datasets
if datasets == ['all_topic_specific']:
datasets = ['cities', 'sp_en_trans', 'inventors', 'animal_class', 'element_symb', 'facts',
'neg_cities', 'neg_sp_en_trans', 'neg_inventors', 'neg_animal_class', 'neg_element_symb', 'neg_facts',
'cities_conj', 'sp_en_trans_conj', 'inventors_conj', 'animal_class_conj', 'element_symb_conj', 'facts_conj',
'cities_disj', 'sp_en_trans_disj', 'inventors_disj', 'animal_class_disj', 'element_symb_disj', 'facts_disj',
'larger_than', 'smaller_than', "cities_de", "neg_cities_de", "sp_en_trans_de", "neg_sp_en_trans_de", "inventors_de", "neg_inventors_de", "animal_class_de",
"neg_animal_class_de", "element_symb_de", "neg_element_symb_de", "facts_de", "neg_facts_de"]
if datasets == ['all']:
datasets = []
for file_path in glob.glob('datasets/**/*.csv', recursive=True):
dataset_name = os.path.relpath(file_path, 'datasets').replace('.csv', '')
datasets.append(dataset_name)
t.set_grad_enabled(False)
tokenizer, model = load_model(args.model_family, args.model_size, args.model_type, args.device)
for dataset in datasets:
statements = load_statements(dataset)
layers = [int(layer) for layer in args.layers]
if layers == [-1]:
layers = list(range(len(model.model.layers)))
save_dir = f"{args.output_dir}/{args.model_family}/{args.model_size}/{args.model_type}/{dataset}/"
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for idx in range(0, len(statements), 25):
acts = get_acts(statements[idx:idx + 25], tokenizer, model, layers, args.device)
for layer, act in acts.items():
t.save(act, f"{save_dir}/layer_{layer}_{idx}.pt")