Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GUI for webinar #4

Merged
merged 14 commits into from
Jun 4, 2020
248 changes: 199 additions & 49 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,93 +2,243 @@
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output, State

import pandas
import base64
import json
import os
import matplotlib.cm
import matplotlib.colors as mcolors
import numpy as np
import random
import plotly.graph_objects as go
import plotly.express as px
from skimage import data, transform


COLORMAP = 'plasma'
KEYPOINTS = ['Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'Throat',
'Withers', 'TailSet', 'L_F_Paw', 'R_F_Paw', 'L_F_Wrist',
'R_F_Wrist', 'L_F_Elbow', 'R_F_Elbow', 'L_B_Paw', 'R_B_Paw',
'L_B_Hock', 'R_B_Hock', 'L_B_Stiffle', 'R_B_Stiffle']
N_SUBSET = 3
IMAGE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'full_dog.png')
encoded_image = base64.b64encode(open(IMAGE_PATH, 'rb').read())


img = data.chelsea()
img = img[::2, ::2]
images = [img, img[::-1], transform.rotate(img, 30)]
cmap = matplotlib.cm.get_cmap(COLORMAP, N_SUBSET)


def make_figure_image(i):
fig = px.imshow(images[i % len(images)])
fig.update_traces(hoverinfo='none')
fig.add_trace(go.Scatter(x=[], y=[], marker_color=[],
marker_cmin=0, marker_cmax=3, marker_size=18, mode='markers'))
fig.layout.xaxis.showticklabels = False
fig.layout.yaxis.showticklabels = False
fig.update_traces(hoverinfo='none', hovertemplate='')
return fig


def draw_circle(center, radius, n_points=50):
pts = np.linspace(0, 2 * np.pi, n_points)
x = center[0] + radius * np.cos(pts)
y = center[1] + radius * np.sin(pts)
path = 'M ' + str(x[0]) + ',' + str(y[1])
for k in range(1, x.shape[0]):
path += ' L ' + str(x[k]) + ',' + str(y[k])
path += ' Z'
return path


def compute_circle_center(path):
"""
See Eqn 1 & 2 pp.12-13 in REGRESSIONS CONIQUES, QUADRIQUES
Régressions linéaires et apparentées, circulaire, sphérique
Jacquelin J., 2009.
"""
coords = [list(map(float, coords.split(','))) for coords in path.split(' ')[1::2]]
x, y = np.array(coords).T
n = len(x)
sum_x = np.sum(x)
sum_y = np.sum(y)
sum_x2 = np.sum(x * x)
sum_y2 = np.sum(y * y)
delta11 = n * np.dot(x, y) - sum_x * sum_y
delta20 = n * sum_x2 - sum_x ** 2
delta02 = n * sum_y2 - sum_y ** 2
delta30 = n * np.sum(x ** 3) - sum_x2 * sum_x
delta03 = n * np.sum(y ** 3) - sum_y * sum_y2
delta21 = n * np.sum(x * x * y) - sum_x2 * sum_y
delta12 = n * np.sum(x * y * y) - sum_x * sum_y2

# Eqn 2, p.13
num_a = (delta30 + delta12) * delta02 - (delta03 + delta21) * delta11
num_b = (delta03 + delta21) * delta20 - (delta30 + delta12) * delta11
den = 2 * (delta20 * delta02 - delta11 * delta11)
a = num_a / den
b = num_b / den
return a, b


def get_plotly_color(n):
return mcolors.to_hex(cmap(n))


fig = make_figure_image(0)

external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']

app = dash.Dash(__name__, external_stylesheets=external_stylesheets)
server = app.server

options = ['left eye', 'right eye', 'nose']
options = random.sample(KEYPOINTS, N_SUBSET)

styles = {
'pre': {
'border': 'thin lightgrey solid',
'overflowX': 'scroll'
}
}

app.layout = html.Div([
html.Div([
dcc.Graph(
id='basic-interactions',
config={'editable':True},
figure=fig,
)],
className="six columns"
id='canvas',
config={'editable': True},
figure=fig)
],
className="six columns"
),
html.Div([
html.H2("Controls"),
dcc.RadioItems(id='radio',
options=[{'label':opt, 'value':opt} for opt in options],
value=options[0]
),
options=[{'label': opt, 'value': opt} for opt in options],
value=options[0]
),
html.Button('Previous', id='previous'),
html.Button('Next', id='next'),
dcc.Store(id='store', data=0)
],
className="six columns"
html.Button('Clear', id='clear'),
html.Button('Save', id='save'),
dcc.Store(id='store', data=0),
html.P([
html.Label('Keypoint size'),
dcc.Slider(id='slider',
min=3,
max=36,
step=1,
value=12)
], style={'width': '80%',
'display': 'inline-block'})
],
className="six columns"
),
html.Div([
dcc.Markdown("""
**Instructions**\n
Click on the image to add a keypoint.
"""),
html.Pre(id='click-data', style=styles['pre']),
html.Img(src='data:image/png;charset=utf-8;base64,{}'.format(encoded_image))
],
className='six columns'
),
])
html.Div(id='placeholder', style={'display': 'none'}),
html.Div(id='shapes', style={'display': 'none'})
]
)


