Skip to content

Commit

Permalink
Added annotations to all of the plots per Joey's suggestions.
Browse files Browse the repository at this point in the history
  • Loading branch information
cmccully committed Jun 6, 2024
1 parent 6d474bf commit 8263582
Showing 1 changed file with 119 additions and 44 deletions.
163 changes: 119 additions & 44 deletions banzai_floyds_ui/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,12 @@ def callback_dropdown_files(*args, **kwargs):
return results


def make_arc_2d_plot(arc_frame_hdu):
def make_arc_2d_plot(arc_frame_hdu, arc_filename):
zmin, zmax = np.percentile(arc_frame_hdu['SCI'].data, [1, 99])
trace = go.Heatmap(z=arc_frame_hdu['SCI'].data, colorscale=COLORMAP, zmin=zmin, zmax=zmax, hoverinfo='none')
trace = go.Heatmap(z=arc_frame_hdu['SCI'].data, colorscale=COLORMAP, zmin=zmin, zmax=zmax, hoverinfo='none',
colorbar=dict(title='Data (counts)'))

layout = dict(title='', margin=dict(t=20, b=50, l=50, r=40), height=350)
layout = dict(margin=dict(t=50, b=50, l=50, r=40), height=370)

orders = orders_from_fits(arc_frame_hdu['ORDER_COEFFS'].data, arc_frame_hdu['ORDER_COEFFS'].header,
arc_frame_hdu['SCI'].data.shape)
Expand All @@ -175,18 +176,29 @@ def make_arc_2d_plot(arc_frame_hdu):
arc_frame_hdu['WAVELENGTHS'].data[in_order])
min_wavelength = np.min(arc_frame_hdu['WAVELENGTHS'].data[in_order])
max_wavelength = np.max(arc_frame_hdu['WAVELENGTHS'].data[in_order])
for line in arc_lines:
for i, line in enumerate(arc_lines):
if min_wavelength <= line['wavelength'] <= max_wavelength:
y_line_center = np.interp(line['wavelength'], x_y_to_wavelength_intepolator(x, y_center), y_center)
y_line = np.arange(y_line_center - order_height / 2.0, y_line_center + order_height / 2.0 + 1)
plot_x = wavelength_to_x_interpolator(y_line, np.ones_like(y_line) * line['wavelength'])
if i == 0:
show_legend = True
name = 'Model'
else:
show_legend = False
name = None
figure_data.append(go.Scatter(x=plot_x, y=y_line,
marker={'color': 'salmon'},
mode='lines',
hovertext=[f'{line["wavelength"]:0.3f} {line["line_source"]}'
for _ in range(len(plot_x))],
hovertemplate='%{hovertext}<extra></extra>',
showlegend=False))
showlegend=show_legend,
name=name))
layout['legend'] = dict(x=0, y=0.95)
layout['title'] = f'Arc Frame Used in Reduction: {arc_filename}'
layout['xaxis'] = dict(title='x (pixel)')
layout['yaxis'] = dict(title='y (pixel)')
fig = dict(data=figure_data, layout=layout)
image_plot = dcc.Graph(id='image-graph1', figure=fig, style={'display': 'inline-block',
'width': '100%', 'height': '100%;'})
Expand All @@ -209,12 +221,12 @@ def download_frame(headers, url=f'{settings.ARCHIVE_URL}', params=None, list_end


def get_related_frame(frame_id, archive_header, related_frame_key):
# Get the related arc frame from the archive
# Get the related frame from the archive that matches related_frame_key in the header.
response = requests.get(f'{settings.ARCHIVE_URL}{frame_id}/headers', headers=archive_header)
response.raise_for_status()
related_frame_filename = response.json()['data'][related_frame_key]
params = {'basename_exact': related_frame_filename}
return download_frame(archive_header, params=params, list_endpoint=True)
return download_frame(archive_header, params=params, list_endpoint=True), related_frame_filename


def calculate_residuals(wavelengths, flux, flux_errors, lines):
Expand Down Expand Up @@ -250,11 +262,15 @@ def make_arc_line_plots(arc_frame_hdu):
binned_data['weights'] = 1.0
extracted_data = extract(binned_data)

fig = make_subplots(rows=2, cols=2, x_title=u'Wavelength (\u212B)', vertical_spacing=0.02,
fig = make_subplots(rows=2, cols=2, vertical_spacing=0.02,
horizontal_spacing=0.05, shared_xaxes=True)

fig.update_yaxes(title_text='Flux (counts)', row=1, col=1)
fig.update_yaxes(title_text='Residuals (\u212B)', row=2, col=1)
fig.update_xaxes(title_text='Wavelength (\u212B)', row=2, col=1, tickformat=".0f")
fig.update_xaxes(title_text='Wavelength (\u212B)', row=2, col=2, tickformat=".0f")
fig.add_annotation(xref='x domain', yref='y domain', x=0.01, y=0.97, text='Blue Order (order=2)', showarrow=False)
fig.add_annotation(xref='x2 domain', yref='y2 domain', x=0.01, y=0.97, text='Red Order (order=1)', showarrow=False)

plot_column = {2: 1, 1: 2}
for order in [2, 1]:
Expand Down Expand Up @@ -305,17 +321,23 @@ def make_arc_line_plots(arc_frame_hdu):
fig.update_yaxes(range=[np.min(residuals) - 0.1 * residual_range,
np.max(residual_range) + 0.1 * residual_range],
row=2, col=plot_column[order])
# fig.add_annotation(xref=, yref=, x=, y=, text='', showarrow=False)
fig.update_layout(showlegend=False, autosize=True, margin=dict(l=0, r=0, t=0, b=0))
line_plot = dcc.Graph(id='arc-line-graph', figure=fig, style={'display': 'inline-block',
'width': '100%', 'height': '550px;'})
return line_plot


def make_2d_sci_plot(frame):
def make_2d_sci_plot(frame, filename):
zmin, zmax = np.percentile(frame['SCI'].data, [1, 99])
trace = go.Heatmap(z=frame['SCI'].data, colorscale=COLORMAP, zmin=zmin, zmax=zmax, hoverinfo='none')

layout = dict(title='', margin=dict(t=20, b=50, l=50, r=40), height=350, showlegend=False)
trace = go.Heatmap(z=frame['SCI'].data, colorscale=COLORMAP, zmin=zmin, zmax=zmax, hoverinfo='none',
colorbar=dict(title='Data (counts)'))

layout = dict(title=f'2-D Science Frame: {filename}', margin=dict(t=40, b=50, l=50, r=40),
height=370)
layout['legend'] = dict(x=0, y=0.95)
layout['xaxis'] = dict(title='x (pixel)')
layout['yaxis'] = dict(title='y (pixel)')
figure_data = [trace]

orders = orders_from_fits(frame['ORDER_COEFFS'].data,
Expand Down Expand Up @@ -356,36 +378,54 @@ def make_2d_sci_plot(frame):

background_upper_end.append(np.max(profile_to_fit['y']))
background_lower_end.append(np.min(profile_to_fit['y']))
if order == 2:
center_name = 'Extraction Center'
center_legend = True

extraction_name = 'Extraction \u00b12\u03C3'
extreaction_legend = True

background_name = 'Background Region'
background_legend = True
else:
center_name = None
center_legend = False

extraction_name = None
extreaction_legend = False

background_name = None
background_legend = False

figure_data.append(go.Scatter(x=plot_x, y=extract_center,
mode='lines', line={'color': 'salmon'},
hoverinfo='skip'))
hoverinfo='skip', name=center_name, showlegend=center_legend))
figure_data.append(go.Scatter(x=plot_x, y=extract_high,
mode='lines', line={'color': 'salmon', 'dash': 'dash'},
hoverinfo='skip'))
hoverinfo='skip', name=extraction_name, showlegend=extreaction_legend))
figure_data.append(go.Scatter(x=plot_x, y=extract_low,
mode='lines', line={'color': 'salmon', 'dash': 'dash'},
hoverinfo='skip'))
hoverinfo='skip', showlegend=False))
figure_data.append(go.Scatter(x=plot_x, y=background_lower_start,
mode='lines', line={'color': '#8F0B0B', 'dash': 'dash'},
hoverinfo='skip'))
hoverinfo='skip', name=background_name, showlegend=background_legend))
figure_data.append(go.Scatter(x=plot_x, y=background_lower_end,
mode='lines', line={'color': '#8F0B0B', 'dash': 'dash'},
hoverinfo='skip'))
hoverinfo='skip', showlegend=False))
figure_data.append(go.Scatter(x=plot_x, y=background_upper_start,
mode='lines', line={'color': '#8F0B0B', 'dash': 'dash'},
hoverinfo='skip'))
hoverinfo='skip', showlegend=False))
figure_data.append(go.Scatter(x=plot_x, y=background_upper_end,
mode='lines', line={'color': '#8F0B0B', 'dash': 'dash'},
hoverinfo='skip'))
hoverinfo='skip', showlegend=False))

