-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathpython.py
46 lines (36 loc) · 1.38 KB
/
python.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
with open('model.txt', 'w') as layers:
layers = [3,5,10,15,20,1]
layers_str = ["Input"] + ["GRU"] * (len(layers) - 2) + ["Output"]
layers_col = ["none"] + ["none"] * (len(layers) - 2) + ["none"]
layers_fill = ["black"] + ["gray"] * (len(layers) - 2) + ["black"]
penwidth = 15
font = "Hilda 10"
print("digraph G {")
print("\tfontname = \"{}\"".format(font))
print("\trankdir=LR")
print("\tsplines=line")
print("\tnodesep=.08;")
print("\tranksep=1;")
print("\tedge [color=black, arrowsize=.5];")
print("\tnode [fixedsize=true,label=\"\",style=filled," + \
"color=none,fillcolor=gray,shape=circle]\n")
# Clusters
for i in range(0, len(layers)):
print(("\tsubgraph cluster_{} {{".format(i)))
print(("\t\tcolor={};".format(layers_col[i])))
print(("\t\tnode [style=filled, color=white, penwidth={},"
"fillcolor={} shape=circle];".format(
penwidth,
layers_fill[i])))
print(("\t\t"), end=' ')
for a in range(layers[i]):
print("l{}{} ".format(i + 1, a), end=' ')
print(";")
print(("\t\tlabel = {};".format(layers_str[i])))
print("\t}\n")
# Nodes
for i in range(1, len(layers)):
for a in range(layers[i - 1]):
for b in range(layers[i]):
print("\tl{}{} -> l{}{}".format(i, a, i + 1, b))
print("}")