diff --git a/pyproject.toml b/pyproject.toml index 678fd9ec83..7aece92705 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ requires = [ [project] authors = [{ email = "allen.goodman@icloud.com", name = "Allen Goodman" }] dependencies = [ + "numpy>=2.0.0", "pooch", "torch==2.2.2", "torchaudio", diff --git a/src/beignet/__init__.py b/src/beignet/__init__.py index fe3fd28ca7..7d523b5f06 100644 --- a/src/beignet/__init__.py +++ b/src/beignet/__init__.py @@ -194,6 +194,7 @@ from ._linear_probabilists_hermite_polynomial import ( linear_probabilists_hermite_polynomial, ) +from ._maximum_mean_discrepancy import maximum_mean_discrepancy from ._multiply_chebyshev_polynomial import multiply_chebyshev_polynomial from ._multiply_chebyshev_polynomial_by_x import multiply_chebyshev_polynomial_by_x from ._multiply_laguerre_polynomial import multiply_laguerre_polynomial diff --git a/src/beignet/_maximum_mean_discrepancy.py b/src/beignet/_maximum_mean_discrepancy.py new file mode 100644 index 0000000000..0aa440843a --- /dev/null +++ b/src/beignet/_maximum_mean_discrepancy.py @@ -0,0 +1,134 @@ +from typing import Any, Callable, Optional + +import numpy.typing as npt +import torch + + +def maximum_mean_discrepancy( + X: npt.ArrayLike, + Y: npt.ArrayLike, + distance_fn: Optional[Callable[[Any, Any], Any]] = None, + kernel_width: Optional[float] = None, +) -> npt.ArrayLike: + """ + Compute Maximum Mean Discrepancy (MMD) between batched sample sets. + + This function efficiently computes MMD between two sets of samples, + supporting both NumPy arrays and PyTorch tensors with arbitrary + batch dimensions. Uses a Gaussian kernel k(x,y) = exp(-||x-y||²/2γ²) + where γ (kernel_width) is calibrated via median heuristic if + not specified. + + Args: + X: First distribution samples. Shape: (*B, N₁, D) + Y: Second distribution samples. Shape: (*B, N₂, D) + distance_fn: Optional custom distance metric. Uses Euclidean if None. + Must support broadcasting over batch dims. + kernel_width: Optional bandwidth γ for Gaussian kernel. Uses median heuristic + per batch if None. + + Returns: + MMD distance with shape (*B) matching input batch dimensions. + + Raises: + ValueError: If either input has fewer than 2 samples along N dimension. + + Note: + Memory scales as O(BN²) for B = product(batch_dims) and N = max(N₁,N₂). + Operations are vectorized over all dimensions for efficiency. + """ + + if torch.is_tensor(X): + xp = torch + xp_is_torch = True + else: + xp = X.__array_namespace__() # Get array namespace for API operations + xp_is_torch = False + + if X.shape[-2] < 2 or Y.shape[-2] < 2: + raise ValueError("Each distribution must have at least 2 samples") + + if distance_fn is not None: + pass + + elif distance_fn is None and hasattr(xp, "expand_dims"): + + def distance_fn(x, y): + # Broadcasting using array API operations + diff = xp.expand_dims(x, -2) - xp.expand_dims(y, -3) + return xp.sqrt((diff**2).sum(-1)) + + elif distance_fn is None and xp_is_torch: + + def distance_fn(x, y): + diff = xp.unsqueeze(x, -2) - xp.unsqueeze(y, -3) + return xp.sqrt((diff**2).sum(-1)) + else: + raise ValueError("Array namespace does not conform to expected API") + + # Compute kernel matrices + D_XX = distance_fn(X, X) + D_YY = distance_fn(Y, Y) + D_XY = distance_fn(X, Y) + + batch_shape = D_XX.shape[:-2] + if kernel_width is None: + # Preserve all batch dimensions, flatten only distance matrices + all_distances = xp.concat( + [ + xp.reshape(D_XX, (*batch_shape, -1)), + xp.reshape(D_YY, (*batch_shape, -1)), + xp.reshape(D_XY, (*batch_shape, -1)), + ], + axis=-1, + ) + + if xp_is_torch: + kernel_width = xp.median(all_distances, dim=-1).values + else: + kernel_width = xp.median(all_distances, axis=-1) + + # Add necessary dimensions for broadcasting + kernel_width = xp.reshape(kernel_width, (*batch_shape, 1, 1)) + + # Apply RBF kernel using array API operations + sq_kernel_width = kernel_width**2 + K_XX = xp.exp(-0.5 * D_XX**2 / sq_kernel_width) + K_YY = xp.exp(-0.5 * D_YY**2 / sq_kernel_width) + K_XY = xp.exp(-0.5 * D_XY**2 / sq_kernel_width) + + m = X.shape[-2] + n = Y.shape[-2] + + # Compute MMD^2 with diagonal correction using array API operations + if xp_is_torch: + + def batched_trace(x, dim1=-2, dim2=-1): + """Compute trace along last two dims while preserving batch dims.""" + i = torch.arange(x.size(dim1)) + return ( + torch.gather(x, dim2, i.expand(x.shape[:-1]).unsqueeze(-1)) + .squeeze(-1) + .sum(-1) + ) + + # Then in the MMD computation: + mmd_squared = ( + (K_XX.sum((-1, -2)) - batched_trace(K_XX)) / (m * (m - 1)) + + (K_YY.sum((-1, -2)) - batched_trace(K_YY)) / (n * (n - 1)) + - 2 * K_XY.mean((-1, -2)) + ) + + else: + mmd_squared = ( + (xp.sum(K_XX, axis=(-1, -2)) - xp.trace(K_XX, axis1=-1, axis2=-2)) + / (m * (m - 1)) + + (xp.sum(K_YY, axis=(-1, -2)) - xp.trace(K_YY, axis1=-1, axis2=-2)) + / (n * (n - 1)) + - 2 * xp.mean(K_XY, axis=(-1, -2)) + ) + + if xp is torch: + return mmd_squared.clamp_min(0.0).sqrt() + + return xp.sqrt(xp.maximum(mmd_squared, 0.0)) diff --git a/tests/beignet/test__maximum_mean_discrepancy.py b/tests/beignet/test__maximum_mean_discrepancy.py new file mode 100644 index 0000000000..95146b1497 --- /dev/null +++ b/tests/beignet/test__maximum_mean_discrepancy.py @@ -0,0 +1,163 @@ +from typing import NamedTuple, Union + +import numpy +import pytest +import torch +from beignet import maximum_mean_discrepancy +from numpy.testing import assert_allclose + +ArrayType = Union[numpy.ndarray, torch.Tensor] + + +class TestData(NamedTuple): + """Container for test arrays to ensure consistent initialization.""" + + X: ArrayType + Y: ArrayType + X_same: ArrayType + X_large: ArrayType + Y_small: ArrayType + + +@pytest.fixture(scope="module") +def numpy_arrays() -> TestData: + """Generate test arrays once per module using vectorized operations.""" + rng = numpy.random.default_rng(42) + # Preallocate all arrays in one block + X = rng.normal(0, 1, (1024, 2)) + Y = rng.normal(5, 1, (1024, 2)) # Different mean for clear separation + X_large = rng.normal(0, 1, (150, 2)) + Y_small = rng.normal(0, 1, (50, 2)) + + return TestData( + X=X, + Y=Y, + X_same=X.copy(), # Explicit copy for identity tests + X_large=X_large, + Y_small=Y_small, + ) + + +@pytest.fixture(scope="module") +def torch_arrays(numpy_arrays: TestData) -> TestData: + """Convert numpy arrays to torch tensors, ensuring numpy is initialized first.""" + return TestData( + X=torch.tensor(numpy_arrays.X), + Y=torch.tensor(numpy_arrays.Y), + X_same=torch.tensor(numpy_arrays.X_same), + X_large=torch.tensor(numpy_arrays.X_large), + Y_small=torch.tensor(numpy_arrays.Y_small), + ) + + +def manhattan_distance(x: ArrayType, y: ArrayType) -> ArrayType: + """Vectorized Manhattan distance using array API operations.""" + xp = x.__array_namespace__() + diff = xp.subtract(xp.expand_dims(x, 1), xp.expand_dims(y, 0)) + return xp.sum(xp.abs(diff), axis=-1) + + +@pytest.mark.parametrize("arrays", ["numpy_arrays", "torch_arrays"]) +def test_mmd_basic(request, arrays): + """Test core MMD functionality with guaranteed initialization order.""" + data = request.getfixturevalue(arrays) + + mmd = maximum_mean_discrepancy(data.X, data.Y) + assert mmd > 0, "MMD should be positive for different distributions" + + mmd_same = maximum_mean_discrepancy(data.X, data.X_same) + assert mmd_same < 0.1, "MMD should be near 0 for identical distributions" + + mmd_xy = maximum_mean_discrepancy(data.X, data.Y) + mmd_yx = maximum_mean_discrepancy(data.Y, data.X) + + if torch.is_tensor(data.X): + assert torch.allclose(mmd_xy, mmd_yx, rtol=1e-5) + else: + assert_allclose(mmd_xy, mmd_yx, rtol=1e-5) + + +@pytest.mark.parametrize("arrays", ["numpy_arrays", "torch_arrays"]) +def test_mmd_validation(request, arrays): + """Test inumpyut validation with single points.""" + data = request.getfixturevalue(arrays) + + with pytest.raises(ValueError): + _ = maximum_mean_discrepancy(data.X[:1], data.Y[:1]) + + mmd = maximum_mean_discrepancy(data.X_large, data.Y_small) + assert numpy.isfinite(float(mmd)) + + +@pytest.mark.parametrize("arrays", ["numpy_arrays", "torch_arrays"]) +def test_mmd_broadcasting(request, arrays): + """Test MMD with batched inputs.""" + data = request.getfixturevalue(arrays) + + # Create batched data (2, B, N, D) + if torch.is_tensor(data.X): + X_batch = data.X[None, None, :, :].repeat(2, 3, 1, 1) + Y_batch = data.Y[None, None, :, :].repeat(2, 3, 1, 1) + else: + X_batch = numpy.ones((2, 3) + data.X.shape) * data.X + Y_batch = numpy.ones((2, 3) + data.Y.shape) * data.Y + + mmd = maximum_mean_discrepancy(X_batch, Y_batch) + assert mmd.shape == (2, 3), f"Expected shape (2, 3), got {mmd.shape}" + + # Check each batch independently matches unbatched computation + for i in range(2): + for j in range(3): + single_mmd = maximum_mean_discrepancy(X_batch[i, j], Y_batch[i, j]) + + if torch.is_tensor(data.X): + assert torch.allclose(mmd[i, j], single_mmd) + else: + assert numpy.allclose(mmd[i, j], single_mmd) + + +def test_mmd_hamming(): + """Test MMD with Hamming distance on string arrays.""" + # Create simple arrays of equal-length strings + rng = numpy.random.default_rng(42) + n_samples = 3 + + # Generate random DNA sequences for efficient Hamming comparison + X = numpy.array( + ["".join(rng.choice(["A", "T", "G", "C"], 32)) for _ in range(n_samples)] + ).reshape(-1, 1) + Y = numpy.array( + ["".join(rng.choice(["A", "R", "G", "C"], 32)) for _ in range(n_samples * 2)] + ).reshape(-1, 1) + + def hamming_distance(x: numpy.ndarray, y: numpy.ndarray) -> numpy.ndarray: + """ + Compute pairwise Hamming distances between arrays of strings. + + Args: + x: First array of strings, shape (n, 1) + y: Second array of strings, shape (m, 1) + + Returns: + Distance matrix of shape (n, m) + """ + # Reshape inputs to flatten them if needed + x_flat = x.reshape(-1) + y_flat = y.reshape(-1) + + n, m = len(x_flat), len(y_flat) + distances = numpy.zeros((n, m), dtype=numpy.float64) + + # Compute pairwise distances + for i in range(n): + for j in range(m): + # For string arrays, count character differences + distances[i, j] = sum( + 1.0 for a, b in zip(x_flat[i], y_flat[j], strict=False) if a != b + ) + + return distances + + mmd = maximum_mean_discrepancy(X, Y, distance_fn=hamming_distance) + print(mmd) + assert mmd > 0, "MMD should be positive for different string distributions"