From 120c7b60920d869ba211bd04732ccaa0844dd814 Mon Sep 17 00:00:00 2001 From: dennisbader Date: Thu, 21 Sep 2023 16:28:21 +0200 Subject: [PATCH 1/3] make tcn loss only compute on output chunk --- .../forecasting/pl_forecasting_module.py | 21 +++--------- darts/models/forecasting/tcn_model.py | 32 ++----------------- .../forecasting/torch_forecasting_model.py | 7 ---- docs/source/conf.py | 2 +- 4 files changed, 7 insertions(+), 55 deletions(-) diff --git a/darts/models/forecasting/pl_forecasting_module.py b/darts/models/forecasting/pl_forecasting_module.py index 821e35745c..0af865eaf7 100644 --- a/darts/models/forecasting/pl_forecasting_module.py +++ b/darts/models/forecasting/pl_forecasting_module.py @@ -185,13 +185,6 @@ def __init__( self.pred_batch_size: Optional[int] = None self.pred_n_jobs: Optional[int] = None - @property - def first_prediction_index(self) -> int: - """ - Returns the index of the first predicted within the output of self.model. - """ - return 0 - @abstractmethod def forward(self, *args, **kwargs) -> Any: super().forward(*args, **kwargs) @@ -592,9 +585,7 @@ def _get_batch_prediction( dim=dim_component, ) - out = self._produce_predict_output(x=(input_past, static_covariates))[ - :, self.first_prediction_index :, : - ] + out = self._produce_predict_output(x=(input_past, static_covariates)) batch_prediction = [out[:, :roll_size, :]] prediction_length = roll_size @@ -641,9 +632,7 @@ def _get_batch_prediction( ] = future_past_covariates[:, left_past:right_past, :] # take only last part of the output sequence where needed - out = self._produce_predict_output(x=(input_past, static_covariates))[ - :, self.first_prediction_index :, : - ] + out = self._produce_predict_output(x=(input_past, static_covariates)) batch_prediction.append(out) prediction_length += self.output_chunk_length @@ -775,9 +764,7 @@ def _get_batch_prediction( ) ) - out = self._produce_predict_output(x=(input_past, input_future, input_static))[ - :, self.first_prediction_index :, : - ] + out = self._produce_predict_output(x=(input_past, input_future, input_static)) batch_prediction = [out[:, :roll_size, :]] prediction_length = roll_size @@ -845,7 +832,7 @@ def _get_batch_prediction( # take only last part of the output sequence where needed out = self._produce_predict_output( x=(input_past, input_future, input_static) - )[:, self.first_prediction_index :, :] + ) batch_prediction.append(out) prediction_length += self.output_chunk_length diff --git a/darts/models/forecasting/tcn_model.py b/darts/models/forecasting/tcn_model.py index 3b9795b033..647101e489 100644 --- a/darts/models/forecasting/tcn_model.py +++ b/darts/models/forecasting/tcn_model.py @@ -4,7 +4,7 @@ """ import math -from typing import Optional, Sequence, Tuple +from typing import Optional, Tuple import torch import torch.nn as nn @@ -16,8 +16,6 @@ io_processor, ) from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel -from darts.timeseries import TimeSeries -from darts.utils.data import PastCovariatesShiftedDataset from darts.utils.torch import MonteCarloDropout logger = get_logger(__name__) @@ -139,7 +137,6 @@ def __init__( weight_norm: bool, target_size: int, nr_params: int, - target_length: int, dropout: float, **kwargs ): @@ -155,8 +152,6 @@ def __init__( The dimensionality of the output time series. nr_params The number of parameters of the likelihood (or 1 if no likelihood is used). - target_length - Number of time steps the torch module will predict into the future at once. kernel_size The size of every kernel in a convolutional layer. num_filters @@ -191,7 +186,6 @@ def __init__( self.input_size = input_size self.n_filters = num_filters self.kernel_size = kernel_size - self.target_length = target_length self.target_size = target_size self.nr_params = nr_params self.dilation_base = dilation_base @@ -249,11 +243,7 @@ def forward(self, x_in: Tuple): batch_size, self.input_chunk_length, self.target_size, self.nr_params ) - return x - - @property - def first_prediction_index(self) -> int: - return -self.output_chunk_length + return x[:, -self.output_chunk_length :, :, :] class TCNModel(PastCovariatesTorchModel): @@ -510,25 +500,7 @@ def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module: num_filters=self.num_filters, num_layers=self.num_layers, dilation_base=self.dilation_base, - target_length=self.output_chunk_length, dropout=self.dropout, weight_norm=self.weight_norm, **self.pl_module_params, ) - - def _build_train_dataset( - self, - target: Sequence[TimeSeries], - past_covariates: Optional[Sequence[TimeSeries]], - future_covariates: Optional[Sequence[TimeSeries]], - max_samples_per_ts: Optional[int], - ) -> PastCovariatesShiftedDataset: - - return PastCovariatesShiftedDataset( - target_series=target, - covariates=past_covariates, - length=self.input_chunk_length, - shift=self.output_chunk_length, - max_samples_per_ts=max_samples_per_ts, - use_static_covariates=self.uses_static_covariates, - ) diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index 347890ceec..3b27bbfbaf 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -1478,13 +1478,6 @@ def predict_from_dataset( # flatten and return return [ts for batch in predictions for ts in batch] - @property - def first_prediction_index(self) -> int: - """ - Returns the index of the first predicted within the output of self.model. - """ - return 0 - @property def min_train_series_length(self) -> int: """ diff --git a/docs/source/conf.py b/docs/source/conf.py index a957818653..2c25b7d8ef 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -53,7 +53,7 @@ + "PastCovariatesTorchModel,FutureCovariatesTorchModel,DualCovariatesTorchModel,MixedCovariatesTorchModel," + "SplitCovariatesTorchModel,TorchParametricProbabilisticForecastingModel," + "min_train_series_length," - + "untrained_model,first_prediction_index,future_covariate_series,past_covariate_series," + + "untrained_model,future_covariate_series,past_covariate_series," + "initialize_encoders,register_datapipe_as_function,register_function,functions," + "SplitTimeSeriesSequence,randint,AnomalyModel", } From f18881f19dbd711ad33ef98fea86b2a50206e0ca Mon Sep 17 00:00:00 2001 From: dennisbader Date: Thu, 21 Sep 2023 16:32:44 +0200 Subject: [PATCH 2/3] update docs --- darts/models/forecasting/tcn_model.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/darts/models/forecasting/tcn_model.py b/darts/models/forecasting/tcn_model.py index 647101e489..2cf2628b4b 100644 --- a/darts/models/forecasting/tcn_model.py +++ b/darts/models/forecasting/tcn_model.py @@ -143,7 +143,6 @@ def __init__( """PyTorch module implementing a dilated TCN module used in `TCNModel`. - Parameters ---------- input_size @@ -174,10 +173,8 @@ def __init__( Outputs ------- - y of shape `(batch_size, input_chunk_length, target_size, nr_params)` - Tensor containing the predictions of the next 'output_chunk_length' points in the last - 'output_chunk_length' entries of the tensor. The entries before contain the data points - leading up to the first prediction, all in chronological order. + y of shape `(batch_size, output_chunk_length, target_size, nr_params)` + Tensor containing the predictions of the next 'output_chunk_length' points. """ super().__init__(**kwargs) From b1581b66e742d683b6426626310402ab2d9a4233 Mon Sep 17 00:00:00 2001 From: dennisbader Date: Thu, 21 Sep 2023 16:41:22 +0200 Subject: [PATCH 3/3] update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d4441fed65..f35303e123 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ but cannot always guarantee backwards compatibility. Changes that may **break co [Full Changelog](https://github.com/unit8co/darts/compare/0.26.0...master) ### For users of the library: + +**Fixed** +- Fixed an issue where `TCNModel` training included the last (input_chunk_length - output_chunk_length) target points in the loss computation. [#2006](https://github.com/unit8co/darts/pull/2006) by [Dennis Bader](https://github.com/dennisbader). + ### For developers of the library: ## [0.26.0](https://github.com/unit8co/darts/tree/0.26.0) (2023-09-16)