From 7986348133fed3f817ff73454848dcc5b2c5a65c Mon Sep 17 00:00:00 2001 From: Felix Divo <4403130+felixdivo@users.noreply.github.com> Date: Tue, 12 Mar 2024 09:13:00 +0100 Subject: [PATCH] Add `ForecastingModel.supports_probabilistic_prediction` (#2259) (#2269) * Remove unnessesary `pass` statements * Rename ForecastingModel_is_probabilistic to supports_probabilistic_prediction, rearrange some documentation * Remove redundant overrides * Reformat * Add CHANGELOG entry --------- Co-authored-by: Dennis Bader --- CHANGELOG.md | 2 ++ darts/explainability/shap_explainer.py | 2 +- darts/models/forecasting/arima.py | 2 +- darts/models/forecasting/catboost_model.py | 2 +- darts/models/forecasting/croston.py | 4 ---- darts/models/forecasting/ensemble_model.py | 17 +++++++++++++---- .../models/forecasting/exponential_smoothing.py | 2 +- darts/models/forecasting/forecasting_model.py | 17 +++++++---------- .../forecasting/global_baseline_models.py | 2 +- darts/models/forecasting/kalman_forecaster.py | 2 +- darts/models/forecasting/lgbm.py | 2 +- .../forecasting/linear_regression_model.py | 2 +- .../models/forecasting/pl_forecasting_module.py | 2 +- darts/models/forecasting/prophet_model.py | 2 +- .../forecasting/regression_ensemble_model.py | 8 +++++--- darts/models/forecasting/sf_auto_arima.py | 2 +- darts/models/forecasting/sf_auto_ces.py | 4 ---- darts/models/forecasting/sf_auto_ets.py | 2 +- darts/models/forecasting/sf_auto_theta.py | 2 +- darts/models/forecasting/tbats_model.py | 2 +- .../forecasting/torch_forecasting_model.py | 4 ++-- darts/models/forecasting/varima.py | 2 +- darts/models/forecasting/xgboost.py | 2 +- darts/tests/models/forecasting/test_TFT.py | 2 +- .../models/forecasting/test_ensemble_models.py | 2 +- .../test_global_forecasting_models.py | 4 ++-- .../test_regression_ensemble_model.py | 12 ++++++------ 27 files changed, 55 insertions(+), 53 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 47ae74c1dc..a96914385f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co ### For users of the library: **Improved** +- Improvements to `ForecastingModel`: + - Renamed the private `_is_probabilistic` property to a public `supports_probabilistic_prediction`. [#2269](https://github.com/unit8co/darts/pull/2269) by [Felix Divo](https://github.com/felixdivo). **Fixed** diff --git a/darts/explainability/shap_explainer.py b/darts/explainability/shap_explainer.py index a31a844ca4..c6dc313081 100644 --- a/darts/explainability/shap_explainer.py +++ b/darts/explainability/shap_explainer.py @@ -162,7 +162,7 @@ def __init__( test_stationarity=True, ) - if model._is_probabilistic: + if model.supports_probabilistic_prediction: logger.warning( "The model is probabilistic, but num_samples=1 will be used for explainability." ) diff --git a/darts/models/forecasting/arima.py b/darts/models/forecasting/arima.py index 489cf338e4..4a6760d430 100644 --- a/darts/models/forecasting/arima.py +++ b/darts/models/forecasting/arima.py @@ -233,7 +233,7 @@ def _predict( return self._build_forecast_series(forecast) @property - def _is_probabilistic(self) -> bool: + def supports_probabilistic_prediction(self) -> bool: return True @property diff --git a/darts/models/forecasting/catboost_model.py b/darts/models/forecasting/catboost_model.py index 104ce9d602..26bb976dde 100644 --- a/darts/models/forecasting/catboost_model.py +++ b/darts/models/forecasting/catboost_model.py @@ -326,7 +326,7 @@ def _likelihood_components_names( return None @property - def _is_probabilistic(self) -> bool: + def supports_probabilistic_prediction(self) -> bool: return self.likelihood is not None @property diff --git a/darts/models/forecasting/croston.py b/darts/models/forecasting/croston.py index 4737aec4b7..0a5f239728 100644 --- a/darts/models/forecasting/croston.py +++ b/darts/models/forecasting/croston.py @@ -164,7 +164,3 @@ def min_train_series_length(self) -> int: @property def _supports_range_index(self) -> bool: return True - - @property - def _is_probabilistic(self) -> bool: - return False diff --git a/darts/models/forecasting/ensemble_model.py b/darts/models/forecasting/ensemble_model.py index 30d36ff2ba..3ac8877410 100644 --- a/darts/models/forecasting/ensemble_model.py +++ b/darts/models/forecasting/ensemble_model.py @@ -119,7 +119,9 @@ def __init__( raise_if( train_num_samples is not None and train_num_samples > 1 - and all([not m._is_probabilistic for m in forecasting_models]), + and all( + [not m.supports_probabilistic_prediction for m in forecasting_models] + ), "`train_num_samples` is greater than 1 but the `RegressionEnsembleModel` " "contains only deterministic `forecasting_models`.", logger, @@ -261,7 +263,9 @@ def _make_multiple_predictions( future_covariates=( future_covariates if model.supports_future_covariates else None ), - num_samples=num_samples if model._is_probabilistic else 1, + num_samples=( + num_samples if model.supports_probabilistic_prediction else 1 + ), predict_likelihood_parameters=predict_likelihood_parameters, ) for model in self.forecasting_models @@ -432,7 +436,12 @@ def output_chunk_length(self) -> Optional[int]: @property def _models_are_probabilistic(self) -> bool: - return all([model._is_probabilistic for model in self.forecasting_models]) + return all( + [ + model.supports_probabilistic_prediction + for model in self.forecasting_models + ] + ) @property def _models_same_likelihood(self) -> bool: @@ -480,7 +489,7 @@ def supports_likelihood_parameter_prediction(self) -> bool: ) @property - def _is_probabilistic(self) -> bool: + def supports_probabilistic_prediction(self) -> bool: return self._models_are_probabilistic @property diff --git a/darts/models/forecasting/exponential_smoothing.py b/darts/models/forecasting/exponential_smoothing.py index 217e5485d7..ef42f8c4cb 100644 --- a/darts/models/forecasting/exponential_smoothing.py +++ b/darts/models/forecasting/exponential_smoothing.py @@ -159,7 +159,7 @@ def supports_multivariate(self) -> bool: return False @property - def _is_probabilistic(self) -> bool: + def supports_probabilistic_prediction(self) -> bool: return True @property diff --git a/darts/models/forecasting/forecasting_model.py b/darts/models/forecasting/forecasting_model.py index beeb3b3327..41b5433641 100644 --- a/darts/models/forecasting/forecasting_model.py +++ b/darts/models/forecasting/forecasting_model.py @@ -192,9 +192,10 @@ def _supports_range_index(self) -> bool: return True @property - def _is_probabilistic(self) -> bool: + def supports_probabilistic_prediction(self) -> bool: """ - Checks if the forecasting model supports probabilistic predictions. + Checks if the forecasting model with this configuration supports probabilistic predictions. + By default, returns False. Needs to be overwritten by models that do support probabilistic predictions. """ @@ -204,7 +205,9 @@ def _is_probabilistic(self) -> bool: def _supports_non_retrainable_historical_forecasts(self) -> bool: """ Checks if the forecasting model supports historical forecasts without retraining - the model. By default, returns False. Needs to be overwritten by models that do + the model. + + By default, returns False. Needs to be overwritten by models that do support historical forecasts without retraining. """ return False @@ -250,7 +253,6 @@ def supports_transferrable_series_prediction(self) -> bool: """ Whether the model supports prediction for any input `series`. """ - pass @property def uses_past_covariates(self) -> bool: @@ -347,7 +349,7 @@ def predict( logger=logger, ) - if not self._is_probabilistic and num_samples > 1: + if not self.supports_probabilistic_prediction and num_samples > 1: raise_log( ValueError( "`num_samples > 1` is only supported for probabilistic models." @@ -488,7 +490,6 @@ def extreme_lags( >>> model.extreme_lags (-10, 6, None, None, 4, 6, 0) """ - pass @property def _training_sample_time_index_length(self) -> int: @@ -1870,7 +1871,6 @@ def _model_encoder_settings( Must return Tuple (input_chunk_length, output_chunk_length, takes_past_covariates, takes_future_covariates, lags_past_covariates, lags_future_covariates). """ - pass @classmethod def _sample_params(model_class, params, n_random_samples): @@ -2481,7 +2481,6 @@ def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = Non """Fits/trains the model on the provided series. DualCovariatesModels must implement the fit logic in this method. """ - pass def predict( self, @@ -2575,7 +2574,6 @@ def _predict( """Forecasts values for a certain number of time steps after the end of the series. DualCovariatesModels must implement the predict logic in this method. """ - pass @property def _model_encoder_settings( @@ -2778,7 +2776,6 @@ def _predict( """Forecasts values for a certain number of time steps after the end of the series. TransferableFutureCovariatesLocalForecastingModel must implement the predict logic in this method. """ - pass @property def supports_transferrable_series_prediction(self) -> bool: diff --git a/darts/models/forecasting/global_baseline_models.py b/darts/models/forecasting/global_baseline_models.py index dc81fe49aa..860e44609c 100644 --- a/darts/models/forecasting/global_baseline_models.py +++ b/darts/models/forecasting/global_baseline_models.py @@ -235,7 +235,7 @@ def min_train_series_length(self) -> int: def supports_likelihood_parameter_prediction(self) -> bool: return False - def _is_probabilistic(self) -> bool: + def supports_probabilistic_prediction(self) -> bool: return False @property diff --git a/darts/models/forecasting/kalman_forecaster.py b/darts/models/forecasting/kalman_forecaster.py index 34ef91a0a4..595f71c443 100644 --- a/darts/models/forecasting/kalman_forecaster.py +++ b/darts/models/forecasting/kalman_forecaster.py @@ -171,5 +171,5 @@ def supports_multivariate(self) -> bool: return True @property - def _is_probabilistic(self) -> bool: + def supports_probabilistic_prediction(self) -> bool: return True diff --git a/darts/models/forecasting/lgbm.py b/darts/models/forecasting/lgbm.py index 602fcac978..5812c12faa 100644 --- a/darts/models/forecasting/lgbm.py +++ b/darts/models/forecasting/lgbm.py @@ -310,7 +310,7 @@ def _predict_and_sample( ) @property - def _is_probabilistic(self) -> bool: + def supports_probabilistic_prediction(self) -> bool: return self.likelihood is not None @property diff --git a/darts/models/forecasting/linear_regression_model.py b/darts/models/forecasting/linear_regression_model.py index 032ab0c460..7bdffff8d6 100644 --- a/darts/models/forecasting/linear_regression_model.py +++ b/darts/models/forecasting/linear_regression_model.py @@ -305,5 +305,5 @@ def _predict_and_sample( ) @property - def _is_probabilistic(self) -> bool: + def supports_probabilistic_prediction(self) -> bool: return self.likelihood is not None diff --git a/darts/models/forecasting/pl_forecasting_module.py b/darts/models/forecasting/pl_forecasting_module.py index 6be20f5da0..7a7524a0bc 100644 --- a/darts/models/forecasting/pl_forecasting_module.py +++ b/darts/models/forecasting/pl_forecasting_module.py @@ -468,7 +468,7 @@ def set_mc_dropout(self, active: bool): module.mc_dropout_enabled = active @property - def _is_probabilistic(self) -> bool: + def supports_probabilistic_prediction(self) -> bool: return self.likelihood is not None or len(self._get_mc_dropout_modules()) > 0 def _produce_predict_output(self, x: Tuple) -> torch.Tensor: diff --git a/darts/models/forecasting/prophet_model.py b/darts/models/forecasting/prophet_model.py index b6674463fa..78ce395cae 100644 --- a/darts/models/forecasting/prophet_model.py +++ b/darts/models/forecasting/prophet_model.py @@ -386,7 +386,7 @@ def supports_multivariate(self) -> bool: return False @property - def _is_probabilistic(self) -> bool: + def supports_probabilistic_prediction(self) -> bool: return True def _stochastic_samples(self, predict_df, n_samples) -> np.ndarray: diff --git a/darts/models/forecasting/regression_ensemble_model.py b/darts/models/forecasting/regression_ensemble_model.py index aeac43b04c..149d6376a9 100644 --- a/darts/models/forecasting/regression_ensemble_model.py +++ b/darts/models/forecasting/regression_ensemble_model.py @@ -222,7 +222,9 @@ def _make_multiple_historical_forecasts( ), forecast_horizon=model.output_chunk_length, stride=model.output_chunk_length, - num_samples=num_samples if model._is_probabilistic else 1, + num_samples=( + num_samples if model.supports_probabilistic_prediction else 1 + ), start=-start_hist_forecasts, start_format="position", retrain=False, @@ -486,9 +488,9 @@ def supports_multivariate(self) -> bool: ) @property - def _is_probabilistic(self) -> bool: + def supports_probabilistic_prediction(self) -> bool: """ A RegressionEnsembleModel is probabilistic if its regression model is probabilistic (ensembling layer) """ - return self.regression_model._is_probabilistic + return self.regression_model.supports_probabilistic_prediction diff --git a/darts/models/forecasting/sf_auto_arima.py b/darts/models/forecasting/sf_auto_arima.py index c036a80b80..cd8569aede 100644 --- a/darts/models/forecasting/sf_auto_arima.py +++ b/darts/models/forecasting/sf_auto_arima.py @@ -134,5 +134,5 @@ def _supports_range_index(self) -> bool: return True @property - def _is_probabilistic(self) -> bool: + def supports_probabilistic_prediction(self) -> bool: return True diff --git a/darts/models/forecasting/sf_auto_ces.py b/darts/models/forecasting/sf_auto_ces.py index 4b79aa111d..5ec8fc1a44 100644 --- a/darts/models/forecasting/sf_auto_ces.py +++ b/darts/models/forecasting/sf_auto_ces.py @@ -84,7 +84,3 @@ def min_train_series_length(self) -> int: @property def _supports_range_index(self) -> bool: return True - - @property - def _is_probabilistic(self) -> bool: - return False diff --git a/darts/models/forecasting/sf_auto_ets.py b/darts/models/forecasting/sf_auto_ets.py index 9636436e0a..95572c42fe 100644 --- a/darts/models/forecasting/sf_auto_ets.py +++ b/darts/models/forecasting/sf_auto_ets.py @@ -164,5 +164,5 @@ def _supports_range_index(self) -> bool: return True @property - def _is_probabilistic(self) -> bool: + def supports_probabilistic_prediction(self) -> bool: return True diff --git a/darts/models/forecasting/sf_auto_theta.py b/darts/models/forecasting/sf_auto_theta.py index 53a6400cca..626c570665 100644 --- a/darts/models/forecasting/sf_auto_theta.py +++ b/darts/models/forecasting/sf_auto_theta.py @@ -99,5 +99,5 @@ def _supports_range_index(self) -> bool: return True @property - def _is_probabilistic(self) -> bool: + def supports_probabilistic_prediction(self) -> bool: return True diff --git a/darts/models/forecasting/tbats_model.py b/darts/models/forecasting/tbats_model.py index ec1cc62cfd..debab5060f 100644 --- a/darts/models/forecasting/tbats_model.py +++ b/darts/models/forecasting/tbats_model.py @@ -248,7 +248,7 @@ def supports_multivariate(self) -> bool: return False @property - def _is_probabilistic(self) -> bool: + def supports_probabilistic_prediction(self) -> bool: return True @property diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index b77989092a..af7f0b19f2 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -2051,9 +2051,9 @@ def output_chunk_shift(self) -> int: ) @property - def _is_probabilistic(self) -> bool: + def supports_probabilistic_prediction(self) -> bool: return ( - self.model._is_probabilistic + self.model.supports_probabilistic_prediction if self.model_created else True # all torch models can be probabilistic (via Dropout) ) diff --git a/darts/models/forecasting/varima.py b/darts/models/forecasting/varima.py index 3b6e4e9c05..cce35bb7e1 100644 --- a/darts/models/forecasting/varima.py +++ b/darts/models/forecasting/varima.py @@ -254,7 +254,7 @@ def min_train_series_length(self) -> int: return 30 @property - def _is_probabilistic(self) -> bool: + def supports_probabilistic_prediction(self) -> bool: return True @property diff --git a/darts/models/forecasting/xgboost.py b/darts/models/forecasting/xgboost.py index e62a37d065..99f38fc59a 100644 --- a/darts/models/forecasting/xgboost.py +++ b/darts/models/forecasting/xgboost.py @@ -328,7 +328,7 @@ def _predict_and_sample( ) @property - def _is_probabilistic(self) -> bool: + def supports_probabilistic_prediction(self) -> bool: return self.likelihood is not None @property diff --git a/darts/tests/models/forecasting/test_TFT.py b/darts/tests/models/forecasting/test_TFT.py index b0eb2e4bdf..5758bef8f3 100644 --- a/darts/tests/models/forecasting/test_TFT.py +++ b/darts/tests/models/forecasting/test_TFT.py @@ -381,7 +381,7 @@ def helper_fit_predict( series=series, past_covariates=past_covariates, future_covariates=future_covariates, - num_samples=(100 if model._is_probabilistic else 1), + num_samples=(100 if model.supports_probabilistic_prediction else 1), ) if isinstance(y_hat, TimeSeries): diff --git a/darts/tests/models/forecasting/test_ensemble_models.py b/darts/tests/models/forecasting/test_ensemble_models.py index 6197c3113f..79d3f5d762 100644 --- a/darts/tests/models/forecasting/test_ensemble_models.py +++ b/darts/tests/models/forecasting/test_ensemble_models.py @@ -199,7 +199,7 @@ def test_stochastic_naive_ensemble(self): # only probabilistic forecasting models naive_ensemble_proba = NaiveEnsembleModel([model_proba_1, model_proba_2]) - assert naive_ensemble_proba._is_probabilistic + assert naive_ensemble_proba.supports_probabilistic_prediction naive_ensemble_proba.fit(self.series1 + self.series2) # by default, only 1 sample diff --git a/darts/tests/models/forecasting/test_global_forecasting_models.py b/darts/tests/models/forecasting/test_global_forecasting_models.py index b12ae6d764..b8b020f342 100644 --- a/darts/tests/models/forecasting/test_global_forecasting_models.py +++ b/darts/tests/models/forecasting/test_global_forecasting_models.py @@ -447,7 +447,7 @@ def test_covariates(self, config): ) # when model is fit using 1 training and 1 covariate series, time series args are optional - if model._is_probabilistic: + if model.supports_probabilistic_prediction: return model = model_cls( input_chunk_length=IN_LEN, output_chunk_length=OUT_LEN, **kwargs @@ -661,7 +661,7 @@ def test_same_result_with_different_n_jobs(self, config): model.fit(multiple_ts) # safe random state for two successive identical predictions - if model._is_probabilistic: + if model.supports_probabilistic_prediction: random_state = deepcopy(model._random_instance) else: random_state = None diff --git a/darts/tests/models/forecasting/test_regression_ensemble_model.py b/darts/tests/models/forecasting/test_regression_ensemble_model.py index e9fa0b9578..bb979955d2 100644 --- a/darts/tests/models/forecasting/test_regression_ensemble_model.py +++ b/darts/tests/models/forecasting/test_regression_ensemble_model.py @@ -610,7 +610,7 @@ def test_stochastic_regression_ensemble_model(self): ) assert ensemble_allproba._models_are_probabilistic - assert ensemble_allproba._is_probabilistic + assert ensemble_allproba.supports_probabilistic_prediction ensemble_allproba.fit(self.ts_random_walk[:100]) # probabilistic forecasting is supported pred = ensemble_allproba.predict(5, num_samples=10) @@ -627,7 +627,7 @@ def test_stochastic_regression_ensemble_model(self): ) assert not ensemble_mixproba._models_are_probabilistic - assert ensemble_mixproba._is_probabilistic + assert ensemble_mixproba.supports_probabilistic_prediction ensemble_mixproba.fit(self.ts_random_walk[:100]) # probabilistic forecasting is supported pred = ensemble_mixproba.predict(5, num_samples=10) @@ -647,7 +647,7 @@ def test_stochastic_regression_ensemble_model(self): ) assert not ensemble_mixproba2._models_are_probabilistic - assert ensemble_mixproba2._is_probabilistic + assert ensemble_mixproba2.supports_probabilistic_prediction ensemble_mixproba2.fit(self.ts_random_walk[:100]) pred = ensemble_mixproba2.predict(5, num_samples=10) assert pred.n_samples == 10 @@ -663,7 +663,7 @@ def test_stochastic_regression_ensemble_model(self): ) assert not ensemble_proba_reg._models_are_probabilistic - assert ensemble_proba_reg._is_probabilistic + assert ensemble_proba_reg.supports_probabilistic_prediction ensemble_proba_reg.fit(self.ts_random_walk[:100]) # probabilistic forecasting is supported pred = ensemble_proba_reg.predict(5, num_samples=10) @@ -680,7 +680,7 @@ def test_stochastic_regression_ensemble_model(self): ) assert ensemble_dete_reg._models_are_probabilistic - assert not ensemble_dete_reg._is_probabilistic + assert not ensemble_dete_reg.supports_probabilistic_prediction ensemble_dete_reg.fit(self.ts_random_walk[:100]) # deterministic forecasting is supported ensemble_dete_reg.predict(5, num_samples=1) @@ -699,7 +699,7 @@ def test_stochastic_regression_ensemble_model(self): ) assert not ensemble_alldete._models_are_probabilistic - assert not ensemble_alldete._is_probabilistic + assert not ensemble_alldete.supports_probabilistic_prediction ensemble_alldete.fit(self.ts_random_walk[:100]) # deterministic forecasting is supported ensemble_alldete.predict(5, num_samples=1)