-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval_vis_server.py
156 lines (131 loc) · 5.2 KB
/
eval_vis_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
from model import Model
from eval import *
from Doping.utils.Dataset import DataObj
from Doping.utils.utils import calculate_P
import argparse
from flask import Flask, request, send_from_directory
from flask_cors import CORS
from flask import render_template
import json
import glob
import os
import html
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
app = Flask(__name__, static_url_path='')
app.config.from_object(__name__)
CORS(app)
torch.no_grad()
parser = argparse.ArgumentParser()
"""
If matrix-path is provided, we are visualizing the X matrix
"""
parser.add_argument('--model-path', help='path to the .pt file')
parser.add_argument('--matrix-path', help='path to the X.json, L.json files')
parser.add_argument('--dataset-path', help='path to the folder containing multiple matrix path')
parser.add_argument('--port', type= int, default=8080, help='port')
parser.add_argument('--limit', type= int, default=50, help='visualize from t_0 to t_limit')
args = parser.parse_args()
def plot(folder):
image_name = "XPM.svg"
#immediately return if the image is previously built
if os.path.isfile(os.path.join(folder, image_name)):
return os.path.join(folder, image_name)
try:
with open(os.path.join(folder, "L.json"), "r") as f:
L = json.load(f)
with open(os.path.join(folder, "L_freq.json"), "r") as f:
L_freq = json.load(f)
with open(os.path.join(folder, "X00000.json"), "r") as f:
last_X = json.load(f)["X"]
except Exception as e:
return "error in reading either L.json, L_freq.json, or X0***.json"
last_P = calculate_P(last_X, L, L_freq)
fig, (ax0, ax1, ax2) = plt.subplots(1, 3)
im = ax0.imshow(np.asarray(last_X), interpolation = None)
im = ax1.imshow(np.asarray(last_P), interpolation = None)
im = ax2.imshow(np.asarray(last_X).astype(bool), interpolation = None)
plt.savefig(os.path.join(folder, image_name), dpi=1000, bbox_inches='tight')
fig.clf()
return os.path.join(folder, image_name)
def detailed_plot(folder):
try:
with open(os.path.join(folder, "L.json"), "r") as f:
L = json.load(f)
with open(os.path.join(folder, "L_freq.json"), "r") as f:
L_freq = json.load(f)
with open(os.path.join(folder, "X00000.json"), "r") as f:
last_X = json.load(f)["X"]
except Exception as e:
return "error in reading either L.json, L_freq.json, or X0***.json"
last_P = calculate_P(last_X, L, L_freq).tolist()
id2L = {}
for k in L:
idx = L[k]
id2L[idx] = html.escape(k)
json_vis_data = {"Xs": [last_X], "Ps": [last_P], "L_size": len(L), "no_of_X": 1, "L": id2L}
return json_vis_data
@app.route('/static/<path:path>')
def send_img(path):
return send_from_directory('', path)
if args.model_path is not None:
model, dataObj, model_metadata = setup_model(args.model_path)
@app.route('/vis/<h_index>', methods=['GET'])
def handle_vis(h_index):
json_vis_data = run(model, dataObj, model_metadata)
return render_template('vis.html', context=json.dumps(json_vis_data), h_index = h_index)
if args.matrix_path is not None:
@app.route('/vis', methods=['GET'])
def handle_vis():
Xs = glob.glob(args.matrix_path+"/X00*.json")
Xs = sorted(Xs)
Ps = glob.glob(args.matrix_path+"/P00*.json")
Ps = sorted(Ps)
#collect X data
Xs_data = []
limit = args.limit
limit = min(len(Xs), limit)
for X in Xs[:limit]:
print(X)
with open(X, "r") as f:
X_data = json.load(f)
X_data = X_data["X"]
Xs_data.append(X_data)
L_size = len(Xs_data[-1])
#collect P data
Ps_data = []
for P in Ps[:limit]:
print(P)
with open(P, "r") as f:
P_data = json.load(f)
P_data = P_data["P"]
Ps_data.append(P_data)
with open(os.path.join(args.matrix_path, "L.json"), "r") as f:
L = json.load(f)
id2L = {}
for k in L:
idx = L[k]
id2L[idx] = html.escape(k)
assert(len(Xs_data) == len(Ps_data))
json_vis_data = {"Xs": Xs_data, "Ps": Ps_data, "L_size": L_size, "no_of_X": len(Xs_data), "L": id2L}
return render_template('matrix_vis.html', context=json.dumps(json_vis_data))
if args.dataset_path is not None:
@app.route('/vis', methods=['GET'])
def handle_vis():
datapaths = glob.glob(args.dataset_path+"/*/ind_gen_files")
datapaths = sorted(datapaths)
json_vis_data = {}
for dp in datapaths[:args.limit]:
print(dp)
json_vis_data[dp] = plot(dp)
return render_template('dataset_vis.html', context=json.dumps(json_vis_data))
@app.route('/detailed_vis/<ind_gen_path>', methods=['GET'])
def handle_detailed_vis(ind_gen_path):
print(ind_gen_path)
ind_gen_path = ind_gen_path.replace("___", "/")
json_vis_data = detailed_plot(ind_gen_path)
return render_template('matrix_vis.html', context=json.dumps(json_vis_data))
if __name__ == '__main__':
app.run(host='0.0.0.0', port=args.port )