diff --git a/preliz/internal/distribution_helper.py b/preliz/internal/distribution_helper.py index 10351d6a..8cc7247b 100644 --- a/preliz/internal/distribution_helper.py +++ b/preliz/internal/distribution_helper.py @@ -64,11 +64,13 @@ def process_extra(input_string): name = match[0] args = match[1].split(",") arg_dict = {} - for arg in args: - key, value = arg.split("=") - arg_dict[key.strip()] = float(value) - result_dict[name] = arg_dict - + try: + for arg in args: + key, value = arg.split("=") + arg_dict[key.strip()] = float(value) + result_dict[name] = arg_dict + except ValueError: + pass return result_dict diff --git a/preliz/internal/optimization.py b/preliz/internal/optimization.py index 53e70f0f..7290c8af 100644 --- a/preliz/internal/optimization.py +++ b/preliz/internal/optimization.py @@ -315,7 +315,10 @@ def fit_to_ecdf(selected_distributions, x_vals, ecdf, mean, std, x_min, x_max, e fitted = Loss(len(selected_distributions)) for dist in selected_distributions: if dist.__class__.__name__ in extra_pros: - dist._parametrization(**extra_pros[dist.__class__.__name__]) + try: + dist._parametrization(**extra_pros[dist.__class__.__name__]) + except TypeError: + pass if dist.__class__.__name__ == "BetaScaled": update_bounds_beta_scaled(dist, x_min, x_max) diff --git a/preliz/tests/roulette.ipynb b/preliz/tests/roulette.ipynb index cef6de4e..3b7f738e 100644 --- a/preliz/tests/roulette.ipynb +++ b/preliz/tests/roulette.ipynb @@ -12,7 +12,7 @@ "import ipytest\n", "ipytest.autoconfig()\n", "\n", - "from preliz import roulette" + "from preliz import Roulette" ] }, { @@ -22,16 +22,64 @@ "metadata": {}, "outputs": [], "source": [ - "%%ipytest\n", - "\n", - "@pytest.mark.parametrize(\"x_min, x_max, nrows, ncols, figsize\", [\n", - " (0, 10, 10, 10, None), # Test default behavior\n", - " (-5, 5, 10, 10, None), # Test different domain\n", - " (0, 10, 5, 5, None), # Test different grid dimensions\n", + "@pytest.mark.parametrize(\"x_min, x_max, nrows, ncols, figsize, dist_names, params\", [\n", + " (0, 10, 10, 10, None, None, None), # Test default behavior\n", + " (-5, 5, 10, 10, None, None, None), # Test different domain\n", + " (0, 10, 5, 5, None, None, None), # Test different grid dimensions\n", " (0, 10, 10, 10, (10, 8)), # Test custom figsize\n", + " (0, 10, 10, 10, None, [\"Normal\", \"StudentT\"], \"Normal(mu=0), StudentT(nu=0.001)\"), # Test custom dist and params\n", "])\n", "def test_roulette(x_min, x_max, nrows, ncols, figsize):\n", - " roulette(x_min, x_max, nrows, ncols, figsize)" + " Roulette(x_min, x_max, nrows, ncols, figsize)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70ae102b", + "metadata": {}, + "outputs": [], + "source": [ + "def test_roulette_initialization():\n", + " roulette = Roulette(x_min=0, x_max=10, nrows=10, ncols=11)\n", + " assert roulette._x_min == 0\n", + " assert roulette._x_max == 10\n", + " assert roulette._nrows == 10\n", + " assert roulette._ncols == 11\n", + " assert roulette._figsize == (8, 6)\n", + "\n", + "\n", + "def test_roulette_update_grid():\n", + " roulette = Roulette(x_min=0, x_max=10, nrows=10, ncols=11)\n", + " roulette._widgets['w_x_min'].value = 1\n", + " roulette._widgets['w_x_max'].value = 9\n", + " roulette._widgets['w_nrows'].value = 8\n", + " roulette._widgets['w_ncols'].value = 9\n", + " roulette._update_grid()\n", + " assert roulette._x_min == 1\n", + " assert roulette._x_max == 9\n", + " assert roulette._nrows == 8\n", + " assert roulette._ncols == 9\n", + "\n", + "\n", + "def test_roulette_weights_to_ecdf():\n", + " roulette = Roulette(x_min=0, x_max=10, nrows=10, ncols=11)\n", + " roulette._grid._weights = {0: 2, 1: 6, 2: 10, 3: 10, 4: 7, 5: 3, 6: 1, 7: 1, 8: 1, 9: 1}\n", + " x_vals, cum_sum, probabilities, mean, std, filled_columns = roulette._weights_to_ecdf()\n", + " assert len(x_vals) == 10\n", + " assert len(cum_sum) == 10\n", + " assert len(probabilities) == 10\n", + " assert filled_columns == 10\n", + "\n", + "\n", + "def test_roulette_on_leave_fig():\n", + " roulette = Roulette(x_min=0, x_max=10, nrows=10, ncols=11)\n", + " roulette._grid._weights = {0: 2, 1: 6, 2: 10, 3: 10, 4: 7, 5: 3, 6: 1, 7: 1, 8: 1, 9: 1}\n", + " roulette._widgets['w_distributions'].value = [\"Gamma\", \"LogNormal\", \"StudentT\", \"BetaScaled\", \"Normal\"]\n", + " roulette._widgets['w_repr'].value = \"pdf\"\n", + " roulette._on_leave_fig()\n", + " assert roulette.dist is not None\n", + " assert roulette.hist is not None" ] } ], diff --git a/preliz/tests/test_roulette.py b/preliz/tests/test_roulette.py index e2e7601d..2fc21850 100644 --- a/preliz/tests/test_roulette.py +++ b/preliz/tests/test_roulette.py @@ -1,28 +1,5 @@ from test_helper import run_notebook -from preliz.unidimensional.roulette import create_figure, create_grid, Rectangles, on_leave_fig def test_roulette(): run_notebook("roulette.ipynb") - - -def test_roulette_mock(): - x_min = 0 - x_max = 10 - ncols = 10 - nrows = 10 - - fig, ax_grid, ax_fit = create_figure((10, 9)) - coll = create_grid(x_min, x_max, nrows, ncols, ax=ax_grid) - grid = Rectangles(fig, coll, nrows, ncols, ax_grid) - grid.weights = {0: 2, 1: 6, 2: 10, 3: 10, 4: 7, 5: 3, 6: 1, 7: 1, 8: 1, 9: 1} - w_repr = "kde" - distributions = ["Gamma", "LogNormal", "StudentT", "BetaScaled", "Normal"] - - for idx, dist in enumerate(distributions): - w_distributions = distributions[idx:] - - fitted_dist = on_leave_fig( - fig.canvas, grid, w_distributions, w_repr, x_min, x_max, ncols, "", ax_fit - ) - assert fitted_dist.__class__.__name__ == dist diff --git a/preliz/unidimensional/__init__.py b/preliz/unidimensional/__init__.py index 58a6f596..ac013f75 100644 --- a/preliz/unidimensional/__init__.py +++ b/preliz/unidimensional/__init__.py @@ -3,6 +3,6 @@ from .mle import mle from .quartile import quartile from .quartile_int import quartile_int -from .roulette import roulette +from .roulette import Roulette -__all__ = ["beta_mode", "maxent", "mle", "roulette", "quartile", "quartile_int"] +__all__ = ["beta_mode", "maxent", "mle", "Roulette", "quartile", "quartile_int"] diff --git a/preliz/unidimensional/roulette.py b/preliz/unidimensional/roulette.py index d0112938..a7f8950d 100644 --- a/preliz/unidimensional/roulette.py +++ b/preliz/unidimensional/roulette.py @@ -1,6 +1,5 @@ +# pylint: disable=too-many-instance-attributes from math import ceil, floor - - import numpy as np import matplotlib.pyplot as plt from matplotlib import patches @@ -9,462 +8,417 @@ import ipywidgets as widgets except ImportError: pass + from ..internal.optimization import fit_to_ecdf, get_distributions from ..internal.plot_helper import check_inside_notebook, representations from ..internal.distribution_helper import process_extra from ..distributions import all_discrete, all_continuous -def roulette(x_min=0, x_max=10, nrows=10, ncols=11, dist_names=None, figsize=None): - """ - Prior elicitation for 1D distribution using the roulette method. - - Draw 1D distributions using a grid as input. - - Parameters - ---------- - x_min: Optional[float] - Minimum value for the domain of the grid and fitted distribution - x_max: Optional[float] - Maximum value for the domain of the grid and fitted distribution - nrows: Optional[int] - Number of rows for the grid. Defaults to 10. - ncols: Optional[int] - Number of columns for the grid. Defaults to 11. - dist_names: list - List of distributions names to be used in the elicitation. If None, almost all 1D - distributions available in PreliZ will be used. Some distributions like Uniform or - Cauchy are omitted by default. - figsize: Optional[Tuple[int, int]] - Figure size. If None it will be defined automatically. - - Returns - ------- - PreliZ distribution - - References - ---------- - * Morris D.E. et al. (2014) see https://doi.org/10.1016/j.envsoft.2013.10.010 - * See roulette mode http://optics.eee.nottingham.ac.uk/match/uncertainty.php - """ - - check_inside_notebook(need_widget=True) - - ( - w_x_min, - w_x_max, - w_ncols, - w_nrows, - w_extra, - w_repr, - w_distributions, - w_checkbox_cont, - w_checkbox_disc, - w_checkbox_none, - ) = get_widgets( - x_min, - x_max, - nrows, - ncols, - dist_names, - ) - - output = widgets.Output() - - with output: - x_min = w_x_min.value - x_max = w_x_max.value - nrows = w_nrows.value - ncols = w_ncols.value - - if figsize is None: - figsize = (8, 6) - - fig, ax_grid, ax_fit = create_figure(figsize) - - coll = create_grid(x_min, x_max, nrows, ncols, ax=ax_grid) - grid = Rectangles(fig, coll, nrows, ncols, ax_grid) - - def handle_checkbox_change(_): - dist_names = handle_checkbox_widget( - w_distributions.options, w_checkbox_cont, w_checkbox_disc, w_checkbox_none - ) - w_distributions.value = dist_names - - w_checkbox_none.observe(handle_checkbox_change) - w_checkbox_cont.observe(handle_checkbox_change) - w_checkbox_disc.observe(handle_checkbox_change) - - def update_grid_(_): - update_grid( - fig.canvas, - w_x_min.value, - w_x_max.value, - w_nrows.value, - w_ncols.value, - grid, - ax_grid, - ax_fit, +class Roulette: + def __init__( + self, x_min=0, x_max=10, nrows=10, ncols=11, dist_names=None, params=None, figsize=None + ): + """ + Prior elicitation for 1D distribution using the roulette method. + + Draw 1D distributions using a grid as input. + + Parameters + ---------- + x_min: Optional[float] + Minimum value for the domain of the grid and fitted distribution + x_max: Optional[float] + Maximum value for the domain of the grid and fitted distribution + nrows: Optional[int] + Number of rows for the grid. Defaults to 10. + ncols: Optional[int] + Number of columns for the grid. Defaults to 11. + dist_names: list + List of distributions names to be used in the elicitation. + For example: ["Normal", "StudentT"]. + Default to None, almost all 1D distributions available in PreliZ will be used, + with some exceptions like Uniform or Cauchy. + params: Optional[str]: + Extra parameters to be passed to the distributions. The format is a string with the + PreliZ's distribution name followed by the argument to fix. + For example: "TruncatedNormal(lower=0), StudentT(nu=8)". If you use the ``params`` + text area quotation marks are not necessary. + figsize: Optional[Tuple[int, int]] + Figure size. If None it will be defined automatically. + + Returns + ------- + PreliZ distribution + + References + ---------- + * Morris D.E. et al. (2014) see https://doi.org/10.1016/j.envsoft.2013.10.010 + * See roulette mode http://optics.eee.nottingham.ac.uk/match/uncertainty.php + """ + + self._x_min = x_min + self._x_max = x_max + self._nrows = nrows + self._ncols = ncols + self._dist_names = dist_names + self._figsize = figsize + self._w_extra = params + self.dist = None + self._hist = None + + check_inside_notebook(need_widget=True) + + self._widgets = self._get_widgets() + self._output = widgets.Output() + + with self._output: + + if self._figsize is None: + self._figsize = (8, 6) + + self._fig, self._ax_grid, self._ax_fit = self._create_figure() + self._coll = self._create_grid() + self._grid = _Rectangles(self._fig, self._coll, self._nrows, self._ncols, self._ax_grid) + + self._setup_observers() + + self._fig.canvas.mpl_connect("button_release_event", lambda event: self._on_leave_fig()) + + controls = widgets.VBox( + [ + self._widgets["w_x_min"], + self._widgets["w_x_max"], + self._widgets["w_nrows"], + self._widgets["w_ncols"], + self._widgets["w_extra"], + ] + ) + control_distribution = widgets.VBox( + [ + self._widgets["w_checkbox_cont"], + self._widgets["w_checkbox_disc"], + self._widgets["w_checkbox_none"], + ] + ) + display( # pylint:disable=undefined-variable + widgets.HBox( + [ + controls, + self._widgets["w_repr"], + self._widgets["w_distributions"], + control_distribution, + ] ) + ) - w_x_min.observe(update_grid_) - w_x_max.observe(update_grid_) - w_nrows.observe(update_grid_) - w_ncols.observe(update_grid_) - - def on_leave_fig_(_): - on_leave_fig( - fig.canvas, - grid, - w_distributions.value, - w_repr.value, - w_x_min.value, - w_x_max.value, - w_ncols.value, - w_extra.value, - ax_fit, + def _create_figure(self): + fig, axes = plt.subplots(2, 1, figsize=self._figsize, constrained_layout=True) + ax_grid = axes[0] + ax_fit = axes[1] + ax_fit.set_yticks([]) + fig.canvas.header_visible = False + fig.canvas.footer_visible = False + fig.canvas.toolbar_position = "right" + return fig, ax_grid, ax_fit + + def _create_grid(self): + xx = np.arange(self._ncols) + yy = np.arange(self._nrows) + + if self._ncols < 11: + num = self._ncols + else: + num = 11 + + self._ax_grid.set( + xticks=np.linspace(0, self._ncols - 1, num=num) + 0.5, + xticklabels=[f"{i:.1f}" for i in np.linspace(self._x_min, self._x_max, num=num)], + ) + + coll = np.zeros((self._nrows, self._ncols), dtype=object) + for idx, xi in enumerate(xx): + for idy, yi in enumerate(yy): + sq = patches.Rectangle((xi, yi), 1, 1, fill=True, facecolor="0.8", edgecolor="w") + self._ax_grid.add_patch(sq) + coll[idy, idx] = sq + + self._ax_grid.set_yticks([]) + self._ax_grid.relim() + self._ax_grid.autoscale_view() + return coll + + def _on_leave_fig(self): + extra_pros = process_extra(self._widgets["w_extra"].value) + + x_vals, ecdf, probs, mean, std, filled_columns = self._weights_to_ecdf() + + fitted_dist = None + if filled_columns > 1: + selected_distributions = get_distributions(self._widgets["w_distributions"].value) + + if selected_distributions: + self._reset_dist_panel(yticks=False) + fitted_dist = fit_to_ecdf( + selected_distributions, + x_vals, + ecdf, + mean, + std, + self._x_min, + self._x_max, + extra_pros, + ) + + if fitted_dist is None: + self._ax_fit.set_title("domain error") + else: + representations(fitted_dist, self._widgets["w_repr"].value, self._ax_fit) + else: + self._reset_dist_panel(yticks=True) + self._fig.canvas.draw() + + self.hist = (x_vals, probs) + self.dist = fitted_dist + + def _weights_to_ecdf(self): + step = (self._x_max - self._x_min) / (self._ncols - 1) + x_vals = [(k + 0.5) * step + self._x_min for k, v in self._grid._weights.items() if v != 0] + total = sum(self._grid._weights.values()) + probabilities = [v / total for v in self._grid._weights.values() if v != 0] + cum_sum = np.cumsum(probabilities) + + mean = sum(value * prob for value, prob in zip(x_vals, probabilities)) + std = (sum(prob * (value - mean) ** 2 for value, prob in zip(x_vals, probabilities))) ** 0.5 + + return x_vals, cum_sum, probabilities, mean, std, len(x_vals) + + def _update_grid(self): + self._ax_grid.cla() + self._coll = self._create_grid() + self._grid._coll = self._coll + self._grid._ncols = self._ncols + self._grid._nrows = self._nrows + self._grid._weights = {k: 0 for k in range(0, self._ncols)} + self._reset_dist_panel(yticks=True) + self._ax_grid.set_yticks([]) + self._ax_grid.relim() + self._ax_grid.autoscale_view() + self._fig.canvas.draw() + + def _reset_dist_panel(self, yticks): + self._ax_fit.cla() + if yticks: + self._ax_fit.set_yticks([]) + self._ax_fit.set_xlim(self._x_min, self._x_max) + self._ax_fit.relim() + self._ax_fit.autoscale_view() + + def _handle_checkbox_widget(self): + if self._widgets["w_checkbox_none"].value: + self._widgets["w_checkbox_disc"].value = False + self._widgets["w_checkbox_cont"].value = False + return [] + all_cls = [] + if self._widgets["w_checkbox_cont"].value: + all_cls += list( + ( + cls.__name__ + for cls in all_continuous + if cls.__name__ in self._widgets["w_distributions"].options + ) + ) + if self._widgets["w_checkbox_disc"].value: + all_cls += list( + ( + cls.__name__ + for cls in all_discrete + if cls.__name__ in self._widgets["w_distributions"].options + ) ) + return all_cls + + def _get_widgets(self): + width_entry_text = widgets.Layout(width="150px") + width_repr_text = widgets.Layout(width="250px") + width_distribution_text = widgets.Layout(width="150px", height="125px") + + w_x_min = widgets.FloatText( + value=self._x_min, + step=1, + description="x_min:", + disabled=False, + layout=width_entry_text, + ) - w_repr.observe(on_leave_fig_) - w_distributions.observe(on_leave_fig_) - - def on_value_change(change): - new_a = change["new"] - if new_a == w_x_max.value: - w_x_max.value = new_a + 1 - - w_x_min.observe(on_value_change, names="value") - - fig.canvas.mpl_connect( - "button_release_event", - lambda event: on_leave_fig( - fig.canvas, - grid, - w_distributions.value, - w_repr.value, - w_x_min.value, - w_x_max.value, - w_ncols.value, - w_extra.value, - ax_fit, - ), + w_x_max = widgets.FloatText( + value=self._x_max, + step=1, + description="x_max:", + disabled=False, + layout=width_entry_text, ) - controls = widgets.VBox([w_x_min, w_x_max, w_nrows, w_ncols, w_extra]) - control_distribution = widgets.VBox([w_checkbox_cont, w_checkbox_disc, w_checkbox_none]) - display( # pylint:disable=undefined-variable - widgets.HBox( - [ - controls, - w_repr, - w_distributions, - control_distribution, + w_nrows = widgets.BoundedIntText( + value=self._nrows, + min=2, + step=1, + description="n_rows:", + disabled=False, + layout=width_entry_text, + ) + + w_ncols = widgets.BoundedIntText( + value=self._ncols, + min=2, + step=1, + description="n_cols:", + disabled=False, + layout=width_entry_text, + ) + + w_extra = widgets.Textarea( + value=self._w_extra, + placeholder="Pass extra parameters", + description="params:", + disabled=False, + layout=width_repr_text, + ) + + w_repr = widgets.RadioButtons( + options=["pdf", "cdf", "ppf"], + value="pdf", + description="", + disabled=False, + layout=width_entry_text, + ) + + if self._dist_names is None: + default_dist = ["Normal", "BetaScaled", "Gamma", "LogNormal", "StudentT"] + dist_names = [ + "AsymmetricLaplace", + "BetaScaled", + "ChiSquared", + "ExGaussian", + "Exponential", + "Gamma", + "Gumbel", + "HalfNormal", + "HalfStudentT", + "InverseGamma", + "Laplace", + "LogNormal", + "Logistic", + # "LogitNormal", # fails if we add chips at x_value= 1 + "Moyal", + "Normal", + "Pareto", + "Rice", + "SkewNormal", + "StudentT", + "Triangular", + "VonMises", + "Wald", + "Weibull", + "BetaBinomial", + "DiscreteWeibull", + "Geometric", + "NegativeBinomial", + "Poisson", ] + else: + default_dist = self._dist_names + dist_names = self._dist_names + + w_distributions = widgets.SelectMultiple( + options=dist_names, + value=default_dist, + description="", + disabled=False, + layout=width_distribution_text, ) - ) - - -def create_figure(figsize): - """ - Initialize a matplotlib figure with two subplots - """ - fig, axes = plt.subplots(2, 1, figsize=figsize, constrained_layout=True) - ax_grid = axes[0] - ax_fit = axes[1] - ax_fit.set_yticks([]) - fig.canvas.header_visible = False - fig.canvas.footer_visible = False - fig.canvas.toolbar_position = "right" - - return fig, ax_grid, ax_fit - - -def create_grid(x_min=0, x_max=1, nrows=10, ncols=10, ax=None): - """ - Create a grid of rectangles - """ - xx = np.arange(ncols) - yy = np.arange(nrows) - - if ncols < 11: - num = ncols - else: - num = 11 - - ax.set( - xticks=np.linspace(0, ncols - 1, num=num) + 0.5, - xticklabels=[f"{i:.1f}" for i in np.linspace(x_min, x_max, num=num)], - ) - - coll = np.zeros((nrows, ncols), dtype=object) - for idx, xi in enumerate(xx): - for idy, yi in enumerate(yy): - sq = patches.Rectangle((xi, yi), 1, 1, fill=True, facecolor="0.8", edgecolor="w") - ax.add_patch(sq) - coll[idy, idx] = sq - - ax.set_yticks([]) - ax.relim() - ax.autoscale_view() - return coll - - -class Rectangles: - """ - Clickable rectangles - Clicked rectangles are highlighted - """ + w_checkbox_cont = widgets.Checkbox( + value=False, description="Continuous", disabled=False, indent=False + ) + w_checkbox_disc = widgets.Checkbox( + value=False, description="Discrete", disabled=False, indent=False + ) + w_checkbox_none = widgets.Checkbox( + value=False, description="None", disabled=False, indent=False + ) + + return { + "w_x_min": w_x_min, + "w_x_max": w_x_max, + "w_ncols": w_ncols, + "w_nrows": w_nrows, + "w_extra": w_extra, + "w_repr": w_repr, + "w_distributions": w_distributions, + "w_checkbox_cont": w_checkbox_cont, + "w_checkbox_disc": w_checkbox_disc, + "w_checkbox_none": w_checkbox_none, + } + + def _setup_observers(self): + self._widgets["w_checkbox_none"].observe(self._handle_checkbox_change) + self._widgets["w_checkbox_cont"].observe(self._handle_checkbox_change) + self._widgets["w_checkbox_disc"].observe(self._handle_checkbox_change) + + def _update_grid_(_): + self._x_min = self._widgets["w_x_min"].value + self._x_max = self._widgets["w_x_max"].value + self._nrows = self._widgets["w_nrows"].value + self._ncols = self._widgets["w_ncols"].value + self._update_grid() + + self._widgets["w_x_min"].observe(_update_grid_) + self._widgets["w_x_max"].observe(_update_grid_) + self._widgets["w_nrows"].observe(_update_grid_) + self._widgets["w_ncols"].observe(_update_grid_) + self._widgets["w_x_min"].observe(self._on_value_change, names="value") + + def _on_leave_fig_(_): + self._on_leave_fig() + + self._widgets["w_repr"].observe(_on_leave_fig_) + self._widgets["w_distributions"].observe(_on_leave_fig_) + self._widgets["w_extra"].observe(_on_leave_fig_) + + def _handle_checkbox_change(self, _): + dist_names = self._handle_checkbox_widget() + self._widgets["w_distributions"].value = dist_names + + def _on_value_change(self, change): + new_a = change["new"] + if new_a == self._widgets["w_x_max"].value: + self._widgets["w_x_max"].value = new_a + 1 + + +class _Rectangles: def __init__(self, fig, coll, nrows, ncols, ax): - self.fig = fig - self.coll = coll - self.nrows = nrows - self.ncols = ncols - self.ax = ax - self.weights = {k: 0 for k in range(0, ncols)} + self._fig = fig + self._coll = coll + self._nrows = nrows + self._ncols = ncols + self._ax = ax + self._weights = {k: 0 for k in range(0, ncols)} fig.canvas.mpl_connect("button_press_event", self) def __call__(self, event): - if event.inaxes == self.ax: + if event.inaxes == self._ax: x = event.xdata y = event.ydata idx = floor(x) idy = ceil(y) - if 0 <= idx < self.ncols and 0 <= idy <= self.nrows: - if self.weights[idx] >= idy: + if 0 <= idx < self._ncols and 0 <= idy <= self._nrows: + if self._weights[idx] >= idy: idy -= 1 - for row in range(self.nrows): - self.coll[row, idx].set_facecolor("0.8") - self.weights[idx] = idy + for row in range(self._nrows): + self._coll[row, idx].set_facecolor("0.8") + self._weights[idx] = idy for row in range(idy): - self.coll[row, idx].set_facecolor("C1") - self.fig.canvas.draw() - - -def on_leave_fig(canvas, grid, dist_names, kind_plot, x_min, x_max, ncols, extra, ax): - x_min = float(x_min) - x_max = float(x_max) - ncols = float(ncols) - x_range = x_max - x_min - extra_pros = process_extra(extra) - - x_vals, ecdf, mean, std, filled_columns = weights_to_ecdf(grid.weights, x_min, x_range, ncols) - - fitted_dist = None - if filled_columns > 1: - selected_distributions = get_distributions(dist_names) - - if selected_distributions: - reset_dist_panel(x_min, x_max, ax, yticks=False) - fitted_dist = fit_to_ecdf( - selected_distributions, - x_vals, - ecdf, - mean, - std, - x_min, - x_max, - extra_pros, - ) - - if fitted_dist is None: - ax.set_title("domain error") - else: - representations(fitted_dist, kind_plot, ax) - else: - reset_dist_panel(x_min, x_max, ax, yticks=True) - canvas.draw() - - return fitted_dist - - -def weights_to_ecdf(weights, x_min, x_range, ncols): - """ - Turn the weights (chips) into the empirical cdf - """ - step = x_range / (ncols - 1) - x_vals = [(k + 0.5) * step + x_min for k, v in weights.items() if v != 0] - total = sum(weights.values()) - probabilities = [v / total for v in weights.values() if v != 0] - cum_sum = np.cumsum(probabilities) - - mean = sum(value * prob for value, prob in zip(x_vals, probabilities)) - std = (sum(prob * (value - mean) ** 2 for value, prob in zip(x_vals, probabilities))) ** 0.5 - - return x_vals, cum_sum, mean, std, len(x_vals) - - -def update_grid(canvas, x_min, x_max, nrows, ncols, grid, ax_grid, ax_fit): - """ - Update the grid subplot - """ - ax_grid.cla() - coll = create_grid(x_min=x_min, x_max=x_max, nrows=nrows, ncols=ncols, ax=ax_grid) - grid.coll = coll - grid.ncols = ncols - grid.nrows = nrows - grid.weights = {k: 0 for k in range(0, ncols)} - reset_dist_panel(x_min, x_max, ax_fit, yticks=True) - ax_grid.set_yticks([]) - ax_grid.relim() - ax_grid.autoscale_view() - canvas.draw() - - -def reset_dist_panel(x_min, x_max, ax, yticks): - """ - Clean the distribution subplot - """ - ax.cla() - if yticks: - ax.set_yticks([]) - ax.set_xlim(x_min, x_max) - ax.relim() - ax.autoscale_view() - - -def handle_checkbox_widget(options, w_checkbox_cont, w_checkbox_disc, w_checkbox_none): - if w_checkbox_none.value: - w_checkbox_disc.value = False - w_checkbox_cont.value = False - return [] - all_cls = [] - if w_checkbox_cont.value: - all_cont_str = [ # pylint:disable=unnecessary-comprehension - dist for dist in (cls.__name__ for cls in all_continuous if cls.__name__ in options) - ] - all_cls += all_cont_str - if w_checkbox_disc.value: - all_dist_str = [ # pylint:disable=unnecessary-comprehension - dist for dist in (cls.__name__ for cls in all_discrete if cls.__name__ in options) - ] - all_cls += all_dist_str - return all_cls - - -def get_widgets(x_min, x_max, nrows, ncols, dist_names): - - width_entry_text = widgets.Layout(width="150px") - width_repr_text = widgets.Layout(width="250px") - width_distribution_text = widgets.Layout(width="150px", height="125px") - - w_x_min = widgets.FloatText( - value=x_min, - step=1, - description="x_min:", - disabled=False, - layout=width_entry_text, - ) - - w_x_max = widgets.FloatText( - value=x_max, - step=1, - description="x_max:", - disabled=False, - layout=width_entry_text, - ) - - w_nrows = widgets.BoundedIntText( - value=nrows, - min=2, - step=1, - description="n_rows:", - disabled=False, - layout=width_entry_text, - ) - - w_ncols = widgets.BoundedIntText( - value=ncols, - min=2, - step=1, - description="n_cols:", - disabled=False, - layout=width_entry_text, - ) - - w_extra = widgets.Textarea( - value="", - placeholder="Pass extra parameters", - description="params:", - disabled=False, - layout=width_repr_text, - ) - - w_repr = widgets.RadioButtons( - options=["pdf", "cdf", "ppf"], - value="pdf", - description="", - disabled=False, - layout=width_entry_text, - ) - - if dist_names is None: - - default_dist = ["Normal", "BetaScaled", "Gamma", "LogNormal", "StudentT"] - - dist_names = [ - "AsymmetricLaplace", - "BetaScaled", - "ChiSquared", - "ExGaussian", - "Exponential", - "Gamma", - "Gumbel", - "HalfNormal", - "HalfStudentT", - "InverseGamma", - "Laplace", - "LogNormal", - "Logistic", - # "LogitNormal", # fails if we add chips at x_value= 1 - "Moyal", - "Normal", - "Pareto", - "Rice", - "SkewNormal", - "StudentT", - "Triangular", - "VonMises", - "Wald", - "Weibull", - "BetaBinomial", - "DiscreteWeibull", - "Geometric", - "NegativeBinomial", - "Poisson", - ] - - else: - default_dist = dist_names - - w_distributions = widgets.SelectMultiple( - options=dist_names, - value=default_dist, - description="", - disabled=False, - layout=width_distribution_text, - ) - - w_checkbox_cont = widgets.Checkbox( - value=False, description="Continuous", disabled=False, indent=False - ) - w_checkbox_disc = widgets.Checkbox( - value=False, description="Discrete", disabled=False, indent=False - ) - w_checkbox_none = widgets.Checkbox( - value=False, description="None", disabled=False, indent=False - ) - - return ( - w_x_min, - w_x_max, - w_ncols, - w_nrows, - w_extra, - w_repr, - w_distributions, - w_checkbox_cont, - w_checkbox_disc, - w_checkbox_none, - ) + self._coll[row, idx].set_facecolor("C1") + self._fig.canvas.draw()