Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Map reduce index bug #790

Merged
merged 18 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
272da26
#779 At coreax.solvers.composite, changed the return statement of the…
qh681248 Sep 24, 2024
d9045a4
#779 At coreax.solvers.composite, changed the return statement of the…
qh681248 Sep 24, 2024
d87afde
Added a test in unit/test_solvers.py that checks if MapReduce's reduc…
qh681248 Sep 25, 2024
08f666a
Added an if statement on MapReduce.reduce method, it now only assigns…
qh681248 Sep 25, 2024
7d41a4b
replaced mapreduce by map_reduce in test_map_reduce_diverse_selection…
qh681248 Sep 26, 2024
48a1037
In coreax/solvers/composite.py, the reduce method updates indices onl…
qh681248 Sep 26, 2024
1bd8501
Removed the line plt.show() in examples/pounce.py
qh681248 Sep 27, 2024
082ada7
Made requested changes in the PR #790 (fixed type hints, comments etc.)
qh681248 Oct 3, 2024
e5f5122
Removed changed NoneType to None (NoneType is not compatible with pyt…
qh681248 Oct 3, 2024
da827a7
removed a folder that wasn't supposed to be added
qh681248 Oct 3, 2024
c187818
On composite.py changed return statement of _jit_tree from dataset[in…
qh681248 Oct 4, 2024
a4f762c
Added MapReduce bugfix to `CHANGELOG.md`
qh681248 Oct 8, 2024
bbae67d
Added double backticks on comments in `composite.py` when referring t…
qh681248 Oct 8, 2024
cf537ed
Removed a redundant comment on a test on `TestMapReduce` class
qh681248 Oct 8, 2024
c894e07
Added analytic test for `MapReduce`
qh681248 Oct 9, 2024
a5fb1be
docs: make suggested changes in the docstring
qh681248 Oct 14, 2024
b48bda4
docs: make suggested changes in the docstring
qh681248 Oct 17, 2024
2a9a1c4
Merge remote-tracking branch 'origin/main' into bugfix/MapReduce-inde…
qh681248 Oct 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 48 additions & 12 deletions coreax/solvers/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,20 @@
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from jax import Array
from sklearn.neighbors import BallTree, KDTree
from typing_extensions import TypeAlias, override

from coreax.coreset import Coreset, Coresubset
from coreax.data import Data
from coreax.solvers.base import ExplicitSizeSolver, PaddingInvariantSolver, Solver
from coreax.util import tree_zero_pad_leading_axis
from coreax.util import ArrayLike, tree_zero_pad_leading_axis

BinaryTree: TypeAlias = Union[KDTree, BallTree]
_Data = TypeVar("_Data", bound=Data)
_Coreset = TypeVar("_Coreset", Coreset, Coresubset)
_State = TypeVar("_State")
_Indices = TypeVar("_Indices", ArrayLike, None)


class CompositeSolver(
Expand Down Expand Up @@ -125,22 +127,56 @@ def reduce(
# There is no obvious way to use state information here.
del solver_state

def _reduce_coreset(data: _Data) -> tuple[_Coreset, _State]:
def _reduce_coreset(
data: _Data, _indices: Optional[_Indices] = None
) -> tuple[_Coreset, _State, _Indices]:
if len(data) <= self.leaf_size:
return self.base_solver.reduce(data)
partitioned_dataset = _jit_tree(data, self.leaf_size, self.tree_type)
coreset_ensemble, _ = jax.vmap(self.base_solver.reduce)(partitioned_dataset)
coreset, state = self.base_solver.reduce(data)
if _indices is not None:
_indices = _indices[coreset.nodes.data]
return coreset, state, _indices

def wrapper(partition: _Data) -> tuple[_Data, Array]:
"""
Apply the `reduce` method of the base solver on a partition.

This is a wrapper for `reduce()` for processing a single partition.
The data is partitioned with `_jit_tree()`.
The reduction is performed on each partition via `v`map()`.
bk958178 marked this conversation as resolved.
Show resolved Hide resolved
"""
x, _ = self.base_solver.reduce(partition)
return x.coreset, x.nodes.data

partitioned_dataset, partitioned_indices = _jit_tree(
data, self.leaf_size, self.tree_type
)
# Reduce each partition and get indices from each
coreset_ensemble, ensemble_indices = jax.vmap(wrapper)(partitioned_dataset)
# Calculate the indices with respect to the original data
concatenated_indices = jax.vmap(lambda x, index: x[index])(
partitioned_indices, ensemble_indices
)
concatenated_indices = jnp.ravel(concatenated_indices)
_coreset = jtu.tree_map(jnp.concatenate, coreset_ensemble)
return _reduce_coreset(_coreset.coreset)

coreset_wrong_pre_coreset_data, output_solver_state = _reduce_coreset(dataset)
coreset = eqx.tree_at(
lambda x: x.pre_coreset_data, coreset_wrong_pre_coreset_data, dataset
)
if _indices is not None:
final_indices = _indices[concatenated_indices]
else:
final_indices = concatenated_indices
return _reduce_coreset(_coreset, final_indices)

(coreset, output_solver_state, _indices) = _reduce_coreset(dataset)
bk958178 marked this conversation as resolved.
Show resolved Hide resolved
# Correct the pre-coreset data and the indices
coreset = eqx.tree_at(lambda x: x.pre_coreset_data, coreset, dataset)
if _indices is not None:
if isinstance(coreset, Coresubset):
coreset = eqx.tree_at(lambda x: x.nodes.data, coreset, _indices)
return coreset, output_solver_state


def _jit_tree(dataset: _Data, leaf_size: int, tree_type: type[BinaryTree]) -> _Data:
def _jit_tree(
dataset: _Data, leaf_size: int, tree_type: type[BinaryTree]
) -> tuple[_Data, _Indices]:
"""
Return JIT compatible BinaryTree partitioning of 'dataset'.

Expand Down Expand Up @@ -183,4 +219,4 @@ def _binary_tree(_input_data: Data) -> np.ndarray:
return node_indices.reshape(n_leaves, -1).astype(np.int32)

indices = jax.pure_callback(_binary_tree, result_shape, padded_dataset)
return dataset[indices]
return padded_dataset[indices], jnp.arange(len(dataset))[indices]
28 changes: 28 additions & 0 deletions tests/unit/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,34 @@ def test_base_solver(
solver_factory.keywords["base_solver"] = base_solver
solver_factory()

def test_map_reduce_diverse_selection(self):
"""Check if MapReduce returns indices from multiple partitions."""
dataset_size = 40
data_dim = 5
coreset_size = 6
leaf_size = 12

key = jr.PRNGKey(0)
dataset = jr.normal(key, shape=(dataset_size, data_dim))

kernel = SquaredExponentialKernel()
base_solver = KernelHerding(coreset_size=coreset_size, kernel=kernel)
bk958178 marked this conversation as resolved.
Show resolved Hide resolved

solver = MapReduce(base_solver=base_solver, leaf_size=leaf_size)
coreset, _ = solver.reduce(Data(dataset))
selected_indices = coreset.nodes.data

# Check if there are indices beyond the first few
bk958178 marked this conversation as resolved.
Show resolved Hide resolved
assert jnp.any(
selected_indices >= coreset_size
), "MapReduce should select points beyond the first few"

# Check if there are indices from different partitions
partitions_represented = jnp.unique(selected_indices // leaf_size)
assert (
len(partitions_represented) > 1
), "MapReduce should select points from multiple partitions"


class TestCaratheodoryRecombination(RecombinationSolverTest):
"""Tests for :class:`coreax.solvers.recombination.CaratheodoryRecombination`."""
Expand Down