fig = dict(data=figure_data, layout=layout)
image_plot = dcc.Graph(id='sci-2d-graph', figure=fig, style={'display': 'inline-block',
'width': '100%', 'height': '550px;'})
return image_plot


def unfilled_histogram(x, y, color):
def unfilled_histogram(x, y, color, name=None):
# I didn't like how the plotly histogram looked so I wrote my own
x_avgs = (x[1:] + x[:-1]) / 2.0
x_lows = np.hstack([x[0] + x[0] - x_avgs[0], x_avgs])
Expand All @@ -399,13 +439,19 @@ def unfilled_histogram(x, y, color):
# Make the flat top at -1, 0, +1 of x
for _ in range(3):
y_plot.append(y_center / np.max(y))
return go.Scatter(x=x_plot, y=y_plot, mode='lines', line={'color': color}, hoverinfo='skip')
show_legend = name is not None
return go.Scatter(x=x_plot, y=y_plot, mode='lines', line={'color': color}, hoverinfo='skip',
name=name, showlegend=show_legend)


def make_profile_plot(sci_2d_frame):
layout = dict(title='', margin=dict(t=20, b=50, l=50, r=40), height=350, showlegend=False)
fig = make_subplots(rows=1, cols=4, vertical_spacing=0.02)
plot_column = {2: 1, 1: 3}
layout = dict(title='', margin=dict(t=20, b=20, l=50, r=40), height=720, showlegend=True)
fig = make_subplots(rows=2, cols=2, vertical_spacing=0.13,
subplot_titles=("Profile Cross Section: Blue Order (order=2)",
"Profile Center: Blue Order (order=2)",
"Profile Cross Section: Red Order (order=1)",
"Profile Center: Red Order (order=1)"))
plot_row = {2: 1, 1: 2}

