-
Notifications
You must be signed in to change notification settings - Fork 0
/
res2rknn.py
67 lines (57 loc) · 2.05 KB
/
res2rknn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import argparse
import logging
import os.path
import numpy as np
def set_logging():
logger = logging.getLogger()
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--path_in',
default='/media/manu/data/sdks/sigmastar/Tiramisu_DLS00V010-20220107/ipu/SGS_IPU_SDK_vQ_0.1.0/log/output/unknown_acfree_640_fixed.sim_sgsimg.img_students_lt.bmp.txt')
parser.add_argument('--dir_out', default='/home/manu/tmp')
return parser.parse_args()
def run(args):
logging.info(args)
with open(args.path_in, 'r') as f:
lines = f.readlines()
i = 0
db = dict()
while i < len(lines):
if 'Tensor:' in lines[i]:
logging.info(lines[i])
tensor_name, _ = lines[i].strip().split()
logging.info(tensor_name)
while 'Original shape:' not in lines[i]:
i += 1
tensor_shape_str = lines[i].strip().split(':')[-1][1:-1]
while 'tensor data:' not in lines[i]:
i += 1
i += 1
tensor_data = list()
while '}' not in lines[i]:
line_lst = lines[i].strip().split()
# logging.info(line_lst)
tensor_data.extend(line_lst)
i += 1
i += 1
db[tensor_name + ';' + tensor_shape_str] = tensor_data
i += 1
# logging.info(db)
for key in db.keys():
data = np.array(db[key]).astype('float')
tensor_name = key.split(';')[0]
path_out = os.path.join(args.dir_out, tensor_name + '.txt')
# np.save(path_out, data)
tensor_shape_str = key.split(';')[-1].split()
c = int(tensor_shape_str[-1])
wh = int((len(data) / c) ** 0.5)
logging.info((1, wh, wh, c, tensor_name))
data = np.transpose(data.reshape((1, wh, wh, c)), (0, 3, 1, 2))
np.savetxt(path_out, data.flatten(), fmt="%f", delimiter="\n")
def main():
set_logging()
args = parse_args()
run(args)
if __name__ == '__main__':
main()