diff --git a/eryn/backends/backend.py b/eryn/backends/backend.py index 47eb6e2..5f4e8cf 100644 --- a/eryn/backends/backend.py +++ b/eryn/backends/backend.py @@ -741,7 +741,7 @@ def get_gelman_rubin_convergence_diagnostic( if chains.shape[2] == 1: # If no multiple leaves, we squeeze and transpose to the # right shape to pass to the psrf function, which is (nwalkers, nsamples, ndim) - chains_in = chains.squeeze().transpose((1, 0, 2)) + chains_in = chains.squeeze(axis=2).transpose((1, 0, 2)) else: # Project onto the model dim all chains [in case of RJ and multiple leaves per branch] inds = self.get_inds(discard=discard, thin=thin)[branch][ diff --git a/eryn/ensemble.py b/eryn/ensemble.py index b1b30d2..8c91802 100644 --- a/eryn/ensemble.py +++ b/eryn/ensemble.py @@ -1020,6 +1020,8 @@ def sample( ) # update after diagnostic and stopping check + # if updating and using burn_in, need to make sure it does not use + # previous chain samples since they are not stored. if ( self.update_iterations > 0 and self.update_fn is not None @@ -1068,9 +1070,6 @@ def run_mcmc( ) initial_state = self._previous_state - # setup thin_by info - thin_by = 1 if "thin_by" not in kwargs else kwargs["thin_by"] - # run burn in if burn is not None and burn != 0: # prepare kwargs that relate to burn @@ -1079,14 +1078,6 @@ def run_mcmc( burn_kwargs["thin_by"] = 1 i = 0 for results in self.sample(initial_state, iterations=burn, **burn_kwargs): - # if updating and using burn_in, need to make sure it does not use - # previous chain samples since they are not stored. - if ( - self.update_iterations > 0 - and self.update_fn is not None - and (i + 1) % (self.update_iterations * thin_by) == 0 - ): - self.update_fn(i, results, self) i += 1 # run post-burn update diff --git a/eryn/moves/combine.py b/eryn/moves/combine.py index 5741177..3ab052d 100644 --- a/eryn/moves/combine.py +++ b/eryn/moves/combine.py @@ -45,7 +45,7 @@ def accepted(self, accepted): # set the accepted arrays for all moves assert isinstance(accepted, np.ndarray) for move in self.moves: - move.accepted = accepted + move.accepted = accepted.copy() @property def acceptance_fraction(self):