Skip to content

Commit

Permalink
accumulate sum of xx, yy, and xy
Browse files Browse the repository at this point in the history
  • Loading branch information
kzkadc committed May 7, 2024
1 parent 32ad1db commit cc5d555
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
13 changes: 8 additions & 5 deletions ignite/metrics/maximum_mean_discrepancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def __init__(

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_mmd2 = torch.tensor(0.0, device=self._device)
self._xx_sum = torch.tensor(0.0, device=self._device)
self._yy_sum = torch.tensor(0.0, device=self._device)
self._xy_sum = torch.tensor(0.0, device=self._device)
self._num_batches = 0

@reinit__is_reduced
Expand Down Expand Up @@ -118,14 +120,15 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
YY = (YY.sum() - n) / (n * (n - 1))
XY = XY.sum() / (n * n)

# mmd cannot be negative
mmd2 = (XX - 2.0 * XY + YY).clamp(min=0.0)
self._xx_sum += XX.to(self._device)
self._yy_sum += YY.to(self._device)
self._xy_sum += XY.to(self._device)

self._sum_of_mmd2 += mmd2.to(self._device)
self._num_batches += 1

@sync_all_reduce("_sum_of_mmd2", "_num_batches")
def compute(self) -> float:
if self._num_batches == 0:
raise NotComputableError("MaximumMeanDiscrepacy must have at least one batch before it can be computed.")
return (self._sum_of_mmd2 / self._num_batches).sqrt().item()
mmd2 = (self._xx_sum + self._yy_sum - 2.0 * self._xy_sum).clamp(min=0.0) / self._num_batches
return mmd2.sqrt().item()
8 changes: 5 additions & 3 deletions tests/ignite/metrics/test_maximum_mean_discrepancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_accumulator_detached():
y = torch.tensor([[-2.0, 1.0], [2.0, 3.0]], dtype=torch.float)
mmd.update((x, y))

assert not mmd._sum_of_mmd2.requires_grad
assert not any(acc.requires_grad for acc in (mmd._xx_sum, mmd._yy_sum, mmd._xy_sum))


@pytest.mark.usefixtures("distributed")
Expand Down Expand Up @@ -163,12 +163,14 @@ def test_accumulator_device(self):
for metric_device in metric_devices:
mmd = MaximumMeanDiscrepancy(device=metric_device)

for dev in (mmd._device, mmd._sum_of_mmd2.device):
devices = (mmd._device, mmd._xx_sum, mmd._yy_sum, mmd._xy_sum)
for dev in devices:
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"

x = torch.tensor([[2.0, 3.0], [-2.0, 1.0]]).float()
y = torch.ones(2, 2).float()
mmd.update((x, y))

for dev in (mmd._device, mmd._sum_of_mmd2.device):
devices = (mmd._device, mmd._xx_sum, mmd._yy_sum, mmd._xy_sum)
for dev in devices:
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"

0 comments on commit cc5d555

Please sign in to comment.