diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9cdbe9acc..02676df0a 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -132,7 +132,7 @@ jobs: python-version: "3.9" markers: -m 'not slow' os: ubuntu-latest - - tox-env: py310-coverage # No markers -- includes slow tests + - tox-env: py310-coverage-lmoments # No markers -- includes slow tests python-version: "3.10" os: ubuntu-latest - tox-env: py311-coverage-sbck diff --git a/CHANGES.rst b/CHANGES.rst index f88226eb0..2deb3f076 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -22,8 +22,10 @@ New features and enhancements * New ``xclim.core.calendar.stack_periods`` and ``unstack_periods`` for performing ``rolling(time=...).construct(..., stride=...)`` but with non-uniform temporal periods like years or months. They replace ``xclim.sdba.processing.construct_moving_yearly_window`` and ``unpack_moving_yearly_window`` which are deprecated and will be removed in a future release. * New ``as_dataset`` options for ``xclim.set_options``. When True, indicators will output Datasets instead of DataArrays. (:issue:`1257`, :pull:`1625`). * Added new option for UTCI calculation to cap low wind velocities to a minimum of 0.5 m/s following Bröde (2012) guidelines. (:issue:`1634`, :pull:`1635`). + * Added option ``never_reached`` to ``degree_days_exceedance_date`` to assign a custom value when the sum threshold is never reached. (:issue:`1459`, :pull:`1647`). * Added option ``min_members`` to ensemble statistics to mask elements when the number of valid members is under a threshold. (:issue:`1459`, :pull:`1647`). +* Distribution instances can now be passed to the ``dist`` argument of most statistical indices. (:pull:`1644`). Breaking changes ^^^^^^^^^^^^^^^^ @@ -42,6 +44,7 @@ Breaking changes * The indice and indicator for ``winter_storm`` has been removed (deprecated since `xclim` v0.46.0 in favour of ``snd_storm_days``). (:pull:`1565`). * `xclim` has dropped support for `scipy` version below v1.9.0 and `numpy` versions below v1.20.0. (:pull:`1565`). * For generic function ``select_resample_op`` and ``core.units.to_agg_units``, operation "sum" will now return the same units as the input, and not implicitly be translated to an "integral". (:issue:`1645`, :pull:`1649`). +* `lmoments3` was removed as a dependency of `xclim` due to incompatible licensing (GPLv3 vs `xclim`'s Apache 2.0). Depending on the outcome of efforts to modify the licensing of `lmoments3`, this change may eventually be reverted. See `Ouranosinc/lmoments3#12 `_. See also the "frequency analysis" notebook for an example on how to continue using the probability weighted moments method for fitting distributions. (:issue:`1620`, :pull:`1644`). Bug fixes ^^^^^^^^^ @@ -49,7 +52,7 @@ Bug fixes * Fix wrong `window` attributes in ``xclim.indices.standardized_precipitation_index``, ``xclim.indices.standardized_precipitation_evapotranspiration_index``. (:issue:`1552` :pull:`1554`). * Fix the daily case `freq='D'` of ``xclim.stats.preprocess_standardized_index`` (:issue:`1602` :pull:`1607`). * Several spelling mistakes have been corrected within the documentation and codebase. (:pull:`1576`). -* Added missing ``xclim.ensembles.robustness_fractions`` and ``xclim.ensembles.robistness_categoris`` in api doc section. (:pull:`1630`). +* Added missing ``xclim.ensembles.robustness_fractions`` and ``xclim.ensembles.robustness_categories`` in api doc section. (:pull:`1630`). * Fixed an issue that can occur when fetching the testing data and running tests on Windows systems. Adapted a few existing tests for Windows support. (:pull:`1648`). Internal changes diff --git a/docs/notebooks/frequency_analysis.ipynb b/docs/notebooks/frequency_analysis.ipynb index 0e4c0d086..0e68b1a18 100644 --- a/docs/notebooks/frequency_analysis.ipynb +++ b/docs/notebooks/frequency_analysis.ipynb @@ -103,7 +103,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The next step is to fit the statistical distribution on these maxima. This is done by the `.fit` method, which takes as argument the sample series, the distribution's name and the parameter estimation `method`. The fit is done by default using the Maximum Likelihood algorithm (`method=\"ML\"`). For some extreme value distributions, however, the maximum likelihood is not always robust, and `xclim` offers the possibility to use Probability Weighted Moments (`method=\"PWM\"`) to estimate parameters. Note that the `lmoments3` package which is used by `xclim` to compute the PWM only supports `expon`, `gamma`, `genextreme`, `genpareto`, `gumbel_r`, `pearson3` and `weibull_min`. Parameters can also be estimated using the method of moments (`method=\"MM\"`)." + "The next step is to fit the statistical distribution on these maxima. This is done by the `.fit` method, which takes as argument the sample series, the distribution's name and the parameter estimation `method`. The fit is done by default using the Maximum Likelihood algorithm (`method=\"ML\"`). Parameters can also be estimated using the method of moments (`method=\"MM\"`).\n", + "\n", + "`xclim` can also accept a distribution instance instead of name (i.e. a subclass of `scipy.stats.rv_continuous`). For example, for some extreme value distributions, the maximum likelihood is not always robust. Using the \"Probability Weighted Moments\" (`method=\"PWM\"`) method can help in that case. This is possible by passing a distribution object from the `lmoments3` package together with `method=\"PWM\"`. That package currently only supports `expon`, `gamma`, `genextreme`, `genpareto`, `gumbel_r`, `pearson3`, and `weibull_min` (with other names, see [the documentation](https://lmoments3.readthedocs.io/en/stable/distributions.html)). In the following example, we fit using the \"Generalized extreme value\" distribution from `lmoments3`." ] }, { @@ -112,8 +114,10 @@ "metadata": {}, "outputs": [], "source": [ + "from lmoments3.distr import gev\n", + "\n", "# The fitting dimension is hard-coded as `time`.\n", - "params = fit(sub, dist=\"genextreme\")\n", + "params = fit(sub, dist=gev, method=\"PWM\")\n", "params" ] }, @@ -186,9 +190,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.11.7" } }, "nbformat": 4, - "nbformat_minor": 1 + "nbformat_minor": 4 } diff --git a/environment.yml b/environment.yml index ecb6c521c..b920d038f 100644 --- a/environment.yml +++ b/environment.yml @@ -13,7 +13,6 @@ dependencies: - Click >=8.1 - dask >=2.6.0 - jsonpickle - - lmoments3 - numba - numpy >=1.20.0 - pandas >=2.2.0 diff --git a/pyproject.toml b/pyproject.toml index c1fc36237..3705c0edd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,6 @@ dependencies = [ "Click>=8.1", "dask[array]>=2.6", "jsonpickle", - "lmoments3>=1.0.5", "numba", "numpy>=1.20.0", "pandas>=2.2", diff --git a/tests/test_stats.py b/tests/test_stats.py index 091c95fc7..adbb6b4c4 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -166,10 +166,11 @@ def test_fa(fitda): np.testing.assert_array_equal(q[0, 0, 0], q0) -def test_fa_gamma(fitda): +def test_fa_gamma_lmom(fitda): + lmom = pytest.importorskip("lmoments3.distr") T = 10 q = stats.fa(fitda, T, "lognorm", method="MM") - q1 = stats.fa(fitda, T, "gamma", method="PWM") + q1 = stats.fa(fitda, T, lmom.gam, method="PWM") np.testing.assert_allclose(q1, q, rtol=0.2) @@ -192,6 +193,21 @@ def test_dims_order(fitda): assert p.dims[-1] == "dparams" +lm3_dist_map = { + "expon": "exp", + "gamma": "gam", + "genextreme": "gev", + # "genlogistic": "glo", + # "gennorm": "gno", + "genpareto": "gpa", + "gumbel_r": "gum", + # "kappa4": "kap", + "norm": "nor", + "pearson3": "pe3", + "weibull_min": "wei", +} + + class TestPWMFit: params = { "expon": {"loc": 0.9527273, "scale": 2.2836364}, @@ -213,20 +229,23 @@ class TestPWMFit: } inputs_pdf = [4, 5, 6, 7] - @pytest.mark.parametrize("dist", stats._lm3_dist_map.keys()) + @pytest.mark.parametrize("dist", lm3_dist_map.keys()) def test_get_lm3_dist(self, dist): """Check that parameterization for lmoments3 and scipy is identical.""" + lmom = pytest.importorskip("lmoments3.distr") + lm3dc = getattr(lmom, lm3_dist_map[dist]) dc = stats.get_dist(dist) - lm3dc = stats.get_lm3_dist(dist) par = self.params[dist] expected = dc(**par).pdf(self.inputs_pdf) values = lm3dc(**par).pdf(self.inputs_pdf) np.testing.assert_array_almost_equal(values, expected) - @pytest.mark.parametrize("dist", stats._lm3_dist_map.keys()) + @pytest.mark.parametrize("dist", lm3_dist_map.keys()) @pytest.mark.parametrize("use_dask", [True, False]) def test_pwm_fit(self, dist, use_dask, random): """Test that the fitted parameters match parameters used to generate a random sample.""" + lmom = pytest.importorskip("lmoments3.distr") + lm3dc = getattr(lmom, lm3_dist_map[dist]) n = 500 dc = stats.get_dist(dist) par = self.params[dist] @@ -237,11 +256,10 @@ def test_pwm_fit(self, dist, use_dask, random): ) if use_dask: da = da.chunk() - out = stats.fit(da, dist=dist, method="PWM").compute() + out = stats.fit(da, dist=lm3dc, method="PWM").compute() # Check that values are identical to lmoments3's output dict - l3dc = stats.get_lm3_dist(dist) - expected = l3dc.lmom_fit(da.values) + expected = lm3dc.lmom_fit(da.values) for key, val in expected.items(): np.testing.assert_array_equal(out.sel(dparams=key), val, 1) @@ -270,9 +288,21 @@ def test_frequency_analysis(ndq_series, use_dask): q.transpose(), mode="max", t=2, dist="genextreme", window=6, freq="YS" ) - # Test with PWM fitting method + +@pytest.mark.parametrize("use_dask", [True, False]) +@pytest.mark.filterwarnings("ignore::RuntimeWarning") +def test_frequency_analysis_lmoments(ndq_series, use_dask): + lmom = pytest.importorskip("lmoments3.distr") + q = ndq_series.copy() + q[:, 0, 0] = np.nan + if use_dask: + q = q.chunk() + + out = stats.frequency_analysis( + q, mode="max", t=2, dist="genextreme", window=6, freq="YS" + ) out1 = stats.frequency_analysis( - q, mode="max", t=2, dist="genextreme", window=6, freq="YS", method="PWM" + q, mode="max", t=2, dist=lmom.gev, window=6, freq="YS", method="PWM" ) np.testing.assert_allclose( out1, diff --git a/tox.ini b/tox.ini index b126a8e6d..264aee431 100644 --- a/tox.ini +++ b/tox.ini @@ -7,7 +7,7 @@ env_list = offline-prefetch py39-upstream-doctest py310 - py311 + py311-lmoments py312-numba labels = test = py39, py310-upstream-doctest, py311, notebooks_doctests, offline-prefetch @@ -105,6 +105,8 @@ deps = coverage: coveralls upstream: -rrequirements_upstream.txt sbck: pybind11 + lmoments: lmoments3 + notebooks_doctests: lmoments3 install_command = python -m pip install --no-user {opts} {packages} download = True commands_pre = diff --git a/xclim/core/utils.py b/xclim/core/utils.py index 60010b25c..75bd62918 100644 --- a/xclim/core/utils.py +++ b/xclim/core/utils.py @@ -680,6 +680,9 @@ def infer_kind_from_parameter(param) -> InputKind: if param.name == "freq": return InputKind.FREQ_STR + if param.kind == param.VAR_KEYWORD: + return InputKind.KWARGS + if annot == {"Quantified"}: return InputKind.QUANTIFIED @@ -692,7 +695,7 @@ def infer_kind_from_parameter(param) -> InputKind: if annot.issubset({"int", "float", "Sequence[int]", "Sequence[float]"}): return InputKind.NUMBER_SEQUENCE - if annot == {"str"}: + if annot.issuperset({"str"}): return InputKind.STRING if annot == {"DateStr"}: @@ -704,9 +707,6 @@ def infer_kind_from_parameter(param) -> InputKind: if annot == {"Dataset"}: return InputKind.DATASET - if param.kind == param.VAR_KEYWORD: - return InputKind.KWARGS - return InputKind.OTHER_PARAMETER diff --git a/xclim/indices/stats.py b/xclim/indices/stats.py index 262a6065b..8c5a559a6 100644 --- a/xclim/indices/stats.py +++ b/xclim/indices/stats.py @@ -6,8 +6,8 @@ from collections.abc import Sequence from typing import Any -import lmoments3.distr import numpy as np +import scipy.stats import xarray as xr from xclim.core.calendar import compare_offsets, resample_doy, select_time @@ -19,13 +19,11 @@ __all__ = [ "_fit_start", - "_lm3_dist_map", "dist_method", "fa", "fit", "frequency_analysis", "get_dist", - "get_lm3_dist", "parametric_cdf", "parametric_quantile", "preprocess_standardized_index", @@ -34,22 +32,6 @@ ] -# Map the scipy distribution name to the lmoments3 name. Distributions with mismatched parameters are excluded. -_lm3_dist_map = { - "expon": "exp", - "gamma": "gam", - "genextreme": "gev", - # "genlogistic": "glo", - # "gennorm": "gno", - "genpareto": "gpa", - "gumbel_r": "gum", - # "kappa4": "kap", - "norm": "nor", - "pearson3": "pe3", - "weibull_min": "wei", -} - - # Fit the parameters. # This would also be the place to impose constraints on the series minimum length if needed. def _fitfunc_1d(arr, *, dist, nparams, method, **fitkwargs): @@ -88,7 +70,7 @@ def _fitfunc_1d(arr, *, dist, nparams, method, **fitkwargs): def fit( da: xr.DataArray, - dist: str = "norm", + dist: str | scipy.stats.rv_continuous = "norm", method: str = "ML", dim: str = "time", **fitkwargs: Any, @@ -99,13 +81,12 @@ def fit( ---------- da : xr.DataArray Time series to be fitted along the time dimension. - dist : str + dist : str or rv_continuous distribution object Name of the univariate distribution, such as beta, expon, genextreme, gamma, gumbel_r, lognorm, norm - (see :py:mod:scipy.stats for full list). If the PWM method is used, only the following distributions are - currently supported: 'expon', 'gamma', 'genextreme', 'genpareto', 'gumbel_r', 'pearson3', 'weibull_min'. + (see :py:mod:scipy.stats for full list) or the distribution object itself. method : {"ML" or "MLE", "MM", "PWM", "APP"} - Fitting method, either maximum likelihood (ML or MLE), method of moments (MM), - probability weighted moments (PWM), also called L-Moments, or approximate method (APP). + Fitting method, either maximum likelihood (ML or MLE), method of moments (MM) or approximate method (APP). + Can also be the probability weighted moments (PWM), also called L-Moments, if a compatible `dist` object is passed. The PWM method is usually more robust to outliers. dim : str The dimension upon which to perform the indexing (default: "time"). @@ -134,13 +115,14 @@ def fit( raise ValueError(f"Fitting method not recognized: {method}") # Get the distribution - dc = get_dist(dist) - if method == "PWM": - lm3dc = get_lm3_dist(dist) - else: - lm3dc = None + dist = get_dist(dist) + + if method == "PWM" and not hasattr(dist, "lmom_fit"): + raise ValueError( + f"The given distribution {dist} does not implement the PWM fitting method. Please pass an instance from the lmoments3 package." + ) - shape_params = [] if dc.shapes is None else dc.shapes.split(",") + shape_params = [] if dist.shapes is None else dist.shapes.split(",") dist_params = shape_params + ["loc", "scale"] data = xr.apply_ufunc( @@ -154,7 +136,7 @@ def fit( keep_attrs=True, kwargs=dict( # Don't know how APP should be included, this works for now - dist=dc if method in ["ML", "MLE", "MM", "APP"] else lm3dc, + dist=dist, nparams=len(dist_params), method=method, **fitkwargs, @@ -170,11 +152,11 @@ def fit( da.attrs, ["standard_name", "long_name", "units", "description"], "original_" ) attrs = dict( - long_name=f"{dist} parameters", - description=f"Parameters of the {dist} distribution", + long_name=f"{dist.name} parameters", + description=f"Parameters of the {dist.name} distribution", method=method, estimator=method_name[method].capitalize(), - scipy_dist=dist, + scipy_dist=dist.name, units="", history=update_history( f"Estimate distribution parameters by {method_name[method]} method along dimension {dim}.", @@ -186,7 +168,11 @@ def fit( return out -def parametric_quantile(p: xr.DataArray, q: float | Sequence[float]) -> xr.DataArray: +def parametric_quantile( + p: xr.DataArray, + q: float | Sequence[float], + dist: str | scipy.stats.rv_continuous | None = None, +) -> xr.DataArray: """Return the value corresponding to the given distribution parameters and quantile. Parameters @@ -197,6 +183,8 @@ def parametric_quantile(p: xr.DataArray, q: float | Sequence[float]) -> xr.DataA and attribute `scipy_dist`, storing the name of the distribution. q : float or Sequence of float Quantile to compute, which must be between `0` and `1`, inclusive. + dist: str, rv_continuous instance, optional + The distribution name or instance if the `scipy_dist` attribute is not available on `p`. Returns ------- @@ -209,20 +197,18 @@ def parametric_quantile(p: xr.DataArray, q: float | Sequence[float]) -> xr.DataA """ q = np.atleast_1d(q) - # Get the distribution - dist = p.attrs["scipy_dist"] - dc = get_dist(dist) + dist = get_dist(dist or p.attrs["scipy_dist"]) # Create a lambda function to facilitate passing arguments to dask. There is probably a better way to do this. if np.all(q > 0.5): def func(x): - return dc.isf(1 - q, *x) + return dist.isf(1 - q, *x) else: def func(x): - return dc.ppf(q, *x) + return dist.ppf(q, *x) data = xr.apply_ufunc( func, @@ -242,8 +228,8 @@ def func(x): out.attrs = unprefix_attrs(p.attrs, ["units", "standard_name"], "original_") attrs = dict( - long_name=f"{dist} quantiles", - description=f"Quantiles estimated by the {dist} distribution", + long_name=f"{dist.name} quantiles", + description=f"Quantiles estimated by the {dist.name} distribution", cell_methods="dparams: ppf", history=update_history( "Compute parametric quantiles from distribution parameters", @@ -255,7 +241,11 @@ def func(x): return out -def parametric_cdf(p: xr.DataArray, v: float | Sequence[float]) -> xr.DataArray: +def parametric_cdf( + p: xr.DataArray, + v: float | Sequence[float], + dist: str | scipy.stats.rv_continuous | None = None, +) -> xr.DataArray: """Return the cumulative distribution function corresponding to the given distribution parameters and value. Parameters @@ -266,6 +256,8 @@ def parametric_cdf(p: xr.DataArray, v: float | Sequence[float]) -> xr.DataArray: and attribute `scipy_dist`, storing the name of the distribution. v : float or Sequence of float Value to compute the CDF. + dist: str, rv_continuous instance, optional + The distribution name or instance is the `scipy_dist` attribute is not available on `p`. Returns ------- @@ -274,13 +266,11 @@ def parametric_cdf(p: xr.DataArray, v: float | Sequence[float]) -> xr.DataArray: """ v = np.atleast_1d(v) - # Get the distribution - dist = p.attrs["scipy_dist"] - dc = get_dist(dist) + dist = get_dist(dist or p.attrs["scipy_dist"]) # Create a lambda function to facilitate passing arguments to dask. There is probably a better way to do this. def func(x): - return dc.cdf(v, *x) + return dist.cdf(v, *x) data = xr.apply_ufunc( func, @@ -300,8 +290,8 @@ def func(x): out.attrs = unprefix_attrs(p.attrs, ["units", "standard_name"], "original_") attrs = dict( - long_name=f"{dist} cdf", - description=f"CDF estimated by the {dist} distribution", + long_name=f"{dist.name} cdf", + description=f"CDF estimated by the {dist.name} distribution", cell_methods="dparams: cdf", history=update_history( "Compute parametric cdf from distribution parameters", @@ -316,7 +306,7 @@ def func(x): def fa( da: xr.DataArray, t: int | Sequence, - dist: str = "norm", + dist: str | scipy.stats.rv_continuous = "norm", mode: str = "max", method: str = "ML", ) -> xr.DataArray: @@ -329,14 +319,15 @@ def fa( t : int or Sequence of int Return period. The period depends on the resolution of the input data. If the input array's resolution is yearly, then the return period is in years. - dist : str + dist : str or rv_continuous instance Name of the univariate distribution, such as: `beta`, `expon`, `genextreme`, `gamma`, `gumbel_r`, `lognorm`, `norm` + Or the distribution instance itself. mode : {'min', 'max} Whether we are looking for a probability of exceedance (max) or a probability of non-exceedance (min). method : {"ML", "MLE", "MOM", "PWM", "APP"} - Fitting method, either maximum likelihood (ML or MLE), method of moments (MOM), - probability weighted moments (PWM), also called L-Moments, or approximate method (APP). + Fitting method, either maximum likelihood (ML or MLE), method of moments (MOM) or approximate method (APP). + Also accepts probability weighted moments (PWM), also called L-Moments, if `dist` is an instance from the lmoments3 library. The PWM method is usually more robust to outliers. Returns @@ -363,7 +354,7 @@ def fa( # Compute the quantiles out = ( - parametric_quantile(p, q) + parametric_quantile(p, q, dist) .rename({"quantile": "return_period"}) .assign_coords(return_period=t) ) @@ -375,7 +366,7 @@ def frequency_analysis( da: xr.DataArray, mode: str, t: int | Sequence[int], - dist: str, + dist: str | scipy.stats.rv_continuous, window: int = 1, freq: str | None = None, method: str = "ML", @@ -392,16 +383,17 @@ def frequency_analysis( t : int or sequence Return period. The period depends on the resolution of the input data. If the input array's resolution is yearly, then the return period is in years. - dist : str + dist : str or rv_continuous Name of the univariate distribution, e.g. `beta`, `expon`, `genextreme`, `gamma`, `gumbel_r`, `lognorm`, `norm`. + Or an instance of the distribution. window : int Averaging window length (days). freq : str, optional Resampling frequency. If None, the frequency is assumed to be 'YS' unless the indexer is season='DJF', in which case `freq` would be set to `YS-DEC`. method : {"ML" or "MLE", "MOM", "PWM", "APP"} - Fitting method, either maximum likelihood (ML or MLE), method of moments (MOM), - probability weighted moments (PWM), also called L-Moments, or approximate method (APP). + Fitting method, either maximum likelihood (ML or MLE), method of moments (MOM) or approximate method (APP). + Also accepts probability weighted moments (PWM), also called L-Moments, if `dist` is an instance from the lmoments3 library. The PWM method is usually more robust to outliers. \*\*indexer Time attribute and values over which to subset the array. For example, use season='DJF' to select winter values, @@ -435,28 +427,18 @@ def frequency_analysis( return fa(sel, t, dist=dist, mode=mode, method=method) -def get_dist(dist: str): +def get_dist(dist: str | scipy.stats.rv_continuous): """Return a distribution object from `scipy.stats`.""" - from scipy import stats # pylint: disable=import-outside-toplevel + if isinstance(dist, scipy.stats.rv_continuous): + return dist - dc = getattr(stats, dist, None) + dc = getattr(scipy.stats, dist, None) if dc is None: e = f"Statistical distribution `{dist}` is not found in scipy.stats." raise ValueError(e) return dc -def get_lm3_dist(dist: str): - """Return a distribution object from `lmoments3.distr`.""" - if dist not in _lm3_dist_map: - raise ValueError( - f"The PWM fitting method cannot be used with the {dist} distribution, as it is not supported " - f"by `lmoments3`." - ) - - return getattr(lmoments3.distr, _lm3_dist_map[dist]) - - def _fit_start(x, dist: str, **fitkwargs: Any) -> tuple[tuple, dict]: r"""Return initial values for distribution parameters. @@ -532,7 +514,9 @@ def _fit_start(x, dist: str, **fitkwargs: Any) -> tuple[tuple, dict]: return (), {} -def _dist_method_1D(*args, dist: str, function: str, **kwargs: Any) -> xr.DataArray: +def _dist_method_1D( + *args, dist: str | scipy.stats.rv_continuous, function: str, **kwargs: Any +) -> xr.DataArray: r"""Statistical function for given argument on given distribution initialized with params. See :py:ref:`scipy:scipy.stats.rv_continuous` for all available functions and their arguments. @@ -561,6 +545,7 @@ def dist_method( function: str, fit_params: xr.DataArray, arg: xr.DataArray | None = None, + dist: str | scipy.stats.rv_continuous | None = None, **kwargs: Any, ) -> xr.DataArray: r"""Vectorized statistical function for given argument on given distribution initialized with params. @@ -574,9 +559,10 @@ def dist_method( The name of the function to call. fit_params : xr.DataArray Distribution parameters are along `dparams`, in the same order as given by :py:func:`fit`. - Must have a `scipy_dist` attribute with the name of the distribution fitted. arg : array_like, optional The first argument for the requested function if different from `fit_params`. + dist : str pr rv_continuous, optional + The distribution name or instance. Defaults to the `scipy_dist` attribute or `fit_params`. \*\*kwargs Other parameters to pass to the function call. @@ -595,7 +581,11 @@ def dist_method( return xr.apply_ufunc( _dist_method_1D, *args, - kwargs={"dist": fit_params.attrs["scipy_dist"], "function": function, **kwargs}, + kwargs={ + "dist": dist or fit_params.attrs["scipy_dist"], + "function": function, + **kwargs, + }, output_dtypes=[float], dask="parallelized", ) @@ -668,7 +658,7 @@ def standardized_index_fit_params( da: xr.DataArray, freq: str | None, window: int, - dist: str, + dist: str | scipy.stats.rv_continuous, method: str, offset: Quantified | None = None, **indexer, @@ -690,7 +680,7 @@ def standardized_index_fit_params( window : int Averaging window length relative to the resampling frequency. For example, if `freq="MS"`, i.e. a monthly resampling, the window is an integer number of months. - dist : {'gamma', 'fisk'} + dist : {'gamma', 'fisk'} or rv_continuous instance Name of the univariate distribution. (see :py:mod:`scipy.stats`). method : {'ML', 'APP', 'PWM'} Name of the fitting method, such as `ML` (maximum likelihood), `APP` (approximate). The approximate method @@ -715,11 +705,12 @@ def standardized_index_fit_params( """ # "WPM" method doesn't seem to work for gamma or pearson3 dist_and_methods = {"gamma": ["ML", "APP", "PWM"], "fisk": ["ML", "APP"]} - if dist not in dist_and_methods: - raise NotImplementedError(f"The distribution `{dist}` is not supported.") - if method not in dist_and_methods[dist]: + dist = get_dist(dist) + if dist.name not in dist_and_methods: + raise NotImplementedError(f"The distribution `{dist.name}` is not supported.") + if method not in dist_and_methods[dist.name]: raise NotImplementedError( - f"The method `{method}` is not supported for distribution `{dist}`." + f"The method `{method}` is not supported for distribution `{dist.name}`." ) if offset is not None: @@ -736,7 +727,7 @@ def standardized_index_fit_params( "calibration_period": cal_range, "freq": freq or "", "window": window, - "scipy_dist": dist, + "scipy_dist": dist.name, "method": method, "group": group, "units": "", @@ -748,7 +739,11 @@ def standardized_index_fit_params( return params -def standardized_index(da: xr.DataArray, params: xr.DataArray): +def standardized_index( + da: xr.DataArray, + params: xr.DataArray, + dist: str | scipy.stats.rv_continuous | None = None, +): """Compute standardized index for given fit parameters. This computes standardized indices which measure the deviation of variables in the dataset compared @@ -764,6 +759,8 @@ def standardized_index(da: xr.DataArray, params: xr.DataArray): ``xclim.indices.preprocess_standardized_index``. params : xarray.DataArray Fit parameters computed using ``xclim.indices.stats.standardized_index_fit_params``. + dist : str or rv_continuous, optional + Name of distribution or instance. Defaults to the "scipy_dist" attribute of `params`. """ group = params.attrs["group"] @@ -779,7 +776,7 @@ def reindex_time(da, da_ref): lambda x: (x == 0).sum("time") / x.notnull().sum("time") ) params, probs_of_zero = (reindex_time(dax, da) for dax in [params, probs_of_zero]) - dist_probs = dist_method("cdf", params, da) + dist_probs = dist_method("cdf", params, da, dist=dist) probs = probs_of_zero + ((1 - probs_of_zero) * dist_probs) params_norm = xr.DataArray(