Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Band depth envelope #17

Open
wants to merge 9 commits into
base: dev
Choose a base branch
from
Open
202 changes: 172 additions & 30 deletions SMHviz_plot/figures.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from datetime import timedelta

import numpy as np
import pandas as pd

from SMHviz_plot.utils import *
Expand Down Expand Up @@ -53,7 +51,7 @@ def add_scatter_trace(fig, data, legend_name, x_col="time_value", y_col="value",
:parameter dash: Option to print the line is dash, options include 'dash', 'dot', and 'dashdot'. By default, "None",
no dash.
:type dash: str | None
:parameter custom_data: Add custom data
:parameter custom_data: Add custom data, which can be referenced in the hover text
:type dash: str | None | pandas.DataFrame
:return: a plotly.graph_objs.Figure object with an added trace
"""
Expand All @@ -70,7 +68,7 @@ def add_scatter_trace(fig, data, legend_name, x_col="time_value", y_col="value",
showlegend=show_legend,
customdata=custom_data,
hovertemplate=hover_text +
"Value: %{y:,.2f}<br>Epiweek: %{x|%Y-%m-%d}<extra></extra>"),
"Value: %{y:,.2f}<br>Epiweek: %{x|%Y-%m-%d}<extra></extra>"),
row=subplot_coord[0], col=subplot_coord[1])
if connect_gaps is not None:
fig.update_traces(connectgaps=connect_gaps)
Expand Down Expand Up @@ -216,7 +214,7 @@ def ui_ribbons(fig, df_plot, quant_sel, legend_name, x_col="target_end_date", y_
name=legend_name,
mode='lines',
line=dict(width=line_width),
marker=dict(color=re.sub(", 1\)", ", " + str(opacity) + ")", color)),
marker=dict(color=re.sub(r", 1\)", ", " + str(opacity) + ")", color)),
legendgroup=legend_name,
showlegend=show_legend,
hovertemplate=second_hover_text),
Expand All @@ -228,10 +226,10 @@ def ui_ribbons(fig, df_plot, quant_sel, legend_name, x_col="target_end_date", y_
name=legend_name,
line=dict(width=line_width),
mode='lines',
marker=dict(color=re.sub(", 1\)", ", " + str(opacity) + ")", color)),
marker=dict(color=re.sub(r", 1\)", ", " + str(opacity) + ")", color)),
legendgroup=legend_name,
showlegend=False,
fillcolor=re.sub(", 1\)", ", " + str(opacity) + ")", color),
fillcolor=re.sub(r", 1\)", ", " + str(opacity) + ")", color),
fill='tonexty',
hovertemplate=first_hover_text),
row=subplot_coord[0], col=subplot_coord[1])
Expand Down Expand Up @@ -335,7 +333,7 @@ def make_proj_plot(fig_plot, proj_data, intervals=None, intervals_dict=None, x_c
elif len(intervals) > 1:
intervals.sort(reverse=True)
for i in range(0, len(intervals)):
if i is 0 and plot_df is None:
if i == 0 and plot_df is None:
ui_show_legend = show_legend
else:
ui_show_legend = False
Expand Down Expand Up @@ -540,7 +538,7 @@ def make_scatter_plot(proj_data, truth_data, intervals=None, intervals_dict=None
else:
show_legend = False
if truth_facet is not None:
if truth_data_type is "scatter":
if truth_data_type == "scatter":
if w_delay is not None:
plot_truth_df = truth_facet[pd.to_datetime(truth_facet[x_truth_col]) <=
(max(pd.to_datetime(truth_facet[x_truth_col])) -
Expand All @@ -560,7 +558,7 @@ def make_scatter_plot(proj_data, truth_data, intervals=None, intervals_dict=None
subplot_coord=subplot_coord, x_col=x_truth_col, y_col=y_truth_col,
width=line_width, connect_gaps=connect_gaps, mode="markers",
color="rgb(200, 200, 200)", line_width=0.5)
elif truth_data_type is "bar":
elif truth_data_type == "bar":
fig_plot = add_bar_trace(fig_plot, truth_facet, truth_legend_name, show_legend=show_legend,
hover_text=truth_legend_name + "<br>", subplot_coord=subplot_coord,
x_col=x_truth_col)
Expand Down Expand Up @@ -590,7 +588,7 @@ def make_scatter_plot(proj_data, truth_data, intervals=None, intervals_dict=None
else:
fig_plot = fig_plot
if truth_data is not None:
if truth_data_type is "scatter":
if truth_data_type == "scatter":
if w_delay is not None:
plot_truth_df = truth_data[pd.to_datetime(truth_data[x_truth_col]) <=
(max(pd.to_datetime(truth_data[x_truth_col])) -
Expand All @@ -608,7 +606,7 @@ def make_scatter_plot(proj_data, truth_data, intervals=None, intervals_dict=None
hover_text=truth_legend_name + "<br>", x_col=x_truth_col,
width=line_width, connect_gaps=connect_gaps, mode="markers",
color="rgb(200, 200, 200)", show_legend=False, line_width=0.5)
elif truth_data_type is "bar":
elif truth_data_type == "bar":
fig_plot = add_bar_trace(fig_plot, truth_data, truth_legend_name,
hover_text=truth_legend_name + "<br>", x_col=x_truth_col)
else:
Expand All @@ -631,11 +629,11 @@ def make_scatter_plot(proj_data, truth_data, intervals=None, intervals_dict=None
# View update
to_vis = list()
leg_only = list()
if viz_truth_data is True:
if viz_truth_data == True:
to_vis.append(truth_legend_name)
elif viz_truth_data == "legendonly":
leg_only.append(truth_legend_name)
if ensemble_view is True:
if ensemble_view == True:
to_vis.append(ensemble_name)
leg_only = leg_only + list(proj_data[legend_col].unique())
leg_only.remove(ensemble_name)
Expand All @@ -656,7 +654,7 @@ def make_scatter_plot(proj_data, truth_data, intervals=None, intervals_dict=None
if notes is not None:
fig_plot.update_layout(legend={"title": {"text": notes + "<br>", "side": "top"}})
# Add buttons
if button is True and ensemble_name is not None:
if button == True and ensemble_name is not None:
button = make_ens_button(fig_plot, viz_truth_data=viz_truth_data, truth_legend_name=truth_legend_name,
ensemble_name=ensemble_name, button_name="Ensemble", button_opt=button_opt)
fig_plot.update_layout(
Expand Down Expand Up @@ -753,7 +751,7 @@ def add_point_scatter(fig, df, ens_name, color_dict=None, multiply=1, symbol="ci
full_model_name = "".join(list(model))
# prerequisite
color_marker = color_line_trace(color_dict, model, line_width=0)
color_marker = re.sub(", 1\)", ", " + str(opacity) + ")", color_marker[0])
color_marker = re.sub(r", 1\)", ", " + str(opacity) + ")", color_marker[0])
model_marker = dict(size=20, color=color_marker, symbol=symbol)
fig.add_trace(go.Scatter(x=df_model["full_x"],
y=df_model["rel_change"] * multi,
Expand Down Expand Up @@ -1063,8 +1061,9 @@ def add_spaghetti_plot(fig, df, color_dict, legend_dict=None,
all_traj_df.loc[pd.isna(all_traj_df['value']), 'type_id'] = np.nan

# Add single trace
color = re.sub(", 1\)", ", " + str(opacity) + ")", col_line[0])
fig = add_scatter_trace(fig, all_traj_df, legend_name, x_col="target_end_date", mode="lines", color=color,
color = re.sub(r", 1\)", ", " + str(opacity) + ")", col_line[0])
fig = add_scatter_trace(fig, all_traj_df, legend_name, x_col="target_end_date",
mode="lines", color=color,
show_legend=show_legend, subplot_coord=subplot_coord,
custom_data=all_traj_df['type_id'],
hover_text=hover_text + "Model: " + legend_name + "<br>Type ID: %{customdata}<br>")
Expand All @@ -1076,34 +1075,177 @@ def add_spaghetti_plot(fig, df, color_dict, legend_dict=None,
return fig


def make_spaghetti_plot(df, legend_col="model_name", spag_col="type_id", show_legend=True, hover_text="", opacity=0.3,
subplot=False, title="", height=1000, subplot_col=None, subplot_titles=None, palette="turbo",
share_x="all", share_y="all", x_title="", y_title="N", theme="plotly_white", color_dict=None,
add_median=False, legend_dict=None):
def add_spaghetti_plot_envelope(fig, df, color_dict, band_depth_limit, legend_dict=None,
legend_col="model_name", spag_col="type_id", show_legend=True,
hover_text="", opacity=0.3,
subplot_coord=None, add_median=False, median=0.5):
"""
:param band_depth_limit: if not None, must be a float X between 0 and 1 where the plot will
show envelope around trajectories with band depth greater than X%.
Band depth is a measure of the representativeness of one trajectory among an ensemble.
For more details, see https://ieeexplore.ieee.org/document/6875964 - Curve Boxplot: Generalization of Boxplot for Ensembles of Curves by Mirzargar et al.
"""

if add_median is True:
df_med = df[df[spag_col] == median]
df = df[df[spag_col] != median]
else:
df_med = None
for leg in df[legend_col].drop_duplicates():
# df_plot contains all data for a given model (and scenario and age group)
df_plot = df[df[legend_col] == leg].drop(legend_col, axis=1)
if legend_dict is None:
legend_name = leg
col_line = color_line_trace(color_dict, leg)
else:
legend_name = legend_dict[leg]
col_line = color_line_trace(color_dict, legend_name)

# Prepare df with all trajectories in a model, separated by null rows (which break up trajectories into different lines)
temp = pd.DataFrame()
traj_list = list(df_plot['type_id'].unique())
temp.loc[:, 'value'] = [np.nan] * len(traj_list)
temp.loc[:, 'type_id'] = traj_list
temp.loc[:, 'target_end_date'] = [pd.NaT] * len(traj_list)
all_traj_df = pd.concat([df_plot, temp], axis=0)
all_traj_df = all_traj_df.sort_values(['type_id', 'target_end_date'])
# Once Nan's are inserted between typeIDs, insert Nan in type ID col so hover text renders correctly
all_traj_df.loc[pd.isna(all_traj_df['value']), 'type_id'] = np.nan
band_depth_df = generate_band_depth_df(df_plot)
all_traj_df = all_traj_df.merge(band_depth_df, how='left', on='type_id')

# Add single trace
connect_gaps = None
color = re.sub(r", 1\)", ", " + str(opacity) + ")", col_line[0])
fig.add_trace(go.Scatter(x=all_traj_df['target_end_date'],
y=all_traj_df['value'],
name=legend_name,
mode='lines',
marker=dict(color=color, line_width=0.0001),
legendgroup=legend_name,
line=dict(width=2, dash=None),
visible=True,
showlegend=show_legend,
customdata=all_traj_df['type_id'],
text=all_traj_df['band_depth'],
hovertemplate=hover_text + f"Model: {legend_name}<br>"
"Type ID: %{customdata}<br>"
"Modified band depth: %{text:.2%}<br>"
"Value: %{y:,.2f}<br>Epiweek: %{x|%Y-%m-%d}<extra></extra>"
),
row=subplot_coord[0], col=subplot_coord[1])
if connect_gaps is not None:
fig.update_traces(connectgaps=connect_gaps)
if add_median is True and df_med is not None:
df_plot_med = df_med[df_med[legend_col] == leg]
add_scatter_trace(fig, df_plot_med, legend_name, x_col="target_end_date",
show_legend=False,
mode="lines", subplot_coord=subplot_coord, width=4,
hover_text=hover_text + spag_col.title() + ": Median <br>",
color=col_line[0])

# Add shaded region for trajectories with top X% of band depths
band_depth_filtered = \
band_depth_df.quantile(q=band_depth_limit, axis=0, interpolation='nearest').iloc[1]
df_top_x_pctile = all_traj_df.loc[all_traj_df['band_depth'] >= band_depth_filtered, :]
# shade region
min_top_x_envelope = df_top_x_pctile.groupby('target_end_date')['value'].agg(
'min').reset_index()
max_top_x_envelope = df_top_x_pctile.groupby('target_end_date')['value'].agg(
'max').reset_index()

# Add trace for min
fig.add_trace(go.Scatter(x=min_top_x_envelope['target_end_date'],
y=min_top_x_envelope['value'],
name=legend_name,
mode='lines',
legendgroup=legend_name,
marker=dict(color=color, line_width=0.0001),
line=dict(width=2, dash=None),
visible=True,
showlegend=False,
),
row=subplot_coord[0], col=subplot_coord[1])
# Add trace for max
fig.add_trace(go.Scatter(x=max_top_x_envelope['target_end_date'],
y=max_top_x_envelope['value'],
name=legend_name,
mode='lines',
legendgroup=legend_name,
marker=dict(color=color, line_width=0.0001),
line=dict(width=2, dash=None),
visible=True,
fill='tonexty',
showlegend=False,
),
row=subplot_coord[0], col=subplot_coord[1])
if connect_gaps is not None:
fig.update_traces(connectgaps=connect_gaps)

return fig


def make_spaghetti_plot(df, legend_col="model_name", spag_col="type_id", show_legend=True,
hover_text="", opacity=0.3,
subplot=False, title="", height=1000, subplot_col=None, subplot_titles=None,
palette="turbo",
share_x="all", share_y="all", x_title="", y_title="N", theme="plotly_white",
color_dict=None,
add_median=False, legend_dict=None, band_depth_limit=None):
"""
:param band_depth_limit: if not None, must be a float X between 0 and 1 where the plot will
show envelope around trajectories with band depth greater than X%.
Band depth is a measure of the representativeness of one trajectory among an ensemble.
For more details, see https://ieeexplore.ieee.org/document/6875964 - Curve Boxplot: Generalization of Boxplot for Ensembles of Curves by Mirzargar et al.
"""

# Colorscale
if color_dict is None:
color_dict = make_palette_sequential(df, legend_col, palette=palette)
# Plot
if subplot is True:
sub_var = list(df[subplot_col].unique())
fig = prep_subplot(sub_var, subplot_titles, x_title, y_title, sort=False, share_x=share_x, share_y=share_y)
fig = prep_subplot(sub_var, subplot_titles, x_title, y_title, sort=False, share_x=share_x,
share_y=share_y)
for var in sub_var:
df_var = df[df[subplot_col] == var].drop(subplot_col, axis=1)
plot_coord = subplot_row_col(sub_var, var)
if var == sub_var[0]:
show_legend = show_legend
else:
show_legend = False
add_spaghetti_plot(fig, df_var, color_dict=color_dict, legend_col=legend_col,
spag_col=spag_col, show_legend=show_legend, hover_text=hover_text,
opacity=opacity, subplot_coord=plot_coord, add_median=add_median,
legend_dict=legend_dict)
if band_depth_limit and band_depth_limit >= 0 and band_depth_limit <= 1:
add_spaghetti_plot_envelope(fig, df_var, color_dict=color_dict,
legend_col=legend_col,
spag_col=spag_col, show_legend=show_legend,
hover_text=hover_text,
opacity=opacity, subplot_coord=plot_coord,
add_median=add_median,
legend_dict=legend_dict,
band_depth_limit=band_depth_limit)

else:
add_spaghetti_plot(fig, df_var, color_dict=color_dict, legend_col=legend_col,
spag_col=spag_col, show_legend=show_legend,
hover_text=hover_text,
opacity=opacity, subplot_coord=plot_coord, add_median=add_median,
legend_dict=legend_dict)
else:
fig = go.Figure()
fig.update_layout(xaxis_title=x_title, yaxis_title=y_title)
add_spaghetti_plot(fig, df, color_dict=color_dict, legend_col=legend_col,
spag_col=spag_col, show_legend=show_legend, hover_text=hover_text,
opacity=opacity, subplot_coord=None, add_median=add_median, legend_dict=legend_dict)
if band_depth_limit and band_depth_limit >= 0 and band_depth_limit <= 1:
add_spaghetti_plot_envelope(fig, df, color_dict=color_dict, legend_col=legend_col,
spag_col=spag_col, show_legend=show_legend,
hover_text=hover_text,
opacity=opacity, subplot_coord=None,
add_median=add_median,
legend_dict=legend_dict, band_depth_limit=band_depth_limit)

else:
add_spaghetti_plot(fig, df, color_dict=color_dict, legend_col=legend_col,
spag_col=spag_col, show_legend=show_legend, hover_text=hover_text,
opacity=opacity, subplot_coord=None, add_median=add_median,
legend_dict=legend_dict)
subplot_fig_output(fig, title, subtitle="", height=height, theme=theme)
return fig

Expand Down
Loading