-
-
Notifications
You must be signed in to change notification settings - Fork 635
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add MaximumMeanDiscrepancy metric (#3243)
* add MaximumMeanDiscrepancy metric * fix URL * update formula * modify test for MMD * set default var value for np_mmd * accumulate mmd2 * accumulate sum of xx, yy, and xy * add reference paper to docstring * fix accumulator variables * fix test_accumulator_device
- Loading branch information
Showing
4 changed files
with
317 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
from typing import Callable, Sequence | ||
|
||
import torch | ||
|
||
from ignite.exceptions import NotComputableError | ||
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce | ||
|
||
__all__ = ["MaximumMeanDiscrepancy"] | ||
|
||
|
||
class MaximumMeanDiscrepancy(Metric): | ||
r"""Calculates the mean of `maximum mean discrepancy (MMD) | ||
<https://www.onurtunali.com/ml/2019/03/08/maximum-mean-discrepancy-in-machine-learning.html>`_. | ||
.. math:: | ||
\begin{align*} | ||
\text{MMD}^2 (P,Q) &= \underset{\| f \| \leq 1}{\text{sup}} | \mathbb{E}_{X\sim P}[f(X)] | ||
- \mathbb{E}_{Y\sim Q}[f(Y)] |^2 \\ | ||
&\approx \frac{1}{B(B-1)} \sum_{i=1}^B \sum_{\substack{j=1 \\ j\neq i}}^B k(\mathbf{x}_i,\mathbf{x}_j) | ||
-\frac{2}{B^2}\sum_{i=1}^B \sum_{j=1}^B k(\mathbf{x}_i,\mathbf{y}_j) | ||
+ \frac{1}{B(B-1)} \sum_{i=1}^B \sum_{\substack{j=1 \\ j\neq i}}^B k(\mathbf{y}_i,\mathbf{y}_j) | ||
\end{align*} | ||
where :math:`B` is the batch size, and :math:`\mathbf{x}_i` and :math:`\mathbf{y}_j` are | ||
feature vectors sampled from :math:`P` and :math:`Q`, respectively. | ||
:math:`k(\mathbf{x},\mathbf{y})=\exp(-\| \mathbf{x}-\mathbf{y} \|^2/ 2\sigma^2)` is the Gaussian RBF kernel. | ||
This metric computes the MMD for each batch and takes the average. | ||
More details can be found in `Gretton et al. 2012`__. | ||
__ https://jmlr.csail.mit.edu/papers/v13/gretton12a.html | ||
- ``update`` must receive output of the form ``(x, y)``. | ||
- ``x`` and ``y`` are expected to be in the same shape :math:`(B, \ldots)`. | ||
Args: | ||
var: the bandwidth :math:`\sigma^2` of the kernel. Default: 1.0 | ||
output_transform: a callable that is used to transform the | ||
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the | ||
form expected by the metric. This can be useful if, for example, you have a multi-output model and | ||
you want to compute the metric with respect to one of the outputs. | ||
By default, this metric requires the output as ``(x, y)``. | ||
device: specifies which device updates are accumulated on. Setting the | ||
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is | ||
non-blocking. By default, CPU. | ||
Examples: | ||
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. | ||
The output of the engine's ``process_function`` needs to be in the format of | ||
``(x, y)``. If not, ``output_tranform`` can be added | ||
to the metric to transform the output into the form expected by the metric. | ||
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. | ||
.. include:: defaults.rst | ||
:start-after: :orphan: | ||
.. testcode:: | ||
metric = MaximumMeanDiscrepancy() | ||
metric.attach(default_evaluator, "mmd") | ||
x = torch.tensor([[-0.80324818, -0.95768364, -0.03807209], | ||
[-0.11059691, -0.38230813, -0.4111988], | ||
[-0.8864329, -0.02890403, -0.60119252], | ||
[-0.68732452, -0.12854739, -0.72095073], | ||
[-0.62604613, -0.52368328, -0.24112842]]) | ||
y = torch.tensor([[0.0686768, 0.80502737, 0.53321717], | ||
[0.83849465, 0.59099726, 0.76385441], | ||
[0.68688272, 0.56833803, 0.98100778], | ||
[0.55267761, 0.13084654, 0.45382906], | ||
[0.0754253, 0.70317304, 0.4756805]]) | ||
state = default_evaluator.run([[x, y]]) | ||
print(state.metrics["mmd"]) | ||
.. testoutput:: | ||
1.0726975202560425 | ||
""" | ||
|
||
_state_dict_all_req_keys = ("_xx_sum", "_yy_sum", "_xy_sum", "_num_batches") | ||
|
||
def __init__( | ||
self, var: float = 1.0, output_transform: Callable = lambda x: x, device: torch.device = torch.device("cpu") | ||
): | ||
self.var = var | ||
super().__init__(output_transform, device) | ||
|
||
@reinit__is_reduced | ||
def reset(self) -> None: | ||
self._xx_sum = torch.tensor(0.0, device=self._device) | ||
self._yy_sum = torch.tensor(0.0, device=self._device) | ||
self._xy_sum = torch.tensor(0.0, device=self._device) | ||
self._num_batches = 0 | ||
|
||
@reinit__is_reduced | ||
def update(self, output: Sequence[torch.Tensor]) -> None: | ||
x, y = output[0].detach(), output[1].detach() | ||
if x.shape != y.shape: | ||
raise ValueError(f"x and y must be in the same shape, got {x.shape} != {y.shape}.") | ||
|
||
if x.ndim >= 3: | ||
x = x.flatten(start_dim=1) | ||
y = y.flatten(start_dim=1) | ||
elif x.ndim == 1: | ||
raise ValueError(f"x must be in the shape of (B, ...), got {x.shape}.") | ||
|
||
xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t()) | ||
rx = xx.diag().unsqueeze(0).expand_as(xx) | ||
ry = yy.diag().unsqueeze(0).expand_as(yy) | ||
|
||
dxx = rx.t() + rx - 2.0 * xx | ||
dyy = ry.t() + ry - 2.0 * yy | ||
dxy = rx.t() + ry - 2.0 * zz | ||
|
||
v = self.var | ||
XX = torch.exp(-0.5 * dxx / v) | ||
YY = torch.exp(-0.5 * dyy / v) | ||
XY = torch.exp(-0.5 * dxy / v) | ||
|
||
# unbiased | ||
n = x.shape[0] | ||
XX = (XX.sum() - n) / (n * (n - 1)) | ||
YY = (YY.sum() - n) / (n * (n - 1)) | ||
XY = XY.sum() / (n * n) | ||
|
||
self._xx_sum += XX.to(self._device) | ||
self._yy_sum += YY.to(self._device) | ||
self._xy_sum += XY.to(self._device) | ||
|
||
self._num_batches += 1 | ||
|
||
@sync_all_reduce("_xx_sum", "_yy_sum", "_xy_sum", "_num_batches") | ||
def compute(self) -> float: | ||
if self._num_batches == 0: | ||
raise NotComputableError("MaximumMeanDiscrepacy must have at least one batch before it can be computed.") | ||
mmd2 = (self._xx_sum + self._yy_sum - 2.0 * self._xy_sum).clamp(min=0.0) / self._num_batches | ||
return mmd2.sqrt().item() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
from typing import Tuple | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
from torch import Tensor | ||
|
||
import ignite.distributed as idist | ||
from ignite.engine import Engine | ||
from ignite.exceptions import NotComputableError | ||
from ignite.metrics import MaximumMeanDiscrepancy | ||
|
||
|
||
def np_mmd2(x: np.ndarray, y: np.ndarray, var: float = 1.0): | ||
n = x.shape[0] | ||
x = x.reshape(n, -1) | ||
y = y.reshape(n, -1) | ||
|
||
a = np.arange(n) | ||
ii, jj = np.meshgrid(a, a, indexing="ij") | ||
XX = np.exp(-np.square(x[ii] - x[jj]).sum(axis=2) / (var * 2)) | ||
XX = (np.sum(XX) - n) / (n * (n - 1)) | ||
|
||
XY = np.exp(-np.square(x[ii] - y[jj]).sum(axis=2) / (var * 2)) | ||
XY = np.sum(XY) / (n * n) | ||
|
||
YY = np.exp(-np.square(y[ii] - y[jj]).sum(axis=2) / (var * 2)) | ||
YY = (np.sum(YY) - n) / (n * (n - 1)) | ||
|
||
mmd2 = np.clip(XX + YY - XY * 2, 0.0, None) | ||
return mmd2 | ||
|
||
|
||
def test_zero_sample(): | ||
mmd = MaximumMeanDiscrepancy() | ||
with pytest.raises( | ||
NotComputableError, match=r"MaximumMeanDiscrepacy must have at least one batch before it can be computed" | ||
): | ||
mmd.compute() | ||
|
||
|
||
def test_shape_mismatch(): | ||
mmd = MaximumMeanDiscrepancy() | ||
x = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float) | ||
y = torch.tensor([[-2.0, 1.0]], dtype=torch.float) | ||
with pytest.raises(ValueError, match=r"x and y must be in the same shape, got"): | ||
mmd.update((x, y)) | ||
|
||
|
||
def test_invalid_shape(): | ||
mmd = MaximumMeanDiscrepancy() | ||
x = torch.tensor([2.0, 3.0], dtype=torch.float) | ||
y = torch.tensor([4.0, 5.0], dtype=torch.float) | ||
with pytest.raises(ValueError, match=r"x must be in the shape of \(B, ...\), got"): | ||
mmd.update((x, y)) | ||
|
||
|
||
@pytest.fixture(params=list(range(4))) | ||
def test_case(request): | ||
return [ | ||
(torch.randn((100, 10)), torch.rand((100, 10)), 10 ** np.random.uniform(-1.0, 0.0), 1), | ||
(torch.rand((100, 500)), torch.randn((100, 500)), 10 ** np.random.uniform(-1.0, 0.0), 1), | ||
# updated batches | ||
(torch.normal(0.0, 5.0, size=(100, 10)), torch.rand((100, 10)), 10 ** np.random.uniform(-1.0, 0.0), 16), | ||
(torch.normal(5.0, 3.0, size=(100, 200)), torch.rand((100, 200)), 10 ** np.random.uniform(-1.0, 0.0), 16), | ||
# image segmentation | ||
(torch.randn((100, 5, 32, 32)), torch.rand((100, 5, 32, 32)), 10 ** np.random.uniform(-1.0, 0.0), 32), | ||
(torch.rand((100, 5, 224, 224)), torch.randn((100, 5, 224, 224)), 10 ** np.random.uniform(-1.0, 0.0), 32), | ||
][request.param] | ||
|
||
|
||
@pytest.mark.parametrize("n_times", range(5)) | ||
def test_compute(n_times, test_case: Tuple[Tensor, Tensor, float, int]): | ||
x, y, var, batch_size = test_case | ||
|
||
mmd = MaximumMeanDiscrepancy(var=var) | ||
mmd.reset() | ||
|
||
if batch_size > 1: | ||
np_mmd2_sum = 0.0 | ||
n_iters = y.shape[0] // batch_size + 1 | ||
for i in range(n_iters): | ||
idx = i * batch_size | ||
x_batch, y_batch = x[idx : idx + batch_size], y[idx : idx + batch_size] | ||
mmd.update((x_batch, y_batch)) | ||
|
||
np_mmd2_sum += np_mmd2(x_batch.cpu().numpy(), y_batch.cpu().numpy(), var) | ||
|
||
np_res = np.sqrt(np_mmd2_sum / n_iters) | ||
else: | ||
mmd.update((x, y)) | ||
np_res = np.sqrt(np_mmd2(x.cpu().numpy(), y.cpu().numpy(), var)) | ||
|
||
res = mmd.compute() | ||
|
||
assert isinstance(res, float) | ||
assert pytest.approx(np_res, abs=1e-4) == res | ||
|
||
|
||
def test_accumulator_detached(): | ||
mmd = MaximumMeanDiscrepancy() | ||
|
||
x = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float) | ||
y = torch.tensor([[-2.0, 1.0], [2.0, 3.0]], dtype=torch.float) | ||
mmd.update((x, y)) | ||
|
||
assert not any(acc.requires_grad for acc in (mmd._xx_sum, mmd._yy_sum, mmd._xy_sum)) | ||
|
||
|
||
@pytest.mark.usefixtures("distributed") | ||
class TestDistributed: | ||
def test_integration(self): | ||
tol = 1e-4 | ||
n_iters = 100 | ||
batch_size = 10 | ||
n_dims = 100 | ||
|
||
rank = idist.get_rank() | ||
torch.manual_seed(12 + rank) | ||
|
||
device = idist.device() | ||
metric_devices = [torch.device("cpu")] | ||
if device.type != "xla": | ||
metric_devices.append(device) | ||
|
||
for metric_device in metric_devices: | ||
y = torch.randn((n_iters * batch_size, n_dims)).float().to(device) | ||
x = torch.normal(2.0, 3.0, size=(n_iters * batch_size, n_dims)).float().to(device) | ||
|
||
def data_loader(i): | ||
return x[i * batch_size : (i + 1) * batch_size], y[i * batch_size : (i + 1) * batch_size] | ||
|
||
engine = Engine(lambda e, i: data_loader(i)) | ||
|
||
m = MaximumMeanDiscrepancy(device=metric_device) | ||
m.attach(engine, "mmd") | ||
|
||
data = list(range(n_iters)) | ||
engine.run(data=data, max_epochs=1) | ||
|
||
x = idist.all_gather(x) | ||
y = idist.all_gather(y) | ||
|
||
assert "mmd" in engine.state.metrics | ||
res = engine.state.metrics["mmd"] | ||
|
||
# compute numpy mmd | ||
true_res = 0.0 | ||
for i in range(n_iters): | ||
x_batch, y_batch = data_loader(i) | ||
x_np = x_batch.cpu().numpy() | ||
y_np = y_batch.cpu().numpy() | ||
true_res += np_mmd2(x_np, y_np) | ||
|
||
true_res = np.sqrt(true_res / n_iters) | ||
assert pytest.approx(true_res, abs=tol) == res | ||
|
||
def test_accumulator_device(self): | ||
device = idist.device() | ||
metric_devices = [torch.device("cpu")] | ||
if device.type != "xla": | ||
metric_devices.append(device) | ||
for metric_device in metric_devices: | ||
mmd = MaximumMeanDiscrepancy(device=metric_device) | ||
|
||
devices = (mmd._device, mmd._xx_sum.device, mmd._yy_sum.device, mmd._xy_sum.device) | ||
for dev in devices: | ||
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" | ||
|
||
x = torch.tensor([[2.0, 3.0], [-2.0, 1.0]]).float() | ||
y = torch.ones(2, 2).float() | ||
mmd.update((x, y)) | ||
|
||
devices = (mmd._device, mmd._xx_sum.device, mmd._yy_sum.device, mmd._xy_sum.device) | ||
for dev in devices: | ||
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" |