From e5f474c6997145b5ee243872d6aee79dc366d4c0 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sat, 13 Apr 2024 13:58:51 +0900 Subject: [PATCH 1/8] add KLDivergence metric --- docs/source/metrics.rst | 1 + ignite/metrics/__init__.py | 2 + ignite/metrics/kl_divergence.py | 102 +++++++++++++ tests/ignite/metrics/test_kl_divergence.py | 158 +++++++++++++++++++++ 4 files changed, 263 insertions(+) create mode 100644 ignite/metrics/kl_divergence.py create mode 100644 tests/ignite/metrics/test_kl_divergence.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index f6742f73be5..ecc2682728a 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -353,6 +353,7 @@ Complete list of metrics FID CosineSimilarity Entropy + KLDivergence AveragePrecision CohenKappa GpuInfo diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index 05ce97c4066..65aabf8bc7b 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -14,6 +14,7 @@ from ignite.metrics.gan.fid import FID from ignite.metrics.gan.inception_score import InceptionScore from ignite.metrics.gpu_info import GpuInfo +from ignite.metrics.kl_divergence import KLDivergence from ignite.metrics.loss import Loss from ignite.metrics.mean_absolute_error import MeanAbsoluteError from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance @@ -57,6 +58,7 @@ "InceptionScore", "mIoU", "JaccardIndex", + "KLDivergence", "MultiLabelConfusionMatrix", "MutualInformation", "Precision", diff --git a/ignite/metrics/kl_divergence.py b/ignite/metrics/kl_divergence.py new file mode 100644 index 00000000000..47723f6d5b0 --- /dev/null +++ b/ignite/metrics/kl_divergence.py @@ -0,0 +1,102 @@ +from typing import Sequence + +import torch +import torch.nn.functional as F + +from ignite.exceptions import NotComputableError +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce + +__all__ = ["KLDivergence"] + + +class KLDivergence(Metric): + r"""Calculates the mean of `Kullback-Leibler (KL) divergence + `_. + + .. math:: D_\text{KL}(\mathbf{p}_i \| \mathbf{q}_i) = \sum_{c=1}^C p_{i,c} \log \frac{p_{i,c}}{q_{i,c}} + + where :math:`\mathbf{p}_i` and :math:`\mathbf{q}_i` are the prediction and ground true probability tensors. + + - ``update`` must receive output of the form ``(y_pred, y)``. + - ``y_pred`` and ``y`` are expected to be the unnormalized logits for each class. :math:`(B, C)` (classification) + or :math:`(B, C, ...)` (e.g., image segmentation) shapes are allowed. + + Args: + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + + Examples: + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in the format of + ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added + to the metric to transform the output into the form expected by the metric. + + For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. + + .. include:: defaults.rst + :start-after: :orphan: + + .. testcode:: + + metric = KLDivergence() + metric.attach(default_evaluator, 'kl-div') + y_true = torch.tensor([ + [ 0.0000, -2.3026, -2.3026], + [ 1.3863, 1.6094, 1.6094], + [ 0.0000, 0.6931, 1.0986] + ]) + y_pred = torch.tensor([ + [ 0.0000, 0.6931, 1.0986], + [ 1.3863, 1.6094, 1.6094], + [ 0.0000, -2.3026, -2.3026] + ]) + state = default_evaluator.run([[y_pred, y_true]]) + print(state.metrics['kl-div']) + + .. testoutput:: + + 0.7220296859741211 + """ + + _state_dict_all_req_keys = ("_sum_of_kl", "_num_examples") + + @reinit__is_reduced + def reset(self) -> None: + self._sum_of_kl = torch.tensor(0.0, device=self._device) + self._num_examples = 0 + + @reinit__is_reduced + def update(self, output: Sequence[torch.Tensor]) -> None: + y_pred, y = output[0].detach(), output[1].detach() + if y_pred.shape != y.shape: + raise ValueError(f"y_pred and y must be in the same shape, got {y_pred.shape} != {y.shape}.") + + if y_pred.ndim >= 3: + num_classes = y_pred.shape[1] + # (B, C, ...) -> (B, ..., C) -> (B*..., C) + # regarding as B*... predictions + y_pred = y_pred.movedim(1, -1).reshape(-1, num_classes) + y = y.movedim(1, -1).reshape(-1, num_classes) + elif y_pred.ndim == 1: + raise ValueError(f"y_pred must be in the shape of (B, C) or (B, C, ...), got {y_pred.shape}.") + + self._num_examples += y_pred.shape[0] + self._update(y_pred, y) + + def _update(self, y_pred: torch.Tensor, y: torch.Tensor) -> None: + y_pred = F.log_softmax(y_pred, dim=1) + y = F.log_softmax(y, dim=1) + kl_sum = F.kl_div(y, y_pred, log_target=True, reduction="sum") + self._sum_of_kl += kl_sum.to(self._device) + + @sync_all_reduce("_sum_of_kl", "_num_examples") + def compute(self) -> float: + if self._num_examples == 0: + raise NotComputableError("KLDivergence must have at least one example before it can be computed.") + return self._sum_of_kl.item() / self._num_examples diff --git a/tests/ignite/metrics/test_kl_divergence.py b/tests/ignite/metrics/test_kl_divergence.py new file mode 100644 index 00000000000..b92a336ace1 --- /dev/null +++ b/tests/ignite/metrics/test_kl_divergence.py @@ -0,0 +1,158 @@ +from typing import Tuple + +import numpy as np +import pytest +import torch +from scipy.special import softmax +from scipy.stats import entropy +from torch import Tensor + +import ignite.distributed as idist +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics import KLDivergence + + +def scipy_kl_div(np_y_pred: np.ndarray, np_y: np.ndarray) -> float: + y_pred_prob = softmax(np_y_pred, axis=1) + y_prob = softmax(np_y, axis=1) + kl_mean = entropy(y_pred_prob, y_prob, axis=1).mean() + return kl_mean + + +def test_zero_sample(): + kl_div = KLDivergence() + with pytest.raises( + NotComputableError, match=r"KLDivergence must have at least one example before it can be computed" + ): + kl_div.compute() + + +def test_shape_mismatch(): + kl_div = KLDivergence() + y_pred = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float) + y = torch.tensor([[-2.0, 1.0]], dtype=torch.float) + with pytest.raises(ValueError, match=r"y_pred and y must be in the same shape, got"): + kl_div.update((y_pred, y)) + + +def test_invalid_shape(): + kl_div = KLDivergence() + y_pred = torch.tensor([2.0, 3.0], dtype=torch.float) + y = torch.tensor([4.0, 5.0], dtype=torch.float) + with pytest.raises(ValueError, match=r"y_pred must be in the shape of \(B, C\) or \(B, C, ...\), got"): + kl_div.update((y_pred, y)) + + +@pytest.fixture(params=list(range(4))) +def test_case(request): + return [ + (torch.randn((100, 10)), torch.rand((100, 10)), 1), + (torch.rand((100, 500)), torch.randn((100, 500)), 1), + # updated batches + (torch.normal(0.0, 5.0, size=(100, 10)), torch.rand((100, 10)), 16), + (torch.normal(5.0, 3.0, size=(100, 200)), torch.rand((100, 200)), 16), + # image segmentation + (torch.randn((100, 5, 32, 32)), torch.rand((100, 5, 32, 32)), 16), + (torch.rand((100, 5, 224, 224)), torch.randn((100, 5, 224, 224)), 16), + ][request.param] + + +@pytest.mark.parametrize("n_times", range(5)) +def test_compute(n_times, test_case: Tuple[Tensor, Tensor, int]): + y_pred, y, batch_size = test_case + + kl_div = KLDivergence() + + kl_div.reset() + if batch_size > 1: + n_iters = y.shape[0] // batch_size + 1 + for i in range(n_iters): + idx = i * batch_size + kl_div.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + else: + kl_div.update((y_pred, y)) + + res = kl_div.compute() + + np_y_pred = y_pred.numpy() + np_y = y.numpy() + + np_res = scipy_kl_div(np_y_pred, np_y) + + assert isinstance(res, float) + assert pytest.approx(np_res, rel=1e-4) == res + + +def test_accumulator_detached(): + kl_div = KLDivergence() + + y_pred = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float) + y = torch.tensor([[-2.0, 1.0], [2.0, 3.0]], dtype=torch.float) + kl_div.update((y_pred, y)) + + assert not kl_div._sum_of_kl.requires_grad + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_integration(self): + tol = 1e-4 + n_iters = 100 + batch_size = 10 + n_dims = 100 + + rank = idist.get_rank() + torch.manual_seed(12 + rank) + + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + for metric_device in metric_devices: + y_true = torch.randn((n_iters * batch_size, n_dims)).float().to(device) + y_preds = torch.normal(2.0, 3.0, size=(n_iters * batch_size, n_dims)).float().to(device) + + engine = Engine( + lambda e, i: ( + y_preds[i * batch_size : (i + 1) * batch_size], + y_true[i * batch_size : (i + 1) * batch_size], + ) + ) + + m = KLDivergence(device=metric_device) + m.attach(engine, "kl_div") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) + + y_preds = idist.all_gather(y_preds) + y_true = idist.all_gather(y_true) + + assert "kl_div" in engine.state.metrics + res = engine.state.metrics["kl_div"] + + y_true_np = y_true.cpu().numpy() + y_preds_np = y_preds.cpu().numpy() + true_res = scipy_kl_div(y_preds_np, y_true_np) + + assert pytest.approx(true_res, rel=tol) == res + + def test_accumulator_device(self): + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + for metric_device in metric_devices: + kl_div = KLDivergence(device=metric_device) + + for dev in (kl_div._device, kl_div._sum_of_kl.device): + assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" + + y_pred = torch.tensor([[2.0, 3.0], [-2.0, 1.0]]).float() + y = torch.ones(2, 2).float() + kl_div.update((y_pred, y)) + + for dev in (kl_div._device, kl_div._sum_of_kl.device): + assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" From a455192bf0f2ce70402b3e46bbc3fa1b74c4adbf Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sat, 13 Apr 2024 13:59:21 +0900 Subject: [PATCH 2/8] add JSDivergence --- docs/source/metrics.rst | 1 + ignite/metrics/__init__.py | 2 + ignite/metrics/js_divergence.py | 87 +++++++++++ tests/ignite/metrics/test_js_divergence.py | 159 +++++++++++++++++++++ 4 files changed, 249 insertions(+) create mode 100644 ignite/metrics/js_divergence.py create mode 100644 tests/ignite/metrics/test_js_divergence.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index ecc2682728a..a7f90b754d9 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -354,6 +354,7 @@ Complete list of metrics CosineSimilarity Entropy KLDivergence + JSDivergence AveragePrecision CohenKappa GpuInfo diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index 65aabf8bc7b..2cc55aace66 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -14,6 +14,7 @@ from ignite.metrics.gan.fid import FID from ignite.metrics.gan.inception_score import InceptionScore from ignite.metrics.gpu_info import GpuInfo +from ignite.metrics.js_divergence import JSDivergence from ignite.metrics.kl_divergence import KLDivergence from ignite.metrics.loss import Loss from ignite.metrics.mean_absolute_error import MeanAbsoluteError @@ -58,6 +59,7 @@ "InceptionScore", "mIoU", "JaccardIndex", + "JSDivergence", "KLDivergence", "MultiLabelConfusionMatrix", "MutualInformation", diff --git a/ignite/metrics/js_divergence.py b/ignite/metrics/js_divergence.py new file mode 100644 index 00000000000..6e757afad09 --- /dev/null +++ b/ignite/metrics/js_divergence.py @@ -0,0 +1,87 @@ +import torch +import torch.nn.functional as F + +from ignite.exceptions import NotComputableError +from ignite.metrics.kl_divergence import KLDivergence +from ignite.metrics.metric import sync_all_reduce + +__all__ = ["JSDivergence"] + + +class JSDivergence(KLDivergence): + r"""Calculates the mean of `Jensen-Shannon (JS) divergence + `_. + + .. math:: + \begin{align*} + D_\text{JS}(\mathbf{p}_i \| \mathbf{q}_i) &= \frac{1}{2} D_\text{KL}(\mathbf{p}_i \| \mathbf{m}_i) + + \frac{1}{2} D_\text{KL}(\mathbf{q}_i \| \mathbf{m}_i), \\ + \mathbf{m}_i &= \frac{1}{2}(\mathbf{p}_i + \mathbf{q}_i), \\ + D_\text{KL}(\mathbf{p}_i, \mathbf{q}_i) &= \sum_{c=1}^C p_{i,c} \log \frac{p_{i,c}}{q_{i,c}}. + \end{align*} + + where :math:`\mathbf{p}_i` and :math:`\mathbf{q}_i` are the prediction and ground true probability tensors, + and :math:`D_\text{KL}` is the KL-divergence. + + - ``update`` must receive output of the form ``(y_pred, y)``. + - ``y_pred`` and ``y`` are expected to be the unnormalized logits for each class. :math:`(B, C)` (classification) + or :math:`(B, C, ...)` (e.g., image segmentation) shapes are allowed. + + Args: + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + + Examples: + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in the format of + ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added + to the metric to transform the output into the form expected by the metric. + + For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. + + .. include:: defaults.rst + :start-after: :orphan: + + .. testcode:: + + metric = JSDivergence() + metric.attach(default_evaluator, 'js-div') + y_true = torch.tensor([ + [ 0.0000, -2.3026, -2.3026], + [ 1.3863, 1.6094, 1.6094], + [ 0.0000, 0.6931, 1.0986] + ]) + y_pred = torch.tensor([ + [ 0.0000, 0.6931, 1.0986], + [ 1.3863, 1.6094, 1.6094], + [ 0.0000, -2.3026, -2.3026] + ]) + state = default_evaluator.run([[y_pred, y_true]]) + print(state.metrics['js-div']) + + .. testoutput:: + + 0.16266516844431558 + """ + + def _update(self, y_pred: torch.Tensor, y: torch.Tensor) -> None: + m_prob = (F.softmax(y_pred, dim=1) + F.softmax(y, dim=1)) / 2 + m_log = m_prob.log() + y_pred = F.log_softmax(y_pred, dim=1) + y = F.log_softmax(y, dim=1) + self._sum_of_kl += ( + F.kl_div(m_log, y_pred, log_target=True, reduction="sum") + + F.kl_div(m_log, y, log_target=True, reduction="sum") + ).to(self._device) + + @sync_all_reduce("_sum_of_kl", "_num_examples") + def compute(self) -> float: + if self._num_examples == 0: + raise NotComputableError("JSDivergence must have at least one example before it can be computed.") + return self._sum_of_kl.item() / (self._num_examples * 2) diff --git a/tests/ignite/metrics/test_js_divergence.py b/tests/ignite/metrics/test_js_divergence.py new file mode 100644 index 00000000000..581d9de2f0c --- /dev/null +++ b/tests/ignite/metrics/test_js_divergence.py @@ -0,0 +1,159 @@ +from typing import Tuple + +import numpy as np +import pytest +import torch +from scipy.spatial.distance import jensenshannon +from scipy.special import softmax +from torch import Tensor + +import ignite.distributed as idist +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics import JSDivergence + + +def scipy_js_div(np_y_pred: np.ndarray, np_y: np.ndarray) -> float: + y_pred_prob = softmax(np_y_pred, axis=1) + y_prob = softmax(np_y, axis=1) + # jensenshannon computes the sqrt of the JS divergence + js_mean = np.mean(np.square(jensenshannon(y_pred_prob, y_prob, axis=1))) + return js_mean + + +def test_zero_sample(): + js_div = JSDivergence() + with pytest.raises( + NotComputableError, match=r"JSDivergence must have at least one example before it can be computed" + ): + js_div.compute() + + +def test_shape_mismatch(): + js_div = JSDivergence() + y_pred = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float) + y = torch.tensor([[-2.0, 1.0]], dtype=torch.float) + with pytest.raises(ValueError, match=r"y_pred and y must be in the same shape, got"): + js_div.update((y_pred, y)) + + +def test_invalid_shape(): + js_div = JSDivergence() + y_pred = torch.tensor([2.0, 3.0], dtype=torch.float) + y = torch.tensor([4.0, 5.0], dtype=torch.float) + with pytest.raises(ValueError, match=r"y_pred must be in the shape of \(B, C\) or \(B, C, ...\), got"): + js_div.update((y_pred, y)) + + +@pytest.fixture(params=list(range(4))) +def test_case(request): + return [ + (torch.randn((100, 10)), torch.rand((100, 10)), 1), + (torch.rand((100, 500)), torch.randn((100, 500)), 1), + # updated batches + (torch.normal(0.0, 5.0, size=(100, 10)), torch.rand((100, 10)), 16), + (torch.normal(5.0, 3.0, size=(100, 200)), torch.rand((100, 200)), 16), + # image segmentation + (torch.randn((100, 5, 32, 32)), torch.rand((100, 5, 32, 32)), 16), + (torch.rand((100, 5, 224, 224)), torch.randn((100, 5, 224, 224)), 16), + ][request.param] + + +@pytest.mark.parametrize("n_times", range(5)) +def test_compute(n_times, test_case: Tuple[Tensor, Tensor, int]): + y_pred, y, batch_size = test_case + + js_div = JSDivergence() + + js_div.reset() + if batch_size > 1: + n_iters = y.shape[0] // batch_size + 1 + for i in range(n_iters): + idx = i * batch_size + js_div.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + else: + js_div.update((y_pred, y)) + + res = js_div.compute() + + np_y_pred = y_pred.numpy() + np_y = y.numpy() + + np_res = scipy_js_div(np_y_pred, np_y) + + assert isinstance(res, float) + assert pytest.approx(np_res, rel=1e-4) == res + + +def test_accumulator_detached(): + js_div = JSDivergence() + + y_pred = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float) + y = torch.tensor([[-2.0, 1.0], [2.0, 3.0]], dtype=torch.float) + js_div.update((y_pred, y)) + + assert not js_div._sum_of_kl.requires_grad + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_integration(self): + tol = 1e-4 + n_iters = 100 + batch_size = 10 + n_dims = 100 + + rank = idist.get_rank() + torch.manual_seed(12 + rank) + + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + for metric_device in metric_devices: + y_true = torch.randn((n_iters * batch_size, n_dims)).float().to(device) + y_preds = torch.normal(2.0, 3.0, size=(n_iters * batch_size, n_dims)).float().to(device) + + engine = Engine( + lambda e, i: ( + y_preds[i * batch_size : (i + 1) * batch_size], + y_true[i * batch_size : (i + 1) * batch_size], + ) + ) + + m = JSDivergence(device=metric_device) + m.attach(engine, "js_div") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) + + y_preds = idist.all_gather(y_preds) + y_true = idist.all_gather(y_true) + + assert "js_div" in engine.state.metrics + res = engine.state.metrics["js_div"] + + y_true_np = y_true.cpu().numpy() + y_preds_np = y_preds.cpu().numpy() + true_res = scipy_js_div(y_preds_np, y_true_np) + + assert pytest.approx(true_res, rel=tol) == res + + def test_accumulator_device(self): + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + for metric_device in metric_devices: + kl_div = JSDivergence(device=metric_device) + + for dev in (kl_div._device, kl_div._sum_of_kl.device): + assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" + + y_pred = torch.tensor([[2.0, 3.0], [-2.0, 1.0]]).float() + y = torch.ones(2, 2).float() + kl_div.update((y_pred, y)) + + for dev in (kl_div._device, kl_div._sum_of_kl.device): + assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" From def3d0ca1693b4dc40bf3db7d10882ed3a9ab6e9 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sat, 13 Apr 2024 14:03:25 +0900 Subject: [PATCH 3/8] fix variable name --- tests/ignite/metrics/test_js_divergence.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/ignite/metrics/test_js_divergence.py b/tests/ignite/metrics/test_js_divergence.py index 581d9de2f0c..ac4721ef99d 100644 --- a/tests/ignite/metrics/test_js_divergence.py +++ b/tests/ignite/metrics/test_js_divergence.py @@ -146,14 +146,14 @@ def test_accumulator_device(self): if device.type != "xla": metric_devices.append(device) for metric_device in metric_devices: - kl_div = JSDivergence(device=metric_device) + js_div = JSDivergence(device=metric_device) - for dev in (kl_div._device, kl_div._sum_of_kl.device): + for dev in (js_div._device, js_div._sum_of_kl.device): assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" y_pred = torch.tensor([[2.0, 3.0], [-2.0, 1.0]]).float() y = torch.ones(2, 2).float() - kl_div.update((y_pred, y)) + js_div.update((y_pred, y)) - for dev in (kl_div._device, kl_div._sum_of_kl.device): + for dev in (js_div._device, js_div._sum_of_kl.device): assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" From c8f75b65748ed9ee66d5cfdc81c75168df2e0fe5 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sat, 13 Apr 2024 14:12:42 +0900 Subject: [PATCH 4/8] update docstring for JSDivergence --- ignite/metrics/js_divergence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/metrics/js_divergence.py b/ignite/metrics/js_divergence.py index 6e757afad09..92ff2dc1304 100644 --- a/ignite/metrics/js_divergence.py +++ b/ignite/metrics/js_divergence.py @@ -17,7 +17,7 @@ class JSDivergence(KLDivergence): D_\text{JS}(\mathbf{p}_i \| \mathbf{q}_i) &= \frac{1}{2} D_\text{KL}(\mathbf{p}_i \| \mathbf{m}_i) + \frac{1}{2} D_\text{KL}(\mathbf{q}_i \| \mathbf{m}_i), \\ \mathbf{m}_i &= \frac{1}{2}(\mathbf{p}_i + \mathbf{q}_i), \\ - D_\text{KL}(\mathbf{p}_i, \mathbf{q}_i) &= \sum_{c=1}^C p_{i,c} \log \frac{p_{i,c}}{q_{i,c}}. + D_\text{KL}(\mathbf{p}_i \| \mathbf{q}_i) &= \sum_{c=1}^C p_{i,c} \log \frac{p_{i,c}}{q_{i,c}}. \end{align*} where :math:`\mathbf{p}_i` and :math:`\mathbf{q}_i` are the prediction and ground true probability tensors, From ab72665ddcf95e133d9165dcab63849c799d22f6 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sun, 14 Apr 2024 23:28:02 +0900 Subject: [PATCH 5/8] Update ignite/metrics/js_divergence.py Co-authored-by: vfdev --- ignite/metrics/js_divergence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/metrics/js_divergence.py b/ignite/metrics/js_divergence.py index 92ff2dc1304..ffe22191413 100644 --- a/ignite/metrics/js_divergence.py +++ b/ignite/metrics/js_divergence.py @@ -20,7 +20,7 @@ class JSDivergence(KLDivergence): D_\text{KL}(\mathbf{p}_i \| \mathbf{q}_i) &= \sum_{c=1}^C p_{i,c} \log \frac{p_{i,c}}{q_{i,c}}. \end{align*} - where :math:`\mathbf{p}_i` and :math:`\mathbf{q}_i` are the prediction and ground true probability tensors, + where :math:`\mathbf{p}_i` and :math:`\mathbf{q}_i` are the prediction and ground truth probability tensors, and :math:`D_\text{KL}` is the KL-divergence. - ``update`` must receive output of the form ``(y_pred, y)``. From b026e717c7f1595c2b4589081655b116ed887ca0 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sun, 14 Apr 2024 23:28:15 +0900 Subject: [PATCH 6/8] Update ignite/metrics/kl_divergence.py Co-authored-by: vfdev --- ignite/metrics/kl_divergence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/metrics/kl_divergence.py b/ignite/metrics/kl_divergence.py index 47723f6d5b0..4efe42f23ae 100644 --- a/ignite/metrics/kl_divergence.py +++ b/ignite/metrics/kl_divergence.py @@ -15,7 +15,7 @@ class KLDivergence(Metric): .. math:: D_\text{KL}(\mathbf{p}_i \| \mathbf{q}_i) = \sum_{c=1}^C p_{i,c} \log \frac{p_{i,c}}{q_{i,c}} - where :math:`\mathbf{p}_i` and :math:`\mathbf{q}_i` are the prediction and ground true probability tensors. + where :math:`\mathbf{p}_i` and :math:`\mathbf{q}_i` are the prediction and ground truth probability tensors. - ``update`` must receive output of the form ``(y_pred, y)``. - ``y_pred`` and ``y`` are expected to be the unnormalized logits for each class. :math:`(B, C)` (classification) From 273dece7a924365e7cc4679b8eed1a46c75d7af9 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Mon, 15 Apr 2024 23:28:50 +0900 Subject: [PATCH 7/8] swap ground truth and prediction --- ignite/metrics/kl_divergence.py | 4 ++-- tests/ignite/metrics/test_kl_divergence.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ignite/metrics/kl_divergence.py b/ignite/metrics/kl_divergence.py index 4efe42f23ae..99f6cbcfa84 100644 --- a/ignite/metrics/kl_divergence.py +++ b/ignite/metrics/kl_divergence.py @@ -15,7 +15,7 @@ class KLDivergence(Metric): .. math:: D_\text{KL}(\mathbf{p}_i \| \mathbf{q}_i) = \sum_{c=1}^C p_{i,c} \log \frac{p_{i,c}}{q_{i,c}} - where :math:`\mathbf{p}_i` and :math:`\mathbf{q}_i` are the prediction and ground truth probability tensors. + where :math:`\mathbf{p}_i` and :math:`\mathbf{q}_i` are the ground truth and prediction probability tensors. - ``update`` must receive output of the form ``(y_pred, y)``. - ``y_pred`` and ``y`` are expected to be the unnormalized logits for each class. :math:`(B, C)` (classification) @@ -92,7 +92,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None: def _update(self, y_pred: torch.Tensor, y: torch.Tensor) -> None: y_pred = F.log_softmax(y_pred, dim=1) y = F.log_softmax(y, dim=1) - kl_sum = F.kl_div(y, y_pred, log_target=True, reduction="sum") + kl_sum = F.kl_div(y_pred, y, log_target=True, reduction="sum") self._sum_of_kl += kl_sum.to(self._device) @sync_all_reduce("_sum_of_kl", "_num_examples") diff --git a/tests/ignite/metrics/test_kl_divergence.py b/tests/ignite/metrics/test_kl_divergence.py index b92a336ace1..6c9512d4231 100644 --- a/tests/ignite/metrics/test_kl_divergence.py +++ b/tests/ignite/metrics/test_kl_divergence.py @@ -16,7 +16,7 @@ def scipy_kl_div(np_y_pred: np.ndarray, np_y: np.ndarray) -> float: y_pred_prob = softmax(np_y_pred, axis=1) y_prob = softmax(np_y, axis=1) - kl_mean = entropy(y_pred_prob, y_prob, axis=1).mean() + kl_mean = entropy(y_prob, y_pred_prob, axis=1).mean() return kl_mean From 414b25a35cd38bc4e30becffbe27d0e72ed82ecf Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Wed, 17 Apr 2024 22:29:30 +0900 Subject: [PATCH 8/8] swap the definitions of p and q --- ignite/metrics/js_divergence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/metrics/js_divergence.py b/ignite/metrics/js_divergence.py index ffe22191413..1bd37cfedc6 100644 --- a/ignite/metrics/js_divergence.py +++ b/ignite/metrics/js_divergence.py @@ -20,7 +20,7 @@ class JSDivergence(KLDivergence): D_\text{KL}(\mathbf{p}_i \| \mathbf{q}_i) &= \sum_{c=1}^C p_{i,c} \log \frac{p_{i,c}}{q_{i,c}}. \end{align*} - where :math:`\mathbf{p}_i` and :math:`\mathbf{q}_i` are the prediction and ground truth probability tensors, + where :math:`\mathbf{p}_i` and :math:`\mathbf{q}_i` are the ground truth and prediction probability tensors, and :math:`D_\text{KL}` is the KL-divergence. - ``update`` must receive output of the form ``(y_pred, y)``.