Skip to content

Commit

Permalink
Merge pull request #19 from antcc/main
Browse files Browse the repository at this point in the history
Minor bug fixes
  • Loading branch information
mikekatz04 authored Aug 26, 2024
2 parents 494f716 + 02a9d8f commit 5b55f9d
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 13 deletions.
2 changes: 1 addition & 1 deletion eryn/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def get_gelman_rubin_convergence_diagnostic(
if chains.shape[2] == 1:
# If no multiple leaves, we squeeze and transpose to the
# right shape to pass to the psrf function, which is (nwalkers, nsamples, ndim)
chains_in = chains.squeeze().transpose((1, 0, 2))
chains_in = chains.squeeze(axis=2).transpose((1, 0, 2))
else:
# Project onto the model dim all chains [in case of RJ and multiple leaves per branch]
inds = self.get_inds(discard=discard, thin=thin)[branch][
Expand Down
13 changes: 2 additions & 11 deletions eryn/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,8 @@ def sample(
)

# update after diagnostic and stopping check
# if updating and using burn_in, need to make sure it does not use
# previous chain samples since they are not stored.
if (
self.update_iterations > 0
and self.update_fn is not None
Expand Down Expand Up @@ -1068,9 +1070,6 @@ def run_mcmc(
)
initial_state = self._previous_state

# setup thin_by info
thin_by = 1 if "thin_by" not in kwargs else kwargs["thin_by"]

# run burn in
if burn is not None and burn != 0:
# prepare kwargs that relate to burn
Expand All @@ -1079,14 +1078,6 @@ def run_mcmc(
burn_kwargs["thin_by"] = 1
i = 0
for results in self.sample(initial_state, iterations=burn, **burn_kwargs):
# if updating and using burn_in, need to make sure it does not use
# previous chain samples since they are not stored.
if (
self.update_iterations > 0
and self.update_fn is not None
and (i + 1) % (self.update_iterations * thin_by) == 0
):
self.update_fn(i, results, self)
i += 1

# run post-burn update
Expand Down
2 changes: 1 addition & 1 deletion eryn/moves/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def accepted(self, accepted):
# set the accepted arrays for all moves
assert isinstance(accepted, np.ndarray)
for move in self.moves:
move.accepted = accepted
move.accepted = accepted.copy()

@property
def acceptance_fraction(self):
Expand Down

0 comments on commit 5b55f9d

Please sign in to comment.