Skip to content

add maximum mean discrepancy metric #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ requires = [
[project]
authors = [{ email = "[email protected]", name = "Allen Goodman" }]
dependencies = [
"numpy>=2.0.0",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let torch manage the numpy dependency

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think torch doesn't depend on numpy now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that what I was seeing

"pooch",
"torch==2.2.2",
"torchaudio",
Expand Down
1 change: 1 addition & 0 deletions src/beignet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@
from ._linear_probabilists_hermite_polynomial import (
linear_probabilists_hermite_polynomial,
)
from ._maximum_mean_discrepancy import maximum_mean_discrepancy
from ._multiply_chebyshev_polynomial import multiply_chebyshev_polynomial
from ._multiply_chebyshev_polynomial_by_x import multiply_chebyshev_polynomial_by_x
from ._multiply_laguerre_polynomial import multiply_laguerre_polynomial
Expand Down
134 changes: 134 additions & 0 deletions src/beignet/_maximum_mean_discrepancy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from typing import Any, Callable, Optional

import numpy.typing as npt
import torch


def maximum_mean_discrepancy(
X: npt.ArrayLike,
Y: npt.ArrayLike,
distance_fn: Optional[Callable[[Any, Any], Any]] = None,
kernel_width: Optional[float] = None,
) -> npt.ArrayLike:
"""
Compute Maximum Mean Discrepancy (MMD) between batched sample sets.

This function efficiently computes MMD between two sets of samples,
supporting both NumPy arrays and PyTorch tensors with arbitrary
batch dimensions. Uses a Gaussian kernel k(x,y) = exp(-||x-y||²/2γ²)
where γ (kernel_width) is calibrated via median heuristic if
not specified.

Args:
X: First distribution samples. Shape: (*B, N₁, D)
Y: Second distribution samples. Shape: (*B, N₂, D)
distance_fn: Optional custom distance metric. Uses Euclidean if None.
Must support broadcasting over batch dims.
kernel_width: Optional bandwidth γ for Gaussian kernel. Uses median heuristic
per batch if None.

Returns:
MMD distance with shape (*B) matching input batch dimensions.

Raises:
ValueError: If either input has fewer than 2 samples along N dimension.

Note:
Memory scales as O(BN²) for B = product(batch_dims) and N = max(N₁,N₂).
Operations are vectorized over all dimensions for efficiency.
"""

if torch.is_tensor(X):
xp = torch
xp_is_torch = True
else:
xp = X.__array_namespace__() # Get array namespace for API operations
xp_is_torch = False

if X.shape[-2] < 2 or Y.shape[-2] < 2:
raise ValueError("Each distribution must have at least 2 samples")

if distance_fn is not None:
pass

elif distance_fn is None and hasattr(xp, "expand_dims"):

def distance_fn(x, y):
# Broadcasting using array API operations
diff = xp.expand_dims(x, -2) - xp.expand_dims(y, -3)
return xp.sqrt((diff**2).sum(-1))

elif distance_fn is None and xp_is_torch:

def distance_fn(x, y):
diff = xp.unsqueeze(x, -2) - xp.unsqueeze(y, -3)
return xp.sqrt((diff**2).sum(-1))
else:
raise ValueError("Array namespace does not conform to expected API")

# Compute kernel matrices
D_XX = distance_fn(X, X)
D_YY = distance_fn(Y, Y)
D_XY = distance_fn(X, Y)

batch_shape = D_XX.shape[:-2]
if kernel_width is None:
# Preserve all batch dimensions, flatten only distance matrices
all_distances = xp.concat(
[
xp.reshape(D_XX, (*batch_shape, -1)),
xp.reshape(D_YY, (*batch_shape, -1)),
xp.reshape(D_XY, (*batch_shape, -1)),
],
axis=-1,
)

if xp_is_torch:
kernel_width = xp.median(all_distances, dim=-1).values
else:
kernel_width = xp.median(all_distances, axis=-1)

# Add necessary dimensions for broadcasting
kernel_width = xp.reshape(kernel_width, (*batch_shape, 1, 1))

# Apply RBF kernel using array API operations
sq_kernel_width = kernel_width**2
K_XX = xp.exp(-0.5 * D_XX**2 / sq_kernel_width)
K_YY = xp.exp(-0.5 * D_YY**2 / sq_kernel_width)
K_XY = xp.exp(-0.5 * D_XY**2 / sq_kernel_width)

m = X.shape[-2]
n = Y.shape[-2]

# Compute MMD^2 with diagonal correction using array API operations
if xp_is_torch:

def batched_trace(x, dim1=-2, dim2=-1):
"""Compute trace along last two dims while preserving batch dims."""
i = torch.arange(x.size(dim1))
return (
torch.gather(x, dim2, i.expand(x.shape[:-1]).unsqueeze(-1))
.squeeze(-1)
.sum(-1)
)

# Then in the MMD computation:
mmd_squared = (
(K_XX.sum((-1, -2)) - batched_trace(K_XX)) / (m * (m - 1))
+ (K_YY.sum((-1, -2)) - batched_trace(K_YY)) / (n * (n - 1))
- 2 * K_XY.mean((-1, -2))
)

else:
mmd_squared = (
(xp.sum(K_XX, axis=(-1, -2)) - xp.trace(K_XX, axis1=-1, axis2=-2))
/ (m * (m - 1))
+ (xp.sum(K_YY, axis=(-1, -2)) - xp.trace(K_YY, axis1=-1, axis2=-2))
/ (n * (n - 1))
- 2 * xp.mean(K_XY, axis=(-1, -2))
)

if xp is torch:
return mmd_squared.clamp_min(0.0).sqrt()

return xp.sqrt(xp.maximum(mmd_squared, 0.0))
163 changes: 163 additions & 0 deletions tests/beignet/test__maximum_mean_discrepancy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from typing import NamedTuple, Union

import numpy
import pytest
import torch
from beignet import maximum_mean_discrepancy
from numpy.testing import assert_allclose

ArrayType = Union[numpy.ndarray, torch.Tensor]


class TestData(NamedTuple):
"""Container for test arrays to ensure consistent initialization."""

X: ArrayType
Y: ArrayType
X_same: ArrayType
X_large: ArrayType
Y_small: ArrayType


@pytest.fixture(scope="module")
def numpy_arrays() -> TestData:
"""Generate test arrays once per module using vectorized operations."""
rng = numpy.random.default_rng(42)
# Preallocate all arrays in one block
X = rng.normal(0, 1, (1024, 2))
Y = rng.normal(5, 1, (1024, 2)) # Different mean for clear separation
X_large = rng.normal(0, 1, (150, 2))
Y_small = rng.normal(0, 1, (50, 2))

return TestData(
X=X,
Y=Y,
X_same=X.copy(), # Explicit copy for identity tests
X_large=X_large,
Y_small=Y_small,
)


@pytest.fixture(scope="module")
def torch_arrays(numpy_arrays: TestData) -> TestData:
"""Convert numpy arrays to torch tensors, ensuring numpy is initialized first."""
return TestData(
X=torch.tensor(numpy_arrays.X),
Y=torch.tensor(numpy_arrays.Y),
X_same=torch.tensor(numpy_arrays.X_same),
X_large=torch.tensor(numpy_arrays.X_large),
Y_small=torch.tensor(numpy_arrays.Y_small),
)


def manhattan_distance(x: ArrayType, y: ArrayType) -> ArrayType:
"""Vectorized Manhattan distance using array API operations."""
xp = x.__array_namespace__()
diff = xp.subtract(xp.expand_dims(x, 1), xp.expand_dims(y, 0))
return xp.sum(xp.abs(diff), axis=-1)


@pytest.mark.parametrize("arrays", ["numpy_arrays", "torch_arrays"])
def test_mmd_basic(request, arrays):
"""Test core MMD functionality with guaranteed initialization order."""
data = request.getfixturevalue(arrays)

mmd = maximum_mean_discrepancy(data.X, data.Y)
assert mmd > 0, "MMD should be positive for different distributions"

mmd_same = maximum_mean_discrepancy(data.X, data.X_same)
assert mmd_same < 0.1, "MMD should be near 0 for identical distributions"

mmd_xy = maximum_mean_discrepancy(data.X, data.Y)
mmd_yx = maximum_mean_discrepancy(data.Y, data.X)

if torch.is_tensor(data.X):
assert torch.allclose(mmd_xy, mmd_yx, rtol=1e-5)
else:
assert_allclose(mmd_xy, mmd_yx, rtol=1e-5)


@pytest.mark.parametrize("arrays", ["numpy_arrays", "torch_arrays"])
def test_mmd_validation(request, arrays):
"""Test inumpyut validation with single points."""
data = request.getfixturevalue(arrays)

with pytest.raises(ValueError):
_ = maximum_mean_discrepancy(data.X[:1], data.Y[:1])

mmd = maximum_mean_discrepancy(data.X_large, data.Y_small)
assert numpy.isfinite(float(mmd))


@pytest.mark.parametrize("arrays", ["numpy_arrays", "torch_arrays"])
def test_mmd_broadcasting(request, arrays):
"""Test MMD with batched inputs."""
data = request.getfixturevalue(arrays)

# Create batched data (2, B, N, D)
if torch.is_tensor(data.X):
X_batch = data.X[None, None, :, :].repeat(2, 3, 1, 1)
Y_batch = data.Y[None, None, :, :].repeat(2, 3, 1, 1)
else:
X_batch = numpy.ones((2, 3) + data.X.shape) * data.X
Y_batch = numpy.ones((2, 3) + data.Y.shape) * data.Y

mmd = maximum_mean_discrepancy(X_batch, Y_batch)
assert mmd.shape == (2, 3), f"Expected shape (2, 3), got {mmd.shape}"

# Check each batch independently matches unbatched computation
for i in range(2):
for j in range(3):
single_mmd = maximum_mean_discrepancy(X_batch[i, j], Y_batch[i, j])

if torch.is_tensor(data.X):
assert torch.allclose(mmd[i, j], single_mmd)
else:
assert numpy.allclose(mmd[i, j], single_mmd)


def test_mmd_hamming():
"""Test MMD with Hamming distance on string arrays."""
# Create simple arrays of equal-length strings
rng = numpy.random.default_rng(42)
n_samples = 3

# Generate random DNA sequences for efficient Hamming comparison
X = numpy.array(
["".join(rng.choice(["A", "T", "G", "C"], 32)) for _ in range(n_samples)]
).reshape(-1, 1)
Y = numpy.array(
["".join(rng.choice(["A", "R", "G", "C"], 32)) for _ in range(n_samples * 2)]
).reshape(-1, 1)

def hamming_distance(x: numpy.ndarray, y: numpy.ndarray) -> numpy.ndarray:
"""
Compute pairwise Hamming distances between arrays of strings.

Args:
x: First array of strings, shape (n, 1)
y: Second array of strings, shape (m, 1)

Returns:
Distance matrix of shape (n, m)
"""
# Reshape inputs to flatten them if needed
x_flat = x.reshape(-1)
y_flat = y.reshape(-1)

n, m = len(x_flat), len(y_flat)
distances = numpy.zeros((n, m), dtype=numpy.float64)

# Compute pairwise distances
for i in range(n):
for j in range(m):
# For string arrays, count character differences
distances[i, j] = sum(
1.0 for a, b in zip(x_flat[i], y_flat[j], strict=False) if a != b
)

return distances

mmd = maximum_mean_discrepancy(X, Y, distance_fn=hamming_distance)
print(mmd)
assert mmd > 0, "MMD should be positive for different string distributions"