-
Notifications
You must be signed in to change notification settings - Fork 0
/
frontend.py
239 lines (214 loc) · 8.05 KB
/
frontend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import json
import os
import pathlib
import pickle
import shutil
import tempfile
from uuid import uuid4
from dash import Input, Output, State, dcc
from dash_component_editor import JSONParameterEditor
from dotenv import load_dotenv
from file_manager.data_project import DataProject
from src.app_layout import DATA_DIR, TILED_KEY, USER, app, long_callback_manager
from src.callbacks.display import ( # noqa: F401
close_warning_modal,
open_warning_modal,
refresh_bottleneck,
refresh_image,
refresh_reconstruction,
update_slider_boundaries_new_dataset,
update_slider_boundaries_prediction,
)
from src.callbacks.download import disable_download, toggle_storage_modal # noqa: F401
from src.callbacks.execute import close_resources_popup, execute # noqa: F401
from src.callbacks.table import delete_row, update_table # noqa: F401
from src.utils.data_utils import get_input_params, prepare_directories
from src.utils.job_utils import MlexJob, str_to_dict
from src.utils.model_utils import get_gui_components, get_model_content
load_dotenv(".env")
APP_HOST = os.getenv("APP_HOST", "127.0.0.1")
APP_PORT = os.getenv("APP_PORT", "8072")
DIR_MOUNT = os.getenv("DIR_MOUNT", DATA_DIR)
server = app.server
@app.callback(
Output("app-parameters", "children"),
Input("model-selection", "value"),
Input("action", "value"),
prevent_intial_call=True,
)
def load_parameters(model_selection, action_selection):
"""
This callback dynamically populates the parameters and contents of the website according to the
selected action & model.
Args:
model_selection: Selected model (from content registry)
action_selection: Selected action (pre-defined actions in Data Clinic)
Returns:
app-parameters: Parameters according to the selected model & action
"""
parameters = get_gui_components(model_selection, action_selection)
gui_item = JSONParameterEditor(
_id={"type": str(uuid4())},
json_blob=parameters,
)
gui_item.init_callbacks(app)
return gui_item
@app.long_callback(
Output("download-out", "data"),
Input("download-button", "n_clicks"),
State("jobs-table", "data"),
State("jobs-table", "selected_rows"),
manager=long_callback_manager,
prevent_intial_call=True,
)
def save_results(download, job_data, row):
"""
This callback saves the experimental results as a ZIP file
Args:
download: Download button
job_data: Table of jobs
row: Selected job/row
Returns:
ZIP file with results
"""
if download and row:
experiment_id = job_data[row[0]]["experiment_id"]
experiment_path = pathlib.Path(f"{DATA_DIR}/mlex_store/{USER}/{experiment_id}")
with tempfile.TemporaryDirectory():
tmp_dir = tempfile.gettempdir()
archive_path = os.path.join(tmp_dir, "results")
shutil.make_archive(archive_path, "zip", experiment_path)
return dcc.send_file(f"{archive_path}.zip")
else:
return None
@app.long_callback(
Output("job-alert-confirm", "is_open"),
Input("submit", "n_clicks"),
State("app-parameters", "children"),
State("num-cpus", "value"),
State("num-gpus", "value"),
State("action", "value"),
State("jobs-table", "data"),
State("jobs-table", "selected_rows"),
State({"base_id": "file-manager", "name": "data-project-dict"}, "data"),
State("model-name", "value"),
State("model-selection", "value"),
State("log-transform", "value"),
State("min-max-percentile", "value"),
running=[(Output("job-alert", "is_open"), "True", "False")],
manager=long_callback_manager,
prevent_initial_call=True,
)
def submit_ml_job(
submit,
children,
num_cpus,
num_gpus,
action_selection,
job_data,
row,
data_project_dict,
model_name,
model_id,
log,
percentiles,
):
"""
This callback submits a job request to the compute service according to the selected action & model
Args:
execute: Execute button
submit: Submit button
children: Model parameters
num_cpus: Number of CPUs assigned to job
num_gpus: Number of GPUs assigned to job
action_selection: Action selected
job_data: Lists of jobs
row: Selected row (job)
data_project_dict: Data project information
model_name: Model name/description assigned by the user
model_id: UID of model in content registry
log: Log toggle
percentiles: Min-Max Percentile values
Returns:
open the alert indicating that the job was submitted
"""
data_project = DataProject.from_dict(data_project_dict, api_key=TILED_KEY)
model_uri, [train_cmd, prediction_cmd, tune_cmd] = get_model_content(model_id)
experiment_id, orig_out_path, data_info = prepare_directories(
USER,
data_project,
train=(action_selection != "prediction_model"),
correct_path=(DATA_DIR == DIR_MOUNT),
)
input_params = get_input_params(children)
input_params["log"] = log
input_params["percentiles"] = percentiles
kwargs = {}
# Find the relative data directory in docker container
if DIR_MOUNT == DATA_DIR:
relative_data_dir = "/app/work/data"
out_path = "/app/work/data" + str(orig_out_path).split(DATA_DIR, 1)[-1]
data_info = "/app/work/data" + str(data_info).split(DATA_DIR, 1)[-1]
else:
relative_data_dir = DATA_DIR
if action_selection == "train_model":
command = f"{train_cmd} -d {data_info} -o {out_path}"
elif action_selection == "tune_model":
training_exp_id = job_data[row[0]]["experiment_id"]
model_path = pathlib.Path(
f"{relative_data_dir}/mlex_store/{USER}/{training_exp_id}"
)
kwargs = {"train_params": job_data[row[0]]["parameters"]}
train_params = str_to_dict(job_data[row[0]]["parameters"])
# Get target size from training process
input_params["target_width"] = train_params["target_width"]
input_params["target_height"] = train_params["target_height"]
# Define command to run
command = f"{tune_cmd} -d {data_info} -m {model_path} -o {out_path}"
else:
training_exp_id = job_data[row[0]]["experiment_id"]
model_path = pathlib.Path(
f"{relative_data_dir}/mlex_store/{USER}/{training_exp_id}"
)
if job_data[row[0]]["job_type"] == "train_model":
train_params = job_data[row[0]]["parameters"]
else:
train_params = job_data[row[0]]["parameters"].split("Training Parameters:")[
-1
]
kwargs = {"train_params": train_params}
train_params = str_to_dict(train_params)
# Get target size from training process
input_params["target_width"] = train_params["target_width"]
input_params["target_height"] = train_params["target_height"]
# Define command to run
command = f"{prediction_cmd} -d {data_info} -m {model_path} -o {out_path}"
# Save data project dict
data_project_dict = data_project.to_dict()
with open(f"{orig_out_path}/.file_manager_vars.pkl", "wb") as file:
pickle.dump(
data_project_dict,
file,
)
job = MlexJob(
service_type="backend",
description=model_name,
working_directory=DIR_MOUNT,
job_kwargs={
"uri": model_uri,
"type": "docker",
"cmd": f"{command} -p '{json.dumps(input_params)}'",
"container_kwargs": {"shm_size": "2gb"},
"kwargs": {
"job_type": action_selection,
"experiment_id": experiment_id,
"dataset": data_project.project_id,
"params": input_params,
**kwargs,
},
},
)
job.submit(USER, num_cpus, num_gpus)
return True
if __name__ == "__main__":
app.run_server(debug=True, host=APP_HOST, port=APP_PORT)