orders = orders_from_fits(sci_2d_frame['ORDER_COEFFS'].data,
sci_2d_frame['ORDER_COEFFS'].header,
Expand All @@ -429,14 +475,20 @@ def make_profile_plot(sci_2d_frame):

x_plot = y2d[y_range, order_center[order]] - orders.center(order_center[order])[order - 1]

if order == 2:
model_name = 'Model'
data_name = 'Data'
else:
model_name = None
data_name = None
fig.add_trace(
unfilled_histogram(x_plot, data, '#023858'),
row=1, col=plot_column[order],
unfilled_histogram(x_plot, data, '#023858', name=data_name),
row=plot_row[order], col=1
)

fig.add_trace(
unfilled_histogram(x_plot, profile, 'salmon'),
row=1, col=plot_column[order],
unfilled_histogram(x_plot, profile, 'salmon', name=model_name),
row=plot_row[order], col=1
)

wavelengths = WavelengthSolution.from_header(sci_2d_frame['WAVELENGTHS'].header, orders)
Expand Down Expand Up @@ -475,40 +527,61 @@ def make_profile_plot(sci_2d_frame):

fig.add_trace(
go.Scatter(x=x_plot, y=y_plot, mode='markers', marker={'color': '#023858'},
hoverinfo='skip'),
row=1, col=plot_column[order] + 1,
hoverinfo='skip', showlegend=False),
row=plot_row[order], col=2,
)
fig.add_trace(
go.Scatter(x=x_plot, y=y_profile_plot, mode='lines', line={'color': 'salmon'},
hoverinfo='skip'),
row=1, col=plot_column[order] + 1,
hoverinfo='skip', showlegend=False),
row=plot_row[order], col=2,
)

fig.update_yaxes(title_text='Normalized Flux', row=1, col=1)
fig.update_xaxes(title_text='y offset from center (pixel)', row=1, col=1)
fig.update_yaxes(title_text='Normalized Flux', row=2, col=1)
fig.update_xaxes(title_text='y offset from center (pixel)', row=2, col=1)

fig.update_yaxes(title_text='y (pixel)', row=1, col=2)
fig.update_xaxes(title_text='x (pixel)', row=1, col=2)
fig.update_yaxes(title_text='y (pixel)', row=2, col=2)
fig.update_xaxes(title_text='x (pixel)', row=2, col=2)

