Skip to content

Commit

Permalink
Make metropolis_hastings seedable (#106)
Browse files Browse the repository at this point in the history
* Make metropolis_hastings seedable

* switching order of parameters so as to not break old code.
  • Loading branch information
odedstein authored Dec 1, 2023
1 parent 3570ae3 commit 94ab93a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
11 changes: 8 additions & 3 deletions src/gpytoolbox/metropolis_hastings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import numpy as np
import random

def metropolis_hastings(unnorm_distr, next_sample, x0 , num_samples=100):
def metropolis_hastings(unnorm_distr,
next_sample,
x0,
num_samples=100,
rng=np.random.default_rng()):
"""Randomly sample according to an unnormalized distribution.
Given a function which is proportional to a probabilistic density and a strategy for generating candidate points, returns a set of samples which will asymptotically tend to being a sample a random sample of the unknown distribution.
Expand All @@ -16,6 +19,8 @@ def metropolis_hastings(unnorm_distr, next_sample, x0 , num_samples=100):
Initial sample
num_samples : int
Number of samples in output (this will be *more* than the total number of considered samples or evaluations of unnorm_distr)
rng : numpy rng, optional (default: new `np.random.default_rng()`)
which numpy random number generator to use
Returns
-------
Expand Down Expand Up @@ -59,7 +64,7 @@ def unnorm_distr(x):
f1 = unnorm_distr(x1)
#print(f1)
# Generate random value between 0 and 1
r = random.uniform(0, 1)
r = rng.uniform(0, 1)
#print(f1/f0)
if r<(f1/f0):
# If accepted, update current sample with candidate sample
Expand Down
8 changes: 4 additions & 4 deletions test/test_metropolis_hastings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class TestMetropolisHastings(unittest.TestCase):
def test_analytic_1d(self):
np.random.seed(0)
rng = np.random.default_rng(4352)
# 1D test
# Sample next point from a normal distribution
def next_sample(x0):
Expand All @@ -18,7 +18,7 @@ def next_sample(x0):
def unnorm_distr(x):
return np.max((1-np.abs(x[0]),1e-8))

S, F = gpytoolbox.metropolis_hastings(unnorm_distr,next_sample,np.array([0.1]),1000000)
S, F = gpytoolbox.metropolis_hastings(unnorm_distr,next_sample,np.array([0.1]),1000000,rng)
# This should look like an absolute value pyramid function
hist, bin_edges = np.histogram(S,bins=np.linspace(-1,1,101), density=True)
bin_centers = (bin_edges[0:100] + bin_edges[1:101])/2.
Expand All @@ -30,7 +30,7 @@ def unnorm_distr(x):
# plt.show(block=False)

def test_analytic_2d(self):
np.random.seed(0)
rng = np.random.default_rng(8312)
# plt.pause(10)
# plt.close(plot1)
# 2D test
Expand All @@ -42,7 +42,7 @@ def next_sample(x0):
def unnorm_distr(x):
return 100*multivariate_normal.pdf(x,mean=np.array([0.0,0.0]),cov=np.array([[0.01,0.0],[0.0,0.01]]))

S, F = gpytoolbox.metropolis_hastings(unnorm_distr,next_sample,np.array([0.01,0.01]),500000)
S, F = gpytoolbox.metropolis_hastings(unnorm_distr,next_sample,np.array([0.01,0.01]),500000,rng)

nbins = 40
H, xedges, yedges = np.histogram2d(S[:,0], S[:,1], density=True, bins=nbins)
Expand Down

0 comments on commit 94ab93a

Please sign in to comment.