-
Notifications
You must be signed in to change notification settings - Fork 1
/
compress.py
62 lines (44 loc) · 1.8 KB
/
compress.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import copy
import brotli
import torch
import sys
from compression import quantize_tensor
from serialization import serialize_state_dict
def size(tensor):
return tensor.element_size() * tensor.nelement()
def compress_state_dict(state_dict):
quantization_config = state_dict["quantization_config"]
del state_dict["quantization_config"]
metadata = copy.deepcopy(state_dict["__meta"])
del state_dict["__meta"]
print("Compressing...")
compressed_state_dict = dict()
compressed_state_dict["quantization_config"] = quantization_config
total_uncompressed = 0
total_compressed = 0
for (key, tensor) in state_dict.items():
tensor_config = {
"bits": quantization_config[key]["bits"],
"bound": quantization_config[key]["bound"]
}
quantized_tensor = quantize_tensor(tensor, tensor_config["bits"], tensor_config["bound"])
array = quantized_tensor.cpu().numpy()
buffer = array.tobytes()
compressed = brotli.compress(buffer, lgwin=10)
compressed_state_dict[key] = compressed
uncompressed_size = size(tensor)
compressed_size = len(compressed)
total_uncompressed += uncompressed_size
total_compressed += compressed_size
print(f"Total uncompressed: {total_uncompressed}")
print(f"Total compressed: {total_compressed}")
print(f"Ratio: {total_uncompressed / total_compressed}")
compressed_state_dict["__meta"] = metadata
return compressed_state_dict
if __name__ == "__main__":
model_dump_path = sys.argv[1]
compressed_model_dump_path = sys.argv[2]
print("Loading state dict...")
state_dict = torch.load(model_dump_path)
compressed_state_dict = compress_state_dict(state_dict)
serialize_state_dict(compressed_state_dict, compressed_model_dump_path)