From dba3049bd43725103161923f5afeddce2789d3d0 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Wed, 20 Mar 2024 18:20:50 +0900 Subject: [PATCH 01/14] add PearsonCorrelation metric --- docs/source/contrib/metrics.rst | 1 + ignite/contrib/metrics/regression/__init__.py | 1 + .../metrics/regression/pearson_correlation.py | 120 +++++++++ .../regression/test_pearson_correlation.py | 244 ++++++++++++++++++ 4 files changed, 366 insertions(+) create mode 100644 ignite/contrib/metrics/regression/pearson_correlation.py create mode 100644 tests/ignite/contrib/metrics/regression/test_pearson_correlation.py diff --git a/docs/source/contrib/metrics.rst b/docs/source/contrib/metrics.rst index eccaf9e7808..740dcafacaa 100644 --- a/docs/source/contrib/metrics.rst +++ b/docs/source/contrib/metrics.rst @@ -52,5 +52,6 @@ Complete list of metrics: MedianAbsoluteError MedianAbsolutePercentageError MedianRelativeAbsoluteError + PearsonCorrelation R2Score WaveHedgesDistance diff --git a/ignite/contrib/metrics/regression/__init__.py b/ignite/contrib/metrics/regression/__init__.py index fbee310e3b2..625099e0038 100644 --- a/ignite/contrib/metrics/regression/__init__.py +++ b/ignite/contrib/metrics/regression/__init__.py @@ -11,5 +11,6 @@ from ignite.contrib.metrics.regression.median_absolute_error import MedianAbsoluteError from ignite.contrib.metrics.regression.median_absolute_percentage_error import MedianAbsolutePercentageError from ignite.contrib.metrics.regression.median_relative_absolute_error import MedianRelativeAbsoluteError +from ignite.contrib.metrics.regression.pearson_correlation import PearsonCorrelation from ignite.contrib.metrics.regression.r2_score import R2Score from ignite.contrib.metrics.regression.wave_hedges_distance import WaveHedgesDistance diff --git a/ignite/contrib/metrics/regression/pearson_correlation.py b/ignite/contrib/metrics/regression/pearson_correlation.py new file mode 100644 index 00000000000..b678564132e --- /dev/null +++ b/ignite/contrib/metrics/regression/pearson_correlation.py @@ -0,0 +1,120 @@ +from typing import Callable, Tuple, Union + +import torch + +from ignite.contrib.metrics.regression._base import _BaseRegression +from ignite.exceptions import NotComputableError +from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce + + +class PearsonCorrelation(_BaseRegression): + r"""Calculates the `Pearson correlation coefficient `_. + + .. math:: + r = \frac{\sum_{i=1}^N (x_i-\bar{x})(y_i-\bar{y})}{\max (\sqrt{\sum_{i=1}^N (x_i-\bar{x})^2 \sum_{i=1}^N (y_i-\bar{y})^2}, \epsilon)}, + \quad \bar{x}=\frac{1}{N}\sum_{i=1}^N x_i, \quad \bar{y}=\frac{1}{N}\sum_{i=1}^N y_i + + where :math:`x_i` is the ground truth and :math:`y_i` is the predicted value. + + - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. + + Parameters are inherited from ``Metric.__init__``. + + Args: + eps: a small value to avoid division by zero. Default: 1e-8 + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + + Examples: + To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. + The output of the engine's ``process_function`` needs to be in format of + ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. + + .. include:: defaults.rst + :start-after: :orphan: + + .. testcode:: + + metric = PearsonCorrelation() + metric.attach(default_evaluator, 'corr') + y_true = torch.tensor([0., 1., 2., 3., 4., 5.]) + y_pred = torch.tensor([0.5, 1.3, 1.9, 2.8, 4.1, 6.0]) + state = default_evaluator.run([[y_pred, y_true]]) + print(state.metrics['corr']) + + .. testoutput:: + + 0.9768688678741455 + """ + + def __init__( + self, + eps: float = 1e-8, + output_transform: Callable = lambda x: x, + device: Union[str, torch.device] = torch.device("cpu"), + ): + super().__init__(output_transform, device) + + self.eps = eps + + _state_dict_all_req_keys = ( + "_sum_of_y_preds", + "_sum_of_ys", + "_sum_of_y_pred_squares", + "_sum_of_y_squares", + "_sum_of_products", + "_num_examples", + ) + + @reinit__is_reduced + def reset(self) -> None: + self._sum_of_y_preds = torch.tensor(0.0, device=self._device) + self._sum_of_ys = torch.tensor(0.0, device=self._device) + self._sum_of_y_pred_squares = torch.tensor(0.0, device=self._device) + self._sum_of_y_squares = torch.tensor(0.0, device=self._device) + self._sum_of_products = torch.tensor(0.0, device=self._device) + self._num_examples = 0 + + def _update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: + y_pred, y = output[0].detach(), output[1].detach() + self._sum_of_y_preds += y_pred.sum() + self._sum_of_ys += y.sum() + self._sum_of_y_pred_squares += y_pred.square().sum() + self._sum_of_y_squares += y.square().sum() + self._sum_of_products += (y_pred * y).sum() + self._num_examples += y.shape[0] + + @sync_all_reduce( + "_sum_of_y_preds", + "_sum_of_ys", + "_sum_of_y_pred_squares", + "_sum_of_y_squares", + "_sum_of_products", + "_num_examples", + ) + def compute(self) -> float: + n = self._num_examples + if n == 0: + raise NotComputableError("PearsonCorrelation must have at least one example before it can be computed.") + + # cov = E[xy] - E[x]*E[y] + cov = self._sum_of_products / n - self._sum_of_y_preds * self._sum_of_ys / (n * n) + + # var = E[x^2] - E[x]^2 + y_pred_mean = self._sum_of_y_preds / n + y_pred_var = self._sum_of_y_pred_squares / n - y_pred_mean * y_pred_mean + y_pred_var = torch.clamp(y_pred_var, min=0.0) + + y_mean = self._sum_of_ys / n + y_var = self._sum_of_y_squares / n - y_mean * y_mean + y_var = torch.clamp(y_var, min=0.0) + + r = cov / torch.clamp(torch.sqrt(y_pred_var * y_var), min=self.eps) + return float(r.item()) diff --git a/tests/ignite/contrib/metrics/regression/test_pearson_correlation.py b/tests/ignite/contrib/metrics/regression/test_pearson_correlation.py new file mode 100644 index 00000000000..aaa93420799 --- /dev/null +++ b/tests/ignite/contrib/metrics/regression/test_pearson_correlation.py @@ -0,0 +1,244 @@ +import os + +import numpy as np +import pytest +import torch + +import ignite.distributed as idist +from ignite.contrib.metrics.regression import PearsonCorrelation +from ignite.engine import Engine +from ignite.exceptions import NotComputableError + + +def np_corr(np_y_pred: np.ndarray, np_y: np.ndarray, eps: float = 1e-8): + cov = np.cov(np_y_pred, np_y, ddof=0)[0, 1] + std_y_pred = np.std(np_y_pred, ddof=0) + std_y = np.std(np_y, ddof=0) + corr = cov / np.clip(std_y_pred * std_y, eps, None) + return corr + + +def test_zero_sample(): + m = PearsonCorrelation() + with pytest.raises( + NotComputableError, match=r"PearsonCorrelation must have at least one example before it can be computed" + ): + m.compute() + + +def test_wrong_input_shapes(): + m = PearsonCorrelation() + + with pytest.raises(ValueError, match=r"Input data shapes should be the same, but given"): + m.update((torch.rand(4), torch.rand(4, 1))) + + with pytest.raises(ValueError, match=r"Input data shapes should be the same, but given"): + m.update((torch.rand(4, 1), torch.rand(4))) + + +def test_pearson_correlation(): + a = np.random.randn(4).astype(np.float32) + b = np.random.randn(4).astype(np.float32) + c = np.random.randn(4).astype(np.float32) + d = np.random.randn(4).astype(np.float32) + ground_truth = np.random.randn(4).astype(np.float32) + + m = PearsonCorrelation() + + m.update((torch.from_numpy(a), torch.from_numpy(ground_truth))) + np_ans = np_corr(a, ground_truth) + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + m.update((torch.from_numpy(b), torch.from_numpy(ground_truth))) + np_ans = np_corr(np.concatenate([a, b]), np.concatenate([ground_truth] * 2)) + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + m.update((torch.from_numpy(c), torch.from_numpy(ground_truth))) + np_ans = np_corr(np.concatenate([a, b, c]), np.concatenate([ground_truth] * 3)) + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + m.update((torch.from_numpy(d), torch.from_numpy(ground_truth))) + np_ans = np_corr(np.concatenate([a, b, c, d]), np.concatenate([ground_truth] * 4)) + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + +def test_integration(): + def _test(y_pred, y, eps, batch_size): + def update_fn(engine, batch): + idx = (engine.state.iteration - 1) * batch_size + y_true_batch = np_y[idx : idx + batch_size] + y_pred_batch = np_y_pred[idx : idx + batch_size] + return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) + + engine = Engine(update_fn) + + m = PearsonCorrelation(eps=eps) + m.attach(engine, "corr") + + np_y = y.numpy().ravel() + np_y_pred = y_pred.numpy().ravel() + + data = list(range(y_pred.shape[0] // batch_size)) + corr = engine.run(data, max_epochs=1).metrics["corr"] + + np_ans = np_corr(np_y_pred, np_y, eps=eps) + + assert pytest.approx(np_ans, rel=1e-4) == corr + + def get_test_cases(): + test_cases = [ + (torch.rand(size=(50,)).float(), torch.rand(size=(50,)).float(), 10 ** np.random.normal(-8, 0), 1), + (torch.rand(size=(50, 1)).float(), torch.rand(size=(50, 1)).float(), 10 ** np.random.normal(-8, 0), 10), + ] + return test_cases + + for _ in range(5): + test_cases = get_test_cases() + for y_pred, y, eps, batch_size in test_cases: + _test(y_pred, y, eps, batch_size) + + +def _test_distrib_compute(device): + rank = idist.get_rank() + + def _test(metric_device): + metric_device = torch.device(metric_device) + m = PearsonCorrelation(device=metric_device) + + y_pred = torch.rand(size=(100,), device=device) + y = torch.rand(size=(100,), device=device) + + m.update((y_pred, y)) + + y_pred = idist.all_gather(y_pred) + y = idist.all_gather(y) + + np_y = y.cpu().numpy() + np_y_pred = y_pred.cpu().numpy() + + np_ans = np_corr(np_y_pred, np_y) + + assert pytest.approx(np_ans) == m.compute() + + for i in range(3): + torch.manual_seed(10 + rank + i) + _test("cpu") + if device.type != "xla": + _test(idist.device()) + + +def _test_distrib_integration(device, tol=1e-5): + rank = idist.get_rank() + + def _test(n_epochs, metric_device): + metric_device = torch.device(metric_device) + n_iters = 80 + batch_size = 16 + + y_true = torch.rand(size=(n_iters * batch_size,)).to(device) + y_preds = torch.rand(size=(n_iters * batch_size,)).to(device) + + def update(engine, i): + return ( + y_preds[i * batch_size : (i + 1) * batch_size], + y_true[i * batch_size : (i + 1) * batch_size], + ) + + engine = Engine(update) + + corr = PearsonCorrelation(device=metric_device) + corr.attach(engine, "corr") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) + + y_preds = idist.all_gather(y_preds) + y_true = idist.all_gather(y_true) + + assert "corr" in engine.state.metrics + + res = engine.state.metrics["corr"] + + np_y = y_true.cpu().numpy() + np_y_pred = y_preds.cpu().numpy() + + np_ans = np_corr(np_y_pred, np_y) + + assert pytest.approx(np_ans, rel=tol) == res + + metric_devices = ["cpu"] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + for i in range(2): + torch.manual_seed(12 + rank + i) + _test(n_epochs=1, metric_device=metric_device) + _test(n_epochs=2, metric_device=metric_device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") +def test_distrib_nccl_gpu(distributed_context_single_node_nccl): + device = idist.device() + _test_distrib_compute(device) + _test_distrib_integration(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo): + device = idist.device() + _test_distrib_compute(device) + _test_distrib_integration(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support") +@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") +def test_distrib_hvd(gloo_hvd_executor): + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") + nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() + + gloo_hvd_executor(_test_distrib_compute, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True) + + +@pytest.mark.multinode_distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") +def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo): + device = idist.device() + _test_distrib_compute(device) + _test_distrib_integration(device) + + +@pytest.mark.multinode_distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") +def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl): + device = idist.device() + _test_distrib_compute(device) + + +@pytest.mark.tpu +@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars") +@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") +def test_distrib_single_device_xla(): + device = idist.device() + _test_distrib_compute(device) + _test_distrib_integration(device) + + +def _test_distrib_xla_nprocs(index): + device = idist.device() + _test_distrib_compute(device) + _test_distrib_integration(device) + + +@pytest.mark.tpu +@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars") +@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") +def test_distrib_xla_nprocs(xmp_executor): + n = int(os.environ["NUM_TPU_WORKERS"]) + xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n) From 1f80edec11fbd113ed6e4111ee3c812af578e115 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Wed, 20 Mar 2024 18:42:13 +0900 Subject: [PATCH 02/14] match the notation of the docstring with the other metrics --- ignite/contrib/metrics/regression/pearson_correlation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ignite/contrib/metrics/regression/pearson_correlation.py b/ignite/contrib/metrics/regression/pearson_correlation.py index b678564132e..1ea18f33e6b 100644 --- a/ignite/contrib/metrics/regression/pearson_correlation.py +++ b/ignite/contrib/metrics/regression/pearson_correlation.py @@ -11,10 +11,10 @@ class PearsonCorrelation(_BaseRegression): r"""Calculates the `Pearson correlation coefficient `_. .. math:: - r = \frac{\sum_{i=1}^N (x_i-\bar{x})(y_i-\bar{y})}{\max (\sqrt{\sum_{i=1}^N (x_i-\bar{x})^2 \sum_{i=1}^N (y_i-\bar{y})^2}, \epsilon)}, - \quad \bar{x}=\frac{1}{N}\sum_{i=1}^N x_i, \quad \bar{y}=\frac{1}{N}\sum_{i=1}^N y_i + r = \frac{\sum_{j=1}^n (P_j-\bar{P})(A_j-\bar{A})}{\max (\sqrt{\sum_{j=1}^n (P_j-\bar{P})^2 \sum_{j=1}^n (A_j-\bar{A})^2}, \epsilon)}, + \quad \bar{P}=\frac{1}{n}\sum_{j=1}^n P_j, \quad \bar{A}=\frac{1}{n}\sum_{j=1}^n A_j - where :math:`x_i` is the ground truth and :math:`y_i` is the predicted value. + where :math:`A_j` is the ground truth and :math:`P_j` is the predicted value. - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`. From 69f7f5ed46f5e006c4ed56d9f51117c87070e7c4 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sat, 23 Mar 2024 16:53:47 +0900 Subject: [PATCH 03/14] move PearsonCorrelation metric from contrib.metrics.regression to metrics.regression --- ignite/{contrib => }/metrics/regression/pearson_correlation.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename ignite/{contrib => }/metrics/regression/pearson_correlation.py (100%) diff --git a/ignite/contrib/metrics/regression/pearson_correlation.py b/ignite/metrics/regression/pearson_correlation.py similarity index 100% rename from ignite/contrib/metrics/regression/pearson_correlation.py rename to ignite/metrics/regression/pearson_correlation.py From 855a5052e46a311ef66e40355cc83d912b23a8c8 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sat, 23 Mar 2024 18:17:58 +0900 Subject: [PATCH 04/14] update test for PearsonCorrelation metric --- .../regression/test_pearson_correlation.py | 179 ++++++------------ 1 file changed, 57 insertions(+), 122 deletions(-) diff --git a/tests/ignite/contrib/metrics/regression/test_pearson_correlation.py b/tests/ignite/contrib/metrics/regression/test_pearson_correlation.py index aaa93420799..59d8b49e537 100644 --- a/tests/ignite/contrib/metrics/regression/test_pearson_correlation.py +++ b/tests/ignite/contrib/metrics/regression/test_pearson_correlation.py @@ -98,147 +98,82 @@ def get_test_cases(): _test(y_pred, y, eps, batch_size) -def _test_distrib_compute(device): - rank = idist.get_rank() +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_compute(self): + rank = idist.get_rank() + device = idist.device() - def _test(metric_device): - metric_device = torch.device(metric_device) - m = PearsonCorrelation(device=metric_device) + def _test(metric_device): + metric_device = torch.device(metric_device) + m = PearsonCorrelation(device=metric_device) - y_pred = torch.rand(size=(100,), device=device) - y = torch.rand(size=(100,), device=device) + y_pred = torch.rand(size=(100,), device=device) + y = torch.rand(size=(100,), device=device) - m.update((y_pred, y)) + m.update((y_pred, y)) - y_pred = idist.all_gather(y_pred) - y = idist.all_gather(y) + y_pred = idist.all_gather(y_pred) + y = idist.all_gather(y) - np_y = y.cpu().numpy() - np_y_pred = y_pred.cpu().numpy() + np_y = y.cpu().numpy() + np_y_pred = y_pred.cpu().numpy() - np_ans = np_corr(np_y_pred, np_y) + np_ans = np_corr(np_y_pred, np_y) - assert pytest.approx(np_ans) == m.compute() + assert pytest.approx(np_ans) == m.compute() - for i in range(3): - torch.manual_seed(10 + rank + i) - _test("cpu") - if device.type != "xla": - _test(idist.device()) - - -def _test_distrib_integration(device, tol=1e-5): - rank = idist.get_rank() - - def _test(n_epochs, metric_device): - metric_device = torch.device(metric_device) - n_iters = 80 - batch_size = 16 - - y_true = torch.rand(size=(n_iters * batch_size,)).to(device) - y_preds = torch.rand(size=(n_iters * batch_size,)).to(device) - - def update(engine, i): - return ( - y_preds[i * batch_size : (i + 1) * batch_size], - y_true[i * batch_size : (i + 1) * batch_size], - ) - - engine = Engine(update) - - corr = PearsonCorrelation(device=metric_device) - corr.attach(engine, "corr") - - data = list(range(n_iters)) - engine.run(data=data, max_epochs=n_epochs) - - y_preds = idist.all_gather(y_preds) - y_true = idist.all_gather(y_true) + for i in range(3): + torch.manual_seed(10 + rank + i) + _test("cpu") + if device.type != "xla": + _test(idist.device()) - assert "corr" in engine.state.metrics + def test_integration(self, tol=1e-5): + rank = idist.get_rank() + device = idist.device() - res = engine.state.metrics["corr"] + def _test(n_epochs, metric_device): + metric_device = torch.device(metric_device) + n_iters = 80 + batch_size = 16 - np_y = y_true.cpu().numpy() - np_y_pred = y_preds.cpu().numpy() + y_true = torch.rand(size=(n_iters * batch_size,)).to(device) + y_preds = torch.rand(size=(n_iters * batch_size,)).to(device) - np_ans = np_corr(np_y_pred, np_y) + def update(engine, i): + return ( + y_preds[i * batch_size : (i + 1) * batch_size], + y_true[i * batch_size : (i + 1) * batch_size], + ) - assert pytest.approx(np_ans, rel=tol) == res + engine = Engine(update) - metric_devices = ["cpu"] - if device.type != "xla": - metric_devices.append(idist.device()) - for metric_device in metric_devices: - for i in range(2): - torch.manual_seed(12 + rank + i) - _test(n_epochs=1, metric_device=metric_device) - _test(n_epochs=2, metric_device=metric_device) + corr = PearsonCorrelation(device=metric_device) + corr.attach(engine, "corr") + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) -@pytest.mark.distributed -@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") -@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") -def test_distrib_nccl_gpu(distributed_context_single_node_nccl): - device = idist.device() - _test_distrib_compute(device) - _test_distrib_integration(device) + y_preds = idist.all_gather(y_preds) + y_true = idist.all_gather(y_true) + assert "corr" in engine.state.metrics -@pytest.mark.distributed -@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") -def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo): - device = idist.device() - _test_distrib_compute(device) - _test_distrib_integration(device) + res = engine.state.metrics["corr"] + np_y = y_true.cpu().numpy() + np_y_pred = y_preds.cpu().numpy() -@pytest.mark.distributed -@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support") -@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") -def test_distrib_hvd(gloo_hvd_executor): - device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") - nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() + np_ans = np_corr(np_y_pred, np_y) - gloo_hvd_executor(_test_distrib_compute, (device,), np=nproc, do_init=True) - gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True) + assert pytest.approx(np_ans, rel=tol) == res - -@pytest.mark.multinode_distributed -@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") -@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") -def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo): - device = idist.device() - _test_distrib_compute(device) - _test_distrib_integration(device) - - -@pytest.mark.multinode_distributed -@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") -@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") -def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl): - device = idist.device() - _test_distrib_compute(device) - - -@pytest.mark.tpu -@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars") -@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") -def test_distrib_single_device_xla(): - device = idist.device() - _test_distrib_compute(device) - _test_distrib_integration(device) - - -def _test_distrib_xla_nprocs(index): - device = idist.device() - _test_distrib_compute(device) - _test_distrib_integration(device) - - -@pytest.mark.tpu -@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars") -@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") -def test_distrib_xla_nprocs(xmp_executor): - n = int(os.environ["NUM_TPU_WORKERS"]) - xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n) + metric_devices = ["cpu"] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + for i in range(2): + torch.manual_seed(12 + rank + i) + _test(n_epochs=1, metric_device=metric_device) + _test(n_epochs=2, metric_device=metric_device) From 6b17c36639b44286c0183c97cd6de913a1c9e798 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sun, 24 Mar 2024 01:39:51 +0900 Subject: [PATCH 05/14] update test --- ignite/contrib/metrics/regression/__init__.py | 1 - ignite/metrics/regression/__init__.py | 1 + .../regression/test_pearson_correlation.py | 179 -------------- .../regression/test_pearson_correlation.py | 226 ++++++++++++++++++ 4 files changed, 227 insertions(+), 180 deletions(-) create mode 100644 ignite/metrics/regression/__init__.py delete mode 100644 tests/ignite/contrib/metrics/regression/test_pearson_correlation.py create mode 100644 tests/ignite/metrics/regression/test_pearson_correlation.py diff --git a/ignite/contrib/metrics/regression/__init__.py b/ignite/contrib/metrics/regression/__init__.py index 625099e0038..fbee310e3b2 100644 --- a/ignite/contrib/metrics/regression/__init__.py +++ b/ignite/contrib/metrics/regression/__init__.py @@ -11,6 +11,5 @@ from ignite.contrib.metrics.regression.median_absolute_error import MedianAbsoluteError from ignite.contrib.metrics.regression.median_absolute_percentage_error import MedianAbsolutePercentageError from ignite.contrib.metrics.regression.median_relative_absolute_error import MedianRelativeAbsoluteError -from ignite.contrib.metrics.regression.pearson_correlation import PearsonCorrelation from ignite.contrib.metrics.regression.r2_score import R2Score from ignite.contrib.metrics.regression.wave_hedges_distance import WaveHedgesDistance diff --git a/ignite/metrics/regression/__init__.py b/ignite/metrics/regression/__init__.py new file mode 100644 index 00000000000..50452d28dc4 --- /dev/null +++ b/ignite/metrics/regression/__init__.py @@ -0,0 +1 @@ +from ignite.metrics.regression.pearson_correlation import PearsonCorrelation diff --git a/tests/ignite/contrib/metrics/regression/test_pearson_correlation.py b/tests/ignite/contrib/metrics/regression/test_pearson_correlation.py deleted file mode 100644 index 59d8b49e537..00000000000 --- a/tests/ignite/contrib/metrics/regression/test_pearson_correlation.py +++ /dev/null @@ -1,179 +0,0 @@ -import os - -import numpy as np -import pytest -import torch - -import ignite.distributed as idist -from ignite.contrib.metrics.regression import PearsonCorrelation -from ignite.engine import Engine -from ignite.exceptions import NotComputableError - - -def np_corr(np_y_pred: np.ndarray, np_y: np.ndarray, eps: float = 1e-8): - cov = np.cov(np_y_pred, np_y, ddof=0)[0, 1] - std_y_pred = np.std(np_y_pred, ddof=0) - std_y = np.std(np_y, ddof=0) - corr = cov / np.clip(std_y_pred * std_y, eps, None) - return corr - - -def test_zero_sample(): - m = PearsonCorrelation() - with pytest.raises( - NotComputableError, match=r"PearsonCorrelation must have at least one example before it can be computed" - ): - m.compute() - - -def test_wrong_input_shapes(): - m = PearsonCorrelation() - - with pytest.raises(ValueError, match=r"Input data shapes should be the same, but given"): - m.update((torch.rand(4), torch.rand(4, 1))) - - with pytest.raises(ValueError, match=r"Input data shapes should be the same, but given"): - m.update((torch.rand(4, 1), torch.rand(4))) - - -def test_pearson_correlation(): - a = np.random.randn(4).astype(np.float32) - b = np.random.randn(4).astype(np.float32) - c = np.random.randn(4).astype(np.float32) - d = np.random.randn(4).astype(np.float32) - ground_truth = np.random.randn(4).astype(np.float32) - - m = PearsonCorrelation() - - m.update((torch.from_numpy(a), torch.from_numpy(ground_truth))) - np_ans = np_corr(a, ground_truth) - assert m.compute() == pytest.approx(np_ans, rel=1e-4) - - m.update((torch.from_numpy(b), torch.from_numpy(ground_truth))) - np_ans = np_corr(np.concatenate([a, b]), np.concatenate([ground_truth] * 2)) - assert m.compute() == pytest.approx(np_ans, rel=1e-4) - - m.update((torch.from_numpy(c), torch.from_numpy(ground_truth))) - np_ans = np_corr(np.concatenate([a, b, c]), np.concatenate([ground_truth] * 3)) - assert m.compute() == pytest.approx(np_ans, rel=1e-4) - - m.update((torch.from_numpy(d), torch.from_numpy(ground_truth))) - np_ans = np_corr(np.concatenate([a, b, c, d]), np.concatenate([ground_truth] * 4)) - assert m.compute() == pytest.approx(np_ans, rel=1e-4) - - -def test_integration(): - def _test(y_pred, y, eps, batch_size): - def update_fn(engine, batch): - idx = (engine.state.iteration - 1) * batch_size - y_true_batch = np_y[idx : idx + batch_size] - y_pred_batch = np_y_pred[idx : idx + batch_size] - return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) - - engine = Engine(update_fn) - - m = PearsonCorrelation(eps=eps) - m.attach(engine, "corr") - - np_y = y.numpy().ravel() - np_y_pred = y_pred.numpy().ravel() - - data = list(range(y_pred.shape[0] // batch_size)) - corr = engine.run(data, max_epochs=1).metrics["corr"] - - np_ans = np_corr(np_y_pred, np_y, eps=eps) - - assert pytest.approx(np_ans, rel=1e-4) == corr - - def get_test_cases(): - test_cases = [ - (torch.rand(size=(50,)).float(), torch.rand(size=(50,)).float(), 10 ** np.random.normal(-8, 0), 1), - (torch.rand(size=(50, 1)).float(), torch.rand(size=(50, 1)).float(), 10 ** np.random.normal(-8, 0), 10), - ] - return test_cases - - for _ in range(5): - test_cases = get_test_cases() - for y_pred, y, eps, batch_size in test_cases: - _test(y_pred, y, eps, batch_size) - - -@pytest.mark.usefixtures("distributed") -class TestDistributed: - def test_compute(self): - rank = idist.get_rank() - device = idist.device() - - def _test(metric_device): - metric_device = torch.device(metric_device) - m = PearsonCorrelation(device=metric_device) - - y_pred = torch.rand(size=(100,), device=device) - y = torch.rand(size=(100,), device=device) - - m.update((y_pred, y)) - - y_pred = idist.all_gather(y_pred) - y = idist.all_gather(y) - - np_y = y.cpu().numpy() - np_y_pred = y_pred.cpu().numpy() - - np_ans = np_corr(np_y_pred, np_y) - - assert pytest.approx(np_ans) == m.compute() - - for i in range(3): - torch.manual_seed(10 + rank + i) - _test("cpu") - if device.type != "xla": - _test(idist.device()) - - def test_integration(self, tol=1e-5): - rank = idist.get_rank() - device = idist.device() - - def _test(n_epochs, metric_device): - metric_device = torch.device(metric_device) - n_iters = 80 - batch_size = 16 - - y_true = torch.rand(size=(n_iters * batch_size,)).to(device) - y_preds = torch.rand(size=(n_iters * batch_size,)).to(device) - - def update(engine, i): - return ( - y_preds[i * batch_size : (i + 1) * batch_size], - y_true[i * batch_size : (i + 1) * batch_size], - ) - - engine = Engine(update) - - corr = PearsonCorrelation(device=metric_device) - corr.attach(engine, "corr") - - data = list(range(n_iters)) - engine.run(data=data, max_epochs=n_epochs) - - y_preds = idist.all_gather(y_preds) - y_true = idist.all_gather(y_true) - - assert "corr" in engine.state.metrics - - res = engine.state.metrics["corr"] - - np_y = y_true.cpu().numpy() - np_y_pred = y_preds.cpu().numpy() - - np_ans = np_corr(np_y_pred, np_y) - - assert pytest.approx(np_ans, rel=tol) == res - - metric_devices = ["cpu"] - if device.type != "xla": - metric_devices.append(idist.device()) - for metric_device in metric_devices: - for i in range(2): - torch.manual_seed(12 + rank + i) - _test(n_epochs=1, metric_device=metric_device) - _test(n_epochs=2, metric_device=metric_device) diff --git a/tests/ignite/metrics/regression/test_pearson_correlation.py b/tests/ignite/metrics/regression/test_pearson_correlation.py new file mode 100644 index 00000000000..93c45ef4d10 --- /dev/null +++ b/tests/ignite/metrics/regression/test_pearson_correlation.py @@ -0,0 +1,226 @@ +from typing import Tuple + +import numpy as np +import pytest +import torch +from torch import Tensor + +import ignite.distributed as idist +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics.regression import PearsonCorrelation + + +def np_corr(np_y_pred: np.ndarray, np_y: np.ndarray, eps: float = 1e-8): + cov = np.cov(np_y_pred, np_y, ddof=0)[0, 1] + std_y_pred = np.std(np_y_pred, ddof=0) + std_y = np.std(np_y, ddof=0) + corr = cov / np.clip(std_y_pred * std_y, eps, None) + return corr + + +def test_zero_sample(): + m = PearsonCorrelation() + with pytest.raises( + NotComputableError, match=r"PearsonCorrelation must have at least one example before it can be computed" + ): + m.compute() + + +def test_wrong_input_shapes(): + m = PearsonCorrelation() + + with pytest.raises(ValueError, match=r"Input data shapes should be the same, but given"): + m.update((torch.rand(4), torch.rand(4, 1))) + + with pytest.raises(ValueError, match=r"Input data shapes should be the same, but given"): + m.update((torch.rand(4, 1), torch.rand(4))) + + +def test_pearson_correlation(): + a = np.random.randn(4).astype(np.float32) + b = np.random.randn(4).astype(np.float32) + c = np.random.randn(4).astype(np.float32) + d = np.random.randn(4).astype(np.float32) + ground_truth = np.random.randn(4).astype(np.float32) + + m = PearsonCorrelation() + + m.update((torch.from_numpy(a), torch.from_numpy(ground_truth))) + np_ans = np_corr(a, ground_truth) + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + m.update((torch.from_numpy(b), torch.from_numpy(ground_truth))) + np_ans = np_corr(np.concatenate([a, b]), np.concatenate([ground_truth] * 2)) + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + m.update((torch.from_numpy(c), torch.from_numpy(ground_truth))) + np_ans = np_corr(np.concatenate([a, b, c]), np.concatenate([ground_truth] * 3)) + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + m.update((torch.from_numpy(d), torch.from_numpy(ground_truth))) + np_ans = np_corr(np.concatenate([a, b, c, d]), np.concatenate([ground_truth] * 4)) + assert m.compute() == pytest.approx(np_ans, rel=1e-4) + + +@pytest.fixture(params=list(range(2))) +def test_case(request): + return [ + (torch.rand(size=(50,)).float(), torch.rand(size=(50,)).float(), 10 ** np.random.uniform(-8, 0), 1), + (torch.rand(size=(50, 1)).float(), torch.rand(size=(50, 1)).float(), 10 ** np.random.uniform(-8, 0), 10), + ][request.param] + + +@pytest.mark.parametrize("n_times", range(5)) +def test_integration(n_times, test_case: Tuple[Tensor, Tensor, float, int]): + y_pred, y, eps, batch_size = test_case + + def update_fn(engine: Engine, batch): + idx = (engine.state.iteration - 1) * batch_size + y_true_batch = np_y[idx : idx + batch_size] + y_pred_batch = np_y_pred[idx : idx + batch_size] + return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) + + engine = Engine(update_fn) + + m = PearsonCorrelation(eps=eps) + m.attach(engine, "corr") + + np_y = y.ravel().numpy() + np_y_pred = y_pred.ravel().numpy() + + data = list(range(y_pred.shape[0] // batch_size)) + corr = engine.run(data, max_epochs=1).metrics["corr"] + + np_ans = np_corr(np_y_pred, np_y, eps=eps) + + assert pytest.approx(np_ans, rel=2e-4) == corr + + +def test_accumulator_detached(): + corr = PearsonCorrelation() + + y_pred = torch.tensor([2.0, 3.0], requires_grad=True) + y = torch.tensor([-2.0, -1.0]) + corr.update((y_pred, y)) + + assert all( + (not accumulator.requires_grad) + for accumulator in ( + corr._sum_of_products, + corr._sum_of_y_pred_squares, + corr._sum_of_y_preds, + corr._sum_of_y_squares, + corr._sum_of_ys, + ) + ) + + +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_compute(self): + rank = idist.get_rank() + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + for i in range(3): + torch.manual_seed(10 + rank + i) + for metric_device in metric_devices: + m = PearsonCorrelation(device=metric_device) + + y_pred = torch.rand(size=[100], device=device) + y = torch.rand(size=[100], device=device) + + m.update((y_pred, y)) + + y_pred = idist.all_gather(y_pred) + y = idist.all_gather(y) + + np_y = y.cpu().numpy() + np_y_pred = y_pred.cpu().numpy() + + np_ans = np_corr(np_y_pred, np_y) + + assert pytest.approx(np_ans) == m.compute() + + @pytest.mark.parametrize("n_epochs", [1, 2]) + def test_integration(self, n_epochs: int): + tol = 2e-4 + rank = idist.get_rank() + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + n_iters = 80 + batch_size = 16 + + for metric_device in metric_devices: + for i in range(2): + torch.manual_seed(12 + rank + i) + + y_true = torch.rand(size=(n_iters * batch_size,)).to(device) + y_preds = torch.rand(size=(n_iters * batch_size,)).to(device) + + engine = Engine( + lambda e, i: ( + y_preds[i * batch_size : (i + 1) * batch_size], + y_true[i * batch_size : (i + 1) * batch_size], + ) + ) + + corr = PearsonCorrelation(device=metric_device) + corr.attach(engine, "corr") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) + + y_preds = idist.all_gather(y_preds) + y_true = idist.all_gather(y_true) + + assert "corr" in engine.state.metrics + + res = engine.state.metrics["corr"] + + np_y = y_true.cpu().numpy() + np_y_pred = y_preds.cpu().numpy() + + np_ans = np_corr(np_y_pred, np_y) + + assert pytest.approx(np_ans, rel=tol) == res + + def test_accumulator_device(self): + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + for metric_device in metric_devices: + corr = PearsonCorrelation(device=metric_device) + + devices = ( + corr._device, + corr._sum_of_products.device, + corr._sum_of_y_pred_squares.device, + corr._sum_of_y_preds.device, + corr._sum_of_y_squares.device, + corr._sum_of_ys.device, + ) + for dev in devices: + assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" + + y_pred = torch.tensor([2.0, 3.0]) + y = torch.tensor([-1.0, 1.0]) + corr.update((y_pred, y)) + + devices = ( + corr._device, + corr._sum_of_products.device, + corr._sum_of_y_pred_squares.device, + corr._sum_of_y_preds.device, + corr._sum_of_y_squares.device, + corr._sum_of_ys.device, + ) + for dev in devices: + assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" From b931ec9030b0f34309bfa8b273cde09bf64d97b3 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Fri, 29 Mar 2024 00:39:50 +0900 Subject: [PATCH 06/14] modify doc for PearsonCorrelation metric --- docs/source/metrics.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index ee660e99d20..0696cc3070a 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -371,6 +371,7 @@ Complete list of metrics regression.MedianAbsoluteError regression.MedianAbsolutePercentageError regression.MedianRelativeAbsoluteError + regression.PearsonCorrelation regression.R2Score regression.WaveHedgesDistance From 66099258ff788325e95c7861001aaf1127d2a1d9 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Fri, 29 Mar 2024 00:40:34 +0900 Subject: [PATCH 07/14] fix import --- ignite/metrics/regression/__init__.py | 1 + ignite/metrics/regression/pearson_correlation.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/ignite/metrics/regression/__init__.py b/ignite/metrics/regression/__init__.py index 7a3fc3e56a9..7be1f18d0f3 100644 --- a/ignite/metrics/regression/__init__.py +++ b/ignite/metrics/regression/__init__.py @@ -11,5 +11,6 @@ from ignite.metrics.regression.median_absolute_error import MedianAbsoluteError from ignite.metrics.regression.median_absolute_percentage_error import MedianAbsolutePercentageError from ignite.metrics.regression.median_relative_absolute_error import MedianRelativeAbsoluteError +from ignite.metrics.regression.pearson_correlation import PearsonCorrelation from ignite.metrics.regression.r2_score import R2Score from ignite.metrics.regression.wave_hedges_distance import WaveHedgesDistance diff --git a/ignite/metrics/regression/pearson_correlation.py b/ignite/metrics/regression/pearson_correlation.py index 1ea18f33e6b..aa3149effa7 100644 --- a/ignite/metrics/regression/pearson_correlation.py +++ b/ignite/metrics/regression/pearson_correlation.py @@ -2,10 +2,11 @@ import torch -from ignite.contrib.metrics.regression._base import _BaseRegression from ignite.exceptions import NotComputableError from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce +from ignite.metrics.regression._base import _BaseRegression + class PearsonCorrelation(_BaseRegression): r"""Calculates the `Pearson correlation coefficient `_. From 90697797f38279cbe85032213cd5dfe47f816ada Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Fri, 29 Mar 2024 19:48:53 +0900 Subject: [PATCH 08/14] resolve code formatting issue --- ignite/metrics/regression/pearson_correlation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ignite/metrics/regression/pearson_correlation.py b/ignite/metrics/regression/pearson_correlation.py index aa3149effa7..5018f70e30a 100644 --- a/ignite/metrics/regression/pearson_correlation.py +++ b/ignite/metrics/regression/pearson_correlation.py @@ -9,10 +9,12 @@ class PearsonCorrelation(_BaseRegression): - r"""Calculates the `Pearson correlation coefficient `_. + r"""Calculates the `Pearson correlation coefficient + `_. .. math:: - r = \frac{\sum_{j=1}^n (P_j-\bar{P})(A_j-\bar{A})}{\max (\sqrt{\sum_{j=1}^n (P_j-\bar{P})^2 \sum_{j=1}^n (A_j-\bar{A})^2}, \epsilon)}, + r = \frac{\sum_{j=1}^n (P_j-\bar{P})(A_j-\bar{A})} + {\max (\sqrt{\sum_{j=1}^n (P_j-\bar{P})^2 \sum_{j=1}^n (A_j-\bar{A})^2}, \epsilon)}, \quad \bar{P}=\frac{1}{n}\sum_{j=1}^n P_j, \quad \bar{A}=\frac{1}{n}\sum_{j=1}^n A_j where :math:`A_j` is the ground truth and :math:`P_j` is the predicted value. From 4a0096d3b0c865da79d3a9fe3c2cbfb16b43c0b4 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Fri, 29 Mar 2024 19:58:48 +0900 Subject: [PATCH 09/14] remove loop from test --- .../regression/test_pearson_correlation.py | 66 +++++++++---------- 1 file changed, 32 insertions(+), 34 deletions(-) diff --git a/tests/ignite/metrics/regression/test_pearson_correlation.py b/tests/ignite/metrics/regression/test_pearson_correlation.py index 93c45ef4d10..c08f8a6ff6a 100644 --- a/tests/ignite/metrics/regression/test_pearson_correlation.py +++ b/tests/ignite/metrics/regression/test_pearson_correlation.py @@ -125,25 +125,24 @@ def test_compute(self): if device.type != "xla": metric_devices.append(device) - for i in range(3): - torch.manual_seed(10 + rank + i) - for metric_device in metric_devices: - m = PearsonCorrelation(device=metric_device) + torch.manual_seed(10 + rank) + for metric_device in metric_devices: + m = PearsonCorrelation(device=metric_device) - y_pred = torch.rand(size=[100], device=device) - y = torch.rand(size=[100], device=device) + y_pred = torch.rand(size=[100], device=device) + y = torch.rand(size=[100], device=device) - m.update((y_pred, y)) + m.update((y_pred, y)) - y_pred = idist.all_gather(y_pred) - y = idist.all_gather(y) + y_pred = idist.all_gather(y_pred) + y = idist.all_gather(y) - np_y = y.cpu().numpy() - np_y_pred = y_pred.cpu().numpy() + np_y = y.cpu().numpy() + np_y_pred = y_pred.cpu().numpy() - np_ans = np_corr(np_y_pred, np_y) + np_ans = np_corr(np_y_pred, np_y) - assert pytest.approx(np_ans) == m.compute() + assert pytest.approx(np_ans) == m.compute() @pytest.mark.parametrize("n_epochs", [1, 2]) def test_integration(self, n_epochs: int): @@ -158,38 +157,37 @@ def test_integration(self, n_epochs: int): batch_size = 16 for metric_device in metric_devices: - for i in range(2): - torch.manual_seed(12 + rank + i) + torch.manual_seed(12 + rank) - y_true = torch.rand(size=(n_iters * batch_size,)).to(device) - y_preds = torch.rand(size=(n_iters * batch_size,)).to(device) + y_true = torch.rand(size=(n_iters * batch_size,)).to(device) + y_preds = torch.rand(size=(n_iters * batch_size,)).to(device) - engine = Engine( - lambda e, i: ( - y_preds[i * batch_size : (i + 1) * batch_size], - y_true[i * batch_size : (i + 1) * batch_size], - ) + engine = Engine( + lambda e, i: ( + y_preds[i * batch_size : (i + 1) * batch_size], + y_true[i * batch_size : (i + 1) * batch_size], ) + ) - corr = PearsonCorrelation(device=metric_device) - corr.attach(engine, "corr") + corr = PearsonCorrelation(device=metric_device) + corr.attach(engine, "corr") - data = list(range(n_iters)) - engine.run(data=data, max_epochs=n_epochs) + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) - y_preds = idist.all_gather(y_preds) - y_true = idist.all_gather(y_true) + y_preds = idist.all_gather(y_preds) + y_true = idist.all_gather(y_true) - assert "corr" in engine.state.metrics + assert "corr" in engine.state.metrics - res = engine.state.metrics["corr"] + res = engine.state.metrics["corr"] - np_y = y_true.cpu().numpy() - np_y_pred = y_preds.cpu().numpy() + np_y = y_true.cpu().numpy() + np_y_pred = y_preds.cpu().numpy() - np_ans = np_corr(np_y_pred, np_y) + np_ans = np_corr(np_y_pred, np_y) - assert pytest.approx(np_ans, rel=tol) == res + assert pytest.approx(np_ans, rel=tol) == res def test_accumulator_device(self): device = idist.device() From ad1e0901ab0f14eb021e563c974a0311d60ad6d2 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sat, 30 Mar 2024 00:43:08 +0900 Subject: [PATCH 10/14] Update ignite/metrics/regression/pearson_correlation.py Co-authored-by: vfdev --- ignite/metrics/regression/pearson_correlation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ignite/metrics/regression/pearson_correlation.py b/ignite/metrics/regression/pearson_correlation.py index 5018f70e30a..16058b48584 100644 --- a/ignite/metrics/regression/pearson_correlation.py +++ b/ignite/metrics/regression/pearson_correlation.py @@ -9,8 +9,8 @@ class PearsonCorrelation(_BaseRegression): - r"""Calculates the `Pearson correlation coefficient - `_. + r"""Calculates the + `Pearson correlation coefficient `_. .. math:: r = \frac{\sum_{j=1}^n (P_j-\bar{P})(A_j-\bar{A})} From 4d6b18e18b5aeb186df791cf4468beec79ed6ca8 Mon Sep 17 00:00:00 2001 From: vfdev Date: Fri, 29 Mar 2024 16:50:06 +0100 Subject: [PATCH 11/14] Update pearson_correlation.py --- ignite/metrics/regression/pearson_correlation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/metrics/regression/pearson_correlation.py b/ignite/metrics/regression/pearson_correlation.py index 16058b48584..f91da83e8b6 100644 --- a/ignite/metrics/regression/pearson_correlation.py +++ b/ignite/metrics/regression/pearson_correlation.py @@ -9,7 +9,7 @@ class PearsonCorrelation(_BaseRegression): - r"""Calculates the + r"""Calculates the `Pearson correlation coefficient `_. .. math:: From ad41bbcc06d236a6ab729b175e9faed8072cdbec Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sat, 30 Mar 2024 01:18:41 +0900 Subject: [PATCH 12/14] update test for PearsonCorrelation --- .../regression/test_pearson_correlation.py | 60 +++++++++++++++---- 1 file changed, 47 insertions(+), 13 deletions(-) diff --git a/tests/ignite/metrics/regression/test_pearson_correlation.py b/tests/ignite/metrics/regression/test_pearson_correlation.py index c08f8a6ff6a..b6e6d2173fb 100644 --- a/tests/ignite/metrics/regression/test_pearson_correlation.py +++ b/tests/ignite/metrics/regression/test_pearson_correlation.py @@ -3,6 +3,7 @@ import numpy as np import pytest import torch +from scipy.stats import pearsonr from torch import Tensor import ignite.distributed as idist @@ -11,7 +12,7 @@ from ignite.metrics.regression import PearsonCorrelation -def np_corr(np_y_pred: np.ndarray, np_y: np.ndarray, eps: float = 1e-8): +def np_corr_eps(np_y_pred: np.ndarray, np_y: np.ndarray, eps: float = 1e-8): cov = np.cov(np_y_pred, np_y, ddof=0)[0, 1] std_y_pred = np.std(np_y_pred, ddof=0) std_y = np.std(np_y, ddof=0) @@ -19,6 +20,11 @@ def np_corr(np_y_pred: np.ndarray, np_y: np.ndarray, eps: float = 1e-8): return corr +def scipy_corr(np_y_pred: np.ndarray, np_y: np.ndarray): + corr = pearsonr(np_y_pred, np_y) + return corr.statistic + + def test_zero_sample(): m = PearsonCorrelation() with pytest.raises( @@ -37,6 +43,30 @@ def test_wrong_input_shapes(): m.update((torch.rand(4, 1), torch.rand(4))) +def test_degenerated_sample(): + # one sample + m = PearsonCorrelation() + y_pred = torch.tensor([1.0]) + y = torch.tensor([1.0]) + m.update((y_pred, y)) + + np_y_pred = y_pred.numpy() + np_y = y_pred.numpy() + np_res = np_corr_eps(np_y_pred, np_y) + assert pytest.approx(np_res) == m.compute() + + # constant samples + m.reset() + y_pred = torch.ones(10).float() + y = torch.zeros(10).float() + m.update((y_pred, y)) + + np_y_pred = y_pred.numpy() + np_y = y_pred.numpy() + np_res = np_corr_eps(np_y_pred, np_y) + assert pytest.approx(np_res) == m.compute() + + def test_pearson_correlation(): a = np.random.randn(4).astype(np.float32) b = np.random.randn(4).astype(np.float32) @@ -47,33 +77,37 @@ def test_pearson_correlation(): m = PearsonCorrelation() m.update((torch.from_numpy(a), torch.from_numpy(ground_truth))) - np_ans = np_corr(a, ground_truth) + np_ans = scipy_corr(a, ground_truth) assert m.compute() == pytest.approx(np_ans, rel=1e-4) m.update((torch.from_numpy(b), torch.from_numpy(ground_truth))) - np_ans = np_corr(np.concatenate([a, b]), np.concatenate([ground_truth] * 2)) + np_ans = scipy_corr(np.concatenate([a, b]), np.concatenate([ground_truth] * 2)) assert m.compute() == pytest.approx(np_ans, rel=1e-4) m.update((torch.from_numpy(c), torch.from_numpy(ground_truth))) - np_ans = np_corr(np.concatenate([a, b, c]), np.concatenate([ground_truth] * 3)) + np_ans = scipy_corr(np.concatenate([a, b, c]), np.concatenate([ground_truth] * 3)) assert m.compute() == pytest.approx(np_ans, rel=1e-4) m.update((torch.from_numpy(d), torch.from_numpy(ground_truth))) - np_ans = np_corr(np.concatenate([a, b, c, d]), np.concatenate([ground_truth] * 4)) + np_ans = scipy_corr(np.concatenate([a, b, c, d]), np.concatenate([ground_truth] * 4)) assert m.compute() == pytest.approx(np_ans, rel=1e-4) @pytest.fixture(params=list(range(2))) def test_case(request): + # correlated sample + x = torch.randn(size=[50]).float() + y = x + torch.randn_like(x) * 0.1 + return [ - (torch.rand(size=(50,)).float(), torch.rand(size=(50,)).float(), 10 ** np.random.uniform(-8, 0), 1), - (torch.rand(size=(50, 1)).float(), torch.rand(size=(50, 1)).float(), 10 ** np.random.uniform(-8, 0), 10), + (x, y, 1), + (torch.rand(size=(50, 1)).float(), torch.rand(size=(50, 1)).float(), 10), ][request.param] @pytest.mark.parametrize("n_times", range(5)) -def test_integration(n_times, test_case: Tuple[Tensor, Tensor, float, int]): - y_pred, y, eps, batch_size = test_case +def test_integration(n_times, test_case: Tuple[Tensor, Tensor, int]): + y_pred, y, batch_size = test_case def update_fn(engine: Engine, batch): idx = (engine.state.iteration - 1) * batch_size @@ -83,7 +117,7 @@ def update_fn(engine: Engine, batch): engine = Engine(update_fn) - m = PearsonCorrelation(eps=eps) + m = PearsonCorrelation() m.attach(engine, "corr") np_y = y.ravel().numpy() @@ -92,7 +126,7 @@ def update_fn(engine: Engine, batch): data = list(range(y_pred.shape[0] // batch_size)) corr = engine.run(data, max_epochs=1).metrics["corr"] - np_ans = np_corr(np_y_pred, np_y, eps=eps) + np_ans = scipy_corr(np_y_pred, np_y) assert pytest.approx(np_ans, rel=2e-4) == corr @@ -140,7 +174,7 @@ def test_compute(self): np_y = y.cpu().numpy() np_y_pred = y_pred.cpu().numpy() - np_ans = np_corr(np_y_pred, np_y) + np_ans = scipy_corr(np_y_pred, np_y) assert pytest.approx(np_ans) == m.compute() @@ -185,7 +219,7 @@ def test_integration(self, n_epochs: int): np_y = y_true.cpu().numpy() np_y_pred = y_preds.cpu().numpy() - np_ans = np_corr(np_y_pred, np_y) + np_ans = scipy_corr(np_y_pred, np_y) assert pytest.approx(np_ans, rel=tol) == res From 30f168411022a859243c7ce42113618701ba1aa1 Mon Sep 17 00:00:00 2001 From: vfdev Date: Fri, 29 Mar 2024 18:31:23 +0100 Subject: [PATCH 13/14] Update tests/ignite/metrics/regression/test_pearson_correlation.py --- tests/ignite/metrics/regression/test_pearson_correlation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ignite/metrics/regression/test_pearson_correlation.py b/tests/ignite/metrics/regression/test_pearson_correlation.py index b6e6d2173fb..1c89e0bbec0 100644 --- a/tests/ignite/metrics/regression/test_pearson_correlation.py +++ b/tests/ignite/metrics/regression/test_pearson_correlation.py @@ -176,7 +176,7 @@ def test_compute(self): np_ans = scipy_corr(np_y_pred, np_y) - assert pytest.approx(np_ans) == m.compute() + assert pytest.approx(np_ans, rel=1e-6) == m.compute() @pytest.mark.parametrize("n_epochs", [1, 2]) def test_integration(self, n_epochs: int): From c6b6d9299857310287328cfb67c4c7aa3f19d2d2 Mon Sep 17 00:00:00 2001 From: Kazuki Adachi Date: Sat, 30 Mar 2024 17:25:02 +0900 Subject: [PATCH 14/14] relax pytest.approx --- tests/ignite/metrics/regression/test_pearson_correlation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ignite/metrics/regression/test_pearson_correlation.py b/tests/ignite/metrics/regression/test_pearson_correlation.py index 1c89e0bbec0..1a330b3a67e 100644 --- a/tests/ignite/metrics/regression/test_pearson_correlation.py +++ b/tests/ignite/metrics/regression/test_pearson_correlation.py @@ -176,7 +176,7 @@ def test_compute(self): np_ans = scipy_corr(np_y_pred, np_y) - assert pytest.approx(np_ans, rel=1e-6) == m.compute() + assert pytest.approx(np_ans, rel=2e-4) == m.compute() @pytest.mark.parametrize("n_epochs", [1, 2]) def test_integration(self, n_epochs: int):