@app.callback(Output('placeholder', 'children'),
[Input('save', 'n_clicks')],
[State('store', 'data')])
def save_data(click_s, ind_image):
if click_s:
xy = {shape.name: compute_circle_center(shape.path) for shape in fig.layout.shapes}
print(xy, ind_image)


@app.callback(
[Output('basic-interactions', 'figure'),
Output('store', 'data')],
[Input('next', 'n_clicks'),
Input('previous', 'n_clicks')],
[State('store', 'data')]
[Output('canvas', 'figure'),
Output('radio', 'value'),
Output('store', 'data'),
Output('shapes', 'children')],
[Input('canvas', 'clickData'),
Input('canvas', 'relayoutData'),
Input('next', 'n_clicks'),
Input('previous', 'n_clicks'),
Input('clear', 'n_clicks'),
Input('slider', 'value')],
[State('canvas', 'figure'),
State('radio', 'value'),
State('store', 'data'),
State('shapes', 'children')]
)
def display_click_data(n_clicks_n, n_clicks_p, val):
if n_clicks_n is None and n_clicks_p is None:
return dash.no_update, dash.no_update
if val is None:
val = 0
def update_image(clickData, relayoutData, click_n, click_p, click_c, slider_val,
figure, option, ind_image, shapes):
if not any(event for event in (clickData, click_n, click_p, click_c)):
return dash.no_update, dash.no_update, dash.no_update, dash.no_update

if ind_image is None:
ind_image = 0

if shapes is None:
shapes = []
else:
shapes = json.loads(shapes)
n_bpt = options.index(option)

ctx = dash.callback_context
button_id = ctx.triggered[0]['prop_id'].split('.')[0]
index = val + 1 if button_id == 'next' else val - 1
fig = make_figure_image(index)
return fig, index

if button_id == 'clear':
fig.layout.shapes = []
return make_figure_image(ind_image), options[0], ind_image, '[]'
elif button_id == 'next':
ind_image = (ind_image + 1) % len(images)
return make_figure_image(ind_image), options[0], ind_image, '[]'
elif button_id == 'previous':
ind_image = (ind_image - 1) % len(images)
return make_figure_image(ind_image), options[0], ind_image, '[]'
elif button_id == 'slider':
for i in range(len(shapes)):
center = compute_circle_center(shapes[i]['path'])
new_path = draw_circle(center, slider_val)
shapes[i]['path'] = new_path

@app.callback(
[Output('basic-interactions', 'extendData'),
Output('radio', 'value')],
[Input('basic-interactions', 'clickData')],
[State('radio', 'value')]
)
def display_click_data(clickData, option):
if clickData is None or fig is None:
return dash.no_update, dash.no_update
if clickData is None or fig is None:
return dash.no_update
x, y = clickData['points'][0]['x'], clickData['points'][0]['y']
for i, el in enumerate(options):
if el == option:
new_option = options[(i+1)%(len(options))]
color=i
return [{'x':[[x]], 'y':[[y]], "marker.color":[[color]]}, [1]], new_option
already_labeled = [shape['name'] for shape in shapes]
key = list(relayoutData)[0]
if option not in already_labeled and button_id != 'slider':
if clickData:
x, y = clickData['points'][0]['x'], clickData['points'][0]['y']
circle = draw_circle((x, y), slider_val)
color = get_plotly_color(n_bpt)
shape = dict(type='path',
path=circle,
line_color=color,
fillcolor=color,
layer='above',
opacity=0.8,
name=option)
shapes.append(shape)
else:
if 'path' in key and button_id != 'slider':
ind_moving = int(key.split('[')[1].split(']')[0])
path = relayoutData.pop(key)
shapes[ind_moving]['path'] = path
fig.update_layout(shapes=shapes)
if 'range[' in key:
xrange = relayoutData['xaxis.range[0]'], relayoutData['xaxis.range[1]']
yrange = relayoutData['yaxis.range[0]'], relayoutData['yaxis.range[1]']
fig.update_xaxes(range=xrange, autorange=False)
fig.update_yaxes(range=yrange, autorange=False)
elif 'autorange' in key:
fig.update_xaxes(autorange=True)
fig.update_yaxes(autorange=True)
if button_id != 'slider':
n_bpt += 1
new_option = options[min(len(options) - 1, n_bpt)]
return ({'data': figure['data'], 'layout': fig['layout']},
new_option,
ind_image,
json.dumps(shapes))


if __name__ == '__main__':
Expand Down