Skip to content

Commit

Permalink
Flatten add kernels dim into features
Browse files Browse the repository at this point in the history
  • Loading branch information
khurram-ghani committed Sep 20, 2024
1 parent 4e9c84b commit 569ad16
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 16 deletions.
23 changes: 18 additions & 5 deletions gpflux/layers/basis_functions/fourier_features/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,17 @@ def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: M
self.n_components = n_components
if isinstance(kernel, gpflow.kernels.MultioutputKernel):
self.is_batched = True
self.is_multioutput = True
self.batch_size = kernel.num_latent_gps
self.sub_kernels = kernel.latent_kernels
elif isinstance(kernel, gpflow.kernels.Combination):
self.is_batched = True
self.is_multioutput = False
self.batch_size = len(kernel.kernels)
self.sub_kernels = kernel.kernels
else:
self.is_batched = False
self.is_multioutput = False
self.batch_size = 1
self.sub_kernels = []

Expand All @@ -68,7 +71,7 @@ def call(self, inputs: TensorType) -> tf.Tensor:
:param inputs: The evaluation points, a tensor with the shape ``[N, D]``.
:return: A tensor with the shape ``[N, M]``, or shape ``[P, N, M]'' in the multioutput case.
:return: A tensor with the shape ``[N, M]``, or shape ``[P, N, M]'' in the batched case.
"""
if self.is_batched:
X = [tf.divide(inputs, k.lengthscales) for k in self.sub_kernels]
Expand All @@ -78,6 +81,13 @@ def call(self, inputs: TensorType) -> tf.Tensor:
const = self._compute_constant() # [] or [P, 1, 1]
bases = self._compute_bases(X) # [N, M] or [P, N, M]
output = const * bases

# For combination kernels, remove batch dimension and instead concatenate into the
# feature dimension.
if self.is_batched and not self.is_multioutput:
output = tf.transpose(output, perm=[1, 2, 0]) # [N, M, P]
output = tf.reshape(output, [output.shape[0], -1]) # [N, M*P]

tf.ensure_shape(output, self.compute_output_shape(inputs.shape))
return output

Expand All @@ -90,12 +100,12 @@ def compute_output_shape(self, input_shape: ShapeType) -> tf.TensorShape:
# TODO: Keras docs say "If the layer has not been built, this method
# will call `build` on the layer." -- do we need to do so?
tensor_shape = tf.TensorShape(input_shape).with_rank(2)
output_dim = self._compute_output_dim(input_shape)
output_dim = self.compute_output_dim(input_shape)
trailing_shape = tensor_shape[:-1].concatenate(output_dim)
if self.is_batched:
if self.is_multioutput:
return tf.TensorShape([self.batch_size]).concatenate(trailing_shape) # [P, N, M]
else:
return trailing_shape # [N, M]
return trailing_shape # [N, M] or [N, M*P]

def get_config(self) -> Mapping:
"""
Expand All @@ -115,7 +125,10 @@ def get_config(self) -> Mapping:
return config

@abstractmethod
def _compute_output_dim(self, input_shape: ShapeType) -> int:
def compute_output_dim(self, input_shape: ShapeType) -> int:
"""
Compute the output dimension of the layer.
"""
pass

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def build(self, input_shape: ShapeType) -> None:
self.factors = tf.Variable(initial_value=omegas_value, trainable=False) # (M^D,)
super(QuadratureFourierFeatures, self).build(input_shape)

def _compute_output_dim(self, input_shape: ShapeType) -> int:
def compute_output_dim(self, input_shape: ShapeType) -> int:
input_dim = input_shape[-1]
return 2 * self.n_components ** input_dim

Expand Down
18 changes: 14 additions & 4 deletions gpflux/layers/basis_functions/fourier_features/random/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,13 @@ class RandomFourierFeatures(RandomFourierFeaturesBase):
from phase-shifted cosines :class:`RandomFourierFeaturesCosine` :cite:p:`sutherland2015error`.
"""

def _compute_output_dim(self, input_shape: ShapeType) -> int:
return 2 * self.n_components
def compute_output_dim(self, input_shape: ShapeType) -> int:
# For combination kernels, the number of features is multiplied by the number of
# sub-kernels.
dim = 2 * self.n_components
if self.is_batched and not self.is_multioutput:
dim *= self.batch_size
return dim

def _compute_bases(self, inputs: TensorType) -> tf.Tensor:
"""
Expand Down Expand Up @@ -281,8 +286,13 @@ def _bias_build(self, n_components: int) -> None:
def _bias_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType:
return tf.random.uniform(shape=shape, maxval=2.0 * np.pi, dtype=dtype)

def _compute_output_dim(self, input_shape: ShapeType) -> int:
return self.n_components
def compute_output_dim(self, input_shape: ShapeType) -> int:
# For combination kernels, the number of features is multiplied by the number of
# sub-kernels.
dim = self.n_components
if self.is_batched and not self.is_multioutput:
dim *= self.batch_size
return dim

def _compute_bases(self, inputs: TensorType) -> tf.Tensor:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,6 @@ def test_multi_random_fourier_features_can_approximate_kernel_multidim(
v = fourier_features(y)
approx_kernel_matrix = u @ tf.linalg.matrix_transpose(v)

if isinstance(multi_kernel, gpflow.kernels.Sum):
approx_kernel_matrix = tf.reduce_sum(approx_kernel_matrix, axis=0)

if isinstance(multi_kernel, gpflow.kernels.MultioutputKernel):
actual_kernel_matrix = multi_kernel.K(x, y, full_output_cov=False)
else:
Expand Down Expand Up @@ -231,9 +228,6 @@ def test_multi_random_fourier_feature_layer_compute_covariance_of_inducing_varia
u = fourier_features(x_new)
approx_kernel_matrix = u @ tf.linalg.matrix_transpose(u)

if isinstance(multi_kernel, gpflow.kernels.Sum):
approx_kernel_matrix = tf.reduce_sum(approx_kernel_matrix, axis=0)

if isinstance(multi_kernel, gpflow.kernels.MultioutputKernel):
actual_kernel_matrix = multi_kernel.K(x_new, x_new, full_output_cov=False)
else:
Expand Down

0 comments on commit 569ad16

Please sign in to comment.