Skip to content

Commit

Permalink
Start refactoring to handle active_dims in subkernels
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Oct 15, 2024
1 parent 95ed3a8 commit f8a78f0
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 56 deletions.
21 changes: 14 additions & 7 deletions gpflux/layers/basis_functions/fourier_features/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
""" Shared functionality for stationary kernel basis functions. """

from abc import ABC, abstractmethod
from typing import Mapping
from itertools import cycle
from typing import Mapping, Optional

import tensorflow as tf

Expand Down Expand Up @@ -73,13 +74,19 @@ def call(self, inputs: TensorType) -> tf.Tensor:
:return: A tensor with the shape ``[N, M]``, or shape ``[P, N, M]'' in the multioutput case.
"""
const = self._compute_constant() # [] or [P, 1, 1]
if self.is_batched:
X = [tf.divide(inputs, k.lengthscales) for k in self.sub_kernels]
X = tf.stack(X, 0) # [1, N, D] or [P, N, D]
# TODO: handle nested active dims
bases = [
self._compute_bases(tf.divide(k.slice(inputs, None)[0], k.lengthscales), i)
# SharedIndependent repeatedly use the same sub_kernel
for i, k in zip(range(self.batch_size), cycle(self.sub_kernels))
]
bases = tf.stack(bases, axis=0) # [P, N, M]
else:
X = tf.divide(inputs, self.kernel.lengthscales) # [N, D]
const = self._compute_constant() # [] or [P, 1, 1]
bases = self._compute_bases(X) # [N, M] or [P, N, M]
X = tf.divide(self.kernel.slice(inputs, None)[0], self.kernel.lengthscales) # [N, D]
bases = self._compute_bases(X, None) # [N, M]

output = const * bases

if self.is_batched and not self.is_multioutput:
Expand Down Expand Up @@ -139,7 +146,7 @@ def _compute_constant(self) -> tf.Tensor:
pass

@abstractmethod
def _compute_bases(self, inputs: TensorType) -> tf.Tensor:
def _compute_bases(self, inputs: TensorType, batch: Optional[int]) -> tf.Tensor:
"""
Compute basis functions.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""

import warnings
from typing import Mapping, Tuple, Type
from typing import Mapping, Optional, Tuple, Type

import tensorflow as tf

Expand Down Expand Up @@ -75,7 +75,7 @@ def compute_output_dim(self, input_shape: ShapeType) -> int:
input_dim = input_shape[-1]
return 2 * self.n_components ** input_dim

def _compute_bases(self, inputs: TensorType) -> tf.Tensor:
def _compute_bases(self, inputs: TensorType, batch: Optional[int]) -> tf.Tensor:
"""
Compute basis functions.
Expand Down
93 changes: 56 additions & 37 deletions gpflux/layers/basis_functions/fourier_features/random/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Mapping, Optional, Tuple, Type
from itertools import cycle
from typing import Callable, Mapping, Optional, Tuple, Type

import numpy as np
import tensorflow as tf

import gpflow
from gpflow.base import DType, TensorType
from gpflow.kernels import Kernel

from gpflux.layers.basis_functions.fourier_features.base import FourierFeaturesBase
from gpflux.layers.basis_functions.fourier_features.utils import (
Expand Down Expand Up @@ -116,18 +118,32 @@ def build(self, input_shape: ShapeType) -> None:
self._weights_build(input_dim, n_components=self.n_components)
super(RandomFourierFeaturesBase, self).build(input_shape)

def _active_input_dim(self, input_dim: int, kernel: Kernel) -> int:
dummy_X = tf.zeros((0, input_dim), dtype=tf.float64)
return kernel.slice(dummy_X, None)[0].shape[-1]

def _weights_build(self, input_dim: int, n_components: int) -> None:
if self.is_batched:
shape = (self.batch_size, n_components, input_dim) # [P, M, D]
# TODO: handle nested active_dims
self.W = [
self.add_weight(
name="weights",
trainable=False,
shape=(n_components, self._active_input_dim(input_dim, k)),
dtype=self.dtype,
initializer=self._weights_init(k),
)
# SharedIndependent repeatedly use the same sub_kernel
for _, k in zip(range(self.batch_size), cycle(self.sub_kernels))
]
else:
shape = (n_components, input_dim) # type: ignore
self.W = self.add_weight(
name="weights",
trainable=False,
shape=shape,
dtype=self.dtype,
initializer=self._weights_init,
)
self.W = self.add_weight(
name="weights",
trainable=False,
shape=(n_components, self._active_input_dim(input_dim, self.kernel)),
dtype=self.dtype,
initializer=self._weights_init(self.kernel),
)

def _weights_init_individual(
self,
Expand All @@ -142,20 +158,11 @@ def _weights_init_individual(
nu = 2.0 * p + 1.0 # degrees of freedom
return _sample_students_t(nu, shape, dtype)

def _weights_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType:
if self.is_batched:
if isinstance(self.kernel, gpflow.kernels.SharedIndependent):
weights_list = [
self._weights_init_individual(self.sub_kernels[0], shape[1:], dtype)
for _ in range(self.batch_size)
]
else:
weights_list = [
self._weights_init_individual(k, shape[1:], dtype) for k in self.sub_kernels
]
return tf.stack(weights_list, 0) # [P, M, D]
else:
return self._weights_init_individual(self.kernel, shape, dtype) # [M, D]
def _weights_init(self, kernel: Kernel) -> Callable[[TensorType, Optional[DType]], TensorType]:
def _initializer(shape: TensorType, dtype: Optional[DType] = None) -> TensorType:
return self._weights_init_individual(kernel, shape, dtype) # [M, D]

return _initializer

@staticmethod
def rff_constant(variance: TensorType, output_dim: int) -> tf.Tensor:
Expand Down Expand Up @@ -207,13 +214,13 @@ def compute_output_dim(self, input_shape: ShapeType) -> int:
dim *= self.batch_size
return dim

def _compute_bases(self, inputs: TensorType) -> tf.Tensor:
def _compute_bases(self, inputs: TensorType, batch: Optional[int]) -> tf.Tensor:
"""
Compute basis functions.
:return: A tensor with the shape ``[N, 2M]`` or ``[P, N, 2M]``.
"""
return _bases_concat(inputs, self.W)
return _bases_concat(inputs, self.W if batch is None else self.W[batch])

def _compute_constant(self) -> tf.Tensor:
"""
Expand Down Expand Up @@ -272,16 +279,24 @@ def build(self, input_shape: ShapeType) -> None:

