From dd32c025098703b74c78f7a21d5ac8d3c75323df Mon Sep 17 00:00:00 2001 From: Erik Bodin Date: Mon, 23 Jan 2023 12:55:36 +0000 Subject: [PATCH] Added symmetry-breaking init in rbf --- jaxkern/stationary/rbf.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jaxkern/stationary/rbf.py b/jaxkern/stationary/rbf.py index 70f8acd..497c3b0 100644 --- a/jaxkern/stationary/rbf.py +++ b/jaxkern/stationary/rbf.py @@ -64,8 +64,10 @@ def __call__( return K.squeeze() def init_params(self, key: KeyArray) -> Dict: + eps = 1e-3 + random_jitter = jax.random.uniform(key, shape=[self.ndims], minval=-eps, maxval=eps) params = { - "lengthscale": jnp.array([1.0] * self.ndims), + "lengthscale": jnp.array([1.0] * self.ndims) + random_jitter, "variance": jnp.array([1.0]), } return jax.tree_util.tree_map(lambda x: jnp.atleast_1d(x), params)