-
Notifications
You must be signed in to change notification settings - Fork 99
/
matplotlib_util.py
28 lines (24 loc) · 1.04 KB
/
matplotlib_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from convnet_drawer import *
import matplotlib.pyplot as plt
def save_model_to_file(model, filename):
model.build()
fig1 = plt.figure()
ax1 = fig1.add_subplot(111, aspect='equal')
ax1.axis('off')
plt.xlim(model.x, model.x + model.width)
plt.ylim(model.y + model.height, model.y)
for feature_map in model.feature_maps + model.layers:
for obj in feature_map.objects:
if isinstance(obj, Line):
if obj.dasharray == 1:
linestyle = ":"
elif obj.dasharray == 2:
linestyle = "--"
else:
linestyle = "-"
plt.plot([obj.x1, obj.x2], [obj.y1, obj.y2], color=[c / 255 for c in obj.color], lw=obj.width,
linestyle=linestyle)
elif isinstance(obj, Text):
ax1.text(obj.x, obj.y, obj.body, horizontalalignment="center", verticalalignment="bottom",
size=2 * obj.size / 3, color=[c / 255 for c in obj.color])
plt.savefig(filename)