From f4210c4a516879503b8e0ff1ef9963c6651230a2 Mon Sep 17 00:00:00 2001 From: George Ho Date: Fri, 20 Jul 2018 10:30:24 -0400 Subject: [PATCH 1/4] DOC: forgot docstring for sampler_args --- bayesalpha/author_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bayesalpha/author_model.py b/bayesalpha/author_model.py index 04a2371..4791b02 100644 --- a/bayesalpha/author_model.py +++ b/bayesalpha/author_model.py @@ -194,6 +194,8 @@ def fit_authors(data, sampler_type : str Whether to use Markov chain Monte Carlo or variational inference. Either 'mcmc' or 'vi'. Defaults to 'mcmc'. + sampler_args : dict + Additional parameters for `pm.sample`. save_data : bool Whether to store the dataset in the result object. seed : int From b456803e33a8c370a2a5bb6c9a4575f309d9694f Mon Sep 17 00:00:00 2001 From: George Ho Date: Fri, 20 Jul 2018 11:50:37 -0400 Subject: [PATCH 2/4] MAINT: rename plotting module to returns_plotting --- bayesalpha/{plotting.py => returns_plotting.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename bayesalpha/{plotting.py => returns_plotting.py} (100%) diff --git a/bayesalpha/plotting.py b/bayesalpha/returns_plotting.py similarity index 100% rename from bayesalpha/plotting.py rename to bayesalpha/returns_plotting.py From af4993155f4aa254b38ae97c93a39c958fd30362 Mon Sep 17 00:00:00 2001 From: George Ho Date: Fri, 20 Jul 2018 11:55:53 -0400 Subject: [PATCH 3/4] BLD: initial commit of author plotting --- bayesalpha/author_plotting.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 bayesalpha/author_plotting.py diff --git a/bayesalpha/author_plotting.py b/bayesalpha/author_plotting.py new file mode 100644 index 0000000..5b0fcec --- /dev/null +++ b/bayesalpha/author_plotting.py @@ -0,0 +1,24 @@ +import numpy as np +import warnings +import functools +try: + import matplotlib.pyplot as plt + import seaborn as sns + _has_mpl = True +except ImportError: + warnings.warn('Could not import matplotlib: Plotting unavailable.') + _has_mpl = False + plt = None + sns = None + + +def _require_mpl(func): + @functools.wraps(func) + def inner(*args, **kwargs): + if not _has_mpl: + raise RuntimeError('Matplotlib is unavailable.') + return func(*args, **kwargs) + + return inner + + From 7bcc6da88e542616148ee2de488da4aab04d0d65 Mon Sep 17 00:00:00 2001 From: George Ho Date: Mon, 23 Jul 2018 16:43:56 -0400 Subject: [PATCH 4/4] ENH: added plot_trace function --- bayesalpha/author_plotting.py | 46 +++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/bayesalpha/author_plotting.py b/bayesalpha/author_plotting.py index 5b0fcec..2f59d60 100644 --- a/bayesalpha/author_plotting.py +++ b/bayesalpha/author_plotting.py @@ -22,3 +22,49 @@ def inner(*args, **kwargs): return inner +@_require_mpl +def plot_trace(trace, varname, title=None, ax=None, **kwargs): + """ + Plot samples from trace for a specific variable. + + Parameters + ---------- + trace : AuthorModelResult object + Result from ba.fit_authors + varname : str + Name of variable to plot. Must be one of ['mu_global', 'mu_author', + 'mu_algo', 'alpha_author', 'alpha_algo'] + title : str (optional) + Title of plot + ax : plt.axis object (optional) + Axis on which to plot + kwargs : dict (optional) + Additional keyword args to pass to sns.distplot + """ + + if varname not in ['mu_global', 'mu_author', 'mu_algo', + 'alpha_author', 'alpha_algo']: + raise ValueError("`varname` must be one of ['mu_global', 'mu_author', " + "'mu_algo', 'alpha_author', 'alpha_algo']") + + if ax is None: + _, ax = plt.subplots(figsize=[12, 4]) + + for i in trace.trace[varname]['chain']: + if varname == 'mu_global': + sns.distplot(trace.trace['mu_global'].sel({'chain': i}).values, + **kwargs) + else: + suffix = varname.split('_')[-1] # Either 'author' or 'algo' + for j in trace.trace[varname][suffix]: + sns.distplot(trace.trace[varname].sel({'chain': i, + suffix: j}).values, + **kwargs) + + if title: + ax.set_title(title) + + plt.xlabel(varname) + plt.ylabel('Probability') + + return ax