From f277af7132e926a13a2afeeee377d31d3000b4fc Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Thu, 13 Jun 2024 11:45:11 +0530 Subject: [PATCH] Implement median helper Co-authored-by: Ricardo Vieira <28983449+ricardov94@users.noreply.github.com> --- pytensor/tensor/math.py | 43 +++++++++++++++++++++++++++++++++++++++ tests/tensor/test_math.py | 31 ++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 57d0c0364b..d1e4dc6195 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -1566,6 +1566,48 @@ def std(input, axis=None, ddof=0, keepdims=False, corrected=False): return ret +def median(x: TensorLike, axis=None) -> TensorVariable: + """ + Computes the median along the given axis(es) of a tensor `input`. + + Parameters + ---------- + x: TensorVariable + The input tensor. + axis: None or int or (list of int) (see `Sum`) + Compute the median along this axis of the tensor. + None means all axes (like numpy). + """ + from pytensor.ifelse import ifelse + + x = as_tensor_variable(x) + x_ndim = x.type.ndim + if axis is None: + axis = list(range(x_ndim)) + else: + axis = list(normalize_axis_tuple(axis, x_ndim)) + + non_axis = [i for i in range(x_ndim) if i not in axis] + non_axis_shape = [x.shape[i] for i in non_axis] + + # Put axis at the end and unravel them + x_raveled = x.transpose(*non_axis, *axis) + if len(axis) > 1: + x_raveled = x_raveled.reshape((*non_axis_shape, -1)) + raveled_size = x_raveled.shape[-1] + k = raveled_size // 2 + + # Sort the input tensor along the specified axis and pick median value + x_sorted = x_raveled.sort(axis=-1) + k_values = x_sorted[..., k] + km1_values = x_sorted[..., k - 1] + + even_median = (k_values + km1_values) / 2.0 + odd_median = k_values.astype(even_median.type.dtype) + even_k = eq(mod(raveled_size, 2), 0) + return ifelse(even_k, even_median, odd_median, name="median") + + @scalar_elemwise(symbolname="scalar_maximum") def maximum(x, y): """elemwise maximum. See max for the maximum in one tensor""" @@ -3015,6 +3057,7 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None): "sum", "prod", "mean", + "median", "var", "std", "std", diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 6cee6d9125..14bc2614e3 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -93,6 +93,7 @@ max_and_argmax, maximum, mean, + median, min, minimum, mod, @@ -3735,3 +3736,33 @@ def test_nan_to_num(nan, posinf, neginf): out, np.nan_to_num(y, nan=nan, posinf=posinf, neginf=neginf), ) + + +@pytest.mark.parametrize( + "ndim, axis", + [ + (2, None), + (2, 1), + (2, (0, 1)), + (3, None), + (3, (1, 2)), + (4, (1, 3, 0)), + ], +) +def test_median(ndim, axis): + # Generate random data with both odd and even lengths + shape_even = np.arange(1, ndim + 1) * 2 + shape_odd = shape_even - 1 + + data_even = np.random.rand(*shape_even) + data_odd = np.random.rand(*shape_odd) + + x = tensor(dtype="float64", shape=(None,) * ndim) + f = function([x], median(x, axis=axis)) + result_odd = f(data_odd) + result_even = f(data_even) + expected_odd = np.median(data_odd, axis=axis) + expected_even = np.median(data_even, axis=axis) + + assert np.allclose(result_odd, expected_odd) + assert np.allclose(result_even, expected_even)