diff --git a/ignite/metrics/maximum_mean_discrepancy.py b/ignite/metrics/maximum_mean_discrepancy.py index e0a18362b36..d92dd5448ce 100644 --- a/ignite/metrics/maximum_mean_discrepancy.py +++ b/ignite/metrics/maximum_mean_discrepancy.py @@ -78,7 +78,7 @@ class MaximumMeanDiscrepancy(Metric): 1.0726975202560425 """ - _state_dict_all_req_keys = ("_sum_of_mmd2", "_num_batches") + _state_dict_all_req_keys = ("_xx_sum", "_yy_sum", "_xy_sum", "_num_batches") def __init__( self, var: float = 1.0, output_transform: Callable = lambda x: x, device: torch.device = torch.device("cpu") @@ -130,7 +130,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None: self._num_batches += 1 - @sync_all_reduce("_sum_of_mmd2", "_num_batches") + @sync_all_reduce("_xx_sum", "_yy_sum", "_xy_sum", "_num_batches") def compute(self) -> float: if self._num_batches == 0: raise NotComputableError("MaximumMeanDiscrepacy must have at least one batch before it can be computed.")