From 77332e4db0389f98cb4c70c62c7794986ab5baa3 Mon Sep 17 00:00:00 2001 From: zhuowenzhao Date: Fri, 1 Mar 2024 16:11:57 -0800 Subject: [PATCH 01/21] :wrench: added new pkg required for using dash_component_editor --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 7a2fbcc..fac494c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,3 +33,4 @@ scipy dash-extensions==1.0.1 dash-bootstrap-components==1.5.0 dash_auth==2.0.0 +dash_daq==0.1.0 From 0a51ed006873a81bbcd9235aadffc7860f810d65 Mon Sep 17 00:00:00 2001 From: zhuowenzhao Date: Fri, 1 Mar 2024 16:13:30 -0800 Subject: [PATCH 02/21] :sparkles: added content registry example --- utils/content_registry.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 utils/content_registry.py diff --git a/utils/content_registry.py b/utils/content_registry.py new file mode 100644 index 0000000..4e1c12e --- /dev/null +++ b/utils/content_registry.py @@ -0,0 +1,30 @@ +import json +from copy import deepcopy + +class Models: + def __init__(self, modelfile_path='./assets/mode_description.json'): + self.path = modelfile_path + f = open('./assets/mode_description.json') + + self.contents = json.load(f)['contents'] + self.modelname_list = [content['model_name'] for content in self.contents] + self.models = {} + + for i, n in enumerate(self.modelname_list): + self.models[n] = self.contents[i] + + @staticmethod + def remove_key_from_dict_list(data, key): + new_data = [] + for item in data: + if key in item: + new_item = deepcopy(item) + new_item.pop(key) + new_data.append(new_item) + else: + new_data.append(item) + + return new_data + + +models = Models() \ No newline at end of file From c02246d189e927f11f3c323e91e60f74f08ee3dc Mon Sep 17 00:00:00 2001 From: zhuowenzhao Date: Fri, 1 Mar 2024 16:14:41 -0800 Subject: [PATCH 03/21] :sparkles: added a model description example --- assets/mode_description.json | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100755 assets/mode_description.json diff --git a/assets/mode_description.json b/assets/mode_description.json new file mode 100755 index 0000000..f394cdc --- /dev/null +++ b/assets/mode_description.json @@ -0,0 +1,6 @@ +{ + "contents":[ + {"model_name": "random_forest", "version": "1.0.0", "type": "supervised", "user": "mlexchange team", "uri": "xxx", "application": ["classification", "segmentation"], "description": "xxx", "gui_parameters": [{"type": "int", "name": "num-tree", "title": "Number of Trees", "param_key": "n_estimators", "value": "30"}, {"type": "int", "name": "tree-depth", "title": "Tree Depth", "param_key": "max_depth", "value": "8"}], "cmd": ["xxx"], "reference": "Adapted from Dash Plotly image segmentation example"}, + {"model_name": "kmeans", "version": "1.0.0", "type": "unsupervised", "user": "mlexchange team", "uri": "xxx", "application": ["segmentation", "clustering"], "description": "xxx", "gui_parameters": [{"type": "int", "name": "num-cluster", "title": "Number of Clusters", "param_key": "n_clusters", "value": "2"}, {"type": "int", "name": "num-iter", "title": "Max Iteration", "param_key": "max_iter", "value": "300"}], "cmd": ["xxx", "xxxx"], "reference": "Nicholas Schwartz & Howard Yanxon"} + ] +} \ No newline at end of file From a9b9b3f8468564e077492a690aec0a7ac567af7e Mon Sep 17 00:00:00 2001 From: zhuowenzhao Date: Fri, 1 Mar 2024 16:15:30 -0800 Subject: [PATCH 04/21] :sparkles: added automatic dash gui generator --- app.py | 21 ++ components/control_bar.py | 18 ++ components/dash_component_editor.py | 408 ++++++++++++++++++++++++++++ 3 files changed, 447 insertions(+) create mode 100644 components/dash_component_editor.py diff --git a/app.py b/app.py index c8514d9..b14303d 100644 --- a/app.py +++ b/app.py @@ -10,6 +10,9 @@ from components.control_bar import layout as control_bar_layout from components.image_viewer import layout as image_viewer_layout +from utils.content_registry import models +from components.dash_component_editor import JSONParameterEditor + USER_NAME = os.getenv("USER_NAME") USER_PASSWORD = os.getenv("USER_PASSWORD") @@ -31,8 +34,26 @@ control_bar_layout(), image_viewer_layout(), dcc.Store(id="current-class-selection", data="#FFA200"), + dcc.Store(id="gui-components-values", data={}) ], ) +### automatic Dash gui callback ### +@callback( + Output("gui-layouts", "children"), + Input("model-list", "value"), +) +def update_gui_parameters(model_name): + data = models.models[model_name] + if data["gui_parameters"]: + item_list = JSONParameterEditor( _id={'type': str(uuid.uuid4())}, # pattern match _id (base id), name + json_blob=models.remove_key_from_dict_list(data["gui_parameters"], "comp_group"), + ) + item_list.init_callbacks(app) + return [html.H4("Model Parameters"), item_list] + else: + return[""] + + if __name__ == "__main__": app.run_server(debug=True) diff --git a/components/control_bar.py b/components/control_bar.py index b4ab2b0..08c65ab 100644 --- a/components/control_bar.py +++ b/components/control_bar.py @@ -7,6 +7,7 @@ from components.annotation_class import annotation_class_item from constants import ANNOT_ICONS, KEYBINDS from utils.data_utils import tiled_dataset +from utils.content_registry import models def _tooltip(text, children): @@ -603,6 +604,23 @@ def layout(): "run-model", id="model-configuration", children=[ + _control_item( + "Model Selection", + "model-selector", + dmc.Select( + id="model-list", + data=models.modelname_list, + value=( + models.modelname_list[0] + if models.modelname_list[0] + else None + ), + placeholder="Select an model...", + ), + ), + dmc.Space(h=25), + html.Div(id="gui-layouts"), + dmc.Space(h=25), dmc.Center( dmc.Button( "Run model", diff --git a/components/dash_component_editor.py b/components/dash_component_editor.py new file mode 100644 index 0000000..aecc456 --- /dev/null +++ b/components/dash_component_editor.py @@ -0,0 +1,408 @@ +import re +from typing import Callable +# noinspection PyUnresolvedReferences +from inspect import signature, _empty + +from dash import html, dcc, dash_table, Input, Output, State, MATCH, ALL +import dash_bootstrap_components as dbc +import dash_daq as daq + +import base64 +#import PIL.Image +import io +#import plotly.express as px +# Procedural dash form generation + + +""" +{'name', 'title', 'value', 'type', +""" + + +class SimpleItem(dbc.Col): + def __init__(self, + name, + base_id, + title=None, + param_key=None, + type='number', + debounce=True, + **kwargs): + + if param_key == None: + param_key = name + self.label = dbc.Label(title) + self.input = dbc.Input(type=type, + debounce=debounce, + id={**base_id, + 'name': name, + 'param_key': param_key}, + **kwargs) + + super(SimpleItem, self).__init__(children=[self.label, self.input]) + + +class FloatItem(SimpleItem): + pass + + +class IntItem(SimpleItem): + def __init__(self, *args, **kwargs): + if 'min' not in kwargs: + kwargs['min'] = -9007199254740991 + super(IntItem, self).__init__(*args, step=1, **kwargs) + + +class StrItem(SimpleItem): + def __init__(self, *args, **kwargs): + super(StrItem, self).__init__(*args, type='text', **kwargs) + + +class SliderItem(dbc.Col): + def __init__(self, + name, + base_id, + title=None, + param_key=None, + debounce=True, + visible=True, + **kwargs): + + if param_key == None: + param_key = name + self.label = dbc.Label(title) + self.input = dcc.Slider(id={**base_id, + 'name': name, + 'param_key': param_key, + 'layer': 'input'}, + tooltip={"placement": "bottom", "always_visible": True}, + **kwargs) + + style = {} + if not visible: + style['display'] = 'none' + + super(SliderItem, self).__init__(id={**base_id, + 'name': name, + 'param_key': param_key, + 'layer': 'form_group'}, + children=[self.label, self.input], + style=style) + + +class DropdownItem(dbc.Col): + def __init__(self, + name, + base_id, + title=None, + param_key=None, + debounce=True, + visible=True, + **kwargs): + + if param_key == None: + param_key = name + self.label = dbc.Label(title) + self.input = dcc.Dropdown(id={**base_id, + 'name': name, + 'param_key': param_key, + 'layer': 'input'}, + **kwargs) + + style = {} + if not visible: + style['display'] = 'none' + + super(DropdownItem, self).__init__(id={**base_id, + 'name': name, + 'param_key': param_key, + 'layer': 'form_group'}, + children=[self.label, self.input], + style=style) + + +class RadioItem(dbc.Col): + def __init__(self, + name, + base_id, + title=None, + param_key=None, + visible=True, + **kwargs): + + if param_key == None: + param_key = name + self.label = dbc.Label(title) + self.input = dbc.RadioItems(id={**base_id, + 'name': name, + 'param_key': param_key, + 'layer': 'input'}, + **kwargs) + + style = {} + if not visible: + style['display'] = 'none' + + super(RadioItem, self).__init__(id={**base_id, + 'name': name, + 'param_key': param_key, + 'layer': 'form_group'}, + children=[self.label, self.input], + style=style) + + +class BoolItem(dbc.Col): + def __init__(self, + name, + base_id, + title=None, + param_key=None, + visible=True, + **kwargs): + + if param_key == None: + param_key = name + self.label = dbc.Label(title) + self.input = daq.ToggleSwitch(id={**base_id, + 'name': name, + 'param_key': param_key, + 'layer': 'input'}, + **kwargs) + self.output_label = dbc.Label('False/True') + + style = {} + if not visible: + style['display'] = 'none' + + super(BoolItem, self).__init__(id={**base_id, + 'name': name, + 'param_key': param_key, + 'layer': 'form_group'}, + children=[self.label, self.input, self.output_label], + style=style) + + +class ImgItem(dbc.Col): + def __init__(self, + name, + src, + base_id, + title=None, + param_key=None, + width='100px', + visible=True, + **kwargs): + + if param_key == None: + param_key = name + + if not (width.endswith('px') or width.endswith('%')): + width = width + 'px' + + self.label = dbc.Label(title) + + encoded_image = base64.b64encode(open(src, 'rb').read()) + self.src = 'data:image/png;base64,{}'.format(encoded_image.decode()) + self.input_img = html.Img(id={**base_id, + 'name': name, + 'param_key': param_key, + 'layer': 'input'}, + src=self.src, + style={'height':'auto', 'width':width}, + **kwargs) + + style = {} + if not visible: + style['display'] = 'none' + + super(ImgItem, self).__init__(id={**base_id, + 'name': name, + 'param_key': param_key, + 'layer': 'form_group'}, + children=[self.label, self.input_img], + style=style) + + +# class GraphItem(dbc.Col): +# def __init__(self, +# name, +# base_id, +# title=None, +# param_key=None, +# visible=True, +# figure = None, +# **kwargs): +# +# self.name = name +# if param_key == None: +# param_key = name +# self.label = dbc.Label(title) +# self.input_graph = dcc.Graph(id={**base_id, +# 'name': name, +# 'param_key': param_key, +# 'layer': 'input'}, +# **kwargs) +# +# self.input_upload = dcc.Upload(id={**base_id, +# 'name': name+'_upload', +# 'param_key': param_key, +# 'layer': 'input'}, +# children=html.Div([ +# 'Drag and Drop or ', +# html.A('Select Files') +# ]), +# style={ +# 'width': '95%', +# 'height': '60px', +# 'lineHeight': '60px', +# 'borderWidth': '1px', +# 'borderStyle': 'dashed', +# 'borderRadius': '5px', +# 'textAlign': 'center', +# 'margin': '10px' +# }, +# multiple = False) +# +# style = {} +# if not visible: +# style['display'] = 'none' +# +# super(GraphItem, self).__init__(id={**base_id, +# 'name': name, +# 'param_key': param_key, +# 'layer': 'form_group'}, +# children=[self.label, self.input_upload, self.input_graph], +# style=style) +# +# # Issue: cannot get inputs from the callback decorator +# def return_upload(self, *args): +# print(f'before if, args {args}') +# if args: +# print(f'args {args}') +# img_bytes = base64.b64decode(contents.split(",")[1]) +# img = PIL.Image.open(io.BytesIO(img_bytes)) +# fig = px.imshow(img, binary_string=True) +# return fig +# +# def init_callbacks(self, app): +# app.callback(Output({**self.id, +# 'name': self.name, +# 'layer': 'input'}, 'figure', allow_duplicate=True), +# Input({**self.id, +# 'name': self.name+'_upload', +# 'layer': 'input'}, +# 'contents'), +# State({**self.id, +# 'name': self.name+'_upload', +# 'layer': 'input'}, 'last_modified'), +# State({**self.id, +# 'name': self.name+'_upload', +# 'layer': 'input'}, 'filename'), +# prevent_initial_call=True)(self.return_upload()) + + + +class ParameterEditor(dbc.Form): + + type_map = {float: FloatItem, + int: IntItem, + str: StrItem, + } + + def __init__(self, _id, parameters, **kwargs): + self._parameters = parameters + + super(ParameterEditor, self).__init__(id=_id, children=[], className='kwarg-editor', **kwargs) + self.children = self.build_children() + + def init_callbacks(self, app): + app.callback(Output(self.id, 'n_submit'), + Input({**self.id, + 'name': ALL}, + 'value'), + State(self.id, 'n_submit'), + ) + + for child in self.children: + if hasattr(child,"init_callbacks"): + child.init_callbacks(app) + + + @property + def values(self): + return {param['name']: param.get('value', None) for param in self._parameters} + + @property + def parameters(self): + return {param['name']: param for param in self._parameters} + + def _determine_type(self, parameter_dict): + if 'type' in parameter_dict: + if parameter_dict['type'] in self.type_map: + return parameter_dict['type'] + elif parameter_dict['type'].__name__ in self.type_map: + return parameter_dict['type'].__name__ + elif type(parameter_dict['value']) in self.type_map: + return type(parameter_dict['value']) + raise TypeError(f'No item type could be determined for this parameter: {parameter_dict}') + + def build_children(self, values=None): + children = [] + for parameter_dict in self._parameters: + parameter_dict = parameter_dict.copy() + if values and parameter_dict['name'] in values: + parameter_dict['value'] = values[parameter_dict['name']] + type = self._determine_type(parameter_dict) + parameter_dict.pop('type', None) + item = self.type_map[type](**parameter_dict, base_id=self.id) + children.append(item) + + return children + + +class JSONParameterEditor(ParameterEditor): + type_map = {'float': FloatItem, + 'int': IntItem, + 'str': StrItem, + 'slider': SliderItem, + 'dropdown': DropdownItem, + 'radio': RadioItem, + 'bool': BoolItem, + 'img': ImgItem, + #'graph': GraphItem, + } + + def __init__(self, _id, json_blob, **kwargs): + super(ParameterEditor, self).__init__(id=_id, children=[], className='kwarg-editor', **kwargs) + self._json_blob = json_blob + self.children = self.build_children() + + def build_children(self, values=None): + children = [] + for json_record in self._json_blob: + ... + # build a parameter dict from self.json_blob + ... + type = json_record.get('type', self._determine_type(json_record)) + json_record = json_record.copy() + if values and json_record['name'] in values: + json_record['value'] = values[json_record['name']] + json_record.pop('type', None) + item = self.type_map[type](**json_record, base_id=self.id) + children.append(item) + + return children + + +class KwargsEditor(ParameterEditor): + def __init__(self, instance_index, func: Callable, **kwargs): + self.func = func + self._instance_index = instance_index + + parameters = [{'name': name, 'value': param.default} for name, param in signature(func).parameters.items() + if param.default is not _empty] + + super(KwargsEditor, self).__init__(dict(index=instance_index, type='kwargs-editor'), parameters=parameters, **kwargs) + + def new_record(self): + return {name: p.default for name, p in signature(self.func).parameters.items() if p.default is not _empty} From 28b2edcd5b0ec3747214f00da502b5e6a1776ef2 Mon Sep 17 00:00:00 2001 From: zhuowenzhao Date: Fri, 1 Mar 2024 16:16:24 -0800 Subject: [PATCH 05/21] :sparkles: added callback to retrieve model paramters from gui --- callbacks/segmentation.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/callbacks/segmentation.py b/callbacks/segmentation.py index 69a206c..588c1ce 100644 --- a/callbacks/segmentation.py +++ b/callbacks/segmentation.py @@ -57,13 +57,15 @@ @callback( Output("output-details", "children"), - Output("submitted-job-id", "data"), + Output("submitted-job-id", "data"), + Output("gui-components-values", "data"), Input("run-model", "n_clicks"), State("annotation-store", "data"), State({"type": "annotation-class-store", "index": ALL}, "data"), State("project-name-src", "value"), + State("gui-layouts", "children") ) -def run_job(n_clicks, global_store, all_annotations, project_name): +def run_job(n_clicks, global_store, all_annotations, project_name, children): """ This callback collects parameters from the UI and submits a job to the computing api. If the app is run from "dev" mode, then only a placeholder job_uid will be created. @@ -72,7 +74,17 @@ def run_job(n_clicks, global_store, all_annotations, project_name): # TODO: Appropriately paramaterize the DEMO_WORKFLOW json depending on user inputs and relevant file paths """ + input_params = {} if n_clicks: + if len(children) >= 2: + params = children[1] + for param in params['props']['children']: + key = param["props"]["children"][1]["props"]["id"]["param_key"] + value = param["props"]["children"][1]["props"]["value"] + input_params[key] = value + + # return the input values in dictionary and saved to dcc.Store "gui-components-values" + print(f'input_param:\n{input_params}') if MODE == "dev": job_uid = str(uuid.uuid4()) return ( @@ -81,6 +93,7 @@ def run_job(n_clicks, global_store, all_annotations, project_name): size="sm", ), job_uid, + input_params ) else: @@ -98,6 +111,7 @@ def run_job(n_clicks, global_store, all_annotations, project_name): size="sm", ), job_uid, + input_params ) else: return ( @@ -106,8 +120,9 @@ def run_job(n_clicks, global_store, all_annotations, project_name): size="sm", ), job_uid, + input_params ) - return no_update, no_update + return no_update, no_update, input_params @callback( From a7bd9267e43b83044c7bded7fed308cd93f17f2c Mon Sep 17 00:00:00 2001 From: zhuowenzhao Date: Fri, 1 Mar 2024 16:30:56 -0800 Subject: [PATCH 06/21] :wrench: cleaned lb import --- components/dash_component_editor.py | 82 +---------------------------- 1 file changed, 2 insertions(+), 80 deletions(-) diff --git a/components/dash_component_editor.py b/components/dash_component_editor.py index aecc456..7fa677a 100644 --- a/components/dash_component_editor.py +++ b/components/dash_component_editor.py @@ -3,7 +3,7 @@ # noinspection PyUnresolvedReferences from inspect import signature, _empty -from dash import html, dcc, dash_table, Input, Output, State, MATCH, ALL +from dash import html, dcc, Input, Output, State, ALL import dash_bootstrap_components as dbc import dash_daq as daq @@ -14,6 +14,7 @@ # Procedural dash form generation + """ {'name', 'title', 'value', 'type', """ @@ -223,85 +224,6 @@ def __init__(self, style=style) -# class GraphItem(dbc.Col): -# def __init__(self, -# name, -# base_id, -# title=None, -# param_key=None, -# visible=True, -# figure = None, -# **kwargs): -# -# self.name = name -# if param_key == None: -# param_key = name -# self.label = dbc.Label(title) -# self.input_graph = dcc.Graph(id={**base_id, -# 'name': name, -# 'param_key': param_key, -# 'layer': 'input'}, -# **kwargs) -# -# self.input_upload = dcc.Upload(id={**base_id, -# 'name': name+'_upload', -# 'param_key': param_key, -# 'layer': 'input'}, -# children=html.Div([ -# 'Drag and Drop or ', -# html.A('Select Files') -# ]), -# style={ -# 'width': '95%', -# 'height': '60px', -# 'lineHeight': '60px', -# 'borderWidth': '1px', -# 'borderStyle': 'dashed', -# 'borderRadius': '5px', -# 'textAlign': 'center', -# 'margin': '10px' -# }, -# multiple = False) -# -# style = {} -# if not visible: -# style['display'] = 'none' -# -# super(GraphItem, self).__init__(id={**base_id, -# 'name': name, -# 'param_key': param_key, -# 'layer': 'form_group'}, -# children=[self.label, self.input_upload, self.input_graph], -# style=style) -# -# # Issue: cannot get inputs from the callback decorator -# def return_upload(self, *args): -# print(f'before if, args {args}') -# if args: -# print(f'args {args}') -# img_bytes = base64.b64decode(contents.split(",")[1]) -# img = PIL.Image.open(io.BytesIO(img_bytes)) -# fig = px.imshow(img, binary_string=True) -# return fig -# -# def init_callbacks(self, app): -# app.callback(Output({**self.id, -# 'name': self.name, -# 'layer': 'input'}, 'figure', allow_duplicate=True), -# Input({**self.id, -# 'name': self.name+'_upload', -# 'layer': 'input'}, -# 'contents'), -# State({**self.id, -# 'name': self.name+'_upload', -# 'layer': 'input'}, 'last_modified'), -# State({**self.id, -# 'name': self.name+'_upload', -# 'layer': 'input'}, 'filename'), -# prevent_initial_call=True)(self.return_upload()) - - - class ParameterEditor(dbc.Form): type_map = {float: FloatItem, From 10dab134c4151eb2ee4f48b587790dcde7dc7e4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Mon, 4 Mar 2024 17:55:28 -0800 Subject: [PATCH 07/21] Apply `pre-commit run --all files` --- app.py | 19 +- assets/mode_description.json | 6 +- callbacks/segmentation.py | 20 +- components/control_bar.py | 2 +- components/dash_component_editor.py | 389 ++++++++++++++-------------- utils/content_registry.py | 21 +- 6 files changed, 231 insertions(+), 226 deletions(-) diff --git a/app.py b/app.py index b14303d..7b9538c 100644 --- a/app.py +++ b/app.py @@ -8,10 +8,9 @@ from callbacks.image_viewer import * # noqa: F403, F401 from callbacks.segmentation import * # noqa: F403, F401 from components.control_bar import layout as control_bar_layout +from components.dash_component_editor import JSONParameterEditor from components.image_viewer import layout as image_viewer_layout - from utils.content_registry import models -from components.dash_component_editor import JSONParameterEditor USER_NAME = os.getenv("USER_NAME") USER_PASSWORD = os.getenv("USER_PASSWORD") @@ -34,10 +33,11 @@ control_bar_layout(), image_viewer_layout(), dcc.Store(id="current-class-selection", data="#FFA200"), - dcc.Store(id="gui-components-values", data={}) + dcc.Store(id="gui-components-values", data={}), ], ) + ### automatic Dash gui callback ### @callback( Output("gui-layouts", "children"), @@ -45,14 +45,17 @@ ) def update_gui_parameters(model_name): data = models.models[model_name] - if data["gui_parameters"]: - item_list = JSONParameterEditor( _id={'type': str(uuid.uuid4())}, # pattern match _id (base id), name - json_blob=models.remove_key_from_dict_list(data["gui_parameters"], "comp_group"), - ) + if data["gui_parameters"]: + item_list = JSONParameterEditor( + _id={"type": str(uuid.uuid4())}, # pattern match _id (base id), name + json_blob=models.remove_key_from_dict_list( + data["gui_parameters"], "comp_group" + ), + ) item_list.init_callbacks(app) return [html.H4("Model Parameters"), item_list] else: - return[""] + return [""] if __name__ == "__main__": diff --git a/assets/mode_description.json b/assets/mode_description.json index f394cdc..7c95a20 100755 --- a/assets/mode_description.json +++ b/assets/mode_description.json @@ -1,6 +1,6 @@ -{ +{ "contents":[ {"model_name": "random_forest", "version": "1.0.0", "type": "supervised", "user": "mlexchange team", "uri": "xxx", "application": ["classification", "segmentation"], "description": "xxx", "gui_parameters": [{"type": "int", "name": "num-tree", "title": "Number of Trees", "param_key": "n_estimators", "value": "30"}, {"type": "int", "name": "tree-depth", "title": "Tree Depth", "param_key": "max_depth", "value": "8"}], "cmd": ["xxx"], "reference": "Adapted from Dash Plotly image segmentation example"}, {"model_name": "kmeans", "version": "1.0.0", "type": "unsupervised", "user": "mlexchange team", "uri": "xxx", "application": ["segmentation", "clustering"], "description": "xxx", "gui_parameters": [{"type": "int", "name": "num-cluster", "title": "Number of Clusters", "param_key": "n_clusters", "value": "2"}, {"type": "int", "name": "num-iter", "title": "Max Iteration", "param_key": "max_iter", "value": "300"}], "cmd": ["xxx", "xxxx"], "reference": "Nicholas Schwartz & Howard Yanxon"} - ] -} \ No newline at end of file + ] +} diff --git a/callbacks/segmentation.py b/callbacks/segmentation.py index 588c1ce..c477bcb 100644 --- a/callbacks/segmentation.py +++ b/callbacks/segmentation.py @@ -57,13 +57,13 @@ @callback( Output("output-details", "children"), - Output("submitted-job-id", "data"), + Output("submitted-job-id", "data"), Output("gui-components-values", "data"), Input("run-model", "n_clicks"), State("annotation-store", "data"), State({"type": "annotation-class-store", "index": ALL}, "data"), State("project-name-src", "value"), - State("gui-layouts", "children") + State("gui-layouts", "children"), ) def run_job(n_clicks, global_store, all_annotations, project_name, children): """ @@ -78,13 +78,13 @@ def run_job(n_clicks, global_store, all_annotations, project_name, children): if n_clicks: if len(children) >= 2: params = children[1] - for param in params['props']['children']: - key = param["props"]["children"][1]["props"]["id"]["param_key"] + for param in params["props"]["children"]: + key = param["props"]["children"][1]["props"]["id"]["param_key"] value = param["props"]["children"][1]["props"]["value"] input_params[key] = value - - # return the input values in dictionary and saved to dcc.Store "gui-components-values" - print(f'input_param:\n{input_params}') + + # return the input values in dictionary and saved to dcc.Store "gui-components-values" + print(f"input_param:\n{input_params}") if MODE == "dev": job_uid = str(uuid.uuid4()) return ( @@ -93,7 +93,7 @@ def run_job(n_clicks, global_store, all_annotations, project_name, children): size="sm", ), job_uid, - input_params + input_params, ) else: @@ -111,7 +111,7 @@ def run_job(n_clicks, global_store, all_annotations, project_name, children): size="sm", ), job_uid, - input_params + input_params, ) else: return ( @@ -120,7 +120,7 @@ def run_job(n_clicks, global_store, all_annotations, project_name, children): size="sm", ), job_uid, - input_params + input_params, ) return no_update, no_update, input_params diff --git a/components/control_bar.py b/components/control_bar.py index 08c65ab..c8f6bf9 100644 --- a/components/control_bar.py +++ b/components/control_bar.py @@ -6,8 +6,8 @@ from components.annotation_class import annotation_class_item from constants import ANNOT_ICONS, KEYBINDS -from utils.data_utils import tiled_dataset from utils.content_registry import models +from utils.data_utils import tiled_dataset def _tooltip(text, children): diff --git a/components/dash_component_editor.py b/components/dash_component_editor.py index 7fa677a..0ab0f14 100644 --- a/components/dash_component_editor.py +++ b/components/dash_component_editor.py @@ -1,44 +1,46 @@ +import base64 + +# import PIL.Image +import io import re -from typing import Callable + # noinspection PyUnresolvedReferences -from inspect import signature, _empty +from inspect import _empty, signature +from typing import Callable -from dash import html, dcc, Input, Output, State, ALL import dash_bootstrap_components as dbc import dash_daq as daq +from dash import ALL, Input, Output, State, dcc, html -import base64 -#import PIL.Image -import io -#import plotly.express as px +# import plotly.express as px # Procedural dash form generation - """ -{'name', 'title', 'value', 'type', +{'name', 'title', 'value', 'type', """ -class SimpleItem(dbc.Col): - def __init__(self, - name, - base_id, - title=None, - param_key=None, - type='number', - debounce=True, - **kwargs): - +class SimpleItem(dbc.Col): + def __init__( + self, + name, + base_id, + title=None, + param_key=None, + type="number", + debounce=True, + **kwargs, + ): if param_key == None: param_key = name self.label = dbc.Label(title) - self.input = dbc.Input(type=type, - debounce=debounce, - id={**base_id, - 'name': name, - 'param_key': param_key}, - **kwargs) + self.input = dbc.Input( + type=type, + debounce=debounce, + id={**base_id, "name": name, "param_key": param_key}, + **kwargs, + ) super(SimpleItem, self).__init__(children=[self.label, self.input]) @@ -49,253 +51,241 @@ class FloatItem(SimpleItem): class IntItem(SimpleItem): def __init__(self, *args, **kwargs): - if 'min' not in kwargs: - kwargs['min'] = -9007199254740991 + if "min" not in kwargs: + kwargs["min"] = -9007199254740991 super(IntItem, self).__init__(*args, step=1, **kwargs) class StrItem(SimpleItem): def __init__(self, *args, **kwargs): - super(StrItem, self).__init__(*args, type='text', **kwargs) + super(StrItem, self).__init__(*args, type="text", **kwargs) class SliderItem(dbc.Col): - def __init__(self, - name, - base_id, - title=None, - param_key=None, - debounce=True, - visible=True, - **kwargs): - + def __init__( + self, + name, + base_id, + title=None, + param_key=None, + debounce=True, + visible=True, + **kwargs, + ): if param_key == None: param_key = name self.label = dbc.Label(title) - self.input = dcc.Slider(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'input'}, - tooltip={"placement": "bottom", "always_visible": True}, - **kwargs) + self.input = dcc.Slider( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + tooltip={"placement": "bottom", "always_visible": True}, + **kwargs, + ) style = {} if not visible: - style['display'] = 'none' + style["display"] = "none" - super(SliderItem, self).__init__(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'form_group'}, - children=[self.label, self.input], - style=style) + super(SliderItem, self).__init__( + id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, + children=[self.label, self.input], + style=style, + ) class DropdownItem(dbc.Col): - def __init__(self, - name, - base_id, - title=None, - param_key=None, - debounce=True, - visible=True, - **kwargs): - + def __init__( + self, + name, + base_id, + title=None, + param_key=None, + debounce=True, + visible=True, + **kwargs, + ): if param_key == None: param_key = name self.label = dbc.Label(title) - self.input = dcc.Dropdown(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'input'}, - **kwargs) + self.input = dcc.Dropdown( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + **kwargs, + ) style = {} if not visible: - style['display'] = 'none' + style["display"] = "none" - super(DropdownItem, self).__init__(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'form_group'}, - children=[self.label, self.input], - style=style) + super(DropdownItem, self).__init__( + id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, + children=[self.label, self.input], + style=style, + ) class RadioItem(dbc.Col): - def __init__(self, - name, - base_id, - title=None, - param_key=None, - visible=True, - **kwargs): - + def __init__( + self, name, base_id, title=None, param_key=None, visible=True, **kwargs + ): if param_key == None: param_key = name self.label = dbc.Label(title) - self.input = dbc.RadioItems(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'input'}, - **kwargs) + self.input = dbc.RadioItems( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + **kwargs, + ) style = {} if not visible: - style['display'] = 'none' + style["display"] = "none" - super(RadioItem, self).__init__(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'form_group'}, - children=[self.label, self.input], - style=style) + super(RadioItem, self).__init__( + id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, + children=[self.label, self.input], + style=style, + ) class BoolItem(dbc.Col): - def __init__(self, - name, - base_id, - title=None, - param_key=None, - visible=True, - **kwargs): - + def __init__( + self, name, base_id, title=None, param_key=None, visible=True, **kwargs + ): if param_key == None: param_key = name self.label = dbc.Label(title) - self.input = daq.ToggleSwitch(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'input'}, - **kwargs) - self.output_label = dbc.Label('False/True') + self.input = daq.ToggleSwitch( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + **kwargs, + ) + self.output_label = dbc.Label("False/True") style = {} if not visible: - style['display'] = 'none' + style["display"] = "none" - super(BoolItem, self).__init__(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'form_group'}, - children=[self.label, self.input, self.output_label], - style=style) + super(BoolItem, self).__init__( + id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, + children=[self.label, self.input, self.output_label], + style=style, + ) class ImgItem(dbc.Col): - def __init__(self, - name, - src, - base_id, - title=None, - param_key=None, - width='100px', - visible=True, - **kwargs): - + def __init__( + self, + name, + src, + base_id, + title=None, + param_key=None, + width="100px", + visible=True, + **kwargs, + ): if param_key == None: param_key = name - - if not (width.endswith('px') or width.endswith('%')): - width = width + 'px' - + + if not (width.endswith("px") or width.endswith("%")): + width = width + "px" + self.label = dbc.Label(title) - - encoded_image = base64.b64encode(open(src, 'rb').read()) - self.src = 'data:image/png;base64,{}'.format(encoded_image.decode()) - self.input_img = html.Img(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'input'}, - src=self.src, - style={'height':'auto', 'width':width}, - **kwargs) + + encoded_image = base64.b64encode(open(src, "rb").read()) + self.src = "data:image/png;base64,{}".format(encoded_image.decode()) + self.input_img = html.Img( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + src=self.src, + style={"height": "auto", "width": width}, + **kwargs, + ) style = {} if not visible: - style['display'] = 'none' + style["display"] = "none" - super(ImgItem, self).__init__(id={**base_id, - 'name': name, - 'param_key': param_key, - 'layer': 'form_group'}, - children=[self.label, self.input_img], - style=style) + super(ImgItem, self).__init__( + id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, + children=[self.label, self.input_img], + style=style, + ) class ParameterEditor(dbc.Form): - - type_map = {float: FloatItem, - int: IntItem, - str: StrItem, - } + type_map = { + float: FloatItem, + int: IntItem, + str: StrItem, + } def __init__(self, _id, parameters, **kwargs): self._parameters = parameters - super(ParameterEditor, self).__init__(id=_id, children=[], className='kwarg-editor', **kwargs) + super(ParameterEditor, self).__init__( + id=_id, children=[], className="kwarg-editor", **kwargs + ) self.children = self.build_children() def init_callbacks(self, app): - app.callback(Output(self.id, 'n_submit'), - Input({**self.id, - 'name': ALL}, - 'value'), - State(self.id, 'n_submit'), - ) - + app.callback( + Output(self.id, "n_submit"), + Input({**self.id, "name": ALL}, "value"), + State(self.id, "n_submit"), + ) + for child in self.children: - if hasattr(child,"init_callbacks"): - child.init_callbacks(app) - - + if hasattr(child, "init_callbacks"): + child.init_callbacks(app) + @property def values(self): - return {param['name']: param.get('value', None) for param in self._parameters} + return {param["name"]: param.get("value", None) for param in self._parameters} @property def parameters(self): - return {param['name']: param for param in self._parameters} + return {param["name"]: param for param in self._parameters} def _determine_type(self, parameter_dict): - if 'type' in parameter_dict: - if parameter_dict['type'] in self.type_map: - return parameter_dict['type'] - elif parameter_dict['type'].__name__ in self.type_map: - return parameter_dict['type'].__name__ - elif type(parameter_dict['value']) in self.type_map: - return type(parameter_dict['value']) - raise TypeError(f'No item type could be determined for this parameter: {parameter_dict}') + if "type" in parameter_dict: + if parameter_dict["type"] in self.type_map: + return parameter_dict["type"] + elif parameter_dict["type"].__name__ in self.type_map: + return parameter_dict["type"].__name__ + elif type(parameter_dict["value"]) in self.type_map: + return type(parameter_dict["value"]) + raise TypeError( + f"No item type could be determined for this parameter: {parameter_dict}" + ) def build_children(self, values=None): children = [] for parameter_dict in self._parameters: parameter_dict = parameter_dict.copy() - if values and parameter_dict['name'] in values: - parameter_dict['value'] = values[parameter_dict['name']] + if values and parameter_dict["name"] in values: + parameter_dict["value"] = values[parameter_dict["name"]] type = self._determine_type(parameter_dict) - parameter_dict.pop('type', None) - item = self.type_map[type](**parameter_dict, base_id=self.id) + parameter_dict.pop("type", None) + item = self.type_map[type](**parameter_dict, base_id=self.id) children.append(item) return children - + class JSONParameterEditor(ParameterEditor): - type_map = {'float': FloatItem, - 'int': IntItem, - 'str': StrItem, - 'slider': SliderItem, - 'dropdown': DropdownItem, - 'radio': RadioItem, - 'bool': BoolItem, - 'img': ImgItem, - #'graph': GraphItem, - } + type_map = { + "float": FloatItem, + "int": IntItem, + "str": StrItem, + "slider": SliderItem, + "dropdown": DropdownItem, + "radio": RadioItem, + "bool": BoolItem, + "img": ImgItem, + #'graph': GraphItem, + } def __init__(self, _id, json_blob, **kwargs): - super(ParameterEditor, self).__init__(id=_id, children=[], className='kwarg-editor', **kwargs) + super(ParameterEditor, self).__init__( + id=_id, children=[], className="kwarg-editor", **kwargs + ) self._json_blob = json_blob self.children = self.build_children() @@ -305,11 +295,11 @@ def build_children(self, values=None): ... # build a parameter dict from self.json_blob ... - type = json_record.get('type', self._determine_type(json_record)) + type = json_record.get("type", self._determine_type(json_record)) json_record = json_record.copy() - if values and json_record['name'] in values: - json_record['value'] = values[json_record['name']] - json_record.pop('type', None) + if values and json_record["name"] in values: + json_record["value"] = values[json_record["name"]] + json_record.pop("type", None) item = self.type_map[type](**json_record, base_id=self.id) children.append(item) @@ -321,10 +311,21 @@ def __init__(self, instance_index, func: Callable, **kwargs): self.func = func self._instance_index = instance_index - parameters = [{'name': name, 'value': param.default} for name, param in signature(func).parameters.items() - if param.default is not _empty] + parameters = [ + {"name": name, "value": param.default} + for name, param in signature(func).parameters.items() + if param.default is not _empty + ] - super(KwargsEditor, self).__init__(dict(index=instance_index, type='kwargs-editor'), parameters=parameters, **kwargs) + super(KwargsEditor, self).__init__( + dict(index=instance_index, type="kwargs-editor"), + parameters=parameters, + **kwargs, + ) def new_record(self): - return {name: p.default for name, p in signature(self.func).parameters.items() if p.default is not _empty} + return { + name: p.default + for name, p in signature(self.func).parameters.items() + if p.default is not _empty + } diff --git a/utils/content_registry.py b/utils/content_registry.py index 4e1c12e..c441d98 100644 --- a/utils/content_registry.py +++ b/utils/content_registry.py @@ -1,18 +1,19 @@ import json from copy import deepcopy + class Models: - def __init__(self, modelfile_path='./assets/mode_description.json'): - self.path = modelfile_path - f = open('./assets/mode_description.json') - - self.contents = json.load(f)['contents'] - self.modelname_list = [content['model_name'] for content in self.contents] + def __init__(self, modelfile_path="./assets/mode_description.json"): + self.path = modelfile_path + f = open("./assets/mode_description.json") + + self.contents = json.load(f)["contents"] + self.modelname_list = [content["model_name"] for content in self.contents] self.models = {} for i, n in enumerate(self.modelname_list): self.models[n] = self.contents[i] - + @staticmethod def remove_key_from_dict_list(data, key): new_data = [] @@ -23,8 +24,8 @@ def remove_key_from_dict_list(data, key): new_data.append(new_item) else: new_data.append(item) - - return new_data + + return new_data -models = Models() \ No newline at end of file +models = Models() From 0b8f71bccdc2b26b5980844f386c29fda54becc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Mon, 4 Mar 2024 18:10:40 -0800 Subject: [PATCH 08/21] :boom: Move children generation callback into control_bar From there, it cannot reference the variable `app`, but we likely do not need this and the callback can become just a function so we will deal with this later. --- app.py | 23 ----------------------- callbacks/control_bar.py | 22 ++++++++++++++++++++++ 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/app.py b/app.py index 7b9538c..1f17230 100644 --- a/app.py +++ b/app.py @@ -8,9 +8,7 @@ from callbacks.image_viewer import * # noqa: F403, F401 from callbacks.segmentation import * # noqa: F403, F401 from components.control_bar import layout as control_bar_layout -from components.dash_component_editor import JSONParameterEditor from components.image_viewer import layout as image_viewer_layout -from utils.content_registry import models USER_NAME = os.getenv("USER_NAME") USER_PASSWORD = os.getenv("USER_PASSWORD") @@ -37,26 +35,5 @@ ], ) - -### automatic Dash gui callback ### -@callback( - Output("gui-layouts", "children"), - Input("model-list", "value"), -) -def update_gui_parameters(model_name): - data = models.models[model_name] - if data["gui_parameters"]: - item_list = JSONParameterEditor( - _id={"type": str(uuid.uuid4())}, # pattern match _id (base id), name - json_blob=models.remove_key_from_dict_list( - data["gui_parameters"], "comp_group" - ), - ) - item_list.init_callbacks(app) - return [html.H4("Model Parameters"), item_list] - else: - return [""] - - if __name__ == "__main__": app.run_server(debug=True) diff --git a/callbacks/control_bar.py b/callbacks/control_bar.py index 9d7d438..0664f50 100644 --- a/callbacks/control_bar.py +++ b/callbacks/control_bar.py @@ -2,6 +2,7 @@ import os import random import time +import uuid import dash_mantine_components as dmc import plotly.express as px @@ -24,8 +25,10 @@ from dash_iconify import DashIconify from components.annotation_class import annotation_class_item +from components.dash_component_editor import JSONParameterEditor from constants import ANNOT_ICONS, ANNOT_NOTIFICATION_MSGS, KEY_MODES, KEYBINDS from utils.annotations import Annotations +from utils.content_registry import models from utils.data_utils import tiled_dataset from utils.plot_utils import generate_notification, generate_notification_bg_icon_col @@ -912,3 +915,22 @@ def update_current_annotated_slices_values(all_classes): ] disabled = True if len(dropdown_values) == 0 else False return dropdown_values, disabled + + +@callback( + Output("gui-layouts", "children"), + Input("model-list", "value"), +) +def update_gui_parameters(model_name): + data = models.models[model_name] + if data["gui_parameters"]: + item_list = JSONParameterEditor( + _id={"type": str(uuid.uuid4())}, # pattern match _id (base id), name + json_blob=models.remove_key_from_dict_list( + data["gui_parameters"], "comp_group" + ), + ) + # item_list.init_callbacks(app) + return [html.H4("Model Parameters"), item_list] + else: + return [""] From 2c6a4b504127fe3086683535b0f99f509d5c3d94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Mon, 4 Mar 2024 18:11:38 -0800 Subject: [PATCH 09/21] Fix remaining `flake8 warnings` E711 comparison to None should be 'if cond is None:' E265 block comment should start with '# ' F401 Unused import --- components/dash_component_editor.py | 27 ++++++--------------------- 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/components/dash_component_editor.py b/components/dash_component_editor.py index 0ab0f14..dd9cd20 100644 --- a/components/dash_component_editor.py +++ b/components/dash_component_editor.py @@ -1,10 +1,4 @@ import base64 - -# import PIL.Image -import io -import re - -# noinspection PyUnresolvedReferences from inspect import _empty, signature from typing import Callable @@ -12,14 +6,6 @@ import dash_daq as daq from dash import ALL, Input, Output, State, dcc, html -# import plotly.express as px -# Procedural dash form generation - - -""" -{'name', 'title', 'value', 'type', -""" - class SimpleItem(dbc.Col): def __init__( @@ -32,7 +18,7 @@ def __init__( debounce=True, **kwargs, ): - if param_key == None: + if param_key is None: param_key = name self.label = dbc.Label(title) self.input = dbc.Input( @@ -72,7 +58,7 @@ def __init__( visible=True, **kwargs, ): - if param_key == None: + if param_key is None: param_key = name self.label = dbc.Label(title) self.input = dcc.Slider( @@ -103,7 +89,7 @@ def __init__( visible=True, **kwargs, ): - if param_key == None: + if param_key is None: param_key = name self.label = dbc.Label(title) self.input = dcc.Dropdown( @@ -126,7 +112,7 @@ class RadioItem(dbc.Col): def __init__( self, name, base_id, title=None, param_key=None, visible=True, **kwargs ): - if param_key == None: + if param_key is None: param_key = name self.label = dbc.Label(title) self.input = dbc.RadioItems( @@ -149,7 +135,7 @@ class BoolItem(dbc.Col): def __init__( self, name, base_id, title=None, param_key=None, visible=True, **kwargs ): - if param_key == None: + if param_key is None: param_key = name self.label = dbc.Label(title) self.input = daq.ToggleSwitch( @@ -181,7 +167,7 @@ def __init__( visible=True, **kwargs, ): - if param_key == None: + if param_key is None: param_key = name if not (width.endswith("px") or width.endswith("%")): @@ -279,7 +265,6 @@ class JSONParameterEditor(ParameterEditor): "radio": RadioItem, "bool": BoolItem, "img": ImgItem, - #'graph': GraphItem, } def __init__(self, _id, json_blob, **kwargs): From 132a5ea982645083eb1e8128307305b9c1e92499 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Mon, 4 Mar 2024 19:56:00 -0800 Subject: [PATCH 10/21] Move `dcc.Store` elements from app file into control bar component --- app.py | 4 +--- components/control_bar.py | 8 ++++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/app.py b/app.py index 1f17230..d2ed8d2 100644 --- a/app.py +++ b/app.py @@ -2,7 +2,7 @@ import dash_auth import dash_mantine_components as dmc -from dash import Dash, dcc +from dash import Dash from callbacks.control_bar import * # noqa: F403, F401 from callbacks.image_viewer import * # noqa: F403, F401 @@ -30,8 +30,6 @@ children=[ control_bar_layout(), image_viewer_layout(), - dcc.Store(id="current-class-selection", data="#FFA200"), - dcc.Store(id="gui-components-values", data={}), ], ) diff --git a/components/control_bar.py b/components/control_bar.py index c8f6bf9..38f7e09 100644 --- a/components/control_bar.py +++ b/components/control_bar.py @@ -453,6 +453,9 @@ def layout(): }, className="add-class-btn", ), + dcc.Store( + id="current-class-selection", data="#FFA200" + ), dmc.Space(h=20), ], ), @@ -605,7 +608,7 @@ def layout(): id="model-configuration", children=[ _control_item( - "Model Selection", + "Model", "model-selector", dmc.Select( id="model-list", @@ -615,11 +618,12 @@ def layout(): if models.modelname_list[0] else None ), - placeholder="Select an model...", + placeholder="Select a model...", ), ), dmc.Space(h=25), html.Div(id="gui-layouts"), + dcc.Store(id="gui-components-values", data={}), dmc.Space(h=25), dmc.Center( dmc.Button( From b51c99803d63ba6b6868b16437b3d78d5abb2c3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Mon, 4 Mar 2024 20:07:45 -0800 Subject: [PATCH 11/21] :bug: Fix (non-)use of path parameter in `Models` Additionally pretty-prints model file for easier editing --- assets/mode_description.json | 6 ---- assets/models.json | 69 ++++++++++++++++++++++++++++++++++++ utils/content_registry.py | 4 +-- 3 files changed, 71 insertions(+), 8 deletions(-) delete mode 100755 assets/mode_description.json create mode 100755 assets/models.json diff --git a/assets/mode_description.json b/assets/mode_description.json deleted file mode 100755 index 7c95a20..0000000 --- a/assets/mode_description.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "contents":[ - {"model_name": "random_forest", "version": "1.0.0", "type": "supervised", "user": "mlexchange team", "uri": "xxx", "application": ["classification", "segmentation"], "description": "xxx", "gui_parameters": [{"type": "int", "name": "num-tree", "title": "Number of Trees", "param_key": "n_estimators", "value": "30"}, {"type": "int", "name": "tree-depth", "title": "Tree Depth", "param_key": "max_depth", "value": "8"}], "cmd": ["xxx"], "reference": "Adapted from Dash Plotly image segmentation example"}, - {"model_name": "kmeans", "version": "1.0.0", "type": "unsupervised", "user": "mlexchange team", "uri": "xxx", "application": ["segmentation", "clustering"], "description": "xxx", "gui_parameters": [{"type": "int", "name": "num-cluster", "title": "Number of Clusters", "param_key": "n_clusters", "value": "2"}, {"type": "int", "name": "num-iter", "title": "Max Iteration", "param_key": "max_iter", "value": "300"}], "cmd": ["xxx", "xxxx"], "reference": "Nicholas Schwartz & Howard Yanxon"} - ] -} diff --git a/assets/models.json b/assets/models.json new file mode 100755 index 0000000..bf77000 --- /dev/null +++ b/assets/models.json @@ -0,0 +1,69 @@ +{ + "contents": [ + { + "model_name": "random_forest", + "version": "1.0.0", + "type": "supervised", + "user": "mlexchange team", + "uri": "xxx", + "application": [ + "classification", + "segmentation" + ], + "description": "xxx", + "gui_parameters": [ + { + "type": "int", + "name": "num-tree", + "title": "Number of Trees", + "param_key": "n_estimators", + "value": "30" + }, + { + "type": "int", + "name": "tree-depth", + "title": "Tree Depth", + "param_key": "max_depth", + "value": "8" + } + ], + "cmd": [ + "xxx" + ], + "reference": "Adapted from Dash Plotly image segmentation example" + }, + { + "model_name": "kmeans", + "version": "1.0.0", + "type": "unsupervised", + "user": "mlexchange team", + "uri": "xxx", + "application": [ + "segmentation", + "clustering" + ], + "description": "xxx", + "gui_parameters": [ + { + "type": "int", + "name": "num-cluster", + "title": "Number of Clusters", + "param_key": "n_clusters", + "value": "2" + }, + { + "type": "int", + "name": "num-iter", + "title": "Max Iteration", + "param_key": "max_iter", + "value": "300" + } + ], + "cmd": [ + "xxx", + "xxxx" + ], + "reference": "Nicholas Schwartz & Howard Yanxon" + } + ] +} diff --git a/utils/content_registry.py b/utils/content_registry.py index c441d98..ed9203c 100644 --- a/utils/content_registry.py +++ b/utils/content_registry.py @@ -3,9 +3,9 @@ class Models: - def __init__(self, modelfile_path="./assets/mode_description.json"): + def __init__(self, modelfile_path="./assets/models.json"): self.path = modelfile_path - f = open("./assets/mode_description.json") + f = open(self.path) self.contents = json.load(f)["contents"] self.modelname_list = [content["model_name"] for content in self.contents] From b3915e7f2f71a0ff2c1c12b932f0f5478d309f8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Tue, 5 Mar 2024 08:28:06 -0800 Subject: [PATCH 12/21] Elevate `_control_item ` function to class `ControlItem`, delete unused code --- callbacks/control_bar.py | 7 +-- components/control_bar.py | 35 +++-------- components/dash_component_editor.py | 92 +++++++---------------------- 3 files changed, 34 insertions(+), 100 deletions(-) diff --git a/callbacks/control_bar.py b/callbacks/control_bar.py index 0664f50..9b15112 100644 --- a/callbacks/control_bar.py +++ b/callbacks/control_bar.py @@ -921,16 +921,15 @@ def update_current_annotated_slices_values(all_classes): Output("gui-layouts", "children"), Input("model-list", "value"), ) -def update_gui_parameters(model_name): +def update_model_parameters(model_name): data = models.models[model_name] if data["gui_parameters"]: item_list = JSONParameterEditor( - _id={"type": str(uuid.uuid4())}, # pattern match _id (base id), name + _id={"type": str(uuid.uuid4())}, json_blob=models.remove_key_from_dict_list( data["gui_parameters"], "comp_group" ), ) - # item_list.init_callbacks(app) - return [html.H4("Model Parameters"), item_list] + return [item_list] else: return [""] diff --git a/components/control_bar.py b/components/control_bar.py index 38f7e09..4ca57e7 100644 --- a/components/control_bar.py +++ b/components/control_bar.py @@ -5,6 +5,7 @@ from dash_iconify import DashIconify from components.annotation_class import annotation_class_item +from components.dash_component_editor import ControlItem from constants import ANNOT_ICONS, KEYBINDS from utils.content_registry import models from utils.data_utils import tiled_dataset @@ -19,24 +20,6 @@ def _tooltip(text, children): ) -def _control_item(title, title_id, item): - """ - Returns a customized layout for a control item - """ - return dmc.Grid( - [ - dmc.Text( - title, - id=title_id, - size="sm", - style={"width": "100px", "margin": "auto", "paddingRight": "5px"}, - align="right", - ), - html.Div(item, style={"width": "265px", "margin": "auto"}), - ] - ) - - def _accordion_item(title, icon, value, children, id): """ Returns a customized layout for an accordion item @@ -79,7 +62,7 @@ def layout(): id="data-selection-controls", children=[ dmc.Space(h=5), - _control_item( + ControlItem( "Dataset", "image-selector", dmc.Grid( @@ -115,7 +98,7 @@ def layout(): ), ), dmc.Space(h=25), - _control_item( + ControlItem( "Slice 1", "image-selection-text", [ @@ -178,7 +161,7 @@ def layout(): ], ), dmc.Space(h=25), - _control_item( + ControlItem( _tooltip( "Jump to your annotated slices", "Annotated slices", @@ -208,7 +191,7 @@ def layout(): children=html.Div( [ dmc.Space(h=5), - _control_item( + ControlItem( "Brightness", "bightness-text", [ @@ -252,7 +235,7 @@ def layout(): ], ), dmc.Space(h=20), - _control_item( + ControlItem( "Contrast", "contrast-text", dmc.Grid( @@ -607,7 +590,7 @@ def layout(): "run-model", id="model-configuration", children=[ - _control_item( + ControlItem( "Model", "model-selector", dmc.Select( @@ -646,7 +629,7 @@ def layout(): styles={"trackLabel": {"cursor": "pointer"}}, ), dmc.Space(h=25), - _control_item( + ControlItem( "Results", "", dmc.Select( @@ -655,7 +638,7 @@ def layout(): ), ), dmc.Space(h=25), - _control_item( + ControlItem( "Opacity", "", dmc.Slider( diff --git a/components/dash_component_editor.py b/components/dash_component_editor.py index dd9cd20..5b232a0 100644 --- a/components/dash_component_editor.py +++ b/components/dash_component_editor.py @@ -1,12 +1,30 @@ -import base64 -from inspect import _empty, signature -from typing import Callable - import dash_bootstrap_components as dbc import dash_daq as daq +import dash_mantine_components as dmc from dash import ALL, Input, Output, State, dcc, html +class ControlItem(dmc.Grid): + """ + Customized layout for a control item + """ + + def __init__(self, title, title_id, item, **kwargs): + super(ControlItem, self).__init__( + [ + dmc.Text( + title, + id=title_id, + size="sm", + style={"width": "100px", "margin": "auto", "paddingRight": "5px"}, + align="right", + ), + html.Div(item, style={"width": "265px", "margin": "auto"}), + ], + **kwargs, + ) + + class SimpleItem(dbc.Col): def __init__( self, @@ -155,46 +173,6 @@ def __init__( ) -class ImgItem(dbc.Col): - def __init__( - self, - name, - src, - base_id, - title=None, - param_key=None, - width="100px", - visible=True, - **kwargs, - ): - if param_key is None: - param_key = name - - if not (width.endswith("px") or width.endswith("%")): - width = width + "px" - - self.label = dbc.Label(title) - - encoded_image = base64.b64encode(open(src, "rb").read()) - self.src = "data:image/png;base64,{}".format(encoded_image.decode()) - self.input_img = html.Img( - id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, - src=self.src, - style={"height": "auto", "width": width}, - **kwargs, - ) - - style = {} - if not visible: - style["display"] = "none" - - super(ImgItem, self).__init__( - id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, - children=[self.label, self.input_img], - style=style, - ) - - class ParameterEditor(dbc.Form): type_map = { float: FloatItem, @@ -264,7 +242,6 @@ class JSONParameterEditor(ParameterEditor): "dropdown": DropdownItem, "radio": RadioItem, "bool": BoolItem, - "img": ImgItem, } def __init__(self, _id, json_blob, **kwargs): @@ -289,28 +266,3 @@ def build_children(self, values=None): children.append(item) return children - - -class KwargsEditor(ParameterEditor): - def __init__(self, instance_index, func: Callable, **kwargs): - self.func = func - self._instance_index = instance_index - - parameters = [ - {"name": name, "value": param.default} - for name, param in signature(func).parameters.items() - if param.default is not _empty - ] - - super(KwargsEditor, self).__init__( - dict(index=instance_index, type="kwargs-editor"), - parameters=parameters, - **kwargs, - ) - - def new_record(self): - return { - name: p.default - for name, p in signature(self.func).parameters.items() - if p.default is not _empty - } From 8b34636b5ea58c90bf1e954fb1c632a0e7168814 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Tue, 5 Mar 2024 17:24:33 -0800 Subject: [PATCH 13/21] :sparkles: Add Dlsia model parameters --- app.py | 2 +- assets/models.json | 1083 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 1051 insertions(+), 34 deletions(-) diff --git a/app.py b/app.py index d2ed8d2..efda09c 100644 --- a/app.py +++ b/app.py @@ -34,4 +34,4 @@ ) if __name__ == "__main__": - app.run_server(debug=True) + app.run_server(host="0.0.0.0", port=8075, debug=True) diff --git a/assets/models.json b/assets/models.json index bf77000..245638e 100755 --- a/assets/models.json +++ b/assets/models.json @@ -1,69 +1,1086 @@ { "contents": [ { - "model_name": "random_forest", - "version": "1.0.0", + "model_name": "DSLIA MSDNet", + "version": "0.0.1", "type": "supervised", "user": "mlexchange team", - "uri": "xxx", + "uri": "ghcr.io/mlexchange/mlex_dlsia_segmentation:main", "application": [ - "classification", "segmentation" ], - "description": "xxx", + "description": "MSDNets in DLSIA for image segmentation", "gui_parameters": [ { "type": "int", - "name": "num-tree", - "title": "Number of Trees", - "param_key": "n_estimators", - "value": "30" + "name": "layer_width", + "title": "Layers Width", + "param_key": "layer_width", + "value": 1, + "comp_group": "train_model" }, { "type": "int", - "name": "tree-depth", - "title": "Tree Depth", - "param_key": "max_depth", - "value": "8" + "name": "num_layers", + "title": "Number of Layers", + "param_key": "num_layers", + "value": 3, + "comp_group": "train_model" + }, + { + "type": "radio", + "name": "custom_dilation", + "title": "Custom Dilation", + "param_key": "custom_dilation", + "value": true, + "options": [ + { + "label": "True", + "value": true + }, + { + "label": "False", + "value": false + } + ], + "comp_group": "train_model" + }, + { + "type": "int", + "name": "max_dilation", + "title": "Maximum Dilation", + "param_key": "max_dilation", + "value": 5, + "comp_group": "train_model" + }, + { + "type": "str", + "name": "dilation_array", + "title": "Dilation Array", + "param_key": "dilation_array", + "value": "[1, 2, 4]", + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "num_epochs", + "title": "Number of epoch", + "param_key": "num_epochs", + "min": 1, + "max": 1000, + "value": 30, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "optimizer", + "title": "Optimizer", + "param_key": "optimizer", + "value": "Adam", + "options": [ + { + "label": "Adadelta", + "value": "Adadelta" + }, + { + "label": "Adagrad", + "value": "Adagrad" + }, + { + "label": "Adam", + "value": "Adam" + }, + { + "label": "AdamW", + "value": "AdamW" + }, + { + "label": "SparseAdam", + "value": "SparseAdam" + }, + { + "label": "Adamax", + "value": "Adamax" + }, + { + "label": "ASGD", + "value": "ASGD" + }, + { + "label": "LBFGS", + "value": "LBFGS" + }, + { + "label": "RMSprop", + "value": "RMSprop" + }, + { + "label": "Rprop", + "value": "Rprop" + }, + { + "label": "SGD", + "value": "SGD" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "criterion", + "title": "Criterion", + "param_key": "criterion", + "value": "MSELoss", + "options": [ + { + "label": "L1Loss", + "value": "L1Loss" + }, + { + "label": "MSELoss", + "value": "MSELoss" + }, + { + "label": "CrossEntropyLoss", + "value": "CrossEntropyLoss" + }, + { + "label": "CTCLoss", + "value": "CTCLoss" + }, + { + "label": "NLLLoss", + "value": "NLLLoss" + }, + { + "label": "PoissonNLLLoss", + "value": "PoissonNLLLoss" + }, + { + "label": "GaussianNLLLoss", + "value": "GaussianNLLLoss" + }, + { + "label": "KLDivLoss", + "value": "KLDivLoss" + }, + { + "label": "BCELoss", + "value": "BCELoss" + }, + { + "label": "BCEWithLogitsLoss", + "value": "BCEWithLogitsLoss" + }, + { + "label": "MarginRankingLoss", + "value": "MarginRankingLoss" + }, + { + "label": "HingeEnbeddingLoss", + "value": "HingeEnbeddingLoss" + }, + { + "label": "MultiLabelMarginLoss", + "value": "MultiLabelMarginLoss" + }, + { + "label": "HuberLoss", + "value": "HuberLoss" + }, + { + "label": "SmoothL1Loss", + "value": "SmoothL1Loss" + }, + { + "label": "SoftMarginLoss", + "value": "SoftMarginLoss" + }, + { + "label": "MutiLabelSoftMarginLoss", + "value": "MutiLabelSoftMarginLoss" + }, + { + "label": "CosineEmbeddingLoss", + "value": "CosineEmbeddingLoss" + }, + { + "label": "MultiMarginLoss", + "value": "MultiMarginLoss" + }, + { + "label": "TripletMarginLoss", + "value": "TripletMarginLoss" + }, + { + "label": "TripletMarginWithDistanceLoss", + "value": "TripletMarginWithDistanceLoss" + } + ], + "comp_group": "train_model" + }, + { + "type": "str", + "name": "weights", + "title": "Class Weights", + "param_key": "weights", + "value": "[1.0, 1.0, 1.0]", + "comp_group": "train_model" + }, + { + "type": "float", + "name": "learning_rate", + "title": "Learning Rate", + "param_key": "learning_rate", + "value": 0.001, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "activation", + "title": "Activation", + "param_key": "activation", + "value": "ReLU", + "options": [ + { + "label": "ReLU", + "value": "ReLU" + }, + { + "label": "Sigmoid", + "value": "Sigmoid" + }, + { + "label": "Tanh", + "value": "Tanh" + }, + { + "label": "Softmax", + "value": "Softmax" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "val_pct", + "title": "Validation %", + "param_key": "val_pct", + "min": 0, + "max": 100, + "step": 5, + "value": 20, + "marks": { + "0": "0", + "100": "100" + }, + "comp_group": "train_model" + }, + { + "type": "radio", + "name": "shuffle_train", + "title": "Shuffle Training", + "param_key": "shuffle_train", + "value": true, + "options": [ + { + "label": "True", + "value": true + }, + { + "label": "False", + "value": false + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_train", + "title": "Training Batch Size", + "param_key": "batch_size_train", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "comp_group": "train_model" + }, + { + "type": "radio", + "name": "shuffle_val", + "title": "Shuffle Training", + "param_key": "shuffle_val", + "value": true, + "options": [ + { + "label": "True", + "value": true + }, + { + "label": "False", + "value": false + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_val", + "title": "Validation Batch Size", + "param_key": "batch_size_val", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "comp_group": "train_model" + }, + { + "type": "radio", + "name": "shuffle_inference", + "title": "Shuffle Inference", + "param_key": "shuffle_inference", + "value": true, + "options": [ + { + "label": "True", + "value": true + }, + { + "label": "False", + "value": false + } + ], + "comp_group": "prediction_model" + }, + { + "type": "slider", + "name": "batch_size_inference", + "title": "Inference Batch Size", + "param_key": "batch_size_inference", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "comp_group": "prediction_model" } ], "cmd": [ - "xxx" + "python3 src/train_model.py", + "python3 src/segment.py" ], - "reference": "Adapted from Dash Plotly image segmentation example" + "reference": "https://dlsia.readthedocs.io/en/latest/" }, { - "model_name": "kmeans", - "version": "1.0.0", - "type": "unsupervised", + "model_name": "DSLIA TUNet", + "version": "0.0.1", + "type": "supervised", "user": "mlexchange team", - "uri": "xxx", + "uri": "ghcr.io/mlexchange/mlex_dlsia_segmentation:main", "application": [ - "segmentation", - "clustering" + "segmentation" ], - "description": "xxx", + "description": "TUNet in DLSIA for image segmentation", "gui_parameters": [ { "type": "int", - "name": "num-cluster", - "title": "Number of Clusters", - "param_key": "n_clusters", - "value": "2" + "name": "depth", + "title": "Depth", + "param_key": "depth", + "value": 4, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "base_channels", + "title": "Base Channels", + "param_key": "base_channels", + "value": 32, + "comp_group": "train_model" }, { "type": "int", - "name": "num-iter", - "title": "Max Iteration", - "param_key": "max_iter", - "value": "300" + "name": "growth_rate", + "title": "Growth Rate", + "param_key": "growth_rate", + "value": 2, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "hidden_rate", + "title": "Hidden Rate", + "param_key": "hidden_rate", + "value": 1, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "num_epochs", + "title": "Number of epoch", + "param_key": "num_epochs", + "min": 1, + "max": 1000, + "value": 30, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "optimizer", + "title": "Optimizer", + "param_key": "optimizer", + "value": "Adam", + "options": [ + { + "label": "Adadelta", + "value": "Adadelta" + }, + { + "label": "Adagrad", + "value": "Adagrad" + }, + { + "label": "Adam", + "value": "Adam" + }, + { + "label": "AdamW", + "value": "AdamW" + }, + { + "label": "SparseAdam", + "value": "SparseAdam" + }, + { + "label": "Adamax", + "value": "Adamax" + }, + { + "label": "ASGD", + "value": "ASGD" + }, + { + "label": "LBFGS", + "value": "LBFGS" + }, + { + "label": "RMSprop", + "value": "RMSprop" + }, + { + "label": "Rprop", + "value": "Rprop" + }, + { + "label": "SGD", + "value": "SGD" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "criterion", + "title": "Criterion", + "param_key": "criterion", + "value": "MSELoss", + "options": [ + { + "label": "L1Loss", + "value": "L1Loss" + }, + { + "label": "MSELoss", + "value": "MSELoss" + }, + { + "label": "CrossEntropyLoss", + "value": "CrossEntropyLoss" + }, + { + "label": "CTCLoss", + "value": "CTCLoss" + }, + { + "label": "NLLLoss", + "value": "NLLLoss" + }, + { + "label": "PoissonNLLLoss", + "value": "PoissonNLLLoss" + }, + { + "label": "GaussianNLLLoss", + "value": "GaussianNLLLoss" + }, + { + "label": "KLDivLoss", + "value": "KLDivLoss" + }, + { + "label": "BCELoss", + "value": "BCELoss" + }, + { + "label": "BCEWithLogitsLoss", + "value": "BCEWithLogitsLoss" + }, + { + "label": "MarginRankingLoss", + "value": "MarginRankingLoss" + }, + { + "label": "HingeEnbeddingLoss", + "value": "HingeEnbeddingLoss" + }, + { + "label": "MultiLabelMarginLoss", + "value": "MultiLabelMarginLoss" + }, + { + "label": "HuberLoss", + "value": "HuberLoss" + }, + { + "label": "SmoothL1Loss", + "value": "SmoothL1Loss" + }, + { + "label": "SoftMarginLoss", + "value": "SoftMarginLoss" + }, + { + "label": "MutiLabelSoftMarginLoss", + "value": "MutiLabelSoftMarginLoss" + }, + { + "label": "CosineEmbeddingLoss", + "value": "CosineEmbeddingLoss" + }, + { + "label": "MultiMarginLoss", + "value": "MultiMarginLoss" + }, + { + "label": "TripletMarginLoss", + "value": "TripletMarginLoss" + }, + { + "label": "TripletMarginWithDistanceLoss", + "value": "TripletMarginWithDistanceLoss" + } + ], + "comp_group": "train_model" + }, + { + "type": "str", + "name": "weights", + "title": "Class Weights", + "param_key": "weights", + "value": "[1.0, 1.0, 1.0]", + "comp_group": "train_model" + }, + { + "type": "float", + "name": "learning_rate", + "title": "Learning Rate", + "param_key": "learning_rate", + "value": 0.001, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "activation", + "title": "Activation", + "param_key": "activation", + "value": "ReLU", + "options": [ + { + "label": "ReLU", + "value": "ReLU" + }, + { + "label": "Sigmoid", + "value": "Sigmoid" + }, + { + "label": "Tanh", + "value": "Tanh" + }, + { + "label": "Softmax", + "value": "Softmax" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "val_pct", + "title": "Validation %", + "param_key": "val_pct", + "min": 0, + "max": 100, + "step": 5, + "value": 20, + "marks": { + "0": "0", + "100": "100" + }, + "comp_group": "train_model" + }, + { + "type": "radio", + "name": "shuffle_train", + "title": "Shuffle Training", + "param_key": "shuffle_train", + "value": true, + "options": [ + { + "label": "True", + "value": true + }, + { + "label": "False", + "value": false + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_train", + "title": "Training Batch Size", + "param_key": "batch_size_train", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "comp_group": "train_model" + }, + { + "type": "radio", + "name": "shuffle_val", + "title": "Shuffle Training", + "param_key": "shuffle_val", + "value": true, + "options": [ + { + "label": "True", + "value": true + }, + { + "label": "False", + "value": false + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_val", + "title": "Validation Batch Size", + "param_key": "batch_size_val", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "comp_group": "train_model" + }, + { + "type": "radio", + "name": "shuffle_inference", + "title": "Shuffle Inference", + "param_key": "shuffle_inference", + "value": true, + "options": [ + { + "label": "True", + "value": true + }, + { + "label": "False", + "value": false + } + ], + "comp_group": "prediction_model" + }, + { + "type": "slider", + "name": "batch_size_inference", + "title": "Inference Batch Size", + "param_key": "batch_size_inference", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "comp_group": "prediction_model" + } + ], + "cmd": [ + "python3 src/train_model.py", + "python3 src/segment.py" + ], + "reference": "https://dlsia.readthedocs.io/en/latest/" + }, + { + "model_name": "DSLIA TUNet3+", + "version": "0.0.1", + "type": "supervised", + "user": "mlexchange team", + "uri": "ghcr.io/mlexchange/mlex_dlsia_segmentation:main", + "application": [ + "segmentation" + ], + "description": "TUNet3+ DLSIA for image segmentation", + "gui_parameters": [ + { + "type": "int", + "name": "depth", + "title": "Depth", + "param_key": "depth", + "value": 4, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "base_channels", + "title": "Base Channels", + "param_key": "base_channels", + "value": 32, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "growth_rate", + "title": "Growth Rate", + "param_key": "growth_rate", + "value": 2, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "hidden_rate", + "title": "Hidden Rate", + "param_key": "hidden_rate", + "value": 1, + "comp_group": "train_model" + }, + { + "type": "int", + "name": "carryover_channels", + "title": "Carryover Channels", + "param_key": "carryover_channels", + "value": 32, + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "num_epochs", + "title": "Number of epoch", + "param_key": "num_epochs", + "min": 1, + "max": 1000, + "value": 30, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "optimizer", + "title": "Optimizer", + "param_key": "optimizer", + "value": "Adam", + "options": [ + { + "label": "Adadelta", + "value": "Adadelta" + }, + { + "label": "Adagrad", + "value": "Adagrad" + }, + { + "label": "Adam", + "value": "Adam" + }, + { + "label": "AdamW", + "value": "AdamW" + }, + { + "label": "SparseAdam", + "value": "SparseAdam" + }, + { + "label": "Adamax", + "value": "Adamax" + }, + { + "label": "ASGD", + "value": "ASGD" + }, + { + "label": "LBFGS", + "value": "LBFGS" + }, + { + "label": "RMSprop", + "value": "RMSprop" + }, + { + "label": "Rprop", + "value": "Rprop" + }, + { + "label": "SGD", + "value": "SGD" + } + ], + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "criterion", + "title": "Criterion", + "param_key": "criterion", + "value": "MSELoss", + "options": [ + { + "label": "L1Loss", + "value": "L1Loss" + }, + { + "label": "MSELoss", + "value": "MSELoss" + }, + { + "label": "CrossEntropyLoss", + "value": "CrossEntropyLoss" + }, + { + "label": "CTCLoss", + "value": "CTCLoss" + }, + { + "label": "NLLLoss", + "value": "NLLLoss" + }, + { + "label": "PoissonNLLLoss", + "value": "PoissonNLLLoss" + }, + { + "label": "GaussianNLLLoss", + "value": "GaussianNLLLoss" + }, + { + "label": "KLDivLoss", + "value": "KLDivLoss" + }, + { + "label": "BCELoss", + "value": "BCELoss" + }, + { + "label": "BCEWithLogitsLoss", + "value": "BCEWithLogitsLoss" + }, + { + "label": "MarginRankingLoss", + "value": "MarginRankingLoss" + }, + { + "label": "HingeEnbeddingLoss", + "value": "HingeEnbeddingLoss" + }, + { + "label": "MultiLabelMarginLoss", + "value": "MultiLabelMarginLoss" + }, + { + "label": "HuberLoss", + "value": "HuberLoss" + }, + { + "label": "SmoothL1Loss", + "value": "SmoothL1Loss" + }, + { + "label": "SoftMarginLoss", + "value": "SoftMarginLoss" + }, + { + "label": "MutiLabelSoftMarginLoss", + "value": "MutiLabelSoftMarginLoss" + }, + { + "label": "CosineEmbeddingLoss", + "value": "CosineEmbeddingLoss" + }, + { + "label": "MultiMarginLoss", + "value": "MultiMarginLoss" + }, + { + "label": "TripletMarginLoss", + "value": "TripletMarginLoss" + }, + { + "label": "TripletMarginWithDistanceLoss", + "value": "TripletMarginWithDistanceLoss" + } + ], + "comp_group": "train_model" + }, + { + "type": "str", + "name": "weights", + "title": "Class Weights", + "param_key": "weights", + "value": "[1.0, 1.0, 1.0]", + "comp_group": "train_model" + }, + { + "type": "float", + "name": "learning_rate", + "title": "Learning Rate", + "param_key": "learning_rate", + "value": 0.001, + "comp_group": "train_model" + }, + { + "type": "dropdown", + "name": "activation", + "title": "Activation", + "param_key": "activation", + "value": "ReLU", + "options": [ + { + "label": "ReLU", + "value": "ReLU" + }, + { + "label": "Sigmoid", + "value": "Sigmoid" + }, + { + "label": "Tanh", + "value": "Tanh" + }, + { + "label": "Softmax", + "value": "Softmax" + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "val_pct", + "title": "Validation %", + "param_key": "val_pct", + "min": 0, + "max": 100, + "step": 5, + "value": 20, + "marks": { + "0": "0", + "100": "100" + }, + "comp_group": "train_model" + }, + { + "type": "radio", + "name": "shuffle_train", + "title": "Shuffle Training", + "param_key": "shuffle_train", + "value": true, + "options": [ + { + "label": "True", + "value": true + }, + { + "label": "False", + "value": false + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_train", + "title": "Training Batch Size", + "param_key": "batch_size_train", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "comp_group": "train_model" + }, + { + "type": "radio", + "name": "shuffle_val", + "title": "Shuffle Training", + "param_key": "shuffle_val", + "value": true, + "options": [ + { + "label": "True", + "value": true + }, + { + "label": "False", + "value": false + } + ], + "comp_group": "train_model" + }, + { + "type": "slider", + "name": "batch_size_val", + "title": "Validation Batch Size", + "param_key": "batch_size_val", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "comp_group": "train_model" + }, + { + "type": "radio", + "name": "shuffle_inference", + "title": "Shuffle Inference", + "param_key": "shuffle_inference", + "value": true, + "options": [ + { + "label": "True", + "value": true + }, + { + "label": "False", + "value": false + } + ], + "comp_group": "prediction_model" + }, + { + "type": "slider", + "name": "batch_size_inference", + "title": "Inference Batch Size", + "param_key": "batch_size_inference", + "min": 16, + "max": 128, + "step": 16, + "value": 32, + "comp_group": "prediction_model" } ], "cmd": [ - "xxx", - "xxxx" + "python3 src/train_model.py", + "python3 src/segment.py" ], - "reference": "Nicholas Schwartz & Howard Yanxon" + "reference": "https://dlsia.readthedocs.io/en/latest/" } ] } From fcc511f48deb509e761f237b45d715c51d606a49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Tue, 5 Mar 2024 17:32:36 -0800 Subject: [PATCH 14/21] Remove shuffling for validation and inference --- assets/models.json | 108 --------------------------------------------- 1 file changed, 108 deletions(-) diff --git a/assets/models.json b/assets/models.json index 245638e..c8bd7c2 100755 --- a/assets/models.json +++ b/assets/models.json @@ -305,24 +305,6 @@ "value": 32, "comp_group": "train_model" }, - { - "type": "radio", - "name": "shuffle_val", - "title": "Shuffle Training", - "param_key": "shuffle_val", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], - "comp_group": "train_model" - }, { "type": "slider", "name": "batch_size_val", @@ -334,24 +316,6 @@ "value": 32, "comp_group": "train_model" }, - { - "type": "radio", - "name": "shuffle_inference", - "title": "Shuffle Inference", - "param_key": "shuffle_inference", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], - "comp_group": "prediction_model" - }, { "type": "slider", "name": "batch_size_inference", @@ -657,24 +621,6 @@ "value": 32, "comp_group": "train_model" }, - { - "type": "radio", - "name": "shuffle_val", - "title": "Shuffle Training", - "param_key": "shuffle_val", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], - "comp_group": "train_model" - }, { "type": "slider", "name": "batch_size_val", @@ -686,24 +632,6 @@ "value": 32, "comp_group": "train_model" }, - { - "type": "radio", - "name": "shuffle_inference", - "title": "Shuffle Inference", - "param_key": "shuffle_inference", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], - "comp_group": "prediction_model" - }, { "type": "slider", "name": "batch_size_inference", @@ -1017,24 +945,6 @@ "value": 32, "comp_group": "train_model" }, - { - "type": "radio", - "name": "shuffle_val", - "title": "Shuffle Training", - "param_key": "shuffle_val", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], - "comp_group": "train_model" - }, { "type": "slider", "name": "batch_size_val", @@ -1046,24 +956,6 @@ "value": 32, "comp_group": "train_model" }, - { - "type": "radio", - "name": "shuffle_inference", - "title": "Shuffle Inference", - "param_key": "shuffle_inference", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], - "comp_group": "prediction_model" - }, { "type": "slider", "name": "batch_size_inference", From ed16b22ea817f882b46c76a94f74c0635f8751b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Wed, 6 Mar 2024 10:53:03 -0800 Subject: [PATCH 15/21] Switch from Bootstrap to Mantine for generated control elements --- assets/models.json | 338 +++++++++++++++++++++------- components/dash_component_editor.py | 151 ++++++++----- 2 files changed, 361 insertions(+), 128 deletions(-) diff --git a/assets/models.json b/assets/models.json index c8bd7c2..df59436 100755 --- a/assets/models.json +++ b/assets/models.json @@ -14,7 +14,7 @@ { "type": "int", "name": "layer_width", - "title": "Layers Width", + "title": "Layer Width", "param_key": "layer_width", "value": 1, "comp_group": "train_model" @@ -28,21 +28,11 @@ "comp_group": "train_model" }, { - "type": "radio", + "type": "bool", "name": "custom_dilation", "title": "Custom Dilation", "param_key": "custom_dilation", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], + "checked": false, "comp_group": "train_model" }, { @@ -64,11 +54,25 @@ { "type": "slider", "name": "num_epochs", - "title": "Number of epoch", + "title": "# Epochs", "param_key": "num_epochs", "min": 1, "max": 1000, "value": 30, + "marks": [ + { + "value": 1, + "label": "1" + }, + { + "value": 100, + "label": "100" + }, + { + "value": 1000, + "label": "1000" + } + ], "comp_group": "train_model" }, { @@ -77,7 +81,7 @@ "title": "Optimizer", "param_key": "optimizer", "value": "Adam", - "options": [ + "data": [ { "label": "Adadelta", "value": "Adadelta" @@ -131,7 +135,7 @@ "title": "Criterion", "param_key": "criterion", "value": "MSELoss", - "options": [ + "data": [ { "label": "L1Loss", "value": "L1Loss" @@ -241,7 +245,7 @@ "title": "Activation", "param_key": "activation", "value": "ReLU", - "options": [ + "data": [ { "label": "ReLU", "value": "ReLU" @@ -270,61 +274,115 @@ "max": 100, "step": 5, "value": 20, - "marks": { - "0": "0", - "100": "100" - }, + "marks": [ + { + "value": 0, + "label": "0%" + }, + { + "value": 50, + "label": "50%" + }, + { + "value": 100, + "label": "100%" + } + ], "comp_group": "train_model" }, { - "type": "radio", + "type": "bool", "name": "shuffle_train", "title": "Shuffle Training", "param_key": "shuffle_train", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], + "checked": true, "comp_group": "train_model" }, { "type": "slider", "name": "batch_size_train", - "title": "Training Batch Size", + "title": "Batch Size Training", "param_key": "batch_size_train", "min": 16, "max": 128, "step": 16, "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], "comp_group": "train_model" }, { "type": "slider", "name": "batch_size_val", - "title": "Validation Batch Size", + "title": "Batch Size Validation", "param_key": "batch_size_val", "min": 16, "max": 128, "step": 16, "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], "comp_group": "train_model" }, { "type": "slider", "name": "batch_size_inference", - "title": "Inference Batch Size", + "title": "Batch Size Inference", "param_key": "batch_size_inference", "min": 16, "max": 128, "step": 16, "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], "comp_group": "prediction_model" } ], @@ -380,11 +438,25 @@ { "type": "slider", "name": "num_epochs", - "title": "Number of epoch", + "title": "# Epochs", "param_key": "num_epochs", "min": 1, "max": 1000, "value": 30, + "marks": [ + { + "value": 1, + "label": "1" + }, + { + "value": 100, + "label": "100" + }, + { + "value": 1000, + "label": "1000" + } + ], "comp_group": "train_model" }, { @@ -393,7 +465,7 @@ "title": "Optimizer", "param_key": "optimizer", "value": "Adam", - "options": [ + "data": [ { "label": "Adadelta", "value": "Adadelta" @@ -447,7 +519,7 @@ "title": "Criterion", "param_key": "criterion", "value": "MSELoss", - "options": [ + "data": [ { "label": "L1Loss", "value": "L1Loss" @@ -557,7 +629,7 @@ "title": "Activation", "param_key": "activation", "value": "ReLU", - "options": [ + "data": [ { "label": "ReLU", "value": "ReLU" @@ -586,28 +658,24 @@ "max": 100, "step": 5, "value": 20, - "marks": { - "0": "0", - "100": "100" - }, + "marks": [ + { + "value": 0, + "label": "0%" + }, + { + "value": 100, + "label": "100%" + } + ], "comp_group": "train_model" }, { - "type": "radio", + "type": "bool", "name": "shuffle_train", "title": "Shuffle Training", "param_key": "shuffle_train", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], + "checked": true, "comp_group": "train_model" }, { @@ -619,6 +687,24 @@ "max": 128, "step": 16, "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], "comp_group": "train_model" }, { @@ -630,6 +716,24 @@ "max": 128, "step": 16, "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], "comp_group": "train_model" }, { @@ -641,6 +745,24 @@ "max": 128, "step": 16, "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], "comp_group": "prediction_model" } ], @@ -704,11 +826,25 @@ { "type": "slider", "name": "num_epochs", - "title": "Number of epoch", + "title": "# Epochs", "param_key": "num_epochs", "min": 1, "max": 1000, "value": 30, + "marks": [ + { + "value": 1, + "label": "1" + }, + { + "value": 100, + "label": "100" + }, + { + "value": 1000, + "label": "1000" + } + ], "comp_group": "train_model" }, { @@ -717,7 +853,7 @@ "title": "Optimizer", "param_key": "optimizer", "value": "Adam", - "options": [ + "data": [ { "label": "Adadelta", "value": "Adadelta" @@ -771,7 +907,7 @@ "title": "Criterion", "param_key": "criterion", "value": "MSELoss", - "options": [ + "data": [ { "label": "L1Loss", "value": "L1Loss" @@ -881,7 +1017,7 @@ "title": "Activation", "param_key": "activation", "value": "ReLU", - "options": [ + "data": [ { "label": "ReLU", "value": "ReLU" @@ -910,28 +1046,24 @@ "max": 100, "step": 5, "value": 20, - "marks": { - "0": "0", - "100": "100" - }, + "marks": [ + { + "value": 0, + "label": "0%" + }, + { + "value": 100, + "label": "100%" + } + ], "comp_group": "train_model" }, { - "type": "radio", + "type": "bool", "name": "shuffle_train", "title": "Shuffle Training", "param_key": "shuffle_train", - "value": true, - "options": [ - { - "label": "True", - "value": true - }, - { - "label": "False", - "value": false - } - ], + "checked": true, "comp_group": "train_model" }, { @@ -943,6 +1075,24 @@ "max": 128, "step": 16, "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], "comp_group": "train_model" }, { @@ -954,6 +1104,24 @@ "max": 128, "step": 16, "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], "comp_group": "train_model" }, { @@ -965,6 +1133,24 @@ "max": 128, "step": 16, "value": 32, + "marks": [ + { + "value": 16, + "label": "16" + }, + { + "value": 32, + "label": "32" + }, + { + "value": 64, + "label": "64" + }, + { + "value": 128, + "label": "128" + } + ], "comp_group": "prediction_model" } ], diff --git a/components/dash_component_editor.py b/components/dash_component_editor.py index 5b232a0..186c987 100644 --- a/components/dash_component_editor.py +++ b/components/dash_component_editor.py @@ -1,7 +1,6 @@ import dash_bootstrap_components as dbc -import dash_daq as daq import dash_mantine_components as dmc -from dash import ALL, Input, Output, State, dcc, html +from dash import ALL, Input, Output, State, html class ControlItem(dmc.Grid): @@ -9,9 +8,9 @@ class ControlItem(dmc.Grid): Customized layout for a control item """ - def __init__(self, title, title_id, item, **kwargs): + def __init__(self, title, title_id, item, style={}): super(ControlItem, self).__init__( - [ + children=[ dmc.Text( title, id=title_id, @@ -21,67 +20,93 @@ def __init__(self, title, title_id, item, **kwargs): ), html.Div(item, style={"width": "265px", "margin": "auto"}), ], - **kwargs, + style=style, ) -class SimpleItem(dbc.Col): +class NumberItem(ControlItem): def __init__( self, name, base_id, title=None, param_key=None, - type="number", - debounce=True, + visible=True, **kwargs, ): if param_key is None: param_key = name - self.label = dbc.Label(title) - self.input = dbc.Input( - type=type, - debounce=debounce, - id={**base_id, "name": name, "param_key": param_key}, + self.input = dmc.NumberInput( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, **kwargs, ) - super(SimpleItem, self).__init__(children=[self.label, self.input]) - + style = {} + if not visible: + style["display"] = "none" -class FloatItem(SimpleItem): - pass + super(NumberItem, self).__init__( + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "title", + }, + item=self.input, + style=style, + ) -class IntItem(SimpleItem): - def __init__(self, *args, **kwargs): - if "min" not in kwargs: - kwargs["min"] = -9007199254740991 - super(IntItem, self).__init__(*args, step=1, **kwargs) +class StrItem(ControlItem): + def __init__( + self, + name, + base_id, + title=None, + param_key=None, + visible=True, + **kwargs, + ): + if param_key is None: + param_key = name + self.input = dmc.TextInput( + id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + **kwargs, + ) + style = {} + if not visible: + style["display"] = "none" -class StrItem(SimpleItem): - def __init__(self, *args, **kwargs): - super(StrItem, self).__init__(*args, type="text", **kwargs) + super(StrItem, self).__init__( + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "title", + }, + item=self.input, + style=style, + ) -class SliderItem(dbc.Col): +class SliderItem(ControlItem): def __init__( self, name, base_id, title=None, param_key=None, - debounce=True, visible=True, **kwargs, ): if param_key is None: param_key = name - self.label = dbc.Label(title) - self.input = dcc.Slider( + self.input = dmc.Slider( id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, - tooltip={"placement": "bottom", "always_visible": True}, + labelAlwaysOn=False, **kwargs, ) @@ -90,27 +115,31 @@ def __init__( style["display"] = "none" super(SliderItem, self).__init__( - id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, - children=[self.label, self.input], + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "title", + }, + item=self.input, style=style, ) -class DropdownItem(dbc.Col): +class DropdownItem(ControlItem): def __init__( self, name, base_id, title=None, param_key=None, - debounce=True, visible=True, **kwargs, ): if param_key is None: param_key = name - self.label = dbc.Label(title) - self.input = dcc.Dropdown( + self.input = dmc.Select( id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, **kwargs, ) @@ -120,20 +149,32 @@ def __init__( style["display"] = "none" super(DropdownItem, self).__init__( - id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, - children=[self.label, self.input], + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, style=style, ) -class RadioItem(dbc.Col): +class RadioItem(ControlItem): def __init__( self, name, base_id, title=None, param_key=None, visible=True, **kwargs ): if param_key is None: param_key = name - self.label = dbc.Label(title) - self.input = dbc.RadioItems( + + options = [ + dmc.Radio(option["label"], value=option["value"]) + for option in kwargs["options"] + ] + kwargs.pop("options", None) + self.input = dmc.RadioGroup( + options, id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, **kwargs, ) @@ -143,24 +184,30 @@ def __init__( style["display"] = "none" super(RadioItem, self).__init__( - id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, - children=[self.label, self.input], + title=title, + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, style=style, ) -class BoolItem(dbc.Col): +class BoolItem(dmc.Grid): def __init__( self, name, base_id, title=None, param_key=None, visible=True, **kwargs ): if param_key is None: param_key = name - self.label = dbc.Label(title) - self.input = daq.ToggleSwitch( + + self.input = dmc.Switch( id={**base_id, "name": name, "param_key": param_key, "layer": "input"}, + label=title, **kwargs, ) - self.output_label = dbc.Label("False/True") style = {} if not visible: @@ -168,15 +215,15 @@ def __init__( super(BoolItem, self).__init__( id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, - children=[self.label, self.input, self.output_label], + children=[self.input, dmc.Space(h=25)], style=style, ) class ParameterEditor(dbc.Form): type_map = { - float: FloatItem, - int: IntItem, + float: NumberItem, + int: NumberItem, str: StrItem, } @@ -235,8 +282,8 @@ def build_children(self, values=None): class JSONParameterEditor(ParameterEditor): type_map = { - "float": FloatItem, - "int": IntItem, + "float": NumberItem, + "int": NumberItem, "str": StrItem, "slider": SliderItem, "dropdown": DropdownItem, From 5b177c467645d5351d9555035cb78e37a5b909c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Wed, 6 Mar 2024 13:34:42 -0800 Subject: [PATCH 16/21] Clean up `Models`, `JSONParameterEditor`, and parameter retrieval --- assets/models.json | 2 +- callbacks/control_bar.py | 23 +++---- callbacks/segmentation.py | 18 ++--- components/control_bar.py | 7 +- components/dash_component_editor.py | 100 ++++++++-------------------- utils/content_registry.py | 31 --------- utils/data_utils.py | 42 ++++++++++++ 7 files changed, 90 insertions(+), 133 deletions(-) delete mode 100644 utils/content_registry.py diff --git a/assets/models.json b/assets/models.json index df59436..f8f8b6e 100755 --- a/assets/models.json +++ b/assets/models.json @@ -22,7 +22,7 @@ { "type": "int", "name": "num_layers", - "title": "Number of Layers", + "title": "# Layers", "param_key": "num_layers", "value": 3, "comp_group": "train_model" diff --git a/callbacks/control_bar.py b/callbacks/control_bar.py index 9b15112..ae3ba28 100644 --- a/callbacks/control_bar.py +++ b/callbacks/control_bar.py @@ -25,11 +25,10 @@ from dash_iconify import DashIconify from components.annotation_class import annotation_class_item -from components.dash_component_editor import JSONParameterEditor +from components.dash_component_editor import ParameterItems from constants import ANNOT_ICONS, ANNOT_NOTIFICATION_MSGS, KEY_MODES, KEYBINDS from utils.annotations import Annotations -from utils.content_registry import models -from utils.data_utils import tiled_dataset +from utils.data_utils import models, tiled_dataset from utils.plot_utils import generate_notification, generate_notification_bg_icon_col # TODO - temporary local file path and user for annotation saving and exporting @@ -918,18 +917,16 @@ def update_current_annotated_slices_values(all_classes): @callback( - Output("gui-layouts", "children"), + Output("model-parameters", "children"), Input("model-list", "value"), ) def update_model_parameters(model_name): - data = models.models[model_name] - if data["gui_parameters"]: - item_list = JSONParameterEditor( - _id={"type": str(uuid.uuid4())}, - json_blob=models.remove_key_from_dict_list( - data["gui_parameters"], "comp_group" - ), + model = models[model_name] + if model["gui_parameters"]: + # TODO: Retain old parameters if they exist + item_list = ParameterItems( + _id={"type": str(uuid.uuid4())}, json_blob=model["gui_parameters"] ) - return [item_list] + return item_list else: - return [""] + return html.Div("Model has no parameters") diff --git a/callbacks/segmentation.py b/callbacks/segmentation.py index c477bcb..2029b18 100644 --- a/callbacks/segmentation.py +++ b/callbacks/segmentation.py @@ -7,7 +7,7 @@ from dash import ALL, Input, Output, State, callback, no_update from dash.exceptions import PreventUpdate -from utils.data_utils import tiled_dataset +from utils.data_utils import extract_parameters_from_html, tiled_dataset MODE = os.getenv("MODE", "") @@ -58,14 +58,14 @@ @callback( Output("output-details", "children"), Output("submitted-job-id", "data"), - Output("gui-components-values", "data"), + Output("model-parameter-values", "data"), Input("run-model", "n_clicks"), State("annotation-store", "data"), State({"type": "annotation-class-store", "index": ALL}, "data"), State("project-name-src", "value"), - State("gui-layouts", "children"), + State("model-parameters", "children"), ) -def run_job(n_clicks, global_store, all_annotations, project_name, children): +def run_job(n_clicks, global_store, all_annotations, project_name, model_parameters): """ This callback collects parameters from the UI and submits a job to the computing api. If the app is run from "dev" mode, then only a placeholder job_uid will be created. @@ -76,14 +76,8 @@ def run_job(n_clicks, global_store, all_annotations, project_name, children): """ input_params = {} if n_clicks: - if len(children) >= 2: - params = children[1] - for param in params["props"]["children"]: - key = param["props"]["children"][1]["props"]["id"]["param_key"] - value = param["props"]["children"][1]["props"]["value"] - input_params[key] = value - - # return the input values in dictionary and saved to dcc.Store "gui-components-values" + input_params = extract_parameters_from_html(model_parameters) + # return the input values in dictionary and save to the model parameter store print(f"input_param:\n{input_params}") if MODE == "dev": job_uid = str(uuid.uuid4()) diff --git a/components/control_bar.py b/components/control_bar.py index 4ca57e7..75a46ea 100644 --- a/components/control_bar.py +++ b/components/control_bar.py @@ -7,8 +7,7 @@ from components.annotation_class import annotation_class_item from components.dash_component_editor import ControlItem from constants import ANNOT_ICONS, KEYBINDS -from utils.content_registry import models -from utils.data_utils import tiled_dataset +from utils.data_utils import models, tiled_dataset def _tooltip(text, children): @@ -605,8 +604,8 @@ def layout(): ), ), dmc.Space(h=25), - html.Div(id="gui-layouts"), - dcc.Store(id="gui-components-values", data={}), + html.Div(id="model-parameters"), + dcc.Store(id="model-parameter-values", data={}), dmc.Space(h=25), dmc.Center( dmc.Button( diff --git a/components/dash_component_editor.py b/components/dash_component_editor.py index 186c987..8404820 100644 --- a/components/dash_component_editor.py +++ b/components/dash_component_editor.py @@ -1,6 +1,6 @@ import dash_bootstrap_components as dbc import dash_mantine_components as dmc -from dash import ALL, Input, Output, State, html +from dash import html class ControlItem(dmc.Grid): @@ -51,7 +51,7 @@ def __init__( **base_id, "name": name, "param_key": param_key, - "layer": "title", + "layer": "label", }, item=self.input, style=style, @@ -85,7 +85,7 @@ def __init__( **base_id, "name": name, "param_key": param_key, - "layer": "title", + "layer": "label", }, item=self.input, style=style, @@ -120,7 +120,7 @@ def __init__( **base_id, "name": name, "param_key": param_key, - "layer": "title", + "layer": "label", }, item=self.input, style=style, @@ -196,7 +196,7 @@ def __init__( ) -class BoolItem(dmc.Grid): +class BoolItem(ControlItem): def __init__( self, name, base_id, title=None, param_key=None, visible=True, **kwargs ): @@ -214,45 +214,33 @@ def __init__( style["display"] = "none" super(BoolItem, self).__init__( - id={**base_id, "name": name, "param_key": param_key, "layer": "form_group"}, - children=[self.input, dmc.Space(h=25)], + title="", # title is already in the switch + title_id={ + **base_id, + "name": name, + "param_key": param_key, + "layer": "label", + }, + item=self.input, style=style, ) -class ParameterEditor(dbc.Form): +class ParameterItems(dbc.Form): type_map = { - float: NumberItem, - int: NumberItem, - str: StrItem, + "float": NumberItem, + "int": NumberItem, + "str": StrItem, + "slider": SliderItem, + "dropdown": DropdownItem, + "radio": RadioItem, + "bool": BoolItem, } - def __init__(self, _id, parameters, **kwargs): - self._parameters = parameters - - super(ParameterEditor, self).__init__( - id=_id, children=[], className="kwarg-editor", **kwargs - ) - self.children = self.build_children() - - def init_callbacks(self, app): - app.callback( - Output(self.id, "n_submit"), - Input({**self.id, "name": ALL}, "value"), - State(self.id, "n_submit"), - ) - - for child in self.children: - if hasattr(child, "init_callbacks"): - child.init_callbacks(app) - - @property - def values(self): - return {param["name"]: param.get("value", None) for param in self._parameters} - - @property - def parameters(self): - return {param["name"]: param for param in self._parameters} + def __init__(self, _id, json_blob, values=None): + super(ParameterItems, self).__init__(id=_id, children=[]) + self._json_blob = json_blob + self.children = self.build_children(values=values) def _determine_type(self, parameter_dict): if "type" in parameter_dict: @@ -266,49 +254,17 @@ def _determine_type(self, parameter_dict): f"No item type could be determined for this parameter: {parameter_dict}" ) - def build_children(self, values=None): - children = [] - for parameter_dict in self._parameters: - parameter_dict = parameter_dict.copy() - if values and parameter_dict["name"] in values: - parameter_dict["value"] = values[parameter_dict["name"]] - type = self._determine_type(parameter_dict) - parameter_dict.pop("type", None) - item = self.type_map[type](**parameter_dict, base_id=self.id) - children.append(item) - - return children - - -class JSONParameterEditor(ParameterEditor): - type_map = { - "float": NumberItem, - "int": NumberItem, - "str": StrItem, - "slider": SliderItem, - "dropdown": DropdownItem, - "radio": RadioItem, - "bool": BoolItem, - } - - def __init__(self, _id, json_blob, **kwargs): - super(ParameterEditor, self).__init__( - id=_id, children=[], className="kwarg-editor", **kwargs - ) - self._json_blob = json_blob - self.children = self.build_children() - def build_children(self, values=None): children = [] for json_record in self._json_blob: - ... - # build a parameter dict from self.json_blob - ... + # Build a parameter dict from self.json_blob type = json_record.get("type", self._determine_type(json_record)) json_record = json_record.copy() if values and json_record["name"] in values: json_record["value"] = values[json_record["name"]] json_record.pop("type", None) + if "comp_group" in json_record: + json_record.pop("comp_group", None) item = self.type_map[type](**json_record, base_id=self.id) children.append(item) diff --git a/utils/content_registry.py b/utils/content_registry.py deleted file mode 100644 index ed9203c..0000000 --- a/utils/content_registry.py +++ /dev/null @@ -1,31 +0,0 @@ -import json -from copy import deepcopy - - -class Models: - def __init__(self, modelfile_path="./assets/models.json"): - self.path = modelfile_path - f = open(self.path) - - self.contents = json.load(f)["contents"] - self.modelname_list = [content["model_name"] for content in self.contents] - self.models = {} - - for i, n in enumerate(self.modelname_list): - self.models[n] = self.contents[i] - - @staticmethod - def remove_key_from_dict_list(data, key): - new_data = [] - for item in data: - if key in item: - new_item = deepcopy(item) - new_item.pop(key) - new_data.append(new_item) - else: - new_data.append(item) - - return new_data - - -models = Models() diff --git a/utils/data_utils.py b/utils/data_utils.py index bb54d29..43a24e9 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -159,3 +159,45 @@ def save_annotations_data(self, global_store, all_annotations, project_name): tiled_dataset = TiledDataLoader() + + +class Models: + def __init__(self, modelfile_path="./assets/models.json"): + self.path = modelfile_path + f = open(self.path) + + contents = json.load(f)["contents"] + self.modelname_list = [content["model_name"] for content in contents] + self.models = {} + + for i, n in enumerate(self.modelname_list): + self.models[n] = contents[i] + + def __getitem__(self, key): + try: + return self.models[key] + except KeyError: + raise KeyError(f"A model with name {key} does not exist.") + + +models = Models() + + +def extract_parameters_from_html(model_parameters_html): + """ + Extracts parameters from the children component of a + """ + input_params = {} + for param in model_parameters_html["props"]["children"]: + # param["props"]["children"][0] is the label + # param["props"]["children"][1] is the input + parameter_container = param["props"]["children"][1] + # The achtual parameter item is the first and only child of the parameter container + parameter_item = parameter_container["props"]["children"]["props"] + key = parameter_item["id"]["param_key"] + if "value" in parameter_item: + value = parameter_item["value"] + elif "checked" in parameter_item: + value = parameter_item["checked"] + input_params[key] = value + return input_params From 790d4efc35bb96ce2995d82c37fa27ded0df893f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Wed, 6 Mar 2024 13:41:17 -0800 Subject: [PATCH 17/21] Rename `dash_component_editor` to better represent new structure --- callbacks/control_bar.py | 2 +- components/control_bar.py | 2 +- components/{dash_component_editor.py => parameter_items.py} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename components/{dash_component_editor.py => parameter_items.py} (100%) diff --git a/callbacks/control_bar.py b/callbacks/control_bar.py index ae3ba28..6011b29 100644 --- a/callbacks/control_bar.py +++ b/callbacks/control_bar.py @@ -25,7 +25,7 @@ from dash_iconify import DashIconify from components.annotation_class import annotation_class_item -from components.dash_component_editor import ParameterItems +from components.parameter_items import ParameterItems from constants import ANNOT_ICONS, ANNOT_NOTIFICATION_MSGS, KEY_MODES, KEYBINDS from utils.annotations import Annotations from utils.data_utils import models, tiled_dataset diff --git a/components/control_bar.py b/components/control_bar.py index 75a46ea..2e09354 100644 --- a/components/control_bar.py +++ b/components/control_bar.py @@ -5,7 +5,7 @@ from dash_iconify import DashIconify from components.annotation_class import annotation_class_item -from components.dash_component_editor import ControlItem +from components.parameter_items import ControlItem from constants import ANNOT_ICONS, KEYBINDS from utils.data_utils import models, tiled_dataset diff --git a/components/dash_component_editor.py b/components/parameter_items.py similarity index 100% rename from components/dash_component_editor.py rename to components/parameter_items.py From 6193c53a0d221798edda5d334bc75f1c3b2b2e4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Wed, 6 Mar 2024 15:51:07 -0800 Subject: [PATCH 18/21] Check if `image_shapes` was initialized --- utils/annotations.py | 4 ++-- utils/data_utils.py | 13 ++++++++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/utils/annotations.py b/utils/annotations.py index c2eafcd..e7875f1 100644 --- a/utils/annotations.py +++ b/utils/annotations.py @@ -11,7 +11,7 @@ class Annotations: - def __init__(self, annotation_store, global_store): + def __init__(self, annotation_store, image_shape): if annotation_store: slices = [] for annotation_class in annotation_store: @@ -49,7 +49,7 @@ def __init__(self, annotation_store, global_store): self.annotation_classes = annotation_classes self.annotations = annotations self.annotations_hash = self.get_annotations_hash() - self.image_shape = global_store["image_shapes"][0] + self.image_shape = image_shape def get_annotations(self): return self.annotations diff --git a/utils/data_utils.py b/utils/data_utils.py index 87cc635..3311019 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -168,7 +168,18 @@ def save_annotations_data(self, global_store, all_annotations, project_name): """ Transforms annotations data to a pixelated mask and outputs to the Tiled server """ - annotations = Annotations(all_annotations, global_store) + if "image_shapes" in global_store: + image_shape = global_store["image_shapes"][0] + else: + print("Global store was not filled.") + data_shape = ( + tiled_datasets.get_data_shape_by_name(project_name) + if project_name + else None + ) + image_shape = (data_shape[1], data_shape[2]) + + annotations = Annotations(all_annotations, image_shape) # TODO: Check sparse status, it may be worthwhile to store the mask as a sparse array # if our machine learning models can handle sparse arrays annotations.create_annotation_mask(sparse=False) From d09a45a6e27e43b0d19be69bd463b0607a5528d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Wed, 6 Mar 2024 16:31:38 -0800 Subject: [PATCH 19/21] Give the generated parameters some space --- components/parameter_items.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/components/parameter_items.py b/components/parameter_items.py index 8404820..f165125 100644 --- a/components/parameter_items.py +++ b/components/parameter_items.py @@ -41,7 +41,7 @@ def __init__( **kwargs, ) - style = {} + style = {"padding": "15px 0px 0px 0px"} if not visible: style["display"] = "none" @@ -75,7 +75,7 @@ def __init__( **kwargs, ) - style = {} + style = {"padding": "15px 0px 0px 0px"} if not visible: style["display"] = "none" @@ -110,7 +110,7 @@ def __init__( **kwargs, ) - style = {} + style = {"padding": "15px 0px 15px 0px"} if not visible: style["display"] = "none" @@ -144,7 +144,7 @@ def __init__( **kwargs, ) - style = {} + style = {"padding": "15px 0px 0px 0px"} if not visible: style["display"] = "none" @@ -179,7 +179,7 @@ def __init__( **kwargs, ) - style = {} + style = {"padding": "15px 0px 0px 0px"} if not visible: style["display"] = "none" @@ -209,7 +209,7 @@ def __init__( **kwargs, ) - style = {} + style = {"padding": "15px 0px 0px 0px"} if not visible: style["display"] = "none" From 8e6e4883da2b118b32db3cb43edcec45fa97dbff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Wed, 6 Mar 2024 16:32:15 -0800 Subject: [PATCH 20/21] :whale: Add missing environment variables --- docker-compose.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docker-compose.yml b/docker-compose.yml index c0a7f2b..5f1a7d5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,6 +8,12 @@ services: environment: DATA_TILED_URI: '${DATA_TILED_URI}' DATA_TILED_API_KEY: '${DATA_TILED_API_KEY}' + MASK_TILED_URI: '${MASK_TILED_URI}' + MASK_TILED_API_KEY: '${TILED_API_KEY}' + SEG_TILED_URI: '${SEG_TILED_URI}' + SEG_TILED_API_KEY: '${SEG_TILED_API_KEY}' + USER_NAME: '${USER_NAME}' + USER_PASSWORD: '${USER_PASSWORD}' volumes: - ./app.py:/app/app.py - ./constants.py:/app/constants.py From 02732a622565ae15e1c42988b7adcce7b39a2a16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wiebke=20K=C3=B6pp?= Date: Wed, 6 Mar 2024 16:35:53 -0800 Subject: [PATCH 21/21] Change default activation and learning rate step --- assets/models.json | 18 +++++++++++++++--- components/control_bar.py | 2 +- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/assets/models.json b/assets/models.json index f8f8b6e..3b22cb1 100755 --- a/assets/models.json +++ b/assets/models.json @@ -237,6 +237,10 @@ "title": "Learning Rate", "param_key": "learning_rate", "value": 0.001, + "min": 0, + "max": 0.1, + "step": 0.0001, + "precision": 4, "comp_group": "train_model" }, { @@ -244,7 +248,7 @@ "name": "activation", "title": "Activation", "param_key": "activation", - "value": "ReLU", + "value": "Sigmoid", "data": [ { "label": "ReLU", @@ -621,6 +625,10 @@ "title": "Learning Rate", "param_key": "learning_rate", "value": 0.001, + "min": 0, + "max": 0.1, + "step": 0.0001, + "precision": 4, "comp_group": "train_model" }, { @@ -628,7 +636,7 @@ "name": "activation", "title": "Activation", "param_key": "activation", - "value": "ReLU", + "value": "Sigmoid", "data": [ { "label": "ReLU", @@ -1009,6 +1017,10 @@ "title": "Learning Rate", "param_key": "learning_rate", "value": 0.001, + "min": 0, + "max": 0.1, + "step": 0.0001, + "precision": 4, "comp_group": "train_model" }, { @@ -1016,7 +1028,7 @@ "name": "activation", "title": "Activation", "param_key": "activation", - "value": "ReLU", + "value": "Sigmoid", "data": [ { "label": "ReLU", diff --git a/components/control_bar.py b/components/control_bar.py index c1e538b..4809e5f 100644 --- a/components/control_bar.py +++ b/components/control_bar.py @@ -603,7 +603,7 @@ def layout(): placeholder="Select a model...", ), ), - dmc.Space(h=25), + dmc.Space(h=15), html.Div(id="model-parameters"), dcc.Store(id="model-parameter-values", data={}), dmc.Space(h=25),