Skip to content

Commit

Permalink
update visualisation and returned evidence
Browse files Browse the repository at this point in the history
update plot gp to mask non-finite evidence so that the robust model does not need to filter these out
  • Loading branch information
uremes committed Jul 8, 2024
1 parent bc9e0d1 commit c79e26e
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 12 deletions.
18 changes: 9 additions & 9 deletions elfi/methods/bo/gpy_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,23 +570,23 @@ def optimize(self):
logger.warning("Numerical error in GP optimization. Stopping optimization")

@property
def n_evidence_all(self):
"""Return the number of observed samples with failed simulations included."""
def n_evidence(self):
"""Return the number of observed samples."""
return len(self._Y)

@property
def failed(self):
"""Return inputs that resulted in failed simulations."""
return self._X[~np.isfinite(self._Y).reshape(-1)]
def n_valid_evidence(self):
"""Return the number of valid observed samples."""
return np.sum(np.isfinite(self._Y))

@property
def X_all(self):
"""Return all inputs."""
def X(self):
"""Return input evidence."""
return self._X

@property
def Y_all(self):
"""Return all outputs."""
def Y(self):
"""Return output evidence."""
Y = self._Y.copy()
Y[~np.isfinite(Y)] = self.FAILED_OUTPUT
return Y
Expand Down
2 changes: 1 addition & 1 deletion elfi/methods/inference/bolfi.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def _allow_submit(self, batch_index):
return True

def _should_optimize(self):
current = self.state['n_evidence']
current = self.target_model.n_evidence + self.batch_size
next_update = self.state['last_GP_update'] + self.update_interval
return current >= self.n_initial_evidence and current >= next_update

Expand Down
13 changes: 11 additions & 2 deletions elfi/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,10 @@ def plot_gp(gp, parameter_names, axes=None, resol=50,
shape = (n_plots, n_plots)
axes, kwargs = _create_axes(axes, shape, **kwargs)

x_evidence = gp.X
y_evidence = gp.Y
valid_inds = np.isfinite(gp.Y).squeeze()
x_evidence = gp.X[valid_inds]
y_evidence = gp.Y[valid_inds]
x_evidence_failed = gp.X[~valid_inds]
if const is None:
const = x_evidence[np.argmin(y_evidence), :]
bounds = bounds or gp.bounds
Expand Down Expand Up @@ -488,6 +490,13 @@ def plot_gp(gp, parameter_names, axes=None, resol=50,
color="red",
alpha=0.7,
s=5)
if len(x_evidence_failed) > 0:
axes[jy, ix].scatter(x_evidence_failed[:, ix],
x_evidence_failed[:, jy],
marker="^",
color="blue",
alpha=0.5,
s=7)

if true_params is not None:
axes[jy, ix].plot([true_params[parameter_names[ix]],
Expand Down

0 comments on commit c79e26e

Please sign in to comment.