From 4ed880f648d3d0c4093760ad5200388b1727ad5b Mon Sep 17 00:00:00 2001 From: gw265981 <184935895+gw265981@users.noreply.github.com> Date: Mon, 21 Oct 2024 15:48:40 +0100 Subject: [PATCH 1/6] bugfix/update-rpcholesky-state --- coreax/solvers/coresubset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coreax/solvers/coresubset.py b/coreax/solvers/coresubset.py index 6e32fdf9..aef3c7a9 100644 --- a/coreax/solvers/coresubset.py +++ b/coreax/solvers/coresubset.py @@ -485,7 +485,7 @@ def _greedy_body( approximation_matrix = jnp.zeros((num_data_points, self.coreset_size)) init_state = (gramian_diagonal, approximation_matrix, coreset_indices) output_state = jax.lax.fori_loop(0, self.coreset_size, _greedy_body, init_state) - _, _, updated_coreset_indices = output_state + gramian_diagonal, _, updated_coreset_indices = output_state updated_coreset = Coresubset(updated_coreset_indices, dataset) return updated_coreset, RPCholeskyState(gramian_diagonal) From 51a5c45ecc3c5f268df92c29178180043136a3ef Mon Sep 17 00:00:00 2001 From: gw265981 <184935895+gw265981@users.noreply.github.com> Date: Mon, 21 Oct 2024 15:48:40 +0100 Subject: [PATCH 2/6] fix: update rpcholesky state --- coreax/solvers/coresubset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coreax/solvers/coresubset.py b/coreax/solvers/coresubset.py index 6e32fdf9..aef3c7a9 100644 --- a/coreax/solvers/coresubset.py +++ b/coreax/solvers/coresubset.py @@ -485,7 +485,7 @@ def _greedy_body( approximation_matrix = jnp.zeros((num_data_points, self.coreset_size)) init_state = (gramian_diagonal, approximation_matrix, coreset_indices) output_state = jax.lax.fori_loop(0, self.coreset_size, _greedy_body, init_state) - _, _, updated_coreset_indices = output_state + gramian_diagonal, _, updated_coreset_indices = output_state updated_coreset = Coresubset(updated_coreset_indices, dataset) return updated_coreset, RPCholeskyState(gramian_diagonal) From 9bb262869228acd849d5ff80a2030e11a44dc984 Mon Sep 17 00:00:00 2001 From: gw265981 <184935895+gw265981@users.noreply.github.com> Date: Tue, 22 Oct 2024 12:32:48 +0100 Subject: [PATCH 3/6] fix: fix rpcholesky algorithm One of the steps in the iteration was using the non-updated version of the approximation matrix. --- coreax/solvers/coresubset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/coreax/solvers/coresubset.py b/coreax/solvers/coresubset.py index aef3c7a9..9c93a8a0 100644 --- a/coreax/solvers/coresubset.py +++ b/coreax/solvers/coresubset.py @@ -469,7 +469,8 @@ def _greedy_body( ) # Track diagonal of residual matrix and ensure it remains non-negative updated_residual_diagonal = jnp.clip( - residual_diagonal - jnp.square(approximation_matrix[:, i]), min=0 + residual_diagonal - jnp.square(updated_approximation_matrix[:, i]), + min=0, ) if self.unique: # Ensures that index selected_pivot_point can't be drawn again in future From 2cd46350edf28416863513e50394bc28f323c206 Mon Sep 17 00:00:00 2001 From: gw265981 <184935895+gw265981@users.noreply.github.com> Date: Tue, 22 Oct 2024 12:55:59 +0100 Subject: [PATCH 4/6] test: add analytic rpcholesky test --- tests/unit/test_solvers.py | 199 ++++++++++++++++++++++++++++++++++++- 1 file changed, 198 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_solvers.py b/tests/unit/test_solvers.py index 94c1e1f0..7f2dc777 100644 --- a/tests/unit/test_solvers.py +++ b/tests/unit/test_solvers.py @@ -22,7 +22,7 @@ nullcontext as does_not_raise, ) from typing import Literal, NamedTuple, Optional, Union, cast -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import equinox as eqx import jax @@ -1058,6 +1058,203 @@ def test_rpcholesky_state(self, reduce_problem: _ReduceProblem) -> None: expected_state = RPCholeskyState(gramian_diagonal) assert eqx.tree_equal(state, expected_state) + def test_rpcholesky_analytic_unique(self): + r""" + Test RPCholesky with a unique coreset on a small verified instance. + + In this example, we have data of: + + .. math:: + x = \begin{pmatrix} + 0.5 & 0.2 \\ + 0.4 & 0.6 \\ + 0.8 & 0.3 + \end{pmatrix} + + We choose a ``SquaredExponentialKernel`` with ``length_scale`` of + :math:`\frac{1}{\sqrt{2}}`, which produces the following Gram matrix: + + .. math:: + \begin{pmatrix} + 1.0 & 0.84366477 & 0.90483737 \\ + 0.84366477 & 1.0 & 0.7788007 \\ + 0.90483737 & 0.7788007 & 1.0 + \end{pmatrix} + + Note that we do not need to precompute the full Gram matrix, the algorithm + only needs to evaluate the pivot column at each iteration. + + The RPCholesky algorithm iteratively builds a coreset by: + - Sampling pivot points based on the residual diagonal of the Gram matrix + - Updating an approximation matrix and the residual diagonal + + We ask for a coreset of size 2 in this example. We start with an empty coreset + and an approximation matrix :math:`F = \mathbf{0}_{N \times k}`, + where :math:`N = 3, k = 2` in our case. + + We first compute the diagonal of the Gram matrix as: + + .. math:: + d = \begin{pmatrix} + 1 \\ + 1 \\ + 1 + \end{pmatrix} + + For the first iteration (i=0): + + 1. We sample a pivot point proportional to their value on the diagonal. + All choices are equally likely, so let us suppose we choose the pivot with + index = 2. + + 2. We now compute g, the column at index 2, as: + + .. math:: + g = \begin{pmatrix} + 0.90483737 \\ + 0.7788007 \\ + 1.0 + \end{pmatrix} + + 3. Remove overlap with previously chosen columns (not needed on the first + iteration). + + 4. Update the approximation matrix: + + .. math:: + F[:, 0] = g / \sqrt{(g[2])} = \begin{pmatrix} + 0.90483737 \\ + 0.7788007 \\ + 1.0 + \end{pmatrix} + + 5. Update the residual diagonal: + + .. math:: + d = d - |F[:,0]|^2 = \begin{pmatrix} + 0.18126933 \\ + 0.39346947 \\ + 0 + \end{pmatrix} + + For the second iteration (i=1): + + 1. We again sample a pivot point proportional to their value on the updated + residual diagonal, :math:`d`. Let's suppose we choose the most likely pivot here + (index=1). + + 2. We now compute g, the column at index 1, as: + + .. math:: + g = \begin{pmatrix} + 0.84366477 \\ + 1.0 \\ + 0.7788007 + \end{pmatrix} + + 3. Remove overlap with previously chosen columns: + + .. math:: + g = g - F[:, 0] F[1, 0] = \begin{pmatrix} + 0.13897679 \\ + 0.39346947 \\ + 0 + \end{pmatrix} + + 4. Update the approximation matrix: + + .. math:: + F[:, 1] = g / \sqrt{(g[1])} = \begin{pmatrix} + 0.22155766 \\ + 0.62727145 \\ + 0 + \end{pmatrix} + + 5. Update the residual diagonal: + + .. math:: + d = d - |F[:,0]|^2 = \begin{pmatrix} + 0.13218154 \\ + 0 \\ + 0 + \end{pmatrix} + + After this iteration, the final state is: + + .. math:: + F = \begin{pmatrix} + 0.90483737 & 0.22155766 \\ + 0.7788007 & 0.62727145 \\ + 1.0 & 0 + \end{pmatrix}, \quad + d = \begin{pmatrix} + 0.13218154 \\ + 0 \\ + 0 + \end{pmatrix}, \quad + S = \{2, 1\} + + This completes the coreset of size :math:`k = 2`. + """ + # Setup example data + coreset_size = 2 + x = jnp.array( + [ + [0.5, 0.2], + [0.4, 0.6], + [0.8, 0.3], + ] + ) + + # Define a kernel + length_scale = 1.0 / jnp.sqrt(2) + kernel = SquaredExponentialKernel(length_scale=length_scale) + + # Create a mock for the random choice function + def deterministic_choice(*_, p, **__): + """ + Return the index of largest element of p. + + If there is a tie, return the largest index. + This is used to mimic random sampling, where we have a deterministic + sampling approach. + """ + # Find indices where the value equals the maximum + is_max = p == p.max() + # Convert boolean mask to integers and multiply by index + # This way, we'll get the highest index where True appears + indices = jnp.arange(p.shape[0]) + return jnp.where(is_max, indices, -1).max() + + # Generate the coreset + data = Data(x) + solver = RPCholesky( + coreset_size=coreset_size, + random_key=jax.random.PRNGKey(0), # Fixed seed for reproducibility + kernel=kernel, + unique=True, + ) + # Mock the random choice function + with patch("jax.random.choice", deterministic_choice): + coreset, solver_state = solver.reduce(data) + + # Independently computed gramian diagonal + expected_gramian_diagonal = jnp.array([0.13218154, 0.0, 0.0]) + + # Coreset indices forced by our mock choice function + expected_coreset_indices = jnp.array([2, 1]) + + # Check output matches expected + np.testing.assert_array_equal( + coreset.unweighted_indices, expected_coreset_indices + ) + np.testing.assert_array_equal( + coreset.coreset.data, data.data[expected_coreset_indices] + ) + np.testing.assert_array_almost_equal( + solver_state.gramian_diagonal, expected_gramian_diagonal + ) + class TestSteinThinning(RefinementSolverTest, ExplicitSizeSolverTest): """Test cases for :class:`coreax.solvers.coresubset.SteinThinning`.""" From 44feec6873733ead11a588c50d7b5f3e71187f57 Mon Sep 17 00:00:00 2001 From: gw265981 <184935895+gw265981@users.noreply.github.com> Date: Tue, 22 Oct 2024 15:26:36 +0100 Subject: [PATCH 5/6] docs: add analytic rpcholesky example --- .../source/examples/analytical_rpcholesky.rst | 195 ++++++++++++++++++ documentation/source/index.rst | 1 + 2 files changed, 196 insertions(+) create mode 100644 documentation/source/examples/analytical_rpcholesky.rst diff --git a/documentation/source/examples/analytical_rpcholesky.rst b/documentation/source/examples/analytical_rpcholesky.rst new file mode 100644 index 00000000..82636528 --- /dev/null +++ b/documentation/source/examples/analytical_rpcholesky.rst @@ -0,0 +1,195 @@ +Analytical example with RPCholesky +================================== + +In this example, we have data of: + +.. math:: + x = \begin{pmatrix} + 0.5 & 0.2 \\ + 0.4 & 0.6 \\ + 0.8 & 0.3 + \end{pmatrix} + +We choose a ``SquaredExponentialKernel`` with ``length_scale`` of :math:`\frac{1}{\sqrt{2}}`, which produces the following Gram matrix: + +.. math:: + \begin{pmatrix} + 1.0 & 0.84366477 & 0.90483737 \\ + 0.84366477 & 1.0 & 0.7788007 \\ + 0.90483737 & 0.7788007 & 1.0 + \end{pmatrix} + +Note that we do not need to precompute the full Gram matrix, the algorithm +only needs to evaluate the pivot column at each iteration. + +The RPCholesky algorithm iteratively builds a coreset by: + - Sampling pivot points based on the residual diagonal of the kernel Gram matrix + - Updating an approximation matrix and the residual diagonal + +We ask for a coreset of size 2 in this example. We start with an empty coreset +and an approximation matrix :math:`F = \mathbf{0}_{N \times k}`, +where :math:`N = 3, k = 2` in our case. + +We first compute the diagonal of the Gram matrix as: + +.. math:: + d = \begin{pmatrix} + 1 \\ + 1 \\ + 1 + \end{pmatrix} + +For the first iteration (i=0): + +1. We sample a pivot point proportional to their value on the diagonal. All choices are equally likely, so let us suppose we choose the pivot with index = 2. + +2. We now compute g, the column at index 2, as: + +.. math:: + g = \begin{pmatrix} + 0.90483737 \\ + 0.7788007 \\ + 1.0 + \end{pmatrix} + +3. Remove overlap with previously chosen columns (not needed on the first iteration). + +4. Update the approximation matrix: + +.. math:: + F[:, 0] = g / \sqrt{(g[2])} = \begin{pmatrix} + 0.90483737 \\ + 0.7788007 \\ + 1.0 + \end{pmatrix} + +5. Update the residual diagonal: + +.. math:: + d = d - |F[:,0]|^2 = \begin{pmatrix} + 0.18126933 \\ + 0.39346947 \\ + 0 + \end{pmatrix} + +For the second iteration (i=1): + +1. We again sample a pivot point proportional to their value on the updated residual diagonal, :math:`d`. Let's suppose we choose the most likely pivot here (index=1). + +2. We now compute g, the column at index 1, as: + +.. math:: + g = \begin{pmatrix} + 0.84366477 \\ + 1.0 \\ + 0.7788007 + \end{pmatrix} + +3. Remove overlap with previously chosen columns: + +.. math:: + g = g - F[:, 0] F[1, 0] = \begin{pmatrix} + 0.13897679 \\ + 0.39346947 \\ + 0 + \end{pmatrix} + +4. Update the approximation matrix: + +.. math:: + F[:, 1] = g / \sqrt{(g[1])} = \begin{pmatrix} + 0.22155766 \\ + 0.62727145 \\ + 0 + \end{pmatrix} + +5. Update the residual diagonal: + +.. math:: + d = d - |F[:,0]|^2 = \begin{pmatrix} + 0.13218154 \\ + 0 \\ + 0 + \end{pmatrix} + +After this iteration, the final state is: + +.. math:: + F = \begin{pmatrix} + 0.90483737 & 0.22155766 \\ + 0.7788007 & 0.62727145 \\ + 1.0 & 0 + \end{pmatrix}, \quad + d = \begin{pmatrix} + 0.13218154 \\ + 0 \\ + 0 + \end{pmatrix}, \quad + S = \{2, 1\} + +This completes the coreset of size :math:`k = 2`. + +.. code-block:: + import jax.numpy as jnp + import jax.random as jr + from unittest.mock import patch + + from coreax import Data, SquaredExponentialKernel + from coreax.solvers import RPCholesky + + # Setup example data + coreset_size = 2 + x = jnp.array( + [ + [0.5, 0.2], + [0.4, 0.6], + [0.8, 0.3], + ] + ) + + # Define a kernel + length_scale = 1.0 / jnp.sqrt(2) + kernel = SquaredExponentialKernel(length_scale=length_scale) + + # Create a mock for the random choice function + def deterministic_choice(*_, p, **__): + """ + Return the index of largest element of p. + + If there is a tie, return the largest index. + This is used to mimic random sampling, where we have a deterministic + sampling approach. + """ + # Find indices where the value equals the maximum + is_max = p == p.max() + # Convert boolean mask to integers and multiply by index + # This way, we'll get the highest index where True appears + indices = jnp.arange(p.shape[0]) + return jnp.where(is_max, indices, -1).max() + + + # Generate the coreset + data = Data(x) + solver = RPCholesky( + coreset_size=coreset_size, + random_key=jr.PRNGKey(0), # Fixed seed for reproducibility + kernel=kernel, + unique=True, + ) + + # Mock the random choice function + with patch("jax.random.choice", deterministic_choice): + coreset, solver_state = solver.reduce(data) + + # Independently computed gramian diagonal + expected_gramian_diagonal = jnp.array([0.13218154, 0.0, 0.0]) + + # Coreset indices forced by our mock choice function + expected_coreset_indices = jnp.array([2, 1]) + + # Inspect results + print("Chosen coreset:") + print(coreset.unweighted_indices) # The coreset_indices + print(coreset.coreset.data) # The data-points in the coreset + print("Residual diagonal:") + print(solver_state.gramian_diagonal) diff --git a/documentation/source/index.rst b/documentation/source/index.rst index 19044935..f152470e 100644 --- a/documentation/source/index.rst +++ b/documentation/source/index.rst @@ -54,6 +54,7 @@ Contents examples/pounce_map_reduce examples/david_map_reduce_weighted examples/analytical_kernel_herding + examples/analytical_rpcholesky .. toctree:: :maxdepth: 2 From 35ac1b66c1355ad6027694c3121b35bee9fbf2f9 Mon Sep 17 00:00:00 2001 From: gw265981 <184935895+gw265981@users.noreply.github.com> Date: Tue, 22 Oct 2024 16:20:33 +0100 Subject: [PATCH 6/6] fix: fix analytical rpcholesky example page --- documentation/source/examples/analytical_rpcholesky.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/documentation/source/examples/analytical_rpcholesky.rst b/documentation/source/examples/analytical_rpcholesky.rst index 82636528..77ca2ee6 100644 --- a/documentation/source/examples/analytical_rpcholesky.rst +++ b/documentation/source/examples/analytical_rpcholesky.rst @@ -130,6 +130,7 @@ After this iteration, the final state is: This completes the coreset of size :math:`k = 2`. .. code-block:: + import jax.numpy as jnp import jax.random as jr from unittest.mock import patch