Skip to content

Commit

Permalink
clean up plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Mar 18, 2021
1 parent 2f61913 commit 5f01346
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 440 deletions.
481 changes: 50 additions & 431 deletions notebooks/imodels_comparisons.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions notebooks/matplotlibrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

figure.dpi: 250
File renamed without changes.
File renamed without changes.
103 changes: 103 additions & 0 deletions notebooks/viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import itertools
import math
import os
import pickle as pkl
from typing import List, Dict, Any

import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.datasets import fetch_openml
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['figure.dpi'] = 250
import dvu # for visualization
dvu.set_style()

def get_x_and_y(result_data: pd.Series) -> (pd.Series, pd.Series):
complexities = result_data[result_data.index.str.contains('complexity')]
rocs = result_data[result_data.index.str.contains('ROC')]
complexity_sort_indices = complexities.argsort()
return complexities[complexity_sort_indices], rocs[complexity_sort_indices]

def viz_comparison_val_average(result: Dict[str, Any]) -> None:
'''Plot dataset-averaged ROC AUC vs dataset-averaged complexity for different hyperparameter settings
of a single model, including zoomed-in plot of overlapping region
'''
result_data = result['df']['mean']
result_estimators = result['estimators']
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))
for est in np.unique(result_estimators):

est_result_data = result_data[result_data.index.str.contains(est)]
x, y = get_x_and_y(est_result_data)
axes[0].plot(x, y, marker='o', markersize=4, label=est.replace('_', ' '))

if est in result['auc_of_auc'].index:
area = result['auc_of_auc'][est]
label = est.split(' - ')[1] + f' AUC: {area:.3f}'
axes[1].plot(x, y, marker='o', markersize=4, label=label.replace('_', ' '))


# ax.legend(frameon=False, handlelength=1)
axes[0].set_title('average ROC AUC across all comparison datasets')
axes[1].set_xlim(result['auc_of_auc_lb'], result['auc_of_auc_ub'])
axes[1].set_title('Overlapping, low (<30) complexity region only')

for ax in axes:
ax.set_xlabel('complexity score')
ax.set_ylabel('ROC AUC')
dvu.line_legend(ax=ax)
plt.tight_layout()

def viz_comparison_test_average(results: List[Dict[str, Any]]) -> None:
'''Plot dataset-averaged ROC AUC vs dataset-averaged complexity for different models
'''
for result in results:
mean_result = result['df']['mean']
est = result['estimators'][0]
x, y = get_x_and_y(mean_result)
plt.plot(x, y, marker='o', markersize=2, linewidth=1, label=est.replace('_', ' '))
plt.xlim(0, 30)
plt.xlabel('complexity score', size=8)
plt.ylabel('ROC AUC', size=8)
plt.title('average ROC AUC across all comparison datasets', size=8)
# plt.legend(frameon=False, handlelength=1, fontsize=8)
dvu.line_legend(fontsize=8, adjust_text_labels=True)

def viz_comparison_datasets(result: Dict[str, Any], cols=3, figsize=(14, 10), test=False) -> None:
'''Plot ROC AUC vs complexity for different datasets and models (not averaged)
'''
if test:
results_df = pd.concat([r['df'] for r in result])
results_estimators = [r['estimators'][0] for r in result]
else:
results_df = result['df']
results_estimators = np.unique(result['estimators'])

datasets = list(results_df.columns)[:-2]
n_rows = int(math.ceil(len(datasets) / cols))
plt.figure(figsize=figsize)
for i, dataset in enumerate(datasets):
# curr_ax = axes[i // cols, i % cols]
plt.subplot(n_rows, cols, i + 1)
results_data = results_df[dataset]
for est in np.unique(results_estimators):
results_data_est = results_data[results_data.index.str.contains(est)]
x, y = get_x_and_y(results_data_est)
plt.plot(x, y, marker='o', markersize=4, label=est.replace('_', ' '))
plt.xlim(0, 30)
plt.xlabel('complexity score')
plt.ylabel('ROC AUC')
# plt.legend()
dvu.line_legend(fontsize=8,
adjust_text_labels=False,
xoffset_spacing=0,
extra_spacing=0)
plt.title(f'dataset {dataset}')
# plt.legend(frameon=False, handlelength=1)
# plt.subplot(n_rows, cols, 1)
#
plt.tight_layout()
6 changes: 3 additions & 3 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
<a href="https://github.com/csinva/imodels/actions"><img src="https://github.com/csinva/imodels/workflows/tests/badge.svg"></a>
<img src="https://img.shields.io/github/checks-status/csinva/imodels/master">
<img src="https://img.shields.io/pypi/v/imodels?color=orange">
<img src="https://static.pepy.tech/personalized-badge/imodels?period=total&units=none&left_color=gray&right_color=orange&left_text=downloads&kill_cache=3">
<img src="https://static.pepy.tech/personalized-badge/imodels?period=total&units=none&left_color=gray&right_color=orange&left_text=downloads&kill_cache=4">
</p>


Expand Down Expand Up @@ -103,8 +103,8 @@ Demos are contained in the [notebooks](notebooks) folder.
>
> we also include some demos of posthoc analysis, which occurs after fitting models
>
> - [posthoc.ipynb](notebooks/2_posthoc.ipynb) - shows different simple analyses to interpret a trained model
> - [uncertainty.ipynb](notebooks/3_uncertainty.ipynb) - basic code to get uncertainty estimates for a model
> - [posthoc.ipynb](notebooks/posthoc_analysis.ipynb) - shows different simple analyses to interpret a trained model
> - [uncertainty.ipynb](notebooks/uncertainty_analysis.ipynb) - basic code to get uncertainty estimates for a model
## Support for different tasks

Expand Down
13 changes: 7 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@
url="https://github.com/csinva/imodels",
packages=setuptools.find_packages(exclude=['tests']),
install_requires=[
'numpy',
'scipy',
'matplotlib',
'pandas',
'scikit-learn>=0.23.0', # 0.23+ only works on py3.6+
'cvxpy',
'cvxopt',
'mlxtend',
'dvu', # for visualization
'matplotlib',
'mlxtend',
'numpy',
'pandas',
'pytest',
'pytest-cov',
'scipy',
'scikit-learn>=0.23.0', # 0.23+ only works on py3.6+
'tqdm'
],
python_requires='>=3.6',
Expand Down

0 comments on commit 5f01346

Please sign in to comment.