Skip to content

Commit

Permalink
Handle unconstrained priors in randomize_hyperparameters (#796)
Browse files Browse the repository at this point in the history
  • Loading branch information
uri-granta authored Nov 27, 2023
1 parent ae11f17 commit 45fbc37
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 13 deletions.
27 changes: 25 additions & 2 deletions tests/unit/models/gpflow/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,29 @@ def test_randomize_hyperparameters_randomizes_kernel_parameters_with_priors(
assert len(np.unique(kernel.lengthscales)) == dim


@random_seed
@pytest.mark.parametrize("compile", [False, True])
def test_randomize_hyperparameters_randomizes_kernel_parameters_with_unconstrained_priors(
dim: int, compile: bool
) -> None:
kernel = gpflow.kernels.RBF(variance=1.0, lengthscales=[0.2] * dim)
kernel.lengthscales = gpflow.Parameter(kernel.lengthscales, transform=tfp.bijectors.Exp())
kernel.lengthscales.prior = tfp.distributions.Uniform(
tf.math.log(tf.constant(0.01, dtype=tf.float64)),
tf.math.log(tf.constant(10.0, dtype=tf.float64)),
)
kernel.lengthscales.prior_on = gpflow.base.PriorOn.UNCONSTRAINED

compiler = tf.function if compile else lambda x: x
compiler(randomize_hyperparameters)(kernel)

npt.assert_allclose(1.0, kernel.variance)
npt.assert_array_equal(dim, kernel.lengthscales.shape)
npt.assert_array_less(kernel.lengthscales, [10.0] * dim)
npt.assert_raises(AssertionError, npt.assert_allclose, [0.2] * dim, kernel.lengthscales)
assert len(np.unique(kernel.lengthscales)) == dim


@random_seed
@pytest.mark.parametrize("compile", [False, True])
def test_randomize_hyperparameters_randomizes_kernel_parameters_with_const_priors(
Expand Down Expand Up @@ -207,13 +230,13 @@ def test_randomize_hyperparameters_samples_from_constraints_when_given_prior_and
kernel.lengthscales = gpflow.Parameter(
kernel.lengthscales, transform=tfp.bijectors.Sigmoid(low=lower, high=upper)
)
kernel.lengthscales.prior = tfp.distributions.Uniform(low=10.0, high=100.0)
kernel.lengthscales.prior = tfp.distributions.Uniform(low=lower, high=upper / 2)

kernel.variance.prior = tfp.distributions.LogNormal(loc=np.float64(-2.0), scale=np.float64(1.0))

randomize_hyperparameters(kernel)

npt.assert_array_less(kernel.lengthscales, [0.5] * dim)
npt.assert_array_less(kernel.lengthscales, [0.25] * dim)
npt.assert_raises(AssertionError, npt.assert_allclose, [0.2] * dim, kernel.lengthscales)


Expand Down
27 changes: 16 additions & 11 deletions trieste/models/gpflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,14 @@ def assert_data_is_compatible(new_data: Dataset, existing_data: Dataset) -> None

def randomize_hyperparameters(object: gpflow.Module) -> None:
"""
Sets hyperparameters to random samples from their constrained domains or (if not constraints
are available) their prior distributions.
Sets hyperparameters to random samples from their prior distributions or (for Sigmoid
constraints with no priors) their constrained domains. Note that it is up to the caller
to ensure that the prior, if defined, is compatible with the transform.
:param object: Any gpflow Module.
"""
for param in object.trainable_parameters:
if isinstance(param.bijector, tfp.bijectors.Sigmoid):
sample = tf.random.uniform(
param.bijector.low.shape,
minval=param.bijector.low,
maxval=param.bijector.high,
dtype=param.bijector.low.dtype,
)
param.assign(sample)
elif param.prior is not None:
if param.prior is not None:
# handle constant priors for multi-dimensional parameters
# Use python conditionals here to avoid creating tensorflow `tf.cond` ops,
# i.e. using `len(param.shape)` instead of `tf.rank(param)`.
Expand All @@ -76,6 +69,18 @@ def randomize_hyperparameters(object: gpflow.Module) -> None:
sample = param.prior.sample(tf.shape(param))
else:
sample = param.prior.sample()
if param.prior_on is gpflow.base.PriorOn.UNCONSTRAINED:
param.unconstrained_variable.assign(sample)
else:
param.assign(sample)

elif isinstance(param.bijector, tfp.bijectors.Sigmoid):
sample = tf.random.uniform(
param.bijector.low.shape,
minval=param.bijector.low,
maxval=param.bijector.high,
dtype=param.bijector.low.dtype,
)
param.assign(sample)


Expand Down

0 comments on commit 45fbc37

Please sign in to comment.