Skip to content

Commit

Permalink
Merge pull request #1104 from SheffieldML/1098-cant-instantiate-halft…
Browse files Browse the repository at this point in the history
…-prior

1098 cant instantiate halft prior
  • Loading branch information
MartinBubel authored Oct 28, 2024
2 parents 1fcb408 + d7da7c6 commit 9a31886
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Changelog

## Unreleased
* update prior `__new__` methods #1098 [MartinBubel]

* fix invalid escape sequence #1011 [janmayer]

## v1.13.2 (2024-07-21)
Expand Down
13 changes: 11 additions & 2 deletions GPy/core/parameterization/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,11 @@ class DGPLVM(Prior):
domain = _REAL

def __new__(cls, sigma2, lbl, x_shape):
return super(Prior, cls).__new__(cls, sigma2, lbl, x_shape)
newfunc = super(Prior, cls).__new__
if newfunc is object.__new__:
return newfunc(cls)
else:
return newfunc(cls, sigma2, lbl, x_shape)

def __init__(self, sigma2, lbl, x_shape):
self.sigma2 = sigma2
Expand Down Expand Up @@ -1275,7 +1279,12 @@ def __new__(cls, A, nu): # Singleton:
for instance in cls._instances:
if instance().A == A and instance().nu == nu:
return instance()
o = super(Prior, cls).__new__(cls, A, nu)

newfunc = super(Prior, cls).__new__
if newfunc is object.__new__:
o = newfunc(cls)
else:
o = newfunc(cls, A, nu)
cls._instances.append(weakref.ref(o))
return cls._instances[-1]()

Expand Down
102 changes: 102 additions & 0 deletions GPy/testing/test_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,21 @@
import pytest
import numpy as np
import GPy
from GPy.core.parameterization.priors import (
Gaussian,
Uniform,
LogGaussian,
MultivariateGaussian,
Gamma,
InverseGamma,
DGPLVM,
DGPLVM_KFDA,
DGPLVM_Lamda,
DGPLVM_T,
HalfT,
Exponential,
StudentT,
)


class TestPrior:
Expand Down Expand Up @@ -178,3 +193,90 @@ def test_fixed_domain_check1(self):
# should raise an assertionerror.
with pytest.raises(AssertionError):
m.rbf.set_prior(gaussian)


def initialize_gaussian_prior() -> None:
return Gaussian(0, 1)


def initialize_uniform_prior() -> None:
return Uniform(0, 1)


def initialize_log_gaussian_prior() -> None:
return LogGaussian(0, 1)


def initialize_multivariate_gaussian_prior() -> None:
return MultivariateGaussian(np.zeros(2), np.eye(2))


def initialize_gamma_prior() -> None:
return Gamma(1, 1)


def initialize_inverse_gamma_prior() -> None:
return InverseGamma(1, 1)


def initialize_dgplvm_prior() -> None:
# return DGPLVM(...)
raise NotImplementedError("No idea how to initialize this prior")


def initialize_dgplvm_kfda_prior() -> None:
# return DGPLVM_KFDA(...)
raise NotImplementedError("No idea how to initialize this prior")


def initialize_dgplvm_lamda_prior() -> None:
# return DGPLVM_Lamda(...)
raise NotImplementedError("No idea how to initialize this prior")


def initialize_dgplvm_t_prior() -> None:
# return DGPLVM_T(1, 1, (1, 1))
raise NotImplementedError("No idea how to initialize this prior")


def initialize_half_t_prior() -> None:
return HalfT(1, 1)


def initialize_exponential_prior() -> None:
return Exponential(1)


def initialize_student_t_prior() -> None:
return StudentT(1, 1, 1)


PRIORS = {
"Gaussian": initialize_gaussian_prior,
"Uniform": initialize_uniform_prior,
"LogGaussian": initialize_log_gaussian_prior,
"MultivariateGaussian": initialize_multivariate_gaussian_prior,
"Gamma": initialize_gamma_prior,
"InverseGamma": initialize_inverse_gamma_prior,
# "DGPLVM": initialize_dgplvm_prior,
# "DGPLVM_KFDA": initialize_dgplvm_kfda_prior,
# "DGPLVM_Lamda": initialize_dgplvm_lamda_prior,
# "DGPLVM_T": initialize_dgplvm_t_prior,
"HalfT": initialize_half_t_prior,
"Exponential": initialize_exponential_prior,
"StudentT": initialize_student_t_prior,
}


def check_prior(prior_getter: str) -> None:
prior_getter()


def test_priors() -> None:
for prior_name, prior_getter in PRIORS.items():
try:
check_prior(prior_getter)
except Exception as e:
raise RuntimeError(
f"Failed to initialize {prior_name} prior"
) from e # noqa E501

0 comments on commit 9a31886

Please sign in to comment.