-
Notifications
You must be signed in to change notification settings - Fork 1
/
get_llama_layers_statistics.py
55 lines (41 loc) · 1.97 KB
/
get_llama_layers_statistics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import torch
import warnings
import pandas as pd
from tqdm import tqdm
from huggingface_hub import login
from itertools import combinations
from transformers import AutoModel
from utils import _get_nodes, _get_layer_kinds
warnings.filterwarnings("ignore")
def _save_llama_layers_locally():
""" Save llama layers locally """
login() # you should provide your own token
nodes_ = _get_nodes(llama=True)
for i, (node_name, node) in enumerate(nodes_):
model = AutoModel.from_pretrained(node)
for name, layer in model.state_dict().items():
splits = name.split('layers.')
if len(splits) == 1:
continue
print(name, '->', splits[-1].split('.')[0])
block_idx_ = int(splits[-1].split('.')[0])
if not os.path.exists(f'llama_blocks/{block_idx_}'):
os.makedirs(f'llama_blocks/{block_idx_}')
torch.save(layer, f'llama_blocks/{block_idx_}/{os.path.basename(node)}-{name}.pt')
if __name__ == '__main__':
nodes = _get_nodes(llama=True)
layer_kinds = _get_layer_kinds(llama=True)
res = pd.DataFrame(columns=['block_idx', 'layer_kind', 'mean', 'std', 'min', 'max'])
for block_idx in tqdm(range(32)):
files = os.listdir(f'llama_blocks/{block_idx}')
for kind in layer_kinds:
layer_data = [torch.load(os.path.join(f'llama_blocks/{block_idx}', f)) for f in files if kind in f]
dist_layer_data = torch.stack([(l1 - l2) for l1, l2 in combinations(layer_data, 2)])
res.loc[res.shape[0]] = {'block_idx': block_idx,
'layer_kind': kind,
'mean': dist_layer_data.mean().item(),
'std': dist_layer_data.std().item(),
'min': dist_layer_data.min().item(),
'max': dist_layer_data.max().item()}
res.to_csv('llama_layers_statistics.csv')