diff --git a/CHANGELOG.md b/CHANGELOG.md index 83b1d002..6dcafe7a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,12 +5,15 @@ ### Added - Adding a filter functionality to the timeseries datasset (#329) +- Add simple models such as LSTM, GRU and a MLP on the decoder (#380) +- Allow usage of any torch optimizer such as SGD (#380) ### Fixed - Moving predictions to CPU to avoid running out of memory (#329) - Correct determination of `output_size` for multi-target forecasting with the TemporalFusionTransformer (#328) - Tqdm autonotebook fix to work outside of Jupyter (#338) +- Fix issue with yaml serialization for TensorboardLogger (#379) ### Contributors diff --git a/README.md b/README.md index 55392dee..a1165590 100755 --- a/README.md +++ b/README.md @@ -49,6 +49,8 @@ documentation with detailed tutorials. methods in the M4 competition. The M4 competition is arguably the most important benchmark for univariate time series forecasting. - [DeepAR: Probabilistic forecasting with autoregressive recurrent networks](https://www.sciencedirect.com/science/article/pii/S0169207019301888) which is the one of the most popular forecasting algorithms and is often used as a baseline +- A baseline model that always predicts the latest known value +- Simple standard networks for baselining: LSTM and GRU networks as well as a MLP on the decoder To implement new models, see the [How to implement new models tutorial](https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/building.html). It covers basic as well as advanced architectures. diff --git a/docs/source/models.rst b/docs/source/models.rst index 8320eb1f..39281895 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -24,6 +24,8 @@ and you should take into account. Here is an overview over the pros and cons of .. csv-table:: Model comparison :header: "Name", "Covariates", "Multiple targets", "Regression", "Classification", "Probabilistic", "Uncertainty", "Interactions between series", "Flexible history length", "Cold-start", "Required computational resources (1-5, 5=most)" + :py:class:`~pytorch_forecasting.models.rnn.RecurrentNetwork`, "x", "x", "x", "", "", "", "", "x", "", 2 + :py:class:`~pytorch_forecasting.models.mlp.DecoderMLP`, "x", "x", "x", "x", "", "x", "", "x", "x", 1 :py:class:`~pytorch_forecasting.models.nbeats.NBeats`, "", "", "x", "", "", "", "", "", "", 1 :py:class:`~pytorch_forecasting.models.deepar.DeepAR`, "x", "x", "x", "", "x", "x", "", "x", "", 3 :py:class:`~pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`, "x", "x", "x", "x", "", "x", "", "x", "x", 4 diff --git a/pytorch_forecasting/__init__.py b/pytorch_forecasting/__init__.py index ad014e09..d7bea384 100644 --- a/pytorch_forecasting/__init__.py +++ b/pytorch_forecasting/__init__.py @@ -36,9 +36,11 @@ Baseline, BaseModel, BaseModelWithCovariates, + DecoderMLP, DeepAR, MultiEmbedding, NBeats, + RecurrentNetwork, TemporalFusionTransformer, get_rnn, ) @@ -85,6 +87,8 @@ "get_embedding_size", "create_mask", "to_list", + "RecurrentNetwork", + "DecoderMLP", ] __version__ = "0.0.0" diff --git a/pytorch_forecasting/models/__init__.py b/pytorch_forecasting/models/__init__.py index 67db8902..496a3691 100644 --- a/pytorch_forecasting/models/__init__.py +++ b/pytorch_forecasting/models/__init__.py @@ -9,13 +9,16 @@ ) from pytorch_forecasting.models.baseline import Baseline from pytorch_forecasting.models.deepar import DeepAR +from pytorch_forecasting.models.mlp import DecoderMLP from pytorch_forecasting.models.nbeats import NBeats from pytorch_forecasting.models.nn import GRU, LSTM, MultiEmbedding, get_rnn +from pytorch_forecasting.models.rnn import RecurrentNetwork from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer __all__ = [ "NBeats", "TemporalFusionTransformer", + "RecurrentNetwork", "DeepAR", "BaseModel", "Baseline", @@ -26,4 +29,5 @@ "LSTM", "GRU", "MultiEmbedding", + "DecoderMLP", ] diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index f76b6adc..0028e2d5 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -20,7 +20,16 @@ from pytorch_forecasting.data import TimeSeriesDataSet from pytorch_forecasting.data.encoders import EncoderNormalizer, GroupNormalizer, MultiNormalizer, NaNLabelEncoder -from pytorch_forecasting.metrics import MASE, SMAPE, DistributionLoss, Metric, MultiLoss +from pytorch_forecasting.metrics import ( + MAE, + MASE, + SMAPE, + DistributionLoss, + Metric, + MultiHorizonMetric, + MultiLoss, + QuantileLoss, +) from pytorch_forecasting.optim import Ranger from pytorch_forecasting.utils import apply_to_list, create_mask, get_embedding_size, groupby_apply, to_list @@ -154,6 +163,7 @@ def __init__( reduce_on_plateau_patience: int = 1000, reduce_on_plateau_min_lr: float = 1e-5, weight_decay: float = 0.0, + optimizer_params: Dict[str, Any] = None, monotone_constaints: Dict[str, int] = {}, output_transformer: Callable = None, optimizer="ranger", @@ -177,6 +187,7 @@ def __init__( reduce_on_plateau_min_lr (float): minimum learning rate for reduce on plateua learning rate scheduler. Defaults to 1e-5 weight_decay (float): weight decay. Defaults to 0.0. + optimizer_params (Dict[str, Any]): additional parameters for the optimizer. Defaults to {}. monotone_constaints (Dict[str, int]): dictionary of monotonicity constraints for continuous decoder variables mapping position (e.g. ``"0"`` for first position) to constraint (``-1`` for negative and ``+1`` for positive, @@ -184,7 +195,8 @@ def __init__( This constraint significantly slows down training. Defaults to {}. output_transformer (Callable): transformer that takes network output and transforms it to prediction space. Defaults to None which is equivalent to ``lambda out: out["prediction"]``. - optimizer (str): Optimizer, "ranger", "adam" or "adamw". Defaults to "ranger". + optimizer (str): Optimizer, "ranger", "sgd", "adam", "adamw" or class name of optimizer in ``torch.optim``. + Defaults to "ranger". """ super().__init__() # update hparams @@ -203,6 +215,21 @@ def __init__( if not hasattr(self, "output_transformer"): self.output_transformer = output_transformer + @property + def n_targets(self) -> int: + """ + Number of targets to forecast. + + Based on loss function. + + Returns: + int: number of targets + """ + if isinstance(self.loss, MultiLoss): + return len(self.loss.metrics) + else: + return 1 + def transform_output(self, out: Dict[str, torch.Tensor]) -> torch.Tensor: """ Extract prediction from network output and rescale it to real space / de-normalize it. @@ -251,6 +278,52 @@ def transform_output(self, out: Dict[str, torch.Tensor]) -> torch.Tensor: out = self.output_transformer(out) return out + @staticmethod + def deduce_default_output_parameters( + dataset: TimeSeriesDataSet, kwargs: Dict[str, Any], default_loss: MultiHorizonMetric = None + ) -> Dict[str, Any]: + """ + Deduce default parameters for output for `from_dataset()` method. + + Determines ``output_size`` and ``loss`` parameters. + + Args: + dataset (TimeSeriesDataSet): timeseries dataset + kwargs (Dict[str, Any]): current hyperparameters + default_loss (MultiHorizonMetric, optional): default loss function. + Defaults to :py:class:`~pytorch_forecasting.metrics.MAE`. + + Returns: + Dict[str, Any]: dictionary with ``output_size`` and ``loss``. + """ + # infer output size + def get_output_size(normalizer, loss): + if isinstance(loss, QuantileLoss): + return len(loss.quantiles) + elif isinstance(normalizer, NaNLabelEncoder): + return len(normalizer.classes_) + else: + return 1 + + # handle multiple targets + new_kwargs = {} + n_targets = len(dataset.target_names) + if default_loss is None: + default_loss = MAE() + loss = kwargs.get("loss", default_loss) + if n_targets > 1: # try to infer number of ouput sizes + if not isinstance(loss, MultiLoss): + loss = MultiLoss([deepcopy(loss)] * n_targets) + new_kwargs["loss"] = loss + if isinstance(loss, MultiLoss) and "output_size" not in kwargs: + new_kwargs["output_size"] = [ + get_output_size(normalizer, l) + for normalizer, l in zip(dataset.target_normalizer.normalizers, loss.metrics) + ] + elif "output_size" not in kwargs: + new_kwargs["output_size"] = get_output_size(dataset.target_normalizer, loss) + return new_kwargs + def size(self) -> int: """ get number of parameters in model @@ -673,6 +746,10 @@ def configure_optimizers(self): Tuple[List]: first entry is list of optimizers and second is list of schedulers """ # either set a schedule of lrs or find it dynamically + if self.hparams.optimizer_params is None: + optimizer_params = {} + else: + optimizer_params = self.hparams.optimizer_params if isinstance(self.hparams.learning_rate, (list, tuple)): # set schedule lrs = self.hparams.learning_rate if self.hparams.optimizer == "adam": @@ -681,8 +758,17 @@ def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=lrs[0]) elif self.hparams.optimizer == "ranger": optimizer = Ranger(self.parameters(), lr=lrs[0], weight_decay=self.hparams.weight_decay) + elif self.hparams.optimizer == "sgd": + optimizer = torch.optim.SGD( + self.parameters(), lr=lrs[0], weight_decay=self.hparams.weight_decay, **optimizer_params + ) else: - raise ValueError(f"Optimizer of self.hparams.optimizer={self.hparams.optimizer} unknown") + try: + optimizer = getattr(torch.optim, self.hparams.optimizer)( + self.parameters(), lr=lrs[0], weight_decay=self.hparams.weight_decay, **optimizer_params + ) + except AttributeError: + raise ValueError(f"Optimizer of self.hparams.optimizer={self.hparams.optimizer} unknown") # normalize lrs lrs = np.array(lrs) / lrs[0] schedulers = [ diff --git a/pytorch_forecasting/models/deepar/__init__.py b/pytorch_forecasting/models/deepar/__init__.py index 12d67aca..02ef8dd8 100644 --- a/pytorch_forecasting/models/deepar/__init__.py +++ b/pytorch_forecasting/models/deepar/__init__.py @@ -176,7 +176,6 @@ def from_dataset( Returns: DeepAR network """ - # assert fixed encoder and decoder length for the moment new_kwargs = {} if dataset.multi_target: new_kwargs.setdefault("loss", MultiLoss([NormalDistributionLoss()] * len(dataset.target_names))) diff --git a/pytorch_forecasting/models/mlp/__init__.py b/pytorch_forecasting/models/mlp/__init__.py new file mode 100644 index 00000000..dcc6db74 --- /dev/null +++ b/pytorch_forecasting/models/mlp/__init__.py @@ -0,0 +1,155 @@ +""" +Simple models based on fully connected networks +""" + + +from typing import Dict, List, Tuple, Union + +import numpy as np +import torch +from torch import nn + +from pytorch_forecasting.data import TimeSeriesDataSet +from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric, QuantileLoss +from pytorch_forecasting.models.base_model import BaseModelWithCovariates +from pytorch_forecasting.models.mlp.submodules import FullyConnectedModule +from pytorch_forecasting.models.nn.embeddings import MultiEmbedding + + +class DecoderMLP(BaseModelWithCovariates): + """ + MLP on the decoder. + + MLP that predicts output only based on information available in the decoder. + """ + + def __init__( + self, + activation_class: str = "ReLU", + hidden_size: int = 300, + n_hidden_layers: int = 3, + dropout: float = 0.1, + norm: bool = True, + static_categoricals: List[str] = [], + static_reals: List[str] = [], + time_varying_categoricals_encoder: List[str] = [], + time_varying_categoricals_decoder: List[str] = [], + categorical_groups: Dict[str, List[str]] = {}, + time_varying_reals_encoder: List[str] = [], + time_varying_reals_decoder: List[str] = [], + embedding_sizes: Dict[str, Tuple[int, int]] = {}, + embedding_paddings: List[str] = [], + embedding_labels: Dict[str, np.ndarray] = {}, + x_reals: List[str] = [], + x_categoricals: List[str] = [], + output_size: Union[int, List[int]] = 1, + target: Union[str, List[str]] = None, + loss: MultiHorizonMetric = None, + logging_metrics: nn.ModuleList = None, + **kwargs, + ): + """ + Args: + activation_class (str, optional): PyTorch activation class. Defaults to "ReLU". + hidden_size (int, optional): hidden recurrent size - the most important hyperparameter along with + ``n_hidden_layers``. Defaults to 10. + n_hidden_layers (int, optional): Number of hidden layers - important hyperparameter. Defaults to 2. + dropout (float, optional): Dropout. Defaults to 0.1. + norm (bool, optional): if to use normalization in the MLP. Defaults to True. + static_categoricals: integer of positions of static categorical variables + static_reals: integer of positions of static continuous variables + time_varying_categoricals_encoder: integer of positions of categorical variables for encoder + time_varying_categoricals_decoder: integer of positions of categorical variables for decoder + time_varying_reals_encoder: integer of positions of continuous variables for encoder + time_varying_reals_decoder: integer of positions of continuous variables for decoder + categorical_groups: dictionary where values + are list of categorical variables that are forming together a new categorical + variable which is the key in the dictionary + x_reals: order of continuous variables in tensor passed to forward function + x_categoricals: order of categorical variables in tensor passed to forward function + embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and + embedding size + embedding_paddings: list of indices for embeddings which transform the zero's embedding to a zero vector + embedding_labels: dictionary mapping (string) indices to list of categorical labels + output_size (Union[int, List[int]], optional): number of outputs (e.g. number of quantiles for + QuantileLoss and one target or list of output sizes). + target (str, optional): Target variable or list of target variables. Defaults to None. + loss (MultiHorizonMetric, optional): loss: loss function taking prediction and targets. + Defaults to QuantileLoss. + logging_metrics (nn.ModuleList, optional): Metrics to log during training. + Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]). + """ + if loss is None: + loss = QuantileLoss() + if logging_metrics is None: + logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + self.save_hyperparameters() + # store loss function separately as it is a module + super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) + + self.input_embeddings = MultiEmbedding( + embedding_sizes={ + name: val + for name, val in embedding_sizes.items() + if name in self.decoder_variables + self.static_variables + }, + embedding_paddings=embedding_paddings, + categorical_groups=categorical_groups, + x_categoricals=x_categoricals, + ) + # define network + if isinstance(self.hparams.output_size, int): + mlp_output_size = self.hparams.output_size + else: + mlp_output_size = sum(self.hparams.output_size) + + cont_size = len(self.decoder_reals_positions) + cat_size = sum([emb.embedding_dim for emb in self.input_embeddings.values()]) + input_size = cont_size + cat_size + + self.mlp = FullyConnectedModule( + dropout=dropout, + norm=self.hparams.norm, + activation_class=getattr(nn, self.hparams.activation_class), + input_size=input_size, + output_size=mlp_output_size, + hidden_size=self.hparams.hidden_size, + n_hidden_layers=self.hparams.n_hidden_layers, + ) + + @property + def decoder_reals_positions(self) -> List[int]: + return [ + self.hparams.x_reals.index(name) + for name in self.reals + if name in self.decoder_variables + self.static_variables + ] + + def forward(self, x: Dict[str, torch.Tensor], n_samples: int = None) -> Dict[str, torch.Tensor]: + """ + Forward network + """ + # x is a batch generated based on the TimeSeriesDataset + batch_size = x["decoder_lengths"].size(0) + embeddings = self.input_embeddings(x["decoder_cat"]) # returns dictionary with embedding tensors + network_input = torch.cat( + [x["decoder_cont"][..., self.decoder_reals_positions]] + list(embeddings.values()), + dim=-1, + ) + prediction = self.mlp(network_input.view(-1, self.mlp.input_size)).view( + batch_size, network_input.size(1), self.mlp.output_size + ) + + # cut prediction into pieces for multiple targets + if self.n_targets > 1: + prediction = torch.split(prediction, self.hparams.output_size, dim=-1) + + # We need to return a dictionary that at least contains the prediction and the target_scale. + # The parameter can be directly forwarded from the input. + return dict(prediction=prediction, target_scale=x["target_scale"]) + + @classmethod + def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): + new_kwargs = cls.deduce_default_output_parameters(dataset, kwargs, QuantileLoss()) + kwargs.update(new_kwargs) + return super().from_dataset(dataset, **kwargs) diff --git a/pytorch_forecasting/models/mlp/submodules.py b/pytorch_forecasting/models/mlp/submodules.py new file mode 100644 index 00000000..469eebbd --- /dev/null +++ b/pytorch_forecasting/models/mlp/submodules.py @@ -0,0 +1,49 @@ +""" +MLP implementation +""" +import torch +from torch import nn + + +class FullyConnectedModule(nn.Module): + def __init__( + self, + input_size: int, + output_size: int, + hidden_size: int, + n_hidden_layers: int, + activation_class: nn.ReLU, + dropout: float = None, + norm: bool = True, + ): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.hidden_size = hidden_size + self.n_hidden_layers = n_hidden_layers + self.activation_class = activation_class + self.dropout = dropout + self.norm = norm + + # input layer + module_list = [nn.Linear(input_size, hidden_size), activation_class()] + if dropout is not None: + module_list.append(nn.Dropout(dropout)) + if norm: + module_list.append(nn.LayerNorm(hidden_size)) + # hidden layers + for _ in range(n_hidden_layers): + module_list.extend([nn.Linear(hidden_size, hidden_size), activation_class()]) + if dropout is not None: + module_list.append(nn.Dropout(dropout)) + if norm: + module_list.append(nn.LayerNorm(hidden_size)) + # output layer + module_list.append(nn.Linear(hidden_size, output_size)) + + self.sequential = nn.Sequential(*module_list) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x of shape: batch_size x n_timesteps_in + # output of shape batch_size x n_timesteps_out + return self.sequential(x) diff --git a/pytorch_forecasting/models/rnn/__init__.py b/pytorch_forecasting/models/rnn/__init__.py new file mode 100644 index 00000000..d69d8284 --- /dev/null +++ b/pytorch_forecasting/models/rnn/__init__.py @@ -0,0 +1,295 @@ +""" +Simple recurrent model - either with LSTM or GRU cells. +""" +from typing import Dict, List, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + +from pytorch_forecasting.data.encoders import MultiNormalizer, NaNLabelEncoder +from pytorch_forecasting.data.timeseries import TimeSeriesDataSet +from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric, MultiLoss, QuantileLoss +from pytorch_forecasting.models.base_model import AutoRegressiveBaseModelWithCovariates +from pytorch_forecasting.models.nn import HiddenState, MultiEmbedding, get_rnn +from pytorch_forecasting.utils import apply_to_list, to_list + + +class RecurrentNetwork(AutoRegressiveBaseModelWithCovariates): + def __init__( + self, + cell_type: str = "LSTM", + hidden_size: int = 10, + rnn_layers: int = 2, + dropout: float = 0.1, + static_categoricals: List[str] = [], + static_reals: List[str] = [], + time_varying_categoricals_encoder: List[str] = [], + time_varying_categoricals_decoder: List[str] = [], + categorical_groups: Dict[str, List[str]] = {}, + time_varying_reals_encoder: List[str] = [], + time_varying_reals_decoder: List[str] = [], + embedding_sizes: Dict[str, Tuple[int, int]] = {}, + embedding_paddings: List[str] = [], + embedding_labels: Dict[str, np.ndarray] = {}, + x_reals: List[str] = [], + x_categoricals: List[str] = [], + output_size: Union[int, List[int]] = 1, + target: Union[str, List[str]] = None, + target_lags: Dict[str, List[int]] = {}, + loss: MultiHorizonMetric = None, + logging_metrics: nn.ModuleList = None, + **kwargs, + ): + """ + Recurrent Network. + + Simple LSTM or GRU layer followed by output layer + + Args: + cell_type (str, optional): Recurrent cell type ["LSTM", "GRU"]. Defaults to "LSTM". + hidden_size (int, optional): hidden recurrent size - the most important hyperparameter along with + ``rnn_layers``. Defaults to 10. + rnn_layers (int, optional): Number of RNN layers - important hyperparameter. Defaults to 2. + dropout (float, optional): Dropout in RNN layers. Defaults to 0.1. + static_categoricals: integer of positions of static categorical variables + static_reals: integer of positions of static continuous variables + time_varying_categoricals_encoder: integer of positions of categorical variables for encoder + time_varying_categoricals_decoder: integer of positions of categorical variables for decoder + time_varying_reals_encoder: integer of positions of continuous variables for encoder + time_varying_reals_decoder: integer of positions of continuous variables for decoder + categorical_groups: dictionary where values + are list of categorical variables that are forming together a new categorical + variable which is the key in the dictionary + x_reals: order of continuous variables in tensor passed to forward function + x_categoricals: order of categorical variables in tensor passed to forward function + embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and + embedding size + embedding_paddings: list of indices for embeddings which transform the zero's embedding to a zero vector + embedding_labels: dictionary mapping (string) indices to list of categorical labels + output_size (Union[int, List[int]], optional): number of outputs (e.g. number of quantiles for + QuantileLoss and one target or list of output sizes). + target (str, optional): Target variable or list of target variables. Defaults to None. + target_lags (Dict[str, Dict[str, int]]): dictionary of target names mapped to list of time steps by + which the variable should be lagged. + Lags can be useful to indicate seasonality to the models. If you know the seasonalit(ies) of your data, + add at least the target variables with the corresponding lags to improve performance. + Defaults to no lags, i.e. an empty dictionary. + loss (MultiHorizonMetric, optional): loss: loss function taking prediction and targets. + logging_metrics (nn.ModuleList, optional): Metrics to log during training. + Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]). + """ + if loss is None: + loss = MAE() + if logging_metrics is None: + logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + self.save_hyperparameters() + # store loss function separately as it is a module + super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) + + self.embeddings = MultiEmbedding( + embedding_sizes=embedding_sizes, + embedding_paddings=embedding_paddings, + categorical_groups=categorical_groups, + x_categoricals=x_categoricals, + ) + + lagged_target_names = [l for lags in target_lags.values() for l in lags] + assert set(self.encoder_variables) - set(to_list(target)) - set(lagged_target_names) == set( + self.decoder_variables + ), "Encoder and decoder variables have to be the same apart from target variable" + for targeti in to_list(target): + assert ( + targeti in time_varying_reals_encoder + ), f"target {targeti} has to be real" # todo: remove this restriction + assert (isinstance(target, str) and isinstance(loss, MultiHorizonMetric)) or ( + isinstance(target, (list, tuple)) and isinstance(loss, MultiLoss) and len(loss) == len(target) + ), "number of targets should be equivalent to number of loss metrics" + + rnn_class = get_rnn(cell_type) + cont_size = len(self.reals) + cat_size = sum([size[1] for size in self.hparams.embedding_sizes.values()]) + input_size = cont_size + cat_size + self.rnn = rnn_class( + input_size=input_size, + hidden_size=self.hparams.hidden_size, + num_layers=self.hparams.rnn_layers, + dropout=self.hparams.dropout if self.hparams.rnn_layers > 1 else 0, + batch_first=True, + ) + + # add linear layers for argument projects + if isinstance(target, str): # single target + self.output_projector = nn.Linear(self.hparams.hidden_size, self.hparams.output_size) + assert not isinstance(self.loss, QuantileLoss), "QuantileLoss does not work with recurrent network" + else: # multi target + self.output_projector = nn.ModuleList( + [nn.Linear(self.hparams.hidden_size, size) for size in self.hparams.output_size] + ) + for l in self.loss: + assert not isinstance(l, QuantileLoss), "QuantileLoss does not work with recurrent network" + + @classmethod + def from_dataset( + cls, + dataset: TimeSeriesDataSet, + allowed_encoder_known_variable_names: List[str] = None, + **kwargs, + ): + """ + Create model from dataset. + + Args: + dataset: timeseries dataset + allowed_encoder_known_variable_names: List of known variables that are allowed in encoder, defaults to all + **kwargs: additional arguments such as hyperparameters for model (see ``__init__()``) + + Returns: + Recurrent network + """ + new_kwargs = cls.deduce_default_output_parameters(dataset=dataset, kwargs=kwargs, default_loss=MAE()) + assert not isinstance(dataset.target_normalizer, NaNLabelEncoder) and ( + not isinstance(dataset.target_normalizer, MultiNormalizer) + or all([not isinstance(normalizer, NaNLabelEncoder) for normalizer in dataset.target_normalizer]) + ), "target(s) should be continuous - categorical targets are not supported" # todo: remove this restriction + return super().from_dataset( + dataset, allowed_encoder_known_variable_names=allowed_encoder_known_variable_names, **new_kwargs + ) + + def construct_input_vector( + self, x_cat: torch.Tensor, x_cont: torch.Tensor, one_off_target: torch.Tensor = None + ) -> torch.Tensor: + """ + Create input vector into RNN network + + Args: + one_off_target: tensor to insert into first position of target. If None (default), remove first time step. + """ + # create input vector + if len(self.categoricals) > 0: + embeddings = self.embeddings(x_cat) + flat_embeddings = torch.cat([emb for emb in embeddings.values()], dim=-1) + input_vector = flat_embeddings + + if len(self.reals) > 0: + input_vector = x_cont + + if len(self.reals) > 0 and len(self.categoricals) > 0: + input_vector = torch.cat([x_cont, flat_embeddings], dim=-1) + + # shift target by one + input_vector[..., self.target_positions] = torch.roll( + input_vector[..., self.target_positions], shifts=1, dims=1 + ) + + if one_off_target is not None: # set first target input (which is rolled over) + input_vector[:, 0, self.target_positions] = one_off_target + else: + input_vector = input_vector[:, 1:] + + # shift target + return input_vector + + def encode(self, x: Dict[str, torch.Tensor]) -> HiddenState: + """ + Encode sequence into hidden state + """ + # encode using rnn + assert x["encoder_lengths"].min() > 0 + encoder_lengths = x["encoder_lengths"] - 1 + input_vector = self.construct_input_vector(x["encoder_cat"], x["encoder_cont"]) + _, hidden_state = self.rnn( + input_vector, lengths=encoder_lengths, enforce_sorted=False + ) # second ouput is not needed (hidden state) + return hidden_state + + def decode_all( + self, + x: torch.Tensor, + hidden_state: HiddenState, + lengths: torch.Tensor = None, + ): + decoder_output, hidden_state = self.rnn(x, hidden_state, lengths=lengths, enforce_sorted=False) + if isinstance(self.hparams.target, str): # single target + output = self.output_projector(decoder_output) + else: + output = [projector(decoder_output) for projector in self.output_projector] + return output, hidden_state + + def decode( + self, + input_vector: torch.Tensor, + target_scale: torch.Tensor, + decoder_lengths: torch.Tensor, + hidden_state: HiddenState, + n_samples: int = None, + ) -> Tuple[torch.Tensor, bool]: + """ + Decode hidden state of RNN into prediction. If n_smaples is given, + decode not by using actual values but rather by + sampling new targets from past predictions iteratively + """ + if self.training: + output, _ = self.decode_all(input_vector, hidden_state, lengths=decoder_lengths) + output_transformation = True + else: + # run in eval, i.e. simulation mode + target_pos = self.target_positions + lagged_target_positions = self.lagged_target_positions + + # define function to run at every decoding step + def decode_one( + idx, + lagged_targets, + hidden_state, + ): + x = input_vector[:, [idx]] + x[:, 0, target_pos] = lagged_targets[-1] + for lag, lag_positions in lagged_target_positions.items(): + if idx > lag: + x[:, 0, lag_positions] = lagged_targets[-lag] + prediction, hidden_state = self.decode_all(x, hidden_state) + prediction = apply_to_list(prediction, lambda x: x[:, 0]) # select first time step + return prediction, hidden_state + + # make predictions which are fed into next step + output = self.decode_autoregressive( + decode_one, + first_target=input_vector[:, 0, target_pos], + first_hidden_state=hidden_state, + target_scale=target_scale, + n_decoder_steps=input_vector.size(1), + ) + output_transformation = None + return output, output_transformation + + def forward(self, x: Dict[str, torch.Tensor], n_samples: int = None) -> Dict[str, torch.Tensor]: + """ + Forward network + """ + hidden_state = self.encode(x) + # decode + input_vector = self.construct_input_vector( + x["decoder_cat"], + x["decoder_cont"], + one_off_target=x["encoder_cont"][ + torch.arange(x["encoder_cont"].size(0), device=x["encoder_cont"].device), + x["encoder_lengths"] - 1, + self.target_positions.unsqueeze(-1), + ].T, + ) + + output, output_transformation = self.decode( + input_vector, + decoder_lengths=x["decoder_lengths"], + target_scale=x["target_scale"], + hidden_state=hidden_state, + ) + # return relevant part + return dict( + prediction=output, + output_transformation=output_transformation, + groups=x["groups"], + decoder_time_idx=x["decoder_time_idx"], + target_scale=x["target_scale"], + ) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index 27b02d48..557fd397 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -33,7 +33,6 @@ def __init__( lstm_layers: int = 1, dropout: float = 0.1, output_size: Union[int, List[int]] = 7, - n_targets: int = 1, loss: MultiHorizonMetric = None, attention_head_size: int = 4, max_encoder_length: int = 10, @@ -91,7 +90,6 @@ def __init__( dropout: dropout rate output_size: number of outputs (e.g. number of quantiles for QuantileLoss and one target or list of output sizes). - n_targets: number of targets. Defaults to 1. loss: loss function taking prediction and targets attention_head_size: number of attention heads (4 is a good default) max_encoder_length: length to encode (can be far longer than the decoder length but does not have to be) @@ -318,7 +316,7 @@ def __init__( # output processing -> no dropout at this late stage self.pre_output_gate_norm = GateAddNorm(self.hparams.hidden_size, dropout=None, trainable_add=False) - if self.hparams.n_targets > 1: # if to run with multiple targets + if self.n_targets > 1: # if to run with multiple targets self.output_layer = nn.ModuleList( [nn.Linear(self.hparams.hidden_size, output_size) for output_size in self.hparams.output_size] ) @@ -345,30 +343,7 @@ def from_dataset( """ # add maximum encoder length new_kwargs = dict(max_encoder_length=dataset.max_encoder_length) - - # infer output size - def get_output_size(normalizer, loss): - if isinstance(loss, QuantileLoss): - return len(loss.quantiles) - elif isinstance(normalizer, NaNLabelEncoder): - return len(normalizer.classes_) - else: - return 1 - - loss = kwargs.get("loss", QuantileLoss()) - # handle multiple targets - new_kwargs["n_targets"] = len(dataset.target_names) - if new_kwargs["n_targets"] > 1: # try to infer number of ouput sizes - if not isinstance(loss, MultiLoss): - loss = MultiLoss([deepcopy(loss)] * new_kwargs["n_targets"]) - new_kwargs["loss"] = loss - if isinstance(loss, MultiLoss) and "output_size" not in kwargs: - new_kwargs["output_size"] = [ - get_output_size(normalizer, l) - for normalizer, l in zip(dataset.target_normalizer.normalizers, loss.metrics) - ] - elif "output_size" not in kwargs: - new_kwargs["output_size"] = get_output_size(dataset.target_normalizer, loss) + new_kwargs.update(cls.deduce_default_output_parameters(dataset, kwargs, QuantileLoss())) # update defaults new_kwargs.update(kwargs) @@ -518,7 +493,7 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # skip connection over temporal fusion decoder (not LSTM decoder despite the LSTM output contains # a skip from the variable selection network) output = self.pre_output_gate_norm(output, lstm_output[:, max_encoder_length:]) - if self.hparams.n_targets > 1: # if to use multi-target architecture + if self.n_targets > 1: # if to use multi-target architecture output = [output_layer(output) for output_layer in self.output_layer] else: output = self.output_layer(output) diff --git a/tests/test_models/test_mlp.py b/tests/test_models/test_mlp.py new file mode 100644 index 00000000..6d58619a --- /dev/null +++ b/tests/test_models/test_mlp.py @@ -0,0 +1,100 @@ +import pickle +import shutil + +import pytest +import pytorch_lightning as pl +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.loggers import TensorBoardLogger +from test_models.conftest import make_dataloaders + +from pytorch_forecasting.metrics import MAE, CrossEntropy, MultiLoss, QuantileLoss +from pytorch_forecasting.models import DecoderMLP + + +def _integration(data_with_covariates, tmp_path, gpus, data_loader_kwargs={}, **kwargs): + data_loader_default_kwargs = dict( + target="target", + time_varying_known_reals=["price_actual"], + time_varying_unknown_reals=["target"], + static_categoricals=["agency"], + add_relative_time_idx=True, + ) + data_loader_default_kwargs.update(data_loader_kwargs) + dataloaders_with_covariates = make_dataloaders(data_with_covariates, **data_loader_default_kwargs) + train_dataloader = dataloaders_with_covariates["train"] + val_dataloader = dataloaders_with_covariates["val"] + early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min") + + logger = TensorBoardLogger(tmp_path) + trainer = pl.Trainer( + max_epochs=3, + gpus=gpus, + weights_summary="top", + gradient_clip_val=0.1, + callbacks=[early_stop_callback], + checkpoint_callback=True, + default_root_dir=tmp_path, + limit_train_batches=2, + limit_val_batches=2, + logger=logger, + ) + + net = DecoderMLP.from_dataset( + train_dataloader.dataset, learning_rate=0.15, log_gradient_flow=True, log_interval=1000, **kwargs + ) + net.size() + try: + trainer.fit( + net, + train_dataloader=train_dataloader, + val_dataloaders=val_dataloader, + ) + # check loading + net = DecoderMLP.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) + + # check prediction + net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True) + finally: + shutil.rmtree(tmp_path, ignore_errors=True) + + net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True) + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + dict( + loss=MultiLoss([QuantileLoss(), MAE()]), + data_loader_kwargs=dict( + time_varying_unknown_reals=["volume", "discount"], + target=["volume", "discount"], + ), + ), + dict( + loss=CrossEntropy(), + data_loader_kwargs=dict( + target="agency", + ), + ), + ], +) +def test_integration(data_with_covariates, tmp_path, gpus, kwargs): + _integration(data_with_covariates.assign(target=lambda x: x.volume), tmp_path, gpus, **kwargs) + + +@pytest.fixture +def model(dataloaders_with_covariates): + dataset = dataloaders_with_covariates["train"].dataset + net = DecoderMLP.from_dataset( + dataset, + learning_rate=0.15, + log_gradient_flow=True, + log_interval=1000, + ) + return net + + +def test_pickle(model): + pkl = pickle.dumps(model) + pickle.loads(pkl) diff --git a/tests/test_models/test_rnn_model.py b/tests/test_models/test_rnn_model.py new file mode 100644 index 00000000..81458f9a --- /dev/null +++ b/tests/test_models/test_rnn_model.py @@ -0,0 +1,117 @@ +import pickle +import shutil + +import pytest +import pytorch_lightning as pl +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.loggers import TensorBoardLogger +from test_models.conftest import make_dataloaders +from torch import nn + +from pytorch_forecasting.data.encoders import GroupNormalizer +from pytorch_forecasting.metrics import MAE, CrossEntropy, QuantileLoss +from pytorch_forecasting.models import RecurrentNetwork + + +def _integration( + data_with_covariates, tmp_path, gpus, cell_type="LSTM", data_loader_kwargs={}, clip_target: bool = False, **kwargs +): + if clip_target: + data_with_covariates["target"] = data_with_covariates["volume"].clip(1e-3, 1.0) + else: + data_with_covariates["target"] = data_with_covariates["volume"] + data_loader_default_kwargs = dict( + target="target", + time_varying_known_reals=["price_actual"], + time_varying_unknown_reals=["target"], + static_categoricals=["agency"], + add_relative_time_idx=True, + ) + data_loader_default_kwargs.update(data_loader_kwargs) + dataloaders_with_covariates = make_dataloaders(data_with_covariates, **data_loader_default_kwargs) + train_dataloader = dataloaders_with_covariates["train"] + val_dataloader = dataloaders_with_covariates["val"] + early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min") + + logger = TensorBoardLogger(tmp_path) + trainer = pl.Trainer( + max_epochs=3, + gpus=gpus, + weights_summary="top", + gradient_clip_val=0.1, + callbacks=[early_stop_callback], + checkpoint_callback=True, + default_root_dir=tmp_path, + limit_train_batches=2, + limit_val_batches=2, + logger=logger, + ) + + net = RecurrentNetwork.from_dataset( + train_dataloader.dataset, + cell_type=cell_type, + learning_rate=0.15, + log_gradient_flow=True, + log_interval=1000, + n_plotting_samples=100, + **kwargs + ) + net.size() + try: + trainer.fit( + net, + train_dataloader=train_dataloader, + val_dataloaders=val_dataloader, + ) + # check loading + net = RecurrentNetwork.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) + + # check prediction + net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True) + finally: + shutil.rmtree(tmp_path, ignore_errors=True) + + net.predict(val_dataloader, fast_dev_run=True, return_index=True, return_decoder_lengths=True) + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"cell_type": "GRU"}, + dict( + data_loader_kwargs=dict(target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False)), + ), + dict( + data_loader_kwargs=dict( + lags={"volume": [2, 5]}, target="volume", time_varying_unknown_reals=["volume"], min_encoder_length=10 + ) + ), + dict( + data_loader_kwargs=dict( + time_varying_unknown_reals=["volume", "discount"], + target=["volume", "discount"], + lags={"volume": [2], "discount": [2]}, + ) + ), + ], +) +def test_integration(data_with_covariates, tmp_path, gpus, kwargs): + _integration(data_with_covariates, tmp_path, gpus, **kwargs) + + +@pytest.fixture +def model(dataloaders_with_covariates): + dataset = dataloaders_with_covariates["train"].dataset + net = RecurrentNetwork.from_dataset( + dataset, + learning_rate=0.15, + log_gradient_flow=True, + log_interval=1000, + ) + return net + + +def test_pickle(model): + pkl = pickle.dumps(model) + pickle.loads(pkl)