Skip to content

Commit

Permalink
Fix another resample method
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Oct 15, 2024
1 parent 485b416 commit 7ddfb00
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions trieste/models/gpflux/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

from abc import ABC
from itertools import cycle
from typing import Callable, cast

import gpflow.kernels
Expand Down Expand Up @@ -439,19 +440,24 @@ def __init__(self, layer: GPLayer, n_components: int):
dummy_X = inducing_points[0:1, :]

self.__call__(dummy_X)
self.b: TensorType = tf.Variable(self.b)
self.W: TensorType = tf.Variable(self.W)

def resample(self) -> None:
"""
Resample weights and biases.
"""
if not hasattr(self, "_bias_init"):
self.b.assign(self._sample_bias(tf.shape(self.b), dtype=self._dtype))
self.W.assign(self._sample_weights(tf.shape(self.W), dtype=self._dtype))
else:
if isinstance(self.b, tf.Variable):
self.b.assign(self._bias_init(tf.shape(self.b), dtype=self._dtype))
self.W.assign(self._weights_init(tf.shape(self.W), dtype=self._dtype))
else:
tf.debugging.Assert(isinstance(self.b, list), [])
for b in self.b:
b.assign(self._bias_init(tf.shape(b), dtype=self._dtype))

if isinstance(self.W, tf.Variable):
self.W.assign(self._weights_init(self.kernel)(tf.shape(self.W), dtype=self._dtype))
else:
tf.debugging.Assert(isinstance(self.W, list), [])
for W, k in zip(self.W, cycle(self.sub_kernels)):
W.assign(self._weights_init(k)(tf.shape(W), dtype=self._dtype))

def __call__(self, x: TensorType) -> TensorType: # [N, D] -> [N, L + M] or [P, N, L + M]
"""
Expand Down

0 comments on commit 7ddfb00

Please sign in to comment.