Skip to content

Commit

Permalink
fix accumulator variables
Browse files Browse the repository at this point in the history
  • Loading branch information
kzkadc committed May 7, 2024
1 parent 2884f70 commit 9f2ca1b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ignite/metrics/maximum_mean_discrepancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class MaximumMeanDiscrepancy(Metric):
1.0726975202560425
"""

_state_dict_all_req_keys = ("_sum_of_mmd2", "_num_batches")
_state_dict_all_req_keys = ("_xx_sum", "_yy_sum", "_xy_sum", "_num_batches")

def __init__(
self, var: float = 1.0, output_transform: Callable = lambda x: x, device: torch.device = torch.device("cpu")
Expand Down Expand Up @@ -130,7 +130,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None:

self._num_batches += 1

@sync_all_reduce("_sum_of_mmd2", "_num_batches")
@sync_all_reduce("_xx_sum", "_yy_sum", "_xy_sum", "_num_batches")
def compute(self) -> float:
if self._num_batches == 0:
raise NotComputableError("MaximumMeanDiscrepacy must have at least one batch before it can be computed.")
Expand Down

0 comments on commit 9f2ca1b

Please sign in to comment.