def _bias_build(self, n_components: int) -> None:
if self.is_batched:
shape = (self.batch_size, 1, n_components)
self.b = [
self.add_weight(
name="bias",
trainable=False,
shape=(1, n_components),
dtype=self.dtype,
initializer=self._bias_init,
)
for _ in range(self.batch_size)
]
else:
shape = (1, n_components) # type: ignore
self.b = self.add_weight(
name="bias",
trainable=False,
shape=shape,
dtype=self.dtype,
initializer=self._bias_init,
)
self.b = self.add_weight(
name="bias",
trainable=False,
shape=(1, n_components),
dtype=self.dtype,
initializer=self._bias_init,
)

def _bias_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType:
return tf.random.uniform(shape=shape, maxval=2.0 * np.pi, dtype=dtype)
Expand All @@ -294,13 +309,17 @@ def compute_output_dim(self, input_shape: ShapeType) -> int:
dim *= self.batch_size
return dim

def _compute_bases(self, inputs: TensorType) -> tf.Tensor:
def _compute_bases(self, inputs: TensorType, batch: Optional[int]) -> tf.Tensor:
"""
Compute basis functions.
:return: A tensor with the shape ``[N, M]`` or ``[P, N, M]``.
"""
return _bases_cosine(inputs, self.W, self.b)
return _bases_cosine(
inputs,
self.W if batch is None else self.W[batch],
self.b if batch is None else self.b[batch],
)

def _compute_constant(self) -> tf.Tensor:
"""
Expand Down
25 changes: 15 additions & 10 deletions gpflux/layers/basis_functions/fourier_features/random/orthogonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
# limitations under the License.
#

from typing import Mapping, Optional, Tuple, Type
from typing import Callable, Mapping, Optional, Tuple, Type

import numpy as np
import tensorflow as tf

import gpflow
from gpflow.base import DType, TensorType
from gpflow.kernels import Kernel

from gpflux.layers.basis_functions.fourier_features.random.base import RandomFourierFeatures
from gpflux.types import ShapeType
Expand Down Expand Up @@ -73,15 +74,19 @@ def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: M
assert isinstance(kernel, ORF_SUPPORTED_KERNELS), "Unsupported Kernel"
super(OrthogonalRandomFeatures, self).__init__(kernel, n_components, **kwargs)

def _weights_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType:
n_components, input_dim = shape # M, D
n_reps = _ceil_divide(n_components, input_dim) # K, smallest integer s.t. K*D >= M
def _weights_init(self, kernel: Kernel) -> Callable[[TensorType, Optional[DType]], TensorType]:
def _initializer(shape: TensorType, dtype: Optional[DType] = None) -> TensorType:

W = tf.random.normal(shape=(n_reps, input_dim, input_dim), dtype=dtype)
Q, _ = tf.linalg.qr(W) # throw away R; shape [K, D, D]
n_components, input_dim = shape # M, D
n_reps = _ceil_divide(n_components, input_dim) # K, smallest integer s.t. K*D >= M

s = _sample_chi(nu=input_dim, shape=(n_reps, input_dim), dtype=dtype) # shape [K, D]
U = tf.expand_dims(s, axis=-1) * Q # equiv: S @ Q where S = diag(s); shape [K, D, D]
V = tf.reshape(U, shape=(-1, input_dim)) # shape [K*D, D]
W = tf.random.normal(shape=(n_reps, input_dim, input_dim), dtype=dtype)
Q, _ = tf.linalg.qr(W) # throw away R; shape [K, D, D]

return V[: self.n_components] # shape [M, D] (throw away K*D - M rows)
s = _sample_chi(nu=input_dim, shape=(n_reps, input_dim), dtype=dtype) # shape [K, D]
U = tf.expand_dims(s, axis=-1) * Q # equiv: S @ Q where S = diag(s); shape [K, D, D]
V = tf.reshape(U, shape=(-1, input_dim)) # shape [K*D, D]

return V[: self.n_components] # shape [M, D] (throw away K*D - M rows)

return _initializer

0 comments on commit f8a78f0

Please sign in to comment.