fig.update_layout(**layout)
profile_plot = dcc.Graph(id='profile-graph', figure=fig,
style={'display': 'inline-block', 'width': '100%', 'height': '100%;'})
return profile_plot


def make_1d_sci_plot(frame_id, archive_header):
layout = dict(title='', margin=dict(t=0, b=50, l=0, r=0), height=1050, showlegend=False)
frame_1d = download_frame(url=f'{settings.ARCHIVE_URL}{frame_id}/', headers=archive_header)[1].data
fig = make_subplots(rows=3, cols=2, x_title=u'Wavelength (\u212B)', vertical_spacing=0.02, horizontal_spacing=0.07,
shared_xaxes=True)

frame_1d = download_frame(url=f'{settings.ARCHIVE_URL}{frame_id}/', headers=archive_header)
frame_data = frame_1d[1].data
title_dict = {
'text': f"1-D Extractions: {frame_1d[0].header['ORIGNAME'].replace('-e00', '-e91-1d')}",
'y': 0.985,
'x': 0.5,
'xanchor': 'center',
'yanchor': 'top'
}
layout = dict(title=title_dict, margin=dict(t=60, b=50, l=0, r=0), height=1080, showlegend=False)
fig = make_subplots(rows=3, cols=2, vertical_spacing=0.02, horizontal_spacing=0.07,
shared_xaxes=True, subplot_titles=['Blue Order (order=2)', 'Red Order (order=1)',
None, None, None, None])
plot_column = {2: 1, 1: 2}
for order in [2, 1]:
where_order = frame_1d['order'] == order
where_order = frame_data['order'] == order
fig.add_trace(
go.Scatter(x=frame_1d['wavelength'][where_order], y=frame_1d['flux'][where_order],
go.Scatter(x=frame_data['wavelength'][where_order], y=frame_data['flux'][where_order],
line_color='#023858', mode='lines'),
row=1, col=plot_column[order],
)
fig.add_trace(
go.Scatter(x=frame_1d['wavelength'][where_order], y=frame_1d['fluxraw'][where_order],
go.Scatter(x=frame_data['wavelength'][where_order], y=frame_data['fluxraw'][where_order],
line_color='#023858', mode='lines'),
row=2, col=plot_column[order],
)
fig.add_trace(
go.Scatter(x=frame_1d['wavelength'][where_order], y=frame_1d['background'][where_order],
go.Scatter(x=frame_data['wavelength'][where_order], y=frame_data['background'][where_order],
line_color='#023858', mode='lines'),
row=3, col=plot_column[order],
)
Expand All @@ -519,6 +592,8 @@ def make_1d_sci_plot(frame_id, archive_header):
fig.update_yaxes(row=2, col=2, exponentformat='power')
fig.update_yaxes(title_text='Background (counts)', row=3, col=1, exponentformat='power')
fig.update_yaxes(row=3, col=2, exponentformat='power')
fig.update_xaxes(title_text='Wavelength (\u212B)', row=3, col=1, tickformat=".0f")
fig.update_xaxes(title_text='Wavelength (\u212B)', row=3, col=2, tickformat=".0f")
fig.update_layout(**layout)

extracted_plot = dcc.Graph(id='extracted-graph', figure=fig,
Expand All @@ -539,12 +614,12 @@ def callback_make_plots(*args, **kwargs):
archive_header = None

# TODO: All of of these should be async so things load faster
arc_frame = get_related_frame(frame_id, archive_header, 'L1IDARC')
arc_image_plot = make_arc_2d_plot(arc_frame)
arc_frame, arc_filename = get_related_frame(frame_id, archive_header, 'L1IDARC')
arc_image_plot = make_arc_2d_plot(arc_frame, arc_filename)
arc_line_plot = make_arc_line_plots(arc_frame)

sci_2d_frame = get_related_frame(frame_id, archive_header, 'L1ID2D')
sci_2d_plot = make_2d_sci_plot(sci_2d_frame)
sci_2d_frame, sci_2d_filename = get_related_frame(frame_id, archive_header, 'L1ID2D')
sci_2d_plot = make_2d_sci_plot(sci_2d_frame, sci_2d_filename)

profile_plot = make_profile_plot(sci_2d_frame)
sci_1d_plot = make_1d_sci_plot(frame_id, archive_header)
Expand Down

0 comments on commit 8263582

Please sign in to comment.