Skip to content

Commit

Permalink
Add MaximumMeanDiscrepancy metric (#3243)
Browse files Browse the repository at this point in the history
* add MaximumMeanDiscrepancy metric

* fix URL

* update formula

* modify test for MMD

* set default var value for np_mmd

* accumulate mmd2

* accumulate sum of xx, yy, and xy

* add reference paper to docstring

* fix accumulator variables

* fix test_accumulator_device
  • Loading branch information
kzkadc authored May 7, 2024
1 parent 3f5febf commit 8c1912a
Show file tree
Hide file tree
Showing 4 changed files with 317 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ Complete list of metrics
Entropy
KLDivergence
JSDivergence
MaximumMeanDiscrepancy
AveragePrecision
CohenKappa
GpuInfo
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ignite.metrics.js_divergence import JSDivergence
from ignite.metrics.kl_divergence import KLDivergence
from ignite.metrics.loss import Loss
from ignite.metrics.maximum_mean_discrepancy import MaximumMeanDiscrepancy
from ignite.metrics.mean_absolute_error import MeanAbsoluteError
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
from ignite.metrics.mean_squared_error import MeanSquaredError
Expand Down Expand Up @@ -61,6 +62,7 @@
"JaccardIndex",
"JSDivergence",
"KLDivergence",
"MaximumMeanDiscrepancy",
"MultiLabelConfusionMatrix",
"MutualInformation",
"Precision",
Expand Down
138 changes: 138 additions & 0 deletions ignite/metrics/maximum_mean_discrepancy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from typing import Callable, Sequence

import torch

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["MaximumMeanDiscrepancy"]


class MaximumMeanDiscrepancy(Metric):
r"""Calculates the mean of `maximum mean discrepancy (MMD)
<https://www.onurtunali.com/ml/2019/03/08/maximum-mean-discrepancy-in-machine-learning.html>`_.
.. math::
\begin{align*}
\text{MMD}^2 (P,Q) &= \underset{\| f \| \leq 1}{\text{sup}} | \mathbb{E}_{X\sim P}[f(X)]
- \mathbb{E}_{Y\sim Q}[f(Y)] |^2 \\
&\approx \frac{1}{B(B-1)} \sum_{i=1}^B \sum_{\substack{j=1 \\ j\neq i}}^B k(\mathbf{x}_i,\mathbf{x}_j)
-\frac{2}{B^2}\sum_{i=1}^B \sum_{j=1}^B k(\mathbf{x}_i,\mathbf{y}_j)
+ \frac{1}{B(B-1)} \sum_{i=1}^B \sum_{\substack{j=1 \\ j\neq i}}^B k(\mathbf{y}_i,\mathbf{y}_j)
\end{align*}
where :math:`B` is the batch size, and :math:`\mathbf{x}_i` and :math:`\mathbf{y}_j` are
feature vectors sampled from :math:`P` and :math:`Q`, respectively.
:math:`k(\mathbf{x},\mathbf{y})=\exp(-\| \mathbf{x}-\mathbf{y} \|^2/ 2\sigma^2)` is the Gaussian RBF kernel.
This metric computes the MMD for each batch and takes the average.
More details can be found in `Gretton et al. 2012`__.
__ https://jmlr.csail.mit.edu/papers/v13/gretton12a.html
- ``update`` must receive output of the form ``(x, y)``.
- ``x`` and ``y`` are expected to be in the same shape :math:`(B, \ldots)`.
Args:
var: the bandwidth :math:`\sigma^2` of the kernel. Default: 1.0
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, this metric requires the output as ``(x, y)``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in the format of
``(x, y)``. If not, ``output_tranform`` can be added
to the metric to transform the output into the form expected by the metric.
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.
.. include:: defaults.rst
:start-after: :orphan:
.. testcode::
metric = MaximumMeanDiscrepancy()
metric.attach(default_evaluator, "mmd")
x = torch.tensor([[-0.80324818, -0.95768364, -0.03807209],
[-0.11059691, -0.38230813, -0.4111988],
[-0.8864329, -0.02890403, -0.60119252],
[-0.68732452, -0.12854739, -0.72095073],
[-0.62604613, -0.52368328, -0.24112842]])
y = torch.tensor([[0.0686768, 0.80502737, 0.53321717],
[0.83849465, 0.59099726, 0.76385441],
[0.68688272, 0.56833803, 0.98100778],
[0.55267761, 0.13084654, 0.45382906],
[0.0754253, 0.70317304, 0.4756805]])
state = default_evaluator.run([[x, y]])
print(state.metrics["mmd"])
.. testoutput::
1.0726975202560425
"""

_state_dict_all_req_keys = ("_xx_sum", "_yy_sum", "_xy_sum", "_num_batches")

def __init__(
self, var: float = 1.0, output_transform: Callable = lambda x: x, device: torch.device = torch.device("cpu")
):
self.var = var
super().__init__(output_transform, device)

@reinit__is_reduced
def reset(self) -> None:
self._xx_sum = torch.tensor(0.0, device=self._device)
self._yy_sum = torch.tensor(0.0, device=self._device)
self._xy_sum = torch.tensor(0.0, device=self._device)
self._num_batches = 0

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
x, y = output[0].detach(), output[1].detach()
if x.shape != y.shape:
raise ValueError(f"x and y must be in the same shape, got {x.shape} != {y.shape}.")

if x.ndim >= 3:
x = x.flatten(start_dim=1)
y = y.flatten(start_dim=1)
elif x.ndim == 1:
raise ValueError(f"x must be in the shape of (B, ...), got {x.shape}.")

xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())
rx = xx.diag().unsqueeze(0).expand_as(xx)
ry = yy.diag().unsqueeze(0).expand_as(yy)

