Skip to content

Commit

Permalink
PPA: several fixes (#549)
Browse files Browse the repository at this point in the history
* fix

* do not include target

* lint

* remove print
  • Loading branch information
aloctavodia authored Oct 1, 2024
1 parent cf3cc2b commit 636d9e1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 18 deletions.
24 changes: 12 additions & 12 deletions preliz/internal/plot_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,29 +543,29 @@ def plot_pp_samples(pp_samples, pp_samples_idxs, references, kind="pdf", sharex=

for ax, idx in zip(axes, pp_samples_idxs):
ax.clear()
plot_references(references, ax)
ax.relim()

sample = pp_samples[idx]

if sharex:
min_ = sample.min()
max_ = sample.max()
if min_ < x_lims[0]:
x_lims[0] = min_
if max_ > x_lims[1]:
x_lims[1] = max_

if kind == "pdf":
plot_kde(sample, ax=ax, color="C0")
elif kind == "hist":
bins, *_ = ax.hist(sample, color="C0", bins="auto", alpha=0.5, density=True)
ax.set_ylim(-bins.max() * 0.05, None)

elif kind == "ecdf":
ax.hist(sample, color="C0")
ax.ecdf(sample, color="C0")

plot_pointinterval(sample, ax=ax)
plot_references(references, ax)

if sharex:
min_, max_ = ax.get_xlim()
if min_ < x_lims[0]:
x_lims[0] = min_
if max_ > x_lims[1]:
x_lims[1] = max_

ax.set_title(idx, alpha=0)
ax.set_yticks([])

Expand All @@ -586,14 +586,13 @@ def plot_pp_mean(pp_samples, selected, references=None, kind="pdf", fig_pp_mean=
ax_pp_mean = fig_pp_mean.axes[0]

ax_pp_mean.clear()
ax_pp_mean.relim()

if np.any(selected):
sample = pp_samples[selected].ravel()
else:
sample = pp_samples.ravel()

plot_references(references, ax_pp_mean)

if kind == "pdf":
plot_kde(
sample,
Expand All @@ -609,6 +608,7 @@ def plot_pp_mean(pp_samples, selected, references=None, kind="pdf", fig_pp_mean=
ax_pp_mean.ecdf(sample, color="k", linestyle="--")

plot_pointinterval(sample, ax=ax_pp_mean)
plot_references(references, ax_pp_mean)
ax_pp_mean.set_yticks([])
fig_pp_mean.canvas.draw()

Expand Down
27 changes: 21 additions & 6 deletions preliz/predictive/ppa.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def ppa(
with output:
references_widget = widgets.Text(
value=str(references),
placeholder="Int, Float or tuple",
placeholder="Int, Float, tuple or dict",
description="references: ",
disabled=False,
layout=widgets.Layout(width="230px", margin="0 20px 0 0"),
Expand All @@ -90,14 +90,26 @@ def ppa(

def kind_(_):
kind = radio_buttons_kind.value
try:
filter_dists.references = ast.literal_eval(references_widget.value)
except (ValueError, SyntaxError):
filter_dists.references = None

plot_pp_samples(
filter_dists.pp_samples,
filter_dists.display_pp_idxs,
ast.literal_eval(references_widget.value),
filter_dists.references,
kind,
check_button_sharex.value,
filter_dists.fig,
)
plot_pp_mean(
filter_dists.pp_samples,
list(filter_dists.selected),
filter_dists.references,
kind,
filter_dists.fig_pp_mean,
)

references_widget.observe(kind_, names=["value"])

Expand Down Expand Up @@ -140,6 +152,7 @@ def __init__(self, fmodel, draws, references, boundaries, target, engine):
self.prior_samples = None # prior samples used for backfitting
self.display_pp_idxs = None # indices of the pp_samples to be displayed
self.pp_octiles = None # octiles computed from pp_samples
self.ref_octiles = None # octiles computed from the target distribution
self.kdt = None # KDTree used to find similar distributions
self.model = None # parsed model used for backfitting
self.clicked = [] # axes clicked by the user
Expand Down Expand Up @@ -170,7 +183,9 @@ def __call__(self):
self.pp_octiles, self.kdt = self.compute_octiles()
self.display_pp_idxs = self.initialize_subsamples(self.target)
self.fig, self.axes = plot_pp_samples(
self.pp_samples, self.display_pp_idxs, self.references
self.pp_samples,
self.display_pp_idxs,
self.references,
)
self.fig_pp_mean = plot_pp_mean(self.pp_samples, self.selected, self.references)

Expand All @@ -180,7 +195,7 @@ def add_target_dist(self):
elif isinstance(self.target, Distribution):
ref_sample = self.target.rvs(self.pp_samples.shape[1])

self.pp_samples = np.vstack([ref_sample, self.pp_samples])
self.ref_octiles = np.quantile(ref_sample, [0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875])

def compute_octiles(self):
"""
Expand Down Expand Up @@ -235,14 +250,14 @@ def initialize_subsamples(self, target):
if new != self.draws:
pp_idxs_to_display.append(new)
else:
new = 0
new = -1
pp_idxs_to_display.append(new)

for _ in range(9):
nearest_neighbor = 2
while new in pp_idxs_to_display:
distance, new = self.kdt.query(
self.pp_octiles[pp_idxs_to_display[-1]], [nearest_neighbor], workers=-1
self.ref_octiles, [nearest_neighbor], workers=-1
)
new = new.item()
nearest_neighbor += 1
Expand Down

0 comments on commit 636d9e1

Please sign in to comment.