This repository has been archived by the owner on May 24, 2018. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 414
/
symbol_converter.py
136 lines (120 loc) · 4.08 KB
/
symbol_converter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import re
from collections import defaultdict
import sys
import os
print("Note: Please remove all inplace setting in cxxnet configure file, only only support [a->b] format")
print("I only implmented some I needed, you have to implment what you need")
if len(sys.argv) < 2:
print ("usage: in.conf out.py")
LAYER_PATTERN = re.compile(r"layer\[(.*)->(.*)\]\s*=\s*(\w+):(\w*)")
# output: source ids, target ids, layer type, name
PARAM_PATTERN = re.compile(r"\s*(\w+)\s*=\s*(\w+)\s*")
# output: key, value
ID_PATTERN = re.compile(r"([^,]+)")
# output id
CONF_START_PATTERN = re.compile(r"\s*netconfig\s*=\s*start\s*")
CONF_END_PATTERN = re.compile(r"\s*netconfig\s*=\s*end\s*")
id2name = {"0":"data"}
name2def = {"data":"mx.symbol.Variable"}
symbol_param = defaultdict(list)
edge = defaultdict(list)
seq = ["data"]
last_name = "data"
def ParamFactory(key, value):
if key == "kernel_size":
return "kernel=(%s, %s)" % (value, value)
elif key == "nchannel":
return "num_filter=%s" % value
elif key == "pad":
return "pad=(%s, %s)" % (value, value)
elif key == "stride":
return "stride=(%s, %s)" % (value, value)
elif key == "nhidden":
return "num_hidden=%s" % value
else:
return "%s=%s" % (key, value)
def SymbolFactory(layer, name):
if layer == "conv":
return "mx.symbol.Convolution"
if layer == "max_pooling":
symbol_param[name].append("pool_type='max'")
return "mx.symbol.Pooling"
if layer == "avg_pooling":
symbol_param[name].append("pool_type='avg'")
return "mx.symbol.Pooling"
if layer == "relu":
symbol_param[name].append("act_type='relu'")
return "mx.symbol.Activation"
if layer == "rrelu":
symbol_param[name].append("act_type='rrelu'")
return "mx.symbol.LeakyReLU"
if layer == "batch_norm":
return "mx.symbol.BatchNorm"
if layer == "ch_concat":
return "mx.symbol.Concat"
if layer == "flatten":
return "mx.symbol.Flatten"
if layer == "fullc":
return "mx.symbol.FullyConnected"
if layer == "softmax":
return "mx.symbol.Softmax"
def InOutFactory(in_ids_str, out_ids_str, name):
in_ids = ID_PATTERN.findall(in_ids_str)
out_ids = ID_PATTERN.findall(out_ids_str)
# split
if len(in_ids) == 1 and len(out_ids) > 1:
for out_id in out_ids:
id2name[out_id] = id2name[in_ids[0]]
else:
# lazy
id2name[out_ids[0]] = name
seq.append(name)
for out_id in out_ids:
for in_id in in_ids:
try:
edge[id2name[out_id]].append(id2name[in_id])
except:
print id2name
raise ValueError("")
def SymbolBuilder(name):
sym = name2def[name]
# data
inputs = edge[name]
if len(inputs) == 0:
data = None
symbol_param[name].append("name='%s'" % name)
elif len(inputs) == 1:
data = "%s" % inputs[0]
else:
# concat
data = None
symbol_param[name].append("*[%s]" % (",".join(inputs)))
if data != None:
symbol_param[name].append("data=%s" % data)
params = ",".join(symbol_param[name])
cmd = "%s = %s(%s)" % (name, sym, params)
return cmd
in_conf_flag = False
fi = file(sys.argv[1])
for line in fi:
if CONF_START_PATTERN.match(line) != None:
in_conf_flag = True
continue
if CONF_END_PATTERN.match(line) != None:
in_conf_flag = False
if not in_conf_flag:
continue
if LAYER_PATTERN.match(line) != None:
in_ids_str, out_ids_str, layer, name = LAYER_PATTERN.findall(line)[0]
last_name = name
name2def[name] = SymbolFactory(layer, name)
InOutFactory(in_ids_str, out_ids_str, name)
symbol_param[name].append("name='%s'" % name)
elif PARAM_PATTERN.match(line) != None:
key, value = PARAM_PATTERN.findall(line)[0]
symbol_param[last_name].append(ParamFactory(key, value))
fo = open(sys.argv[2], "w")
fo.write("import mxnet as mx\n")
for name in seq:
fo.write(SymbolBuilder(name) + '\n')
fo.close()