dxx = rx.t() + rx - 2.0 * xx
dyy = ry.t() + ry - 2.0 * yy
dxy = rx.t() + ry - 2.0 * zz

v = self.var
XX = torch.exp(-0.5 * dxx / v)
YY = torch.exp(-0.5 * dyy / v)
XY = torch.exp(-0.5 * dxy / v)

# unbiased
n = x.shape[0]
XX = (XX.sum() - n) / (n * (n - 1))
YY = (YY.sum() - n) / (n * (n - 1))
XY = XY.sum() / (n * n)

self._xx_sum += XX.to(self._device)
self._yy_sum += YY.to(self._device)
self._xy_sum += XY.to(self._device)

self._num_batches += 1

@sync_all_reduce("_xx_sum", "_yy_sum", "_xy_sum", "_num_batches")
def compute(self) -> float:
if self._num_batches == 0:
raise NotComputableError("MaximumMeanDiscrepacy must have at least one batch before it can be computed.")
mmd2 = (self._xx_sum + self._yy_sum - 2.0 * self._xy_sum).clamp(min=0.0) / self._num_batches
return mmd2.sqrt().item()
176 changes: 176 additions & 0 deletions tests/ignite/metrics/test_maximum_mean_discrepancy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from typing import Tuple

import numpy as np
import pytest
import torch
from torch import Tensor

import ignite.distributed as idist
from ignite.engine import Engine
from ignite.exceptions import NotComputableError
from ignite.metrics import MaximumMeanDiscrepancy


def np_mmd2(x: np.ndarray, y: np.ndarray, var: float = 1.0):
n = x.shape[0]
x = x.reshape(n, -1)
y = y.reshape(n, -1)

a = np.arange(n)
ii, jj = np.meshgrid(a, a, indexing="ij")
XX = np.exp(-np.square(x[ii] - x[jj]).sum(axis=2) / (var * 2))
XX = (np.sum(XX) - n) / (n * (n - 1))

XY = np.exp(-np.square(x[ii] - y[jj]).sum(axis=2) / (var * 2))
XY = np.sum(XY) / (n * n)

YY = np.exp(-np.square(y[ii] - y[jj]).sum(axis=2) / (var * 2))
YY = (np.sum(YY) - n) / (n * (n - 1))

mmd2 = np.clip(XX + YY - XY * 2, 0.0, None)
return mmd2


def test_zero_sample():
mmd = MaximumMeanDiscrepancy()
with pytest.raises(
NotComputableError, match=r"MaximumMeanDiscrepacy must have at least one batch before it can be computed"
):
mmd.compute()


def test_shape_mismatch():
mmd = MaximumMeanDiscrepancy()
x = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float)
y = torch.tensor([[-2.0, 1.0]], dtype=torch.float)
with pytest.raises(ValueError, match=r"x and y must be in the same shape, got"):
mmd.update((x, y))


def test_invalid_shape():
mmd = MaximumMeanDiscrepancy()
x = torch.tensor([2.0, 3.0], dtype=torch.float)
y = torch.tensor([4.0, 5.0], dtype=torch.float)
with pytest.raises(ValueError, match=r"x must be in the shape of \(B, ...\), got"):
mmd.update((x, y))


