From f8a78f0c3f720351e6b7d9c00a50499a15c796b6 Mon Sep 17 00:00:00 2001 From: Uri Granta Date: Tue, 15 Oct 2024 10:25:12 +0100 Subject: [PATCH] Start refactoring to handle active_dims in subkernels --- .../basis_functions/fourier_features/base.py | 21 +++-- .../fourier_features/quadrature/gaussian.py | 4 +- .../fourier_features/random/base.py | 93 +++++++++++-------- .../fourier_features/random/orthogonal.py | 25 +++-- 4 files changed, 87 insertions(+), 56 deletions(-) diff --git a/gpflux/layers/basis_functions/fourier_features/base.py b/gpflux/layers/basis_functions/fourier_features/base.py index cacd9ee9..ffa343a3 100644 --- a/gpflux/layers/basis_functions/fourier_features/base.py +++ b/gpflux/layers/basis_functions/fourier_features/base.py @@ -16,7 +16,8 @@ """ Shared functionality for stationary kernel basis functions. """ from abc import ABC, abstractmethod -from typing import Mapping +from itertools import cycle +from typing import Mapping, Optional import tensorflow as tf @@ -73,13 +74,19 @@ def call(self, inputs: TensorType) -> tf.Tensor: :return: A tensor with the shape ``[N, M]``, or shape ``[P, N, M]'' in the multioutput case. """ + const = self._compute_constant() # [] or [P, 1, 1] if self.is_batched: - X = [tf.divide(inputs, k.lengthscales) for k in self.sub_kernels] - X = tf.stack(X, 0) # [1, N, D] or [P, N, D] + # TODO: handle nested active dims + bases = [ + self._compute_bases(tf.divide(k.slice(inputs, None)[0], k.lengthscales), i) + # SharedIndependent repeatedly use the same sub_kernel + for i, k in zip(range(self.batch_size), cycle(self.sub_kernels)) + ] + bases = tf.stack(bases, axis=0) # [P, N, M] else: - X = tf.divide(inputs, self.kernel.lengthscales) # [N, D] - const = self._compute_constant() # [] or [P, 1, 1] - bases = self._compute_bases(X) # [N, M] or [P, N, M] + X = tf.divide(self.kernel.slice(inputs, None)[0], self.kernel.lengthscales) # [N, D] + bases = self._compute_bases(X, None) # [N, M] + output = const * bases if self.is_batched and not self.is_multioutput: @@ -139,7 +146,7 @@ def _compute_constant(self) -> tf.Tensor: pass @abstractmethod - def _compute_bases(self, inputs: TensorType) -> tf.Tensor: + def _compute_bases(self, inputs: TensorType, batch: Optional[int]) -> tf.Tensor: """ Compute basis functions. """ diff --git a/gpflux/layers/basis_functions/fourier_features/quadrature/gaussian.py b/gpflux/layers/basis_functions/fourier_features/quadrature/gaussian.py index f24c543f..b8e06d63 100644 --- a/gpflux/layers/basis_functions/fourier_features/quadrature/gaussian.py +++ b/gpflux/layers/basis_functions/fourier_features/quadrature/gaussian.py @@ -19,7 +19,7 @@ """ import warnings -from typing import Mapping, Tuple, Type +from typing import Mapping, Optional, Tuple, Type import tensorflow as tf @@ -75,7 +75,7 @@ def compute_output_dim(self, input_shape: ShapeType) -> int: input_dim = input_shape[-1] return 2 * self.n_components ** input_dim - def _compute_bases(self, inputs: TensorType) -> tf.Tensor: + def _compute_bases(self, inputs: TensorType, batch: Optional[int]) -> tf.Tensor: """ Compute basis functions. diff --git a/gpflux/layers/basis_functions/fourier_features/random/base.py b/gpflux/layers/basis_functions/fourier_features/random/base.py index 0751d10e..49b33e9e 100644 --- a/gpflux/layers/basis_functions/fourier_features/random/base.py +++ b/gpflux/layers/basis_functions/fourier_features/random/base.py @@ -13,13 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Mapping, Optional, Tuple, Type +from itertools import cycle +from typing import Callable, Mapping, Optional, Tuple, Type import numpy as np import tensorflow as tf import gpflow from gpflow.base import DType, TensorType +from gpflow.kernels import Kernel from gpflux.layers.basis_functions.fourier_features.base import FourierFeaturesBase from gpflux.layers.basis_functions.fourier_features.utils import ( @@ -116,18 +118,32 @@ def build(self, input_shape: ShapeType) -> None: self._weights_build(input_dim, n_components=self.n_components) super(RandomFourierFeaturesBase, self).build(input_shape) + def _active_input_dim(self, input_dim: int, kernel: Kernel) -> int: + dummy_X = tf.zeros((0, input_dim), dtype=tf.float64) + return kernel.slice(dummy_X, None)[0].shape[-1] + def _weights_build(self, input_dim: int, n_components: int) -> None: if self.is_batched: - shape = (self.batch_size, n_components, input_dim) # [P, M, D] + # TODO: handle nested active_dims + self.W = [ + self.add_weight( + name="weights", + trainable=False, + shape=(n_components, self._active_input_dim(input_dim, k)), + dtype=self.dtype, + initializer=self._weights_init(k), + ) + # SharedIndependent repeatedly use the same sub_kernel + for _, k in zip(range(self.batch_size), cycle(self.sub_kernels)) + ] else: - shape = (n_components, input_dim) # type: ignore - self.W = self.add_weight( - name="weights", - trainable=False, - shape=shape, - dtype=self.dtype, - initializer=self._weights_init, - ) + self.W = self.add_weight( + name="weights", + trainable=False, + shape=(n_components, self._active_input_dim(input_dim, self.kernel)), + dtype=self.dtype, + initializer=self._weights_init(self.kernel), + ) def _weights_init_individual( self, @@ -142,20 +158,11 @@ def _weights_init_individual( nu = 2.0 * p + 1.0 # degrees of freedom return _sample_students_t(nu, shape, dtype) - def _weights_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType: - if self.is_batched: - if isinstance(self.kernel, gpflow.kernels.SharedIndependent): - weights_list = [ - self._weights_init_individual(self.sub_kernels[0], shape[1:], dtype) - for _ in range(self.batch_size) - ] - else: - weights_list = [ - self._weights_init_individual(k, shape[1:], dtype) for k in self.sub_kernels - ] - return tf.stack(weights_list, 0) # [P, M, D] - else: - return self._weights_init_individual(self.kernel, shape, dtype) # [M, D] + def _weights_init(self, kernel: Kernel) -> Callable[[TensorType, Optional[DType]], TensorType]: + def _initializer(shape: TensorType, dtype: Optional[DType] = None) -> TensorType: + return self._weights_init_individual(kernel, shape, dtype) # [M, D] + + return _initializer @staticmethod def rff_constant(variance: TensorType, output_dim: int) -> tf.Tensor: @@ -207,13 +214,13 @@ def compute_output_dim(self, input_shape: ShapeType) -> int: dim *= self.batch_size return dim - def _compute_bases(self, inputs: TensorType) -> tf.Tensor: + def _compute_bases(self, inputs: TensorType, batch: Optional[int]) -> tf.Tensor: """ Compute basis functions. :return: A tensor with the shape ``[N, 2M]`` or ``[P, N, 2M]``. """ - return _bases_concat(inputs, self.W) + return _bases_concat(inputs, self.W if batch is None else self.W[batch]) def _compute_constant(self) -> tf.Tensor: """ @@ -272,16 +279,24 @@ def build(self, input_shape: ShapeType) -> None: def _bias_build(self, n_components: int) -> None: if self.is_batched: - shape = (self.batch_size, 1, n_components) + self.b = [ + self.add_weight( + name="bias", + trainable=False, + shape=(1, n_components), + dtype=self.dtype, + initializer=self._bias_init, + ) + for _ in range(self.batch_size) + ] else: - shape = (1, n_components) # type: ignore - self.b = self.add_weight( - name="bias", - trainable=False, - shape=shape, - dtype=self.dtype, - initializer=self._bias_init, - ) + self.b = self.add_weight( + name="bias", + trainable=False, + shape=(1, n_components), + dtype=self.dtype, + initializer=self._bias_init, + ) def _bias_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType: return tf.random.uniform(shape=shape, maxval=2.0 * np.pi, dtype=dtype) @@ -294,13 +309,17 @@ def compute_output_dim(self, input_shape: ShapeType) -> int: dim *= self.batch_size return dim - def _compute_bases(self, inputs: TensorType) -> tf.Tensor: + def _compute_bases(self, inputs: TensorType, batch: Optional[int]) -> tf.Tensor: """ Compute basis functions. :return: A tensor with the shape ``[N, M]`` or ``[P, N, M]``. """ - return _bases_cosine(inputs, self.W, self.b) + return _bases_cosine( + inputs, + self.W if batch is None else self.W[batch], + self.b if batch is None else self.b[batch], + ) def _compute_constant(self) -> tf.Tensor: """ diff --git a/gpflux/layers/basis_functions/fourier_features/random/orthogonal.py b/gpflux/layers/basis_functions/fourier_features/random/orthogonal.py index 395da743..def312c3 100644 --- a/gpflux/layers/basis_functions/fourier_features/random/orthogonal.py +++ b/gpflux/layers/basis_functions/fourier_features/random/orthogonal.py @@ -14,13 +14,14 @@ # limitations under the License. # -from typing import Mapping, Optional, Tuple, Type +from typing import Callable, Mapping, Optional, Tuple, Type import numpy as np import tensorflow as tf import gpflow from gpflow.base import DType, TensorType +from gpflow.kernels import Kernel from gpflux.layers.basis_functions.fourier_features.random.base import RandomFourierFeatures from gpflux.types import ShapeType @@ -73,15 +74,19 @@ def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: M assert isinstance(kernel, ORF_SUPPORTED_KERNELS), "Unsupported Kernel" super(OrthogonalRandomFeatures, self).__init__(kernel, n_components, **kwargs) - def _weights_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType: - n_components, input_dim = shape # M, D - n_reps = _ceil_divide(n_components, input_dim) # K, smallest integer s.t. K*D >= M + def _weights_init(self, kernel: Kernel) -> Callable[[TensorType, Optional[DType]], TensorType]: + def _initializer(shape: TensorType, dtype: Optional[DType] = None) -> TensorType: - W = tf.random.normal(shape=(n_reps, input_dim, input_dim), dtype=dtype) - Q, _ = tf.linalg.qr(W) # throw away R; shape [K, D, D] + n_components, input_dim = shape # M, D + n_reps = _ceil_divide(n_components, input_dim) # K, smallest integer s.t. K*D >= M - s = _sample_chi(nu=input_dim, shape=(n_reps, input_dim), dtype=dtype) # shape [K, D] - U = tf.expand_dims(s, axis=-1) * Q # equiv: S @ Q where S = diag(s); shape [K, D, D] - V = tf.reshape(U, shape=(-1, input_dim)) # shape [K*D, D] + W = tf.random.normal(shape=(n_reps, input_dim, input_dim), dtype=dtype) + Q, _ = tf.linalg.qr(W) # throw away R; shape [K, D, D] - return V[: self.n_components] # shape [M, D] (throw away K*D - M rows) + s = _sample_chi(nu=input_dim, shape=(n_reps, input_dim), dtype=dtype) # shape [K, D] + U = tf.expand_dims(s, axis=-1) * Q # equiv: S @ Q where S = diag(s); shape [K, D, D] + V = tf.reshape(U, shape=(-1, input_dim)) # shape [K*D, D] + + return V[: self.n_components] # shape [M, D] (throw away K*D - M rows) + + return _initializer