Skip to content

Commit

Permalink
add some stuff to the app
Browse files Browse the repository at this point in the history
  • Loading branch information
sprivite committed Jul 4, 2024
1 parent d4a491f commit 0e1d818
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 16 deletions.
72 changes: 64 additions & 8 deletions bin/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import streamlit as st
import pyarrow as pa
import seaborn as sns

from pybalance.propensity import PropensityScoreMatcher
from pybalance.sim import generate_toy_dataset
from pybalance.visualization import plot_numeric_features, plot_categoric_features, plot_per_feature_loss
from pybalance.visualization import (
plot_numeric_features,
plot_categoric_features,
plot_per_feature_loss,
)
from pybalance.utils import BALANCE_CALCULATORS

OBJECTIVES = list(BALANCE_CALCULATORS.keys())
Expand Down Expand Up @@ -42,6 +49,13 @@
help="Restrict hyperparameter search based on time in seconds",
)
method = st.sidebar.selectbox("Method", ["greedy", "linear_sum_assignment"])
cumulative = st.sidebar.checkbox("Cumulative plots", value=False)
if cumulative:
bins = 500
else:
bins = 10

palette = sns.color_palette("colorblind")

# Update the parameters based on user input
matching_data = generate_toy_dataset(n_pool, n_target, seed)
Expand All @@ -63,29 +77,71 @@
)
matching_data.append(post_matching_data.data)

hue_order += list( set(matching_data.populations) - set(hue_order) )
balance_calculator = BALANCE_CALCULATORS[objective](pre_matching_data)
st.sidebar.write(balance_calculator.__doc__)
hue_order += list(set(matching_data.populations) - set(hue_order))

# Display the figures
if matching_data:

tab1, tab2, tab3 = st.tabs(["Numeric", "Categoric", "SMD"])

with tab1:
numeric_fig = plot_numeric_features(matching_data, col_wrap=2, height=6, hue_order=hue_order)

plot_vars = []
for i, col in enumerate(st.columns(len(matching_data.headers["numeric"]))):
with col:
col_name = matching_data.headers["numeric"][i]
if st.checkbox(col_name, value=True):
plot_vars.append(col_name)
print("streamlit", plot_vars)
numeric_fig = plot_numeric_features(
matching_data,
col_wrap=2,
height=6,
hue_order=hue_order,
cumulative=cumulative,
bins=bins,
include_only=plot_vars,
# palette=palette,
)
st.pyplot(numeric_fig)
st.write("---")
# import pdb
# pdb.set_trace()
summary = matching_data.describe_numeric().astype("object")
summary = summary[summary.index.get_level_values(0).isin(plot_vars)]
st.dataframe(summary, use_container_width=True)

with tab2:
plot_vars = []
for i, col in enumerate(st.columns(len(matching_data.headers["categoric"]))):
with col:
col_name = matching_data.headers["categoric"][i]
if st.checkbox(col_name, value=True):
plot_vars.append(col_name)

print("streamlit", plot_vars)
categoric_fig = plot_categoric_features(
matching_data, col_wrap=2, height=6, include_binary=True, hue_order=hue_order
matching_data,
col_wrap=2,
height=6,
include_binary=True,
hue_order=hue_order,
include_only=plot_vars,
# palette=palette,
)
st.pyplot(categoric_fig)
st.write("---")
summary = matching_data.describe_categoric().astype("object")
summary = summary[summary.index.get_level_values(0).isin(plot_vars)]
st.dataframe(summary, use_container_width=True)

balance_calculator = BALANCE_CALCULATORS[objective](pre_matching_data)
with tab3:
categoric_fig = plot_per_feature_loss(
matching_data,
balance_calculator,
hue_order=hue_order,
debin=False
debin=False,
# palette=palette,
)
st.pyplot(categoric_fig)
st.pyplot(categoric_fig)
15 changes: 10 additions & 5 deletions pybalance/utils/balance_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,11 @@ def per_feature_loss(

class BetaBalance(BaseBalanceCalculator):
"""
Convenience interface to BaseBalanceCalculator to computes the distance
between populations as the mean standardized mean difference. Uses
StandardMatchingPreprocessor as the preprocessor.
BetaBalance computes the balance between two populatiosn as the mean
absolute standardized mean difference across all features. Uses
StandardMatchingPreprocessor as the preprocessor. In this preprocessor,
numeric variables are left unchanged, while categorical variables are
one-hot encoded. See StandardMatchingPreprocessor for more details.
"""

name = "beta"
Expand All @@ -345,8 +347,11 @@ def __init__(

class BetaSquaredBalance(BaseBalanceCalculator):
"""
Same as BetaBalance, except that per-feature balances are averaged in a
mean square fashion.
BetaSquaredBalance computes the balance between two populatiosn as the mean
square standardized mean difference across all features. Uses
StandardMatchingPreprocessor as the preprocessor. In this preprocessor,
numeric variables are left unchanged, while categorical variables are
one-hot encoded. See StandardMatchingPreprocessor for more details.
"""

name = "beta_squared"
Expand Down
28 changes: 25 additions & 3 deletions pybalance/visualization/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,19 @@ def plot_categoric_features(
col_wrap: int = 2,
height: float = 6,
include_binary=True,
include_only: Optional[List[str]] = None,
**plot_params,
) -> plt.Figure:
"""
Plot the one-dimensional marginal distributions for all categoric features
and all treatment groups found in matching_data. Extra keyword arguments are
passed to seaborn.histplot and override defaults.
:param matching_data: MatchingData instance containing at least one population.
:param include_binary: Whether to include binary features in the plot.
:param include_only: List of features to consider for plotting. Otherwise,
all categoric features are plotted. If include_binary is False, binary
features are excluded, even if present in include_only.
"""
# Set up default plotting params for categoric varaibles.
default_params = {
Expand All @@ -113,7 +120,11 @@ def plot_categoric_features(
default_params.update(plot_params)

# Determine which covariates to plot.
headers = matching_data.headers["categoric"]
if include_only is None:
headers = matching_data.headers["categoric"]
else:
headers = include_only

if not include_binary:
headers = [c for c in headers if matching_data[c].nunique() > 2]

Expand All @@ -128,12 +139,20 @@ def plot_categoric_features(


def plot_numeric_features(
matching_data: MatchingData, col_wrap: int = 2, height: float = 6, **plot_params
matching_data: MatchingData,
col_wrap: int = 2,
height: float = 6,
include_only: Optional[List[str]] = None,
**plot_params,
) -> plt.Figure:
"""
Plot the one-dimensional marginal distributions for all numerical features
and all treatment groups found in matching_data. Extra keyword arguments are
passed to seaborn.histplot and override defaults.
:param matching_data: MatchingData instance containing at least one population.
:param include_only: List of features to consider for plotting. Otherwise,
all numeric features are plotted.
"""
# Set up default plotting params for numeric varaibles.
default_params = {
Expand All @@ -153,7 +172,10 @@ def plot_numeric_features(
default_params.update(plot_params)

# Determine which covariates to plot.
headers = matching_data.headers["numeric"]
if include_only is None:
headers = matching_data.headers["numeric"]
else:
headers = include_only

# PLOT!
fig = _plot_1d_marginals(matching_data, headers, col_wrap, height, **default_params)
Expand Down

0 comments on commit 0e1d818

Please sign in to comment.