@pytest.fixture(params=list(range(4)))
def test_case(request):
return [
(torch.randn((100, 10)), torch.rand((100, 10)), 10 ** np.random.uniform(-1.0, 0.0), 1),
(torch.rand((100, 500)), torch.randn((100, 500)), 10 ** np.random.uniform(-1.0, 0.0), 1),
# updated batches
(torch.normal(0.0, 5.0, size=(100, 10)), torch.rand((100, 10)), 10 ** np.random.uniform(-1.0, 0.0), 16),
(torch.normal(5.0, 3.0, size=(100, 200)), torch.rand((100, 200)), 10 ** np.random.uniform(-1.0, 0.0), 16),
# image segmentation
(torch.randn((100, 5, 32, 32)), torch.rand((100, 5, 32, 32)), 10 ** np.random.uniform(-1.0, 0.0), 32),
(torch.rand((100, 5, 224, 224)), torch.randn((100, 5, 224, 224)), 10 ** np.random.uniform(-1.0, 0.0), 32),
][request.param]


@pytest.mark.parametrize("n_times", range(5))
def test_compute(n_times, test_case: Tuple[Tensor, Tensor, float, int]):
x, y, var, batch_size = test_case

mmd = MaximumMeanDiscrepancy(var=var)
mmd.reset()

if batch_size > 1:
np_mmd2_sum = 0.0
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
x_batch, y_batch = x[idx : idx + batch_size], y[idx : idx + batch_size]
mmd.update((x_batch, y_batch))

np_mmd2_sum += np_mmd2(x_batch.cpu().numpy(), y_batch.cpu().numpy(), var)

np_res = np.sqrt(np_mmd2_sum / n_iters)
else:
mmd.update((x, y))
np_res = np.sqrt(np_mmd2(x.cpu().numpy(), y.cpu().numpy(), var))

res = mmd.compute()

assert isinstance(res, float)
assert pytest.approx(np_res, abs=1e-4) == res


def test_accumulator_detached():
mmd = MaximumMeanDiscrepancy()

x = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float)
y = torch.tensor([[-2.0, 1.0], [2.0, 3.0]], dtype=torch.float)
mmd.update((x, y))

assert not any(acc.requires_grad for acc in (mmd._xx_sum, mmd._yy_sum, mmd._xy_sum))


@pytest.mark.usefixtures("distributed")
class TestDistributed:
def test_integration(self):
tol = 1e-4
n_iters = 100
batch_size = 10
n_dims = 100

rank = idist.get_rank()
torch.manual_seed(12 + rank)

device = idist.device()
metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)

for metric_device in metric_devices:
y = torch.randn((n_iters * batch_size, n_dims)).float().to(device)
x = torch.normal(2.0, 3.0, size=(n_iters * batch_size, n_dims)).float().to(device)

def data_loader(i):
return x[i * batch_size : (i + 1) * batch_size], y[i * batch_size : (i + 1) * batch_size]

engine = Engine(lambda e, i: data_loader(i))

m = MaximumMeanDiscrepancy(device=metric_device)
m.attach(engine, "mmd")

data = list(range(n_iters))
engine.run(data=data, max_epochs=1)

x = idist.all_gather(x)
y = idist.all_gather(y)

assert "mmd" in engine.state.metrics
res = engine.state.metrics["mmd"]

# compute numpy mmd
true_res = 0.0
for i in range(n_iters):
x_batch, y_batch = data_loader(i)
x_np = x_batch.cpu().numpy()
y_np = y_batch.cpu().numpy()
true_res += np_mmd2(x_np, y_np)

true_res = np.sqrt(true_res / n_iters)
assert pytest.approx(true_res, abs=tol) == res

def test_accumulator_device(self):
device = idist.device()
metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
for metric_device in metric_devices:
mmd = MaximumMeanDiscrepancy(device=metric_device)

devices = (mmd._device, mmd._xx_sum.device, mmd._yy_sum.device, mmd._xy_sum.device)
for dev in devices:
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"

x = torch.tensor([[2.0, 3.0], [-2.0, 1.0]]).float()
y = torch.ones(2, 2).float()
mmd.update((x, y))

devices = (mmd._device, mmd._xx_sum.device, mmd._yy_sum.device, mmd._xy_sum.device)
for dev in devices:
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"

0 comments on commit 8c1912a

Please sign in to comment.