diff --git a/app.py b/app.py index 686b4a0..0c89058 100644 --- a/app.py +++ b/app.py @@ -2,24 +2,87 @@ import dash_core_components as dcc import dash_html_components as html from dash.dependencies import Input, Output, State - -import pandas +import base64 +import json +import os +import matplotlib.cm +import matplotlib.colors as mcolors +import numpy as np +import random import plotly.graph_objects as go import plotly.express as px from skimage import data, transform + +COLORMAP = 'plasma' +KEYPOINTS = ['Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'Throat', + 'Withers', 'TailSet', 'L_F_Paw', 'R_F_Paw', 'L_F_Wrist', + 'R_F_Wrist', 'L_F_Elbow', 'R_F_Elbow', 'L_B_Paw', 'R_B_Paw', + 'L_B_Hock', 'R_B_Hock', 'L_B_Stiffle', 'R_B_Stiffle'] +N_SUBSET = 3 +IMAGE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'full_dog.png') +encoded_image = base64.b64encode(open(IMAGE_PATH, 'rb').read()) + + img = data.chelsea() img = img[::2, ::2] images = [img, img[::-1], transform.rotate(img, 30)] +cmap = matplotlib.cm.get_cmap(COLORMAP, N_SUBSET) def make_figure_image(i): fig = px.imshow(images[i % len(images)]) - fig.update_traces(hoverinfo='none') - fig.add_trace(go.Scatter(x=[], y=[], marker_color=[], - marker_cmin=0, marker_cmax=3, marker_size=18, mode='markers')) + fig.layout.xaxis.showticklabels = False + fig.layout.yaxis.showticklabels = False + fig.update_traces(hoverinfo='none', hovertemplate='') return fig + +def draw_circle(center, radius, n_points=50): + pts = np.linspace(0, 2 * np.pi, n_points) + x = center[0] + radius * np.cos(pts) + y = center[1] + radius * np.sin(pts) + path = 'M ' + str(x[0]) + ',' + str(y[1]) + for k in range(1, x.shape[0]): + path += ' L ' + str(x[k]) + ',' + str(y[k]) + path += ' Z' + return path + + +def compute_circle_center(path): + """ + See Eqn 1 & 2 pp.12-13 in REGRESSIONS CONIQUES, QUADRIQUES + Régressions linéaires et apparentées, circulaire, sphérique + Jacquelin J., 2009. + """ + coords = [list(map(float, coords.split(','))) for coords in path.split(' ')[1::2]] + x, y = np.array(coords).T + n = len(x) + sum_x = np.sum(x) + sum_y = np.sum(y) + sum_x2 = np.sum(x * x) + sum_y2 = np.sum(y * y) + delta11 = n * np.dot(x, y) - sum_x * sum_y + delta20 = n * sum_x2 - sum_x ** 2 + delta02 = n * sum_y2 - sum_y ** 2 + delta30 = n * np.sum(x ** 3) - sum_x2 * sum_x + delta03 = n * np.sum(y ** 3) - sum_y * sum_y2 + delta21 = n * np.sum(x * x * y) - sum_x2 * sum_y + delta12 = n * np.sum(x * y * y) - sum_x * sum_y2 + + # Eqn 2, p.13 + num_a = (delta30 + delta12) * delta02 - (delta03 + delta21) * delta11 + num_b = (delta03 + delta21) * delta20 - (delta30 + delta12) * delta11 + den = 2 * (delta20 * delta02 - delta11 * delta11) + a = num_a / den + b = num_b / den + return a, b + + +def get_plotly_color(n): + return mcolors.to_hex(cmap(n)) + + fig = make_figure_image(0) external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css'] @@ -27,68 +90,155 @@ def make_figure_image(i): app = dash.Dash(__name__, external_stylesheets=external_stylesheets) server = app.server -options = ['left eye', 'right eye', 'nose'] +options = random.sample(KEYPOINTS, N_SUBSET) + +styles = { + 'pre': { + 'border': 'thin lightgrey solid', + 'overflowX': 'scroll' + } +} app.layout = html.Div([ html.Div([ dcc.Graph( - id='basic-interactions', - config={'editable':True}, - figure=fig, - )], - className="six columns" + id='canvas', + config={'editable': True}, + figure=fig) + ], + className="six columns" ), html.Div([ html.H2("Controls"), dcc.RadioItems(id='radio', - options=[{'label':opt, 'value':opt} for opt in options], - value=options[0] - ), + options=[{'label': opt, 'value': opt} for opt in options], + value=options[0] + ), html.Button('Previous', id='previous'), html.Button('Next', id='next'), - dcc.Store(id='store', data=0) - ], - className="six columns" + html.Button('Clear', id='clear'), + html.Button('Save', id='save'), + dcc.Store(id='store', data=0), + html.P([ + html.Label('Keypoint size'), + dcc.Slider(id='slider', + min=3, + max=36, + step=1, + value=12) + ], style={'width': '80%', + 'display': 'inline-block'}) + ], + className="six columns" + ), + html.Div([ + dcc.Markdown(""" + **Instructions**\n + Click on the image to add a keypoint. + """), + html.Pre(id='click-data', style=styles['pre']), + html.Img(src='data:image/png;charset=utf-8;base64,{}'.format(encoded_image)) + ], + className='six columns' ), - ]) + html.Div(id='placeholder', style={'display': 'none'}), + html.Div(id='shapes', style={'display': 'none'}) +] +) + + +@app.callback(Output('placeholder', 'children'), + [Input('save', 'n_clicks')], + [State('store', 'data')]) +def save_data(click_s, ind_image): + if click_s: + xy = {shape.name: compute_circle_center(shape.path) for shape in fig.layout.shapes} + print(xy, ind_image) @app.callback( - [Output('basic-interactions', 'figure'), - Output('store', 'data')], - [Input('next', 'n_clicks'), - Input('previous', 'n_clicks')], - [State('store', 'data')] + [Output('canvas', 'figure'), + Output('radio', 'value'), + Output('store', 'data'), + Output('shapes', 'children')], + [Input('canvas', 'clickData'), + Input('canvas', 'relayoutData'), + Input('next', 'n_clicks'), + Input('previous', 'n_clicks'), + Input('clear', 'n_clicks'), + Input('slider', 'value')], + [State('canvas', 'figure'), + State('radio', 'value'), + State('store', 'data'), + State('shapes', 'children')] ) -def display_click_data(n_clicks_n, n_clicks_p, val): - if n_clicks_n is None and n_clicks_p is None: - return dash.no_update, dash.no_update - if val is None: - val = 0 +def update_image(clickData, relayoutData, click_n, click_p, click_c, slider_val, + figure, option, ind_image, shapes): + if not any(event for event in (clickData, click_n, click_p, click_c)): + return dash.no_update, dash.no_update, dash.no_update, dash.no_update + + if ind_image is None: + ind_image = 0 + + if shapes is None: + shapes = [] + else: + shapes = json.loads(shapes) + n_bpt = options.index(option) + ctx = dash.callback_context button_id = ctx.triggered[0]['prop_id'].split('.')[0] - index = val + 1 if button_id == 'next' else val - 1 - fig = make_figure_image(index) - return fig, index - + if button_id == 'clear': + fig.layout.shapes = [] + return make_figure_image(ind_image), options[0], ind_image, '[]' + elif button_id == 'next': + ind_image = (ind_image + 1) % len(images) + return make_figure_image(ind_image), options[0], ind_image, '[]' + elif button_id == 'previous': + ind_image = (ind_image - 1) % len(images) + return make_figure_image(ind_image), options[0], ind_image, '[]' + elif button_id == 'slider': + for i in range(len(shapes)): + center = compute_circle_center(shapes[i]['path']) + new_path = draw_circle(center, slider_val) + shapes[i]['path'] = new_path -@app.callback( - [Output('basic-interactions', 'extendData'), - Output('radio', 'value')], - [Input('basic-interactions', 'clickData')], - [State('radio', 'value')] - ) -def display_click_data(clickData, option): - if clickData is None or fig is None: - return dash.no_update, dash.no_update - if clickData is None or fig is None: - return dash.no_update - x, y = clickData['points'][0]['x'], clickData['points'][0]['y'] - for i, el in enumerate(options): - if el == option: - new_option = options[(i+1)%(len(options))] - color=i - return [{'x':[[x]], 'y':[[y]], "marker.color":[[color]]}, [1]], new_option + already_labeled = [shape['name'] for shape in shapes] + key = list(relayoutData)[0] + if option not in already_labeled and button_id != 'slider': + if clickData: + x, y = clickData['points'][0]['x'], clickData['points'][0]['y'] + circle = draw_circle((x, y), slider_val) + color = get_plotly_color(n_bpt) + shape = dict(type='path', + path=circle, + line_color=color, + fillcolor=color, + layer='above', + opacity=0.8, + name=option) + shapes.append(shape) + else: + if 'path' in key and button_id != 'slider': + ind_moving = int(key.split('[')[1].split(']')[0]) + path = relayoutData.pop(key) + shapes[ind_moving]['path'] = path + fig.update_layout(shapes=shapes) + if 'range[' in key: + xrange = relayoutData['xaxis.range[0]'], relayoutData['xaxis.range[1]'] + yrange = relayoutData['yaxis.range[0]'], relayoutData['yaxis.range[1]'] + fig.update_xaxes(range=xrange, autorange=False) + fig.update_yaxes(range=yrange, autorange=False) + elif 'autorange' in key: + fig.update_xaxes(autorange=True) + fig.update_yaxes(autorange=True) + if button_id != 'slider': + n_bpt += 1 + new_option = options[min(len(options) - 1, n_bpt)] + return ({'data': figure['data'], 'layout': fig['layout']}, + new_option, + ind_image, + json.dumps(shapes)) if __name__ == '__main__':