From a1dd84fcd522ff5b59c34507d1db04718cd016c0 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Fri, 3 May 2024 11:34:38 +0200 Subject: [PATCH 1/8] add code for torch mqcnn --- src/gluonts/transform/split.py | 243 ++++++++++++++++++++++++++++ test/torch/model/test_estimators.py | 20 +++ test/torch/model/test_modules.py | 33 ++++ 3 files changed, 296 insertions(+) diff --git a/src/gluonts/transform/split.py b/src/gluonts/transform/split.py index f435945dce..99b15b1f2e 100644 --- a/src/gluonts/transform/split.py +++ b/src/gluonts/transform/split.py @@ -11,9 +11,11 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. +from collections import Counter from typing import Iterator, List, Optional, Tuple import numpy as np +from numpy.lib.stride_tricks import as_strided from pandas.tseries.offsets import BaseOffset from gluonts.core.component import validated @@ -573,3 +575,244 @@ def flatmap_transform( d[self.forecast_start_field] = d[self.start_field] + i + lt yield d + + +class ForkingSequenceSplitter(FlatMapTransformation): + """Forking sequence splitter used by MQ-CNN Model""" + + @validated() + def __init__( + self, + instance_sampler, + enc_len: int, + dec_len: int, + num_forking: Optional[int] = None, + target_field: str = FieldName.TARGET, + encoder_series_fields: Optional[List[str]] = None, + decoder_series_fields: Optional[List[str]] = None, + encoder_disabled_fields: Optional[List[str]] = None, + decoder_disabled_fields: Optional[List[str]] = None, + prediction_time_decoder_exclude: Optional[List[str]] = None, + is_pad_out: str = FieldName.IS_PAD, + start_input_field: str = FieldName.TARGET, + lead_time: int = 0, + ) -> None: + """Creates forking sequences + + Args: + instance_sampler ([type]): + sampler + enc_len (int): + length of the encoder + dec_len (int): + length of the decoder + num_forking (Optional[int], optional): + number of forked sequences to produce. + (default: enc_len if None) + Defaults to None. + target_field (str, optional): + name of the target field. + Defaults to FieldName.TARGET. + encoder_series_fields (Optional[List[str]], optional): + List of the encoder enabled fields. + Defaults to None. + decoder_series_fields (Optional[List[str]], optional): + List of the decoder enabled fields. + Defaults to None. + encoder_disabled_fields (Optional[List[str]], optional): + List of the encoder disabled fields. + Defaults to None. + decoder_disabled_fields (Optional[List[str]], optional): + List of the decoder disabled fields. + Defaults to None. + prediction_time_decoder_exclude (Optional[List[str]], optional): + List of fields that are not used at prediction time for the decoder + Defaults to None. + is_pad_out (str, optional): + encodder padding + Defaults to FieldName.IS_PAD. + start_input_field (str, optional): + start forecast field + Defaults to FieldName.START. + """ + super().__init__() + + assert enc_len > 0, "The value of `enc_len` should be > 0" + assert dec_len > 0, "The value of `dec_len` should be > 0" + + self.instance_sampler = instance_sampler + self.enc_len = enc_len + self.dec_len = dec_len + self.num_forking = ( + num_forking if num_forking is not None else self.enc_len + ) + self.target_field = target_field + + self.encoder_series_fields = ( + encoder_series_fields + [self.target_field] + if encoder_series_fields is not None + else [self.target_field] + ) + self.decoder_series_fields = ( + decoder_series_fields + [self.target_field] + if decoder_series_fields is not None + else [self.target_field] + ) + + self.encoder_disabled_fields = ( + encoder_disabled_fields + if encoder_disabled_fields is not None + else [] + ) + + self.decoder_disabled_fields = ( + decoder_disabled_fields + if decoder_disabled_fields is not None + else [] + ) + + # Fields that are not used at prediction time for the decoder + self.prediction_time_decoder_exclude = ( + prediction_time_decoder_exclude + [self.target_field] + if prediction_time_decoder_exclude is not None + else [self.target_field] + ) + + self.is_pad_out = is_pad_out + self.start_in = start_input_field + self.lead_time = lead_time + + def _past(self, col_name): + return f"past_{col_name}" + + def _future(self, col_name): + return f"future_{col_name}" + + def flatmap_transform( + self, data: DataEntry, is_train: bool + ) -> Iterator[DataEntry]: + target = data[self.target_field] + + if is_train: + # We currently cannot handle time series that are shorter than the + # prediction length during training, so we just skip these. + # If we want to include them we would need to pad and to mask + # the loss. + if len(target) < self.dec_len: + return + + sampling_indices = self.instance_sampler(target) + else: + sampling_indices = [len(target)] + + # Loops over all encoder and decoder fields even those that are disabled to + # set to dummy zero fields in those cases + ts_fields_counter = Counter( + set(self.encoder_series_fields + self.decoder_series_fields) + ) + + for sampling_idx in sampling_indices: + # irrelevant data should have been removed by now in the + # transformation chain, so copying everything is ok + out = data.copy() + + enc_len_diff = sampling_idx - self.enc_len + dec_len_diff = sampling_idx - self.num_forking + + # ensure start indices are not negative + start_idx_enc = max(0, enc_len_diff) + start_idx_dec = max(0, dec_len_diff) + + # Define pad length indices for shorter time series of variable length being updated in place + pad_length_enc = max(0, -enc_len_diff) + pad_length_dec = max(0, -dec_len_diff) + + # Define pad length for indices that extend into the future beyond the known data due + # to forecast date that is closer to the end of the time-series + pad_length_future_unknown = max( + 0, sampling_idx + self.dec_len - len(target) + ) + + for ts_field in ts_fields_counter.keys(): + + # target is 1d, this ensures ts is always 2d + ts = np.atleast_2d(out[ts_field]).T + ts_len = ts.shape[1] + + if ts_fields_counter[ts_field] == 1: + del out[ts_field] + else: + ts_fields_counter[ts_field] -= 1 + + out[self._past(ts_field)] = np.zeros( + shape=(self.enc_len, ts_len), dtype=ts.dtype + ) + if ts_field not in self.encoder_disabled_fields: + out[self._past(ts_field)][pad_length_enc:] = ts[ + start_idx_enc:sampling_idx, : + ] + + # exclude some fields at prediction time + if ( + not is_train + and ts_field in self.prediction_time_decoder_exclude + ): + continue + + if ts_field in self.decoder_series_fields: + + # Adding zeros to the end of all time-series to accommodate a forecast date, + # from which the decoder length may extend into the future beyond the known data + ts = np.concatenate( + [ + ts, + np.zeros( + shape=[pad_length_future_unknown, ts.shape[1]], + dtype=ts.dtype, + ), + ], + axis=0, + ) + + out[self._future(ts_field)] = np.zeros( + shape=(self.num_forking, self.dec_len, ts_len), + dtype=ts.dtype, + ) + if ts_field not in self.decoder_disabled_fields: + # This is where some of the forking magic happens: + # For each of the num_forking time-steps at which the decoder is applied we slice the + # corresponding inputs called decoder_fields to the appropriate dec_len + decoder_fields = ts[ + start_idx_dec + 1 : sampling_idx + 1, : + ] + # For default row-major arrays, strides = (dtype*n_cols, dtype). Since this array is transposed, + # it is stored in column-major (Fortran) ordering with strides = (dtype, dtype*n_rows) + stride = decoder_fields.strides + out[self._future(ts_field)][pad_length_dec:] = ( + as_strided( + decoder_fields, + shape=( + self.num_forking - pad_length_dec, + self.dec_len, + ts_len, + ), + # strides for 2D array expanded to 3D array of shape (dim1, dim2, dim3) = + # (1, n_rows, n_cols). For transposed data, strides = + # (dtype, dtype * dim1, dtype*dim1*dim2) = (dtype, dtype, dtype*n_rows). + strides=stride[0:1] + stride, + ) + ) + + # edge case for prediction_length = 1 + if out[self._future(ts_field)].shape[-1] == 1: + out[self._future(ts_field)] = np.squeeze( + out[self._future(ts_field)], axis=-1 + ) + + pad_indicator = np.zeros(self.enc_len) + pad_indicator[:pad_length_enc] = True + out[self._past(self.is_pad_out)] = pad_indicator + + out[FieldName.FORECAST_START] = out[self.start_in] + sampling_idx + + yield out diff --git a/test/torch/model/test_estimators.py b/test/torch/model/test_estimators.py index faea91b313..e894589911 100644 --- a/test/torch/model/test_estimators.py +++ b/test/torch/model/test_estimators.py @@ -39,6 +39,7 @@ from gluonts.torch.model.lag_tst import LagTSTEstimator from gluonts.torch.model.tft import TemporalFusionTransformerEstimator from gluonts.torch.model.wavenet import WaveNetEstimator +from gluonts.torch.model.mq_cnn import MQCNNEstimator from gluonts.torch.distributions import ImplicitQuantileNetworkOutput @@ -163,6 +164,14 @@ num_batches_per_epoch=3, trainer_kwargs=dict(max_epochs=2), ), + # lambda dataset: MQCNNEstimator( + # freq=dataset.metadata.freq, + # distr_output=QuantileOutput(quantiles=[0.1, 0.6, 0.85]), + # prediction_length=dataset.metadata.prediction_length, + # batch_size=4, + # num_batches_per_epoch=3, + # trainer_kwargs=dict(max_epochs=2), + # ), ], ) @pytest.mark.parametrize("use_validation_data", [False, True]) @@ -334,6 +343,17 @@ def test_estimator_constant_dataset( cardinality=[2, 2], trainer_kwargs=dict(max_epochs=2), ), + # lambda freq, prediction_length: MQCNNEstimator( + # freq=freq, + # prediction_length=prediction_length, + # batch_size=4, + # num_batches_per_epoch=3, + # num_feat_dynamic_real=3, + # num_feat_static_real=1, + # num_feat_static_cat=2, + # cardinality=[2, 2], + # trainer_kwargs=dict(max_epochs=2), + # ), ], ) def test_estimator_with_features(estimator_constructor): diff --git a/test/torch/model/test_modules.py b/test/torch/model/test_modules.py index 2ed728e4fb..5093055f4e 100644 --- a/test/torch/model/test_modules.py +++ b/test/torch/model/test_modules.py @@ -19,6 +19,7 @@ from gluonts.torch.model.mqf2 import MQF2MultiHorizonModel from gluonts.torch.model.simple_feedforward import SimpleFeedForwardModel from gluonts.torch.model.tft import TemporalFusionTransformerModel +from gluonts.torch.model.mq_cnn import MQCNNModel def assert_shapes_and_dtypes(tensors, shapes, dtypes): @@ -90,6 +91,38 @@ def assert_shapes_and_dtypes(tensors, shapes, dtypes): [[(4, 12, 5)], (4, 1), (4, 1)], [[torch.float], torch.float, torch.float], ), + ( + MQCNNModel( + context_length=24, + prediction_length=12, + num_forking=8, + distr_output=QuantileOutput([0.2, 0.25, 0.5, 0.9, 0.95]), + past_feat_dynamic_real_dim=4, + feat_dynamic_real_dim=2, + feat_static_real_dim=2, + cardinality_dynamic=[2], + cardinality_static=[2, 2], + scaling=False, + scaling_decoder_dynamic_feature=False, + embedding_dimension_dynamic=[2], + embedding_dimension_static=[2, 2], + encoder_cnn_init_dim=8, + dilation_seq=[1, 3, 9], + kernel_size_seq=[7, 3, 3], + channels_seq=[30, 30, 30], + joint_embedding_dimension=30, + encoder_mlp_init_dim=7, + encoder_mlp_dim_seq=[30], + use_residual=True, + decoder_mlp_dim_seq=[30], + decoder_hidden_dim=60, + decoder_future_init_dim=4, + decoder_future_embedding_dim=50, + ), + 4, + [[(4, 8, 12, 5)], (4, 1), (4, 1)], + [[torch.float], torch.float, torch.float], + ), ], ) def test_module_smoke(module, batch_size, expected_shapes, expected_dtypes): From e630224183ccdba417461003e4e0ebc1b747171a Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Fri, 3 May 2024 11:36:35 +0200 Subject: [PATCH 2/8] add missing files --- src/gluonts/torch/model/mq_cnn/__init__.py | 22 + src/gluonts/torch/model/mq_cnn/estimator.py | 788 ++++++++++++++++++ src/gluonts/torch/model/mq_cnn/layers.py | 488 +++++++++++ .../torch/model/mq_cnn/lightning_module.py | 177 ++++ src/gluonts/torch/model/mq_cnn/module.py | 490 +++++++++++ test/torch/model/test_mq_cnn.py | 155 ++++ 6 files changed, 2120 insertions(+) create mode 100644 src/gluonts/torch/model/mq_cnn/__init__.py create mode 100644 src/gluonts/torch/model/mq_cnn/estimator.py create mode 100644 src/gluonts/torch/model/mq_cnn/layers.py create mode 100644 src/gluonts/torch/model/mq_cnn/lightning_module.py create mode 100644 src/gluonts/torch/model/mq_cnn/module.py create mode 100644 test/torch/model/test_mq_cnn.py diff --git a/src/gluonts/torch/model/mq_cnn/__init__.py b/src/gluonts/torch/model/mq_cnn/__init__.py new file mode 100644 index 0000000000..69ff062990 --- /dev/null +++ b/src/gluonts/torch/model/mq_cnn/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from .module import MQCNNModel +from .lightning_module import MQCNNLightningModule +from .estimator import MQCNNEstimator + +__all__ = [ + "MQCNNModel", + "MQCNNLightningModule", + "MQCNNEstimator", +] diff --git a/src/gluonts/torch/model/mq_cnn/estimator.py b/src/gluonts/torch/model/mq_cnn/estimator.py new file mode 100644 index 0000000000..aea88d4675 --- /dev/null +++ b/src/gluonts/torch/model/mq_cnn/estimator.py @@ -0,0 +1,788 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from typing import Any, Optional, List, Dict, Iterable + +import numpy as np +import torch + +from gluonts.core.component import validated +from gluonts.dataset.common import Dataset +from gluonts.dataset.field_names import FieldName +from gluonts.dataset.loader import as_stacked_batches +from gluonts.itertools import Cyclic +from gluonts.torch.model.estimator import PyTorchLightningEstimator +from gluonts.torch.model.predictor import PyTorchPredictor +from gluonts.time_feature import time_features_from_frequency_str +from gluonts.torch.distributions import Output, QuantileOutput + +from gluonts.transform import ( + Chain, + RemoveFields, + Transformation, + AsNumpyArray, + VstackFeatures, + AddConstFeature, + AddAgeFeature, + AddTimeFeatures, + AddObservedValuesIndicator, + Chain, + RenameFields, + SetField, + ExpectedNumInstanceSampler, + TestSplitSampler, + ValidationSplitSampler, +) +from gluonts.transform.sampler import InstanceSampler +from gluonts.transform.split import ForkingSequenceSplitter + +from .lightning_module import MQCNNLightningModule + +PREDICTION_INPUT_NAMES = [ + "past_target", + "past_feat_dynamic", + "future_feat_dynamic", + "feat_static_real", + "feat_static_cat", + "past_observed_values", + "past_feat_dynamic_cat", + "future_feat_dynamic_cat", +] + +TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [ + "future_target", + "future_observed_values", +] + + +class MQCNNEstimator(PyTorchLightningEstimator): + """ + Args: + + freq (str): + Time granularity of the data. + prediction_length (int): + Length of the prediction, also known as 'horizon'. + context_length (int, optional): + Number of time units that condition the predictions, also known as 'lookback period'. + Defaults to None. + num_forking (int, optional): + Decides how much forking to do in the decoder. + (default: context_length if None) + Defaults to None. + quantiles (List[float], optional): + The list of quantiles that will be optimized for, and predicted by, the model. + (default: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] if None) + Defaults to None. + use_past_feat_dynamic_real (bool, optional): + Whether to use the ``past_feat_dynamic_real`` field from the data. + Defaults to False. + use_feat_dynamic_real (bool, optional): + Whether to use the ``feat_dynamic_real`` field from the data. + Defaults to False. + use_feat_dynamic_cat (bool, optional): + Whether to use the ``feat_dynamic_cat`` field from the data. + Defaults to False. + use_feat_static_real (bool, optional): + Whether to use the ``feat_static_real`` field from the data. + Defaults to False. + use_feat_static_cat (bool, optional): + Whether to use the ``feat_static_cat`` field from the data. + Defaults to False. + add_time_feature (bool, optional): + Adds a set of time features. + Defaults to True. + add_age_feature (bool, optional): + Adds an age feature. + The age feature starts with a small value at the start of the time series and grows over time. + Defaults to False. + enable_encoder_dynamic_feature (bool, optional): + Whether the encoder should be provided with the dynamic features (``age``, ``time`` + and ``feat_dynamic_real/cat`` if enabled respectively) + Defaults to True. + enable_decoder_dynamic_feature (bool, optional): + Whether the decoder should be provided with the dynamic features (``age``, ``time`` + and ``feat_dynamic_real/cat`` if enabled respectively). + Defaults to True. + feat_dynamic_real_dim (int, optional): + Dimension of real dynamic features. + Defaults to None. + past_feat_dynamic_real_dim (int, optional): + Dimension of past real dynamic features + Defaults to None. + cardinality_dynamic (List[int], optional): + Number of values of each dynamic categorical feature. + This must be set if ``use_feat_dynamic_cat == True`` + Defaults to None. + embedding_dimension_dynamic (List[int], optional): + Dimension of the embeddings for dynamic categorical features. + (default: [cat for cat in cardinality_dinamic] if None) + Defaults to None. + feat_static_real_dim (int, optional): + Dimension of real static features. + Defaults to None. + cardinality_static (List[int], optional): + Number of values of each static categorical feature. + This must be set if ``use_feat_static_cat == True`` + Defaults to None. + embedding_dimension_static (List[int], optional): + Dimension of the embeddings for categorical features. + (default: [min(50, (cat+1)//2) for cat in cardinality_static] if None) + Defaults to None. + scaling (bool, optional): + Whether to automatically scale the target values. + Defaults to None. + scaling_decoder_dynamic_feature (bool, optional): + Whether to automatically scale the dynamic features for the decoder. + Defaults to False. + joint_embedding_dimension (int, optional): + Dimension of the joint embedding for all static features (real and categorical) as the end of the encoder + (default: if None, channels_seq[-1] * sqrt(feat_static_dim)), + where feat_static_dim is appx sum(embedding_dimension_static)) + Defaults to None. + encoder_mlp_dim_seq (List[int], optional): + The dimensionalities of the MLP layers of the encoder for static features (default: [] if None) + Defaults to None. + decoder_mlp_dim_seq (List[int], optional): + The dimensionalities of the layers of the local MLP decoder. (default: [30] if None) + Defaults to None. + decoder_hidden_dim (int, optional): + Hidden dimension of the decoder used to produce horizon agnostic and horizon specific encodings of the input. + (default: 30 if None) + Defaults to None. + decoder_future_embedding_dim (int, optional): + Size of the embeddings used to globally encode future dynamic features. + (default: 50 if None) + Defaults to None. + channels_seq (List[int], optional): + The number of channels (i.e. filters or convolutions) for each layer of the HierarchicalCausalConv1DEncoder. + More channels usually correspond to better performance and larger network size. + (default: [30, 30, 30] if None) + Defaults to None. + dilation_seq (List[int], optional): + The dilation of the convolutions in each layer of the HierarchicalCausalConv1DEncoder. + Greater numbers correspond to a greater receptive field of the network, which is usually + better with longer context_length. (Same length as channels_seq) (default: [1, 3, 5] if None) + Defaults to None. + kernel_size_seq (List[int], optional): + The kernel sizes (i.e. window size) of the convolutions in each layer of the HierarchicalCausalConv1DEncoder. + (Same length as channels_seq) (default: [7, 3, 3] if None) + Defaults to None. + use_residual (bool, optional): + Whether the hierarchical encoder should additionally pass the unaltered + past target to the decoder. + Defaults to True. + batch_size (int, optional): + The size of the batches to be used training and prediction. + Defaults to 32. + val_batch_size(int, optional): + batch size for validation. + If None, will use the same batch size + Defaults to None. + lr (float, optional) + Learning rate, by default 1e-3 + learning_rate_decay_factor (float, optional): + Learning rate decay factor, by default 0.1. + minimum_learning_rate (float, optional): + Minimum learning rate, by default 1e-6. + clip_gradient (float, optional): + Clip gradient level, by default 10.0. + weight_decay (float, optional) + Weight decay, by default 1e-8 + patience (int, optional): + Patience applied to learning rate scheduling, by deafult 10. + num_batches_per_epoch (int, optional): + Number of batches to be processed in each training epoch, + by default 50 + trainer_kwargs (Dict, optional) + Additional arguments to provide to ``pl.Trainer`` for construction, + by default None + train_sampler (InstanceSampler, optional): + Controls the sampling of windows during training. + Defaults to None. + validation_sampler (InstanceSampler, optional): + Controls the sampling of windows during validation. + Defaults to None. + """ + + @validated() + def __init__( + self, + freq: str, + prediction_length: int, + context_length: Optional[int] = None, + num_forking: Optional[int] = None, + quantiles: Optional[List[float]] = None, + distr_output: Optional[Output] = None, + use_past_feat_dynamic_real: bool = False, + use_feat_dynamic_real: bool = False, + use_feat_dynamic_cat: bool = False, + use_feat_static_real: bool = False, + use_feat_static_cat: bool = False, + add_time_feature: bool = True, + add_age_feature: bool = False, + enable_encoder_dynamic_feature: bool = True, + enable_decoder_dynamic_feature: bool = True, + feat_dynamic_real_dim: Optional[int] = None, + past_feat_dynamic_real_dim: Optional[int] = None, + cardinality_dynamic: Optional[List[int]] = None, + embedding_dimension_dynamic: Optional[List[int]] = None, + feat_static_real_dim: Optional[int] = None, + cardinality_static: Optional[List[int]] = None, + embedding_dimension_static: Optional[List[int]] = None, + scaling: Optional[bool] = None, + scaling_decoder_dynamic_feature: bool = False, + joint_embedding_dimension: Optional[int] = None, + encoder_mlp_dim_seq: Optional[List[int]] = None, + decoder_mlp_dim_seq: Optional[List[int]] = None, + decoder_hidden_dim: Optional[int] = None, + decoder_future_embedding_dim: Optional[int] = None, + channels_seq: Optional[List[int]] = None, + dilation_seq: Optional[List[int]] = None, + kernel_size_seq: Optional[List[int]] = None, + use_residual: bool = True, + batch_size: int = 32, + val_batch_size: Optional[int] = None, + lr: float = 1e-3, + learning_rate_decay_factor: float = 0.1, + minimum_learning_rate: float = 1e-6, + clip_gradient: float = 10.0, + weight_decay: float = 1e-8, + patience: int = 10, + num_batches_per_epoch: int = 50, + trainer_kwargs: Dict[str, Any] = None, + train_sampler: Optional[InstanceSampler] = None, + validation_sampler: Optional[InstanceSampler] = None, + ) -> None: + + torch.set_default_tensor_type(torch.FloatTensor) + + default_trainer_kwargs = { + "max_epochs": 100, + "gradient_clip_val": clip_gradient, + } + if trainer_kwargs is not None: + default_trainer_kwargs.update(trainer_kwargs) + super().__init__(trainer_kwargs=default_trainer_kwargs) + + self.freq = freq + self.prediction_length = prediction_length + self.context_length = context_length + self.num_forking = ( + min(num_forking, self.context_length) + if num_forking is not None + else self.context_length + ) + + # Model architecture + if distr_output is not None and quantiles is not None: + raise ValueError( + "Only one of `distr_output` and `quantiles` must be specified" + ) + elif distr_output is not None: + self.distr_output = distr_output + else: + if quantiles is None: + quantiles = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] + self.distr_output = QuantileOutput(quantiles=quantiles) + + if time_features is None: + time_features = time_features_from_frequency_str(self.freq) + self.time_features = time_features + + self.use_past_feat_dynamic_real = use_past_feat_dynamic_real + self.use_feat_dynamic_real = use_feat_dynamic_real + self.use_feat_dynamic_cat = use_feat_dynamic_cat + self.use_feat_static_real = use_feat_static_real + self.use_feat_static_cat = use_feat_static_cat + + self.add_time_feature = add_time_feature + self.add_age_feature = add_age_feature + self.use_dynamic_feat = ( + use_feat_dynamic_real + or add_age_feature + or add_time_feature + or use_feat_dynamic_cat + ) + + self.enable_encoder_dynamic_feature = enable_encoder_dynamic_feature + self.enable_decoder_dynamic_feature = enable_decoder_dynamic_feature + + self.scaling = scaling if scaling is not None else False + self.scaling_decoder_dynamic_feature = scaling_decoder_dynamic_feature + + self.train_sampler = train_sampler or ExpectedNumInstanceSampler( + num_instances=1.0, min_future=prediction_length + ) + self.validation_sampler = validation_sampler or ValidationSplitSampler( + min_future=prediction_length + ) + + self.enc_cnn_init_dim = 3 # target, observed, const + self.dec_future_init_dim = 1 # observed + if add_time_feature: + self.enc_cnn_init_dim += 1 + self.dec_future_init_dim += 1 + if add_age_feature: + self.enc_cnn_init_dim += 1 + self.dec_future_init_dim += 1 + + # Training procedure + self.lr = lr + self.batch_size = batch_size + self.val_batch_size = ( + val_batch_size if val_batch_size is not None else batch_size + ) + self.learning_rate_decay_factor = learning_rate_decay_factor + self.minimum_learning_rate = minimum_learning_rate + self.weight_decay = weight_decay + self.patience = patience + self.num_batches_per_epoch = num_batches_per_epoch + + self.encoder_mlp_dim_seq = ( + encoder_mlp_dim_seq if encoder_mlp_dim_seq is not None else [] + ) + self.decoder_mlp_dim_seq = ( + decoder_mlp_dim_seq if decoder_mlp_dim_seq is not None else [30] + ) + self.decoder_hidden_dim = ( + decoder_hidden_dim if decoder_hidden_dim is not None else 30 + ) + self.decoder_future_embedding_dim = ( + decoder_future_embedding_dim + if decoder_future_embedding_dim is not None + else 50 + ) + self.channels_seq = ( + channels_seq if channels_seq is not None else [30, 30, 30] + ) + self.dilation_seq = ( + dilation_seq if dilation_seq is not None else [1, 3, 9] + ) + self.kernel_size_seq = ( + kernel_size_seq if kernel_size_seq is not None else [7, 3, 3] + ) + + assert ( + len(channels_seq) == len(dilation_seq) == len(kernel_size_seq) + ), ( + f"mismatch CNN configurations: {len(channels_seq)} vs. " + f"{len(dilation_seq)} vs. {len(kernel_size_seq)}" + ) + + self.use_residual = use_residual + + if self.use_feat_dynamic_cat: + self.cardinality_dynamic = cardinality_dynamic + self.embedding_dimension_dynamic = ( + embedding_dimension_dynamic + if embedding_dimension_dynamic is not None + else [cat for cat in cardinality_dynamic] + ) + + self.enc_cnn_init_dim += sum(self.embedding_dimension_dynamic) + self.dec_future_init_dim += sum(self.embedding_dimension_dynamic) + else: + self.cardinality_dynamic = 0 + + if self.use_past_feat_dynamic_real: + assert ( + past_feat_dynamic_real_dim is not None + ), "past_feat_dynamic_real should be provided" + self.enc_cnn_init_dim += past_feat_dynamic_real_dim + self.past_feat_dynamic_real_dim = past_feat_dynamic_real_dim + else: + self.past_feat_dynamic_real_dim = 0 + + if self.use_feat_dynamic_real: + assert ( + feat_dynamic_real_dim is not None + ), "dim_feat_dynamic_real should be provided" + self.enc_cnn_init_dim += feat_dynamic_real_dim + self.dec_future_init_dim += feat_dynamic_real_dim + self.feat_dynamic_real_dim = feat_dynamic_real_dim + else: + self.feat_dynamic_real_dim = 0 + + self.enc_mlp_init_dim = 1 # start with 1 because of scaler + if self.use_feat_static_cat: + self.cardinality_static = cardinality_static + self.embedding_dimension_static = ( + embedding_dimension_static + if embedding_dimension_static is not None + else [min(50, (cat + 1) // 2) for cat in cardinality_static] + ) + self.enc_mlp_init_dim += sum(self.embedding_dimension_static) + else: + self.cardinality_static = 0 + self.embedding_dimension_static = 0 + + self.joint_embedding_dimension = joint_embedding_dimension + if self.joint_embedding_dimension is None: + feat_static_dim = sum(self.embedding_dimension_static) + self.joint_embedding_dimension = int( + channels_seq[-1] * max(np.sqrt(feat_static_dim), 1) + ) + + if self.use_feat_static_real: + assert ( + feat_static_real_dim is not None + ), "feat_static_real should be provided" + self.enc_mlp_init_dim += feat_static_real_dim + self.feat_static_real_dim = feat_static_real_dim + else: + self.feat_static_real_dim = 0 + + def create_transformation(self) -> Chain: + """Creates transformation to be applied to input dataset + + Returns: + Chain: + transformation chain to be applied to the input data + """ + + dynamic_feat_fields = [] + remove_field_names = [] + + if not self.use_past_feat_dynamic_real: + remove_field_names.append(FieldName.PAST_FEAT_DYNAMIC_REAL) + if not self.use_feat_dynamic_real: + remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL) + if not self.use_feat_dynamic_cat: + remove_field_names.append(FieldName.FEAT_DYNAMIC_CAT) + if not self.use_feat_static_real: + remove_field_names.append(FieldName.FEAT_STATIC_REAL) + if not self.use_feat_static_cat: + remove_field_names.append(FieldName.FEAT_STATIC_CAT) + + transforms = [ + RemoveFields(field_names=remove_field_names), + AsNumpyArray(field=FieldName.TARGET, expected_ndim=1), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ), + ] + + if len(self.time_features) > 0: + transforms.append( + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=self.time_features, + pred_length=self.prediction_length, + ) + ) + dynamic_feat_fields.append(FieldName.FEAT_TIME) + + if self.add_age_feature: + transforms.append( + AddAgeFeature( + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_AGE, + pred_length=self.prediction_length, + ) + ) + dynamic_feat_fields.append(FieldName.FEAT_AGE) + + if self.use_feat_dynamic_real: + # Backwards compatibility: + transforms.append( + RenameFields({"dynamic_feat": FieldName.FEAT_DYNAMIC_REAL}) + ) + dynamic_feat_fields.append(FieldName.FEAT_DYNAMIC_REAL) + + # we need to make sure that there is always some dynamic input + # we will however disregard it in the hybrid forward. + # the time feature is empty for yearly freq so also adding a dummy feature + # in the case that the time feature is the only one on + if len(dynamic_feat_fields) == 0 or ( + not self.add_age_feature + and not self.add_time_feature + and not self.use_feat_dynamic_real + ): + transforms.append( + AddConstFeature( + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_CONST, + pred_length=self.prediction_length, + const=0.0, + ) + ) + dynamic_feat_fields.append(FieldName.FEAT_CONST) + + # now we map all the dynamic input of length context_length + prediction_length onto FieldName.FEAT_DYNAMIC + # we exclude past_feat_dynamic_real since its length is only context_length + if len(dynamic_feat_fields) > 1: + transforms.append( + VstackFeatures( + output_field=FieldName.FEAT_DYNAMIC, + input_fields=dynamic_feat_fields, + ) + ) + elif len(dynamic_feat_fields) == 1: + transforms.append( + RenameFields({dynamic_feat_fields[0]: FieldName.FEAT_DYNAMIC}) + ) + + if not self.use_feat_dynamic_cat: + transforms.append( + AddConstFeature( + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_DYNAMIC_CAT, + pred_length=self.prediction_length, + const=0, + ) + ) + + if not self.use_feat_static_cat: + transforms.append( + SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0]) + ) + transforms.append( + AsNumpyArray( + field=FieldName.FEAT_STATIC_CAT, + expected_ndim=1, + dtype=np.int64, + ) + ) + + if not self.use_feat_static_real: + transforms.append( + SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0]) + ) + transforms.append( + AsNumpyArray(field=FieldName.FEAT_STATIC_REAL, expected_ndim=1) + ) + + return Chain(transforms) + + def _create_instance_splitter(self, mode: str) -> Chain: + """Creates instance splitter to be applied to the dataset + + Args: + mode (str): `training`, `validation` or `test` + + Returns: + Chain: + transformation chain to split input data along the time + dimension before processing + """ + + assert mode in ["training", "validation", "test"] + + instance_sampler = { + "training": self.train_sampler, + "validation": self.validation_sampler, + "test": TestSplitSampler(), + }[mode] + + chain = [] + + encoder_series_fields = [ + FieldName.OBSERVED_VALUES, + FieldName.FEAT_DYNAMIC, + FieldName.FEAT_DYNAMIC_CAT, + ] + ( + [FieldName.PAST_FEAT_DYNAMIC_REAL] + if self.use_past_feat_dynamic_real + else [] + ) + + encoder_disabled_fields = ( + [FieldName.FEAT_DYNAMIC, FieldName.FEAT_DYNAMIC_CAT] + if not self.enable_encoder_dynamic_feature + else [] + ) + ( + [FieldName.PAST_FEAT_DYNAMIC_REAL] + if not self.enable_encoder_dynamic_feature + and self.use_past_feat_dynamic_real + else [] + ) + + decoder_series_fields = ( + [ + FieldName.FEAT_DYNAMIC, + FieldName.FEAT_DYNAMIC_CAT, + ] + + ([FieldName.OBSERVED_VALUES] if mode != "test" else []), + ) + + decoder_disabled_fields = ( + [FieldName.FEAT_DYNAMIC, FieldName.FEAT_DYNAMIC_CAT] + if not self.enable_decoder_dynamic_feature + else [] + ) + + chain.append( + # because of how the forking decoder works, every time step + # in context is used for splitting, which is why we use the TestSplitSampler + ForkingSequenceSplitter( + instance_sampler=instance_sampler, + enc_len=self.context_length, + dec_len=self.prediction_length, + num_forking=self.num_forking, + target_field=FieldName.TARGET, + encoder_series_fields=encoder_series_fields, + encoder_disabled_fields=encoder_disabled_fields, + decoder_series_fields=decoder_series_fields, + decoder_disabled_fields=decoder_disabled_fields, + prediction_time_decoder_exclude=[FieldName.OBSERVED_VALUES], + is_pad_out=FieldName.IS_PAD, + start_input_field=FieldName.START, + ) + ) + + # past_feat_dynamic features generated above in ForkingSequenceSplitter from those under feat_dynamic - we need + # to stack with the other short related time series from the system labeled as past_past_feat_dynamic_real. + # The system labels them as past_feat_dynamic_real and the additional past_ is added to the string + # in the ForkingSequenceSplitter + if self.use_past_feat_dynamic_real: + # Stack features from ForkingSequenceSplitter horizontally since they were transposed + # so shape is now (enc_len, num_past_feature_dynamic) + chain.append( + VstackFeatures( + output_field=FieldName.PAST_FEAT_DYNAMIC, + input_fields=[ + "past_" + FieldName.PAST_FEAT_DYNAMIC_REAL, + FieldName.PAST_FEAT_DYNAMIC, + ], + h_stack=True, + ) + ) + + return Chain(chain) + + def create_training_data_loader( + self, + data: Dataset, + module: MQCNNLightningModule, + shuffle_buffer_length: Optional[int] = None, + **kwargs, + ) -> Iterable: + """Creates data loader for the training dataset + + Args: + data (Dataset): training dataset + + Returns: + DataLoader: training data loader + """ + train_transformation = ( + self.create_transformation() + + self._create_instance_splitter("training") + ) + + data = Cyclic(data).stream() + transformed_data = train_transformation.apply(data) + return as_stacked_batches( + transformed_data, + batch_size=self.batch_size, + shuffle_buffer_length=shuffle_buffer_length, + field_names=TRAINING_INPUT_NAMES, + output_type=torch.tensor, + num_batches_per_epoch=self.num_batches_per_epoch, + ) + + def create_validation_data_loader( + self, + data: Dataset, + module: MQCNNLightningModule, + **kwargs, + ) -> Iterable: + """Creates data loader for the validation dataset + + Args: + data (Dataset): validation dataset + + Returns: + DataLoader: validation data loader + """ + + train_transformation = ( + self.create_transformation() + + self._create_instance_splitter("validation") + ) + + transformed_data = train_transformation.apply(data) + return as_stacked_batches( + transformed_data, + batch_size=self.val_batch_size, + field_names=TRAINING_INPUT_NAMES, + output_type=torch.tensor, + ) + + def create_lightning_module(self) -> MQCNNLightningModule: + return MQCNNLightningModule( + lr=self.lr, + learning_rate_decay_factor=self.learning_rate_decay_factor, + minimum_learning_rate=self.minimum_learning_rate, + weight_decay=self.weight_decay, + patience=self.patience, + model_kwargs={ + "context_length": self.context_length, + "prediction_length": self.prediction_length, + "num_forking": self.num_forking, + "past_feat_dynamic_real_dim": self.past_feat_dynamic_real_dim, + "feat_dynamic_real_dim": self.feat_dynamic_real_dim, + "cardinality_dynamic": self.cardinality_dynamic, + "embedding_dimension_dynamic": self.embedding_dimension_dynamic, + "feat_static_real_dim": self.feat_static_real_dim, + "cardinality_static": self.cardinality_static, + "embedding_dimension_static": self.embedding_dimension_static, + "scaling": self.scaling, + "scaling_decoder_dynamic_feature": self.scaling_decoder_dynamic_feature, + "encoder_cnn_init_dim": self.enc_cnn_init_dim, + "dilation_seq": self.dilation_seq, + "kernel_size_seq": self.kernel_size_seq, + "channels_seq": self.channels_seq, + "joint_embedding_dimension": self.joint_embedding_dimension, + "encoder_mlp_init_dim": self.enc_mlp_init_dim, + "encoder_mlp_dim_seq": self.encoder_mlp_dim_seq, + "use_residual": self.use_residual, + "decoder_mlp_dim_seq": self.decoder_mlp_dim_seq, + "decoder_hidden_dim": self.decoder_hidden_dim, + "decoder_future_init_dim": self.dec_future_init_dim, + "decoder_future_embedding_dim": self.decoder_future_embedding_dim, + "distr_output": self.distr_output, + }, + ) + + def create_predictor( + self, + transformation: Transformation, + module: MQCNNLightningModule, + ) -> PyTorchPredictor: + """Creates predictor for inference + + Args: + transformation (Transformation): transformation to be applied to data input to predictor + trained_network (MQCnnModel): trained network + + Returns: + Predictor: + """ + + # For inference, the transformation is needed to add "future" parts of the time-series features + prediction_splitter = self._create_instance_splitter("test") + + return PyTorchPredictor( + input_transform=transformation + prediction_splitter, + input_names=PREDICTION_INPUT_NAMES, + prediction_net=module, + batch_size=self.batch_size, + prediction_length=self.prediction_length, + device="auto", + forecast_generator=self.distr_output.forecast_generator, + ) diff --git a/src/gluonts/torch/model/mq_cnn/layers.py b/src/gluonts/torch/model/mq_cnn/layers.py new file mode 100644 index 0000000000..d4351baa79 --- /dev/null +++ b/src/gluonts/torch/model/mq_cnn/layers.py @@ -0,0 +1,488 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from typing import List, Tuple, Union + +import torch +from torch import nn, Tensor + +from gluonts.core.component import validated +from gluonts.torch.modules.lambda_layer import LambdaLayer + + +def _get_int(a: Union[int, List[int], Tuple[int]]) -> int: + if isinstance(a, (list, tuple)): + assert len(a) == 1 + return a[0] + return a + + +class CausalConv1D(nn.Module): + """ + 1D causal temporal convolution, where the term causal means that output[t] + does not depend on input[t+1:]. Notice that Conv1D is not implemented in + Gluon. + + This is the basic structure used in Wavenet [ODZ+16]_ and Temporal + Convolution Network [BKK18]_. + + The output has the same shape as the input, while we always left-pad zeros. + + Parameters + ---------- + + layer_no + layer number + init_dim + input dimension into the layer + channels + The dimensionality of the output space, i.e. the number of output + channels (filters) in the convolution. + kernel_size + Specifies the dimensions of the convolution window. + dilation + Specifies the dilation rate to use for dilated convolution. + """ + + def __init__( + self, + layer_no: int, + init_dim: int, + channels: int, + kernel_size: Union[int, Tuple[int], List[int]], + dilation: Union[int, Tuple[int], List[int]] = 1, + **kwargs, + ): + super().__init__() + + self.dilation = _get_int(dilation) + self.kernel_size = _get_int(kernel_size) + self.padding = self.dilation * (self.kernel_size - 1) + self.conv1d = nn.Sequential( + nn.Conv1d( + in_channels=init_dim if layer_no == 0 else channels, + out_channels=channels, + kernel_size=self.kernel_size, + dilation=self.dilation, + padding=self.padding, + ), + nn.ReLU(), + ) + + def forward(self, data: Tensor) -> Tensor: + """ + In Gluon's conv1D implementation, input has dimension NCW where N is + batch_size, C is channel, and W is time (sequence_length). + + Parameters + ---------- + data + Shape (batch_size, num_features, sequence_length) + + Returns + ------- + Tensor + causal conv1d output. Shape (batch_size, num_features, + sequence_length) + """ + ct = self.conv1d(data) + if self.kernel_size > 0: + ct = ct[:, :, : ct.shape[2] - self.padding] + return ct + + +class HierarchicalCausalConv1DEncoder(nn.Module): + """ + Defines a stack of dilated convolutions as the encoder. + See the following paper for details: + 1. Van Den Oord, A., Dieleman, S., Zen, H., Simonyan, K., Vinyals, O., Graves, A., Kalchbrenner, + N., Senior, A.W. and Kavukcuoglu, K., 2016, September. WaveNet: A generative model for raw audio. In SSW (p. 125). + Parameters + ---------- + cnn_init_dim + input dimension into CNN encoder + dilation_seq + dilation for each convolution in the stack. + kernel_size_seq + kernel size for each convolution in the stack. + channels_seq + number of channels for each convolution in the stack. + joint_embedding_dimension (int): + final dimension to embed all static features + hidden_dimension_sequence (List[int], optional): + list of hidden dimensions for the MLP used to embed static features. Defaults to []. + use_residual + flag to toggle using residual connections. + use_static_feat + flag to toggle whether to use use_static_feat as input to the encoder + use_dynamic_feat + flag to toggle whether to use use_dynamic_feat as input to the encoder + """ + + @validated() + def __init__( + self, + cnn_init_dim: int, + dilation_seq: List[int], + kernel_size_seq: List[int], + channels_seq: List[int], + joint_embedding_dimension: int, + mlp_init_dim: int, + hidden_dimension_sequence: List[int] = [], + use_residual: bool = False, + use_static_feat: bool = False, + use_dynamic_feat: bool = False, + **kwargs, + ) -> None: + + assert all( + [x > 0 for x in dilation_seq] + ), "`dilation_seq` values must be greater than zero" + assert all( + [x > 0 for x in kernel_size_seq] + ), "`kernel_size_seq` values must be greater than zero" + assert all( + [x > 0 for x in channels_seq] + ), "`channel_dim_seq` values must be greater than zero" + + super().__init__() + + self.use_residual = use_residual + self.use_static_feat = use_static_feat + self.use_dynamic_feat = use_dynamic_feat + + # CNN for dynamic features (and/or target) + self.cnn = nn.Sequential() + + # swap axes because Conv1D expects NCT + self.cnn.append(LambdaLayer(lambda x: torch.transpose(x, 2, 1))) + + it = zip(channels_seq, kernel_size_seq, dilation_seq) + for layer_no, (channels, kernel_size, dilation) in enumerate(it): + + convolution = CausalConv1D( + layer_no=layer_no, + init_dim=cnn_init_dim, + channels=channels, + kernel_size=kernel_size, + dilation=dilation, + ) + self.cnn.append(convolution) + + # swap axes to get back to NTC + self.cnn.append(LambdaLayer(lambda x: torch.transpose(x, 2, 1))) + + # MLP for static features + modules: List[nn.Module] = [] + mlp_dimension_sequence = ( + [mlp_init_dim] + + hidden_dimension_sequence + + [joint_embedding_dimension] + ) + if use_static_feat: + + for in_features, out_features in zip( + mlp_dimension_sequence[:-1], mlp_dimension_sequence[1:] + ): + layer = nn.Linear( + in_features, + out_features, + ) + modules += [layer, nn.ReLU()] + + self.static = nn.Sequential(*modules) + + def forward( + self, + target: Tensor, + static_features: Tensor, + dynamic_features: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Parameters + ---------- + target + target time series, + shape (batch_size, sequence_length, 1) + static_features + static features, + shape (batch_size, num_feat_static) + dynamic_features + dynamic_features, + shape (batch_size, sequence_length, num_feat_dynamic) + Returns + ------- + Tensor + static code, + shape (batch_size, channel_seqs + (1) if use_residual) + Tensor + dynamic code, + shape (batch_size, sequence_length, channel_seqs + (1) if use_residual) + """ + + if self.use_dynamic_feat: + dynamic_inputs = torch.cat( + (target, dynamic_features), dim=2 + ) # (N, T, C) + else: + dynamic_inputs = target + + dynamic_encoded = self.cnn(dynamic_inputs.float()) + + if self.use_residual: + dynamic_encoded = torch.cat((dynamic_encoded, target), dim=2) + + if self.use_static_feat: + static_encoded = self.static(static_features) + else: + static_encoded = None + + # we return them separately so that static features can be replicated + return static_encoded, dynamic_encoded + + +class Enc2Dec(nn.Module): + """ + Integrates the encoder_output_static, encoder_output_dynamic and future_features_dynamic + and passes them through as the dynamic input to the decoder. + + Parameters: + ------------ + num_forking [int]: + number of forks + """ + + @validated() + def __init__( + self, + num_forking: int, + **kwargs, + ) -> None: + + super().__init__() + self.num_forking = num_forking + + def forward( + self, + encoder_output_static: torch.Tensor, + encoder_output_dynamic: torch.Tensor, + ) -> torch.Tensor: + """ + Parameters + ---------- + encoder_output_static + shape (batch_size, num_feat_static) or (N, C) + encoder_output_dynamic + shape (batch_size, sequence_length, channels_seq[-1] + 1) or (N, T, C) + Returns + ------- + Tensor + shape (batch_size, sequence_length, channels_seq[-1] + 1 + num_feat_static) or (N, C) + """ + + encoder_output_static = torch.unsqueeze(encoder_output_static, dim=1) + encoder_output_static_expanded = torch.repeat_interleave( + encoder_output_static, repeats=self.num_forking, dim=1 + ) + + # concatenate static and dynamic output of the encoder + # => (batch_size, sequence_length, num_enc_output_dynamic + num_enc_output_static) + encoder_output = torch.cat( + (encoder_output_dynamic, encoder_output_static_expanded), dim=2 + ) + + return encoder_output + + +class ForkingMLPDecoder(nn.Module): + """ + Multilayer perceptron decoder for sequence-to-sequence models. + See [WTN+17]_ for details. + Parameters + ---------- + dec_len + length of the decoder (usually the number of forecasted time steps). + encoded_input_dim (int): + input dimension out of encoder + local_mlp_final_dim (int): + final dimension of the local mlp (output of the decoder before quantile output layer) + global_mlp_final_dim (int): + final dimension of the horizon agnostic part of the global mlp. Note that horizon specific part will use //2 dimension. + future_feat_embedding_dim (int): + dimension of the embedding layer for global encoding of future features. + local_mlp_hidden_dim_sequence (List[int], optional): + dimensions of local mlp hidden layers. Defaults to []. + """ + + @validated() + def __init__( + self, + dec_len: int, + encoded_input_dim: int, + local_mlp_final_dim: int, + global_mlp_final_dim: int, + future_feat_init_dim: int, + future_feat_embedding_dim: int, + local_mlp_hidden_dim_sequence: List[int] = [], + **kwargs, + ) -> None: + + super().__init__() + + self.dec_len = dec_len + + # Global embeddings for future dynamic features [N, T, C], + # where N - batch_size, T - number of forks, + # C - joint embedding dimension for future features + self.global_future_layer = self._get_global_future_layer( + future_feat_init_dim, future_feat_embedding_dim + ) + + # Local embeddings for future dynamic features [N, T, K, C] + # where N - batch_size, T - number of forks, K - number of horizons (dec_len), + # C - number of future dynamic features per horizon (same dimensions as input to decoder) + self.local_future_layer = self._get_local_future_layer() + + # Horizon specific global MLP outputs [N, T, K, C], + # C - number of outputs per horizon (global mlp final dimension // 2) + input_dim = encoded_input_dim + future_feat_embedding_dim + horizon_specific_dim = global_mlp_final_dim // 2 + self.horizon_specific = self._get_horizon_specific( + input_dim, horizon_specific_dim + ) + + # Horizon agnostic global MLP outputs [N, T, K, C], + # C - number of identical outputs per horizon (global mlp final dimension) + self.horizon_agnostic = self._get_horizon_agnostic( + input_dim, global_mlp_final_dim + ) + + # Local MLP outputs [N, T, K, C], + # C - number of outputs per horizon (local mlp final dimension) + local_mlp_init_dim = ( + horizon_specific_dim + global_mlp_final_dim + future_feat_init_dim + ) + self.local_mlp = self._get_local_mlp( + local_mlp_init_dim, + local_mlp_final_dim, + local_mlp_hidden_dim_sequence, + ) + + def _get_global_future_layer(self, input_size, embedding_dim): + layer = nn.Sequential() + layer.append( + LambdaLayer( + lambda x: torch.reshape(x, (x.shape[0], x.shape[1], -1)) + ) + ) ## [N, T, K, C] where T is number of forks, K number of horizons (dec_len) + + layer.append(nn.Linear(input_size * self.dec_len, embedding_dim)) + layer.append(nn.Tanh()) # [N, T, embedding_dim] + return layer + + def _get_local_future_layer(self): + layer = nn.Sequential() + layer.append(nn.Tanh()) ##[N, T, K, C] + return layer + + def _get_horizon_specific(self, input_size, units_per_horizon): + mlp = nn.Sequential() + mlp.append(nn.Linear(input_size, self.dec_len * units_per_horizon)) + mlp.append(nn.ReLU()) + mlp.append( + LambdaLayer( + lambda x: torch.reshape( + x, (x.shape[0], x.shape[1], self.dec_len, -1) + ) + ) + ) + return mlp + + def _get_horizon_agnostic(self, input_size, hidden_size): + mlp = nn.Sequential() + mlp.append(nn.Linear(input_size, hidden_size)) + mlp.append(nn.ReLU()) + mlp.append(LambdaLayer(lambda x: torch.unsqueeze(x, dim=2))) + mlp.append( + LambdaLayer( + lambda x: torch.repeat_interleave( + x, repeats=self.dec_len, dim=2 + ) + ) + ) + return mlp + + def _get_local_mlp(self, init_dim, final_dim, hidden_dimension_seq): + modules: List[nn.Module] = [] + dimensions = [init_dim] + hidden_dimension_seq + + for in_features, out_features in zip(dimensions[:-1], dimensions[1:]): + layer = nn.Linear( + in_features, + out_features, + ) + modules += [layer, nn.ReLU()] + + modules += [nn.Linear(dimensions[-1], final_dim), nn.Softplus()] + + local_mlp = nn.Sequential(*modules) + return local_mlp + + def forward(self, encoded_input: Tensor, future_input: Tensor) -> Tensor: + """Forward pass for MQCNN decoder + + Args: + encoded_input (Tensor): + decoder input from the output of the MQCNN encoder, including static and past dynamic feature encoding + future_input (Tensor): + decoder input from future dynamic features + + Returns: + output of the decoder (MQCNN) + """ + + # Embed future features globally + global_future_embedded = self.global_future_layer(future_input.float()) + + # Encode future features locally for each timestep/feature + local_future_encoded = self.local_future_layer(future_input) + + # Combine encoded historical dynamic and static features and globally embedded future features + encoded_input_and_future = torch.cat( + (encoded_input, global_future_embedded), dim=-1 + ) + + # Produce horizon specific encoding (c_t for K horizons in the paper) + horizon_specific_encoded = self.horizon_specific( + encoded_input_and_future + ) + + # Produce horizon agnostic encoding (c_a in the paper) + horizon_agnostic_encoded = self.horizon_agnostic( + encoded_input_and_future + ) + + # Combine horizon agnostic, horizon specific and future local encodings + encoded = torch.cat( + ( + horizon_specific_encoded, + horizon_agnostic_encoded, + local_future_encoded, + ), + dim=-1, + ) + + # Train local mlp on each fork and horizon (weights are shared) + output = self.local_mlp(encoded.float()) + + return output diff --git a/src/gluonts/torch/model/mq_cnn/lightning_module.py b/src/gluonts/torch/model/mq_cnn/lightning_module.py new file mode 100644 index 0000000000..687bc73e7a --- /dev/null +++ b/src/gluonts/torch/model/mq_cnn/lightning_module.py @@ -0,0 +1,177 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from pytorch_lightning import LightningModule +import torch +from torch.optim.lr_scheduler import ReduceLROnPlateau +from gluonts.torch.model.lightning_util import has_validation_loop + +from gluonts.core.component import validated + +from .module import MQCNNModel + + +class MQCNNLightningModule(LightningModule): + """ + A ``pl.LightningModule`` class that can be used to train a + ``MQCNNModel`` with PyTorch Lightning. + This is a thin layer around a (wrapped) ``MQCNNModel`` object, + that exposes the methods to evaluate training and validation loss. + Parameters + ---------- + model_kwargs + Keyword arguments for the ``MQCNNModel`` object. + lr + Learning rate. + learning_rate_decay_factor + Learning rate decay factor. + minimum_learning_rate + Minimum learning rate. + weight_decay + Weight decay regularization parameter. + patience + Patience parameter for learning rate scheduler. + """ + + @validated() + def __init__( + self, + model_kwargs: dict, + lr: float = 1e-3, + learning_rate_decay_factor: float = 0.9, + minimum_learning_rate: float = 1e-6, + weight_decay: float = 1e-8, + patience: int = 10, + ) -> None: + super().__init__() + self.save_hyperparameters() + self.model = MQCNNModel(**model_kwargs) + self.lr = lr + self.learning_rate_decay_factor = learning_rate_decay_factor + self.minimum_learning_rate = minimum_learning_rate + self.weight_decay = weight_decay + self.patience = patience + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def training_step(self, batch, batch_idx: int): # type: ignore + """ + Execute training step. + """ + past_target = batch["past_target"] + future_target = batch["future_target"] + past_feat_dynamic = batch["past_feat_dynamic"] + future_feat_dynamic = batch["future_feat_dynamic"] + feat_static_real = batch["feat_static_real"] + feat_static_cat = batch["feat_static_cat"] + past_observed_values = batch["past_observed_values"] + future_observed_values = batch["future_observed_values"] + past_feat_dynamic_cat = batch["past_feat_dynamic_cat"] + future_feat_dynamic_cat = batch["future_feat_dynamic_cat"] + + loss = self.model.loss( + past_target=past_target, + future_target=future_target, + past_feat_dynamic=past_feat_dynamic, + future_feat_dynamic=future_feat_dynamic, + feat_static_real=feat_static_real, + feat_static_cat=feat_static_cat, + past_observed_values=past_observed_values, + future_observed_values=future_observed_values, + past_feat_dynamic_cat=past_feat_dynamic_cat, + future_feat_dynamic_cat=future_feat_dynamic_cat, + ) + + # Log every step and epoch, synchronize every epoch + train_loss = loss.mean() + self.log( + "train/loss", + train_loss, + on_epoch=True, + on_step=True, + prog_bar=True, + sync_dist=True, + logger=True, + ) + + return train_loss + + def validation_step(self, batch, batch_idx: int): # type: ignore + """ + Execute validation step. + """ + past_target = batch["past_target"] + future_target = batch["future_target"] + past_feat_dynamic = batch["past_feat_dynamic"] + future_feat_dynamic = batch["future_feat_dynamic"] + feat_static_real = batch["feat_static_real"] + feat_static_cat = batch["feat_static_cat"] + past_observed_values = batch["past_observed_values"] + future_observed_values = batch["future_observed_values"] + past_feat_dynamic_cat = batch["past_feat_dynamic_cat"] + future_feat_dynamic_cat = batch["future_feat_dynamic_cat"] + + loss = self.model.loss( + past_target=past_target, + future_target=future_target, + past_feat_dynamic=past_feat_dynamic, + future_feat_dynamic=future_feat_dynamic, + feat_static_real=feat_static_real, + feat_static_cat=feat_static_cat, + past_observed_values=past_observed_values, + future_observed_values=future_observed_values, + past_feat_dynamic_cat=past_feat_dynamic_cat, + future_feat_dynamic_cat=future_feat_dynamic_cat, + ) + + # Log and synchronize every epoch + val_loss = loss.mean() + self.log( + "val/loss", + val_loss, + on_epoch=True, + on_step=False, + prog_bar=True, + sync_dist=True, + logger=True, + ) + + return val_loss + + def configure_optimizers(self): + """ + Returns the optimizer to use. + """ + optimizer = torch.optim.Adam( + self.model.parameters(), + lr=self.lr, + weight_decay=self.weight_decay, + ) + monitor = ( + "val/loss" if has_validation_loop(self.trainer) else "train/loss" + ) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": ReduceLROnPlateau( + optimizer=optimizer, + mode="min", + factor=self.learning_rate_decay_factor, + patience=self.patience, + min_lr=self.minimum_learning_rate, + ), + "monitor": monitor, + }, + } diff --git a/src/gluonts/torch/model/mq_cnn/module.py b/src/gluonts/torch/model/mq_cnn/module.py new file mode 100644 index 0000000000..ab9f6ed69e --- /dev/null +++ b/src/gluonts/torch/model/mq_cnn/module.py @@ -0,0 +1,490 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from typing import Optional, Tuple, List + +import torch +from torch import nn +from gluonts.core.component import validated +from gluonts.torch.distributions import Output +from gluonts.model import Input, InputSpec + +from gluonts.torch.modules.feature import FeatureEmbedder +from gluonts.torch.scaler import MeanScaler, NOPScaler +from gluonts.torch.distributions.quantile_output import QuantileOutput +from gluonts.torch.util import weighted_average + +from .layers import ( + HierarchicalCausalConv1DEncoder, + Enc2Dec, + ForkingMLPDecoder, +) + + +class MQCNNModel(nn.Module): + """ + Base network for the :class:`MQCnnEstimator`. + Parameters + ---------- + context_length (int): + length of the encoding sequence. + prediction_length (int): + prediction length + num_forking (int): + number of forks to do in the decoder. + past_feat_dynamic_real_dim (int): + dimension of past real dynamic features + feat_static_real_dim (int): + dimenstion of real static features + feat_dynamic_real_dim (int): + dimension of real dynamic features + cardinality_dynamic (List[int]): + cardinalities of dynamic categorical features + embedding_dimension_dynamic (List[int]): + embedding dymensions of dynamic categorical features + cardinality_static (List[int]): + cardinalities of static categorical features + embedding_dimension_static (List[int]): + embedding dimensions of static categorical features + scaling (bool): + if True, scale the target values + scaling_decoder_dynamic_feature (bool): + if True, scale the dynamic features for the decoder + encoder_cnn_init_dim (int): + input dimensions of encoder CNN + dilation_seq (List[int]): + dilation sequence of encoder CNN + kernel_size_seq (List[int]): + kernel sizes of encoder CNN + channels_seq (List[int]): + numbers of cannels of encoder CNN + joint_embedding_dimension (int): + joint embedding dimension of the encoder + encoder_mlp_init_dim (int): + input dimension of static features encoder MLP + encoder_mlp_dim_seq (List[int]): + sequence of hidden layer dimentions of encoder MLP + use_residual (bool): + if True, target is added to encoder CNN output + decoder_mlp_dim_seq (List[int]): + sequence of layer dimensions of decoder MLP + decoder_hidden_dim (int): + decoder MLP hidden dimension + decoder_future_init_dim (int): + decoder init dimension for embedding future dynamic features + decoder_future_embedding_dim (int): + decoder embedding dimension for future dynamic features + distr_output (Optional[Output]): + distribution output block. Defaults to None, + kwargs: dict + dictionary of parameters + """ + + @validated() + def __init__( + self, + context_length: int, + prediction_length: int, + num_forking: int, + past_feat_dynamic_real_dim: int, + feat_static_real_dim: int, + feat_dynamic_real_dim: int, + cardinality_dynamic: List[int], + embedding_dimension_dynamic: List[int], + cardinality_static: List[int], + embedding_dimension_static: List[int], + scaling: bool, + scaling_decoder_dynamic_feature: bool, + encoder_cnn_init_dim: int, + dilation_seq: List[int], + kernel_size_seq: List[int], + channels_seq: List[int], + joint_embedding_dimension: int, + encoder_mlp_init_dim: int, + encoder_mlp_dim_seq: List[int], + use_residual: bool, + decoder_mlp_dim_seq: List[int], + decoder_hidden_dim: int, + decoder_future_init_dim: int, + decoder_future_embedding_dim: int, + distr_output: Optional[Output] = None, + **kwargs, + ) -> None: + + super().__init__(**kwargs) + + self.context_length = context_length + self.prediction_length = prediction_length + self.num_forking = num_forking + + self.feat_static_real_dim = feat_static_real_dim + self.feat_dynamic_real_dim = feat_dynamic_real_dim + self.past_feat_dynamic_real_dim = past_feat_dynamic_real_dim + self.cardinality_dynamic = cardinality_dynamic + self.cardinality_static = cardinality_static + + self.embedder_dynamic = ( + FeatureEmbedder( + cardinalities=cardinality_dynamic, + embedding_dims=embedding_dimension_dynamic, + ) + if len(cardinality_dynamic) > 0 + else None + ) + + self.embedder_dynamic_future = ( + FeatureEmbedder( + cardinalities=cardinality_dynamic, + embedding_dims=embedding_dimension_dynamic, + ) + if len(cardinality_dynamic) > 0 + else None + ) + + self.embedder_static = ( + FeatureEmbedder( + cardinalities=cardinality_static, + embedding_dims=embedding_dimension_static, + ) + if len(cardinality_static) > 0 + else None + ) + + self.scaling = scaling + self.scaling_decoder_dynamic_feature = scaling_decoder_dynamic_feature + + if self.scaling: + self.scaler = MeanScaler(dim=1) + else: + self.scaler = NOPScaler(dim=1) + + if self.scaling_decoder_dynamic_feature: + self.scaler_decoder_dynamic_feature = MeanScaler(dim=1) + else: + self.scaler_decoder_dynamic_feature = NOPScaler(dim=1) + + self.encoder = HierarchicalCausalConv1DEncoder( + cnn_init_dim=encoder_cnn_init_dim, + dilation_seq=dilation_seq, + kernel_size_seq=kernel_size_seq, + channels_seq=channels_seq, + joint_embedding_dimension=joint_embedding_dimension, + mlp_init_dim=encoder_mlp_init_dim, + hidden_dimension_sequence=encoder_mlp_dim_seq, + use_residual=use_residual, + use_static_feat=True, + use_dynamic_feat=True, + ) + + self.enc2dec = Enc2Dec(num_forking=num_forking) + + self.decoder = ForkingMLPDecoder( + dec_len=prediction_length, + encoded_input_dim=channels_seq[-1] + joint_embedding_dimension + 1, + local_mlp_final_dim=decoder_mlp_dim_seq[-1], + global_mlp_final_dim=decoder_hidden_dim, + future_feat_init_dim=decoder_future_init_dim, + future_feat_embedding_dim=decoder_future_embedding_dim, + local_mlp_hidden_dim_sequence=decoder_mlp_dim_seq[:-1], + ) + + if distr_output is None: + distr_output = QuantileOutput( + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] + ) + self.distr_output = distr_output + self.distr_proj = self.distr_output.get_args_proj( + in_features=decoder_mlp_dim_seq[-1] + ) + + def describe_inputs(self, batch_size=1) -> InputSpec: + return InputSpec( + { + "past_target": Input( + shape=(batch_size, self.context_length, 1), + dtype=torch.float, + ), + "past_feat_dynamic": Input( + shape=( + batch_size, + self.context_length, + self.past_feat_dynamic_real_dim, + ), + dtype=torch.float, + ), + "future_feat_dynamic": Input( + shape=( + batch_size, + self.num_forking, + self.prediction_length, + self.feat_dynamic_real_dim, + ), + dtype=torch.float, + ), + "feat_static_real": Input( + shape=(batch_size, self.feat_static_real_dim), + dtype=torch.float, + ), + "feat_static_cat": Input( + shape=(batch_size, len(self.cardinality_static)), + dtype=torch.long, + ), + "past_observed_values": Input( + shape=(batch_size, self.context_length, 1), + dtype=torch.float, + ), + "past_feat_dynamic_cat": Input( + shape=( + batch_size, + self.context_length, + len(self.cardinality_dynamic), + ), + dtype=torch.long, + ), + "future_feat_dynamic_cat": Input( + shape=( + batch_size, + self.num_forking, + self.prediction_length, + len(self.cardinality_dynamic), + ), + dtype=torch.long, + ), + }, + torch.zeros, + ) + + # this method connects the sub-networks and returns the decoder output + def get_decoder_network_output( + self, + past_target: torch.Tensor, + past_feat_dynamic: torch.Tensor, + future_feat_dynamic: torch.Tensor, + feat_static: torch.Tensor, + past_observed_values: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters + ---------- + past_target: torch.Tensor + shape (batch_size, encoder_length, 1) + past_feat_dynamic + shape (batch_size, encoder_length, num_past_feat_dynamic) + future_feat_dynamic + shape (batch_size, num_forking, decoder_length, num_feat_dynamic) + feat_static + shape (batch_size, num_feat_static_real) + past_observed_values: torch.Tensor + shape (batch_size, encoder_length, 1) + Returns + ------- + decoder output, loc and scale + """ + + # scale shape: (batch_size, 1, 1) + scaled_past_target, loc, scale = self.scaler( + past_target, past_observed_values + ) + + # in addition to embedding features, use the log scale as it can help prediction too + # (batch_size, num_feat_static = sum(embedding_dimension) + 1) + feat_static = torch.cat((feat_static, torch.log(scale)), dim=1) + + # Passing past_observed_values as a feature would allow the network to + # make that distinction and possibly ignore the masked values. + past_feat_dynamic_extended = torch.cat( + (past_feat_dynamic, past_observed_values), dim=-1 + ) + + # arguments: target, static_features, dynamic_features + # enc_output_static shape: (batch_size, channels_seq[-1] + 1) + # enc_output_dynamic shape: (batch_size, encoder_length, channels_seq[-1] + 1) + enc_output_static, enc_output_dynamic = self.encoder( + scaled_past_target, feat_static, past_feat_dynamic_extended + ) + + # arguments: encoder_output_static, encoder_output_dynamic, future_features + # dec_input_static shape: (batch_size, channels_seq[-1] + 1) + # dec_input_dynamic shape:(batch_size, num_forking, channels_seq[-1] + 1 + decoder_length * num_feat_dynamic) + dec_input_encoded = self.enc2dec( + enc_output_static, + # slice axis 1 from encoder_length = context_length to num_forking + enc_output_dynamic[ + :, -self.num_forking : enc_output_dynamic.shape[1], ... + ], + ) + + scaled_future_feat_dynamic, _, _ = self.scaler_decoder_dynamic_feature( + future_feat_dynamic, torch.ones_like(future_feat_dynamic) + ) + + # arguments: dynamic_input, static_input + dec_output = self.decoder( + dec_input_encoded, scaled_future_feat_dynamic + ) + + # the output shape should be: (batch_size, num_forking, dec_len, decoder_mlp_dim_seq[0]) + return dec_output, loc, scale + + # noinspection PyMethodOverriding + def forward( + self, + past_target: torch.Tensor, + past_feat_dynamic: torch.Tensor, + future_feat_dynamic: torch.Tensor, + feat_static_real: torch.Tensor, + feat_static_cat: torch.Tensor, + past_observed_values: torch.Tensor, + past_feat_dynamic_cat: torch.Tensor, + future_feat_dynamic_cat: torch.Tensor, + ) -> Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor]: + """ + Parameters + ---------- + past_target: torch.Tensor + shape (batch_size, encoder_length, 1) + past_feat_dynamic + shape (batch_size, encoder_length, past_feat_dynamic_dim) + future_feat_dynamic + shape (batch_size, num_forking, decoder_length, feat_dynamic_dim) + feat_static_real + shape (batch_size, feat_static_real_dim) + feat_static_cat + shape (batch_size, feat_static_cat_dim) + past_observed_values: torch.Tensor + shape (batch_size, encoder_length, 1) + future_observed_values: torch.Tensor + shape (batch_size, num_forking, decoder_length) + past_feat_dynamic_cat: torch.Tensor, + shape (batch_size, encoder_length, feature_dynamic_cat_dim) + future_feat_dynamic_cat: torch.Tensor, + shape (batch_size, num_forking, decoder_length, feature_dynamic_cat_dim) + Returns + ------- + distr_args, loc, scale + """ + + if self.embedder_static is not None: + embedded_cat = self.embedder_static(feat_static_cat) + feat_static = torch.cat((embedded_cat, feat_static_real), dim=1) + else: + feat_static = torch.add(feat_static_real, feat_static_cat) + + if self.embedder_dynamic is not None: + + # Embed dynamic categorical features + embedded_past_feature_dynamic_cat = self.embedder_dynamic( + past_feat_dynamic_cat + ) + embedded_future_feature_dynamic_cat = self.embedder_dynamic_future( + future_feat_dynamic_cat + ) + # Combine all dynamic features + past_feat_dynamic = torch.cat( + (past_feat_dynamic, embedded_past_feature_dynamic_cat), dim=-1 + ) + future_feat_dynamic = torch.cat( + (future_feat_dynamic, embedded_future_feature_dynamic_cat), + dim=-1, + ) + + else: + # Make sure that future_feat_dynamic_cat also has [N, T, H, C] dimensions + future_feat_dynamic_cat = torch.reshape( + future_feat_dynamic_cat, [0, 0, 0, -1] + ) + + past_feat_dynamic = torch.add( + past_feat_dynamic, past_feat_dynamic_cat + ) + future_feat_dynamic = torch.add( + future_feat_dynamic, future_feat_dynamic_cat + ) + + # shape: (batch_size, num_forking, decoder_length, decoder_mlp_dim_seq[0]) + dec_output, loc, scale = self.get_decoder_network_output( + past_target, + past_feat_dynamic, + future_feat_dynamic, + feat_static, + past_observed_values, + ) + + # shape: (batch_size, num_forking, decoder_length, len(quantiles)) + distr_args = self.distr_proj(dec_output) + return distr_args, loc, scale + + # noinspection PyMethodOverriding + def loss( + self, + past_target: torch.Tensor, + future_target: torch.Tensor, + past_feat_dynamic: torch.Tensor, + future_feat_dynamic: torch.Tensor, + feat_static_real: torch.Tensor, + feat_static_cat: torch.Tensor, + past_observed_values: torch.Tensor, + future_observed_values: torch.Tensor, + past_feat_dynamic_cat: torch.Tensor, + future_feat_dynamic_cat: torch.Tensor, + ) -> torch.Tensor: + """ + Parameters + ---------- + past_target: torch.Tensor + shape (batch_size, encoder_length, 1) + future_target: torch.Tensor + shape (batch_size, num_forking, decoder_length) + past_feat_dynamic + shape (batch_size, encoder_length, past_feat_dynamic_dim) + future_feat_dynamic + shape (batch_size, num_forking, decoder_length, feat_dynamic_dim) + feat_static_real + shape (batch_size, feat_static_real_dim) + feat_static_cat + shape (batch_size, feat_static_cat_dim) + past_observed_values: torch.Tensor + shape (batch_size, encoder_length, 1) + future_observed_values: torch.Tensor + shape (batch_size, num_forking, decoder_length) + past_feat_dynamic_cat: torch.Tensor, + shape (batch_size, encoder_length, feature_dynamic_cat_dim) + future_feat_dynamic_cat: torch.Tensor, + shape (batch_size, num_forking, decoder_length, feature_dynamic_cat_dim) + Returns + ------- + loss with shape (batch_size, prediction_length) + """ + + distr_args, loc, scale = self( + past_target, + past_feat_dynamic, + future_feat_dynamic, + feat_static_real, + feat_static_cat, + past_observed_values, + past_feat_dynamic_cat, + future_feat_dynamic_cat, + ) + + # shape: (batch_size, num_forking, decoder_length = prediction_length) + loss = self.distr_output.loss( + target=future_target, + distr_args=distr_args, + loc=loc.unsqueeze(-1), + scale=scale.unsqueeze(-1), + ) + + # mask the loss based on observed indicator + # shape: (batch_size, decoder_length) + return weighted_average(x=loss, weights=future_observed_values, dim=1) diff --git a/test/torch/model/test_mq_cnn.py b/test/torch/model/test_mq_cnn.py new file mode 100644 index 0000000000..c31708efd4 --- /dev/null +++ b/test/torch/model/test_mq_cnn.py @@ -0,0 +1,155 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from typing import List + +import torch +import pytest + +from gluonts.torch.distributions import QuantileOutput +from gluonts.torch.model.mq_cnn import MQCNNLightningModule + + +@pytest.mark.parametrize( + "past_feat_dynamic_real_dim, feat_dynamic_real_dim, cardinality_dynamic, feat_static_real_dim, cardinality_static, quantiles", + [ + (3, 2, [5, 5, 5], 3, [5, 2], [0.1, 0.5, 0.9]), + (2, 0, [4, 2], 2, [4, 2, 2], [0.05, 0.25]), + ], +) +def test_mq_cnn_modules( + past_feat_dynamic_real_dim: int, + feat_dynamic_real_dim: int, + cardinality_dynamic: List[int], + feat_static_real_dim: int, + cardinality_static: List[int], + quantiles: List[float], +): + batch_size = 4 + prediction_length = 6 + context_length = 12 + num_forking = 8 + + enc_cnn_init_dim = ( + 2 + past_feat_dynamic_real_dim + sum(cardinality_dynamic) + ) # with target, observed + enc_mlp_init_dim = ( + 1 + feat_static_real_dim + sum(cardinality_static) + ) # with scaler + dec_future_init_dim = feat_dynamic_real_dim + sum(cardinality_dynamic) + joint_embedding_dim = 30 + encoder_mlp_dim_seq = [30] + decoder_mlp_dim_seq = [30] + decoder_hidden_dim = 64 + decoder_future_embedding_dim = 50 + channels_seq = [30, 30, 30] + dilation_seq = [1, 3, 9] + kernel_size_seq = [7, 3, 3] + + lightning_module = MQCNNLightningModule( + { + "context_length": context_length, + "prediction_length": prediction_length, + "num_forking": num_forking, + "past_feat_dynamic_real_dim": past_feat_dynamic_real_dim, + "feat_dynamic_real_dim": feat_dynamic_real_dim, + "cardinality_dynamic": cardinality_dynamic, + "embedding_dimension_dynamic": cardinality_dynamic, + "feat_static_real_dim": feat_static_real_dim, + "cardinality_static": cardinality_static, + "embedding_dimension_static": cardinality_static, + "scaling": False, + "scaling_decoder_dynamic_feature": False, + "encoder_cnn_init_dim": enc_cnn_init_dim, + "dilation_seq": dilation_seq, + "kernel_size_seq": kernel_size_seq, + "channels_seq": channels_seq, + "joint_embedding_dimension": joint_embedding_dim, + "encoder_mlp_init_dim": enc_mlp_init_dim, + "encoder_mlp_dim_seq": encoder_mlp_dim_seq, + "use_residual": True, + "decoder_mlp_dim_seq": decoder_mlp_dim_seq, + "decoder_hidden_dim": decoder_hidden_dim, + "decoder_future_init_dim": dec_future_init_dim, + "decoder_future_embedding_dim": decoder_future_embedding_dim, + "distr_output": QuantileOutput(quantiles), + } + ) + model = lightning_module.model + + feat_static_cat = torch.zeros( + batch_size, len(cardinality_static), dtype=torch.long + ) + feat_static_real = torch.ones(batch_size, feat_static_real_dim) + future_feat_dynamic_cat = torch.zeros( + batch_size, + num_forking, + prediction_length, + len(cardinality_dynamic), + dtype=torch.long, + ) + past_feat_dynamic_cat = torch.zeros( + batch_size, + context_length, + len(cardinality_dynamic), + dtype=torch.long, + ) + future_feat_dynamic = torch.ones( + batch_size, + num_forking, + prediction_length, + feat_dynamic_real_dim, + ) + past_feat_dynamic = torch.ones( + batch_size, context_length, past_feat_dynamic_real_dim + ) + past_target = torch.ones(batch_size, context_length, 1) + past_observed_values = torch.ones(batch_size, context_length, 1) + future_target = torch.ones(batch_size, num_forking, prediction_length) + future_observed_values = torch.ones( + batch_size, num_forking, prediction_length + ) + output, loc, scale = model( + past_target=past_target, + past_feat_dynamic=past_feat_dynamic, + future_feat_dynamic=future_feat_dynamic, + feat_static_real=feat_static_real, + feat_static_cat=feat_static_cat, + past_observed_values=past_observed_values, + past_feat_dynamic_cat=past_feat_dynamic_cat, + future_feat_dynamic_cat=future_feat_dynamic_cat, + ) + + assert output[0].shape == ( + batch_size, + num_forking, + prediction_length, + len(quantiles), + ) + assert loc.shape == scale.shape == (batch_size, 1) + + batch = dict( + past_target=past_target, + future_target=future_target, + past_feat_dynamic=past_feat_dynamic, + future_feat_dynamic=future_feat_dynamic, + feat_static_real=feat_static_real, + feat_static_cat=feat_static_cat, + past_observed_values=past_observed_values, + future_observed_values=future_observed_values, + past_feat_dynamic_cat=past_feat_dynamic_cat, + future_feat_dynamic_cat=future_feat_dynamic_cat, + ) + + assert lightning_module.training_step(batch, batch_idx=0).shape == () + assert lightning_module.validation_step(batch, batch_idx=0).shape == () From 24756567871f52146a4c7943f6d112fcb9b47ad4 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Mon, 27 May 2024 12:12:21 +0200 Subject: [PATCH 3/8] update --- src/gluonts/torch/model/mq_cnn/estimator.py | 7 +++- test/torch/model/test_estimators.py | 38 ++++++++++----------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/gluonts/torch/model/mq_cnn/estimator.py b/src/gluonts/torch/model/mq_cnn/estimator.py index aea88d4675..d6583d72d5 100644 --- a/src/gluonts/torch/model/mq_cnn/estimator.py +++ b/src/gluonts/torch/model/mq_cnn/estimator.py @@ -36,7 +36,6 @@ AddAgeFeature, AddTimeFeatures, AddObservedValuesIndicator, - Chain, RenameFields, SetField, ExpectedNumInstanceSampler, @@ -150,6 +149,11 @@ class MQCNNEstimator(PyTorchLightningEstimator): (default: if None, channels_seq[-1] * sqrt(feat_static_dim)), where feat_static_dim is appx sum(embedding_dimension_static)) Defaults to None. + time_features (list, optional): + List of time features, from :py:mod:`gluonts.time_feature`, to use as + inputs to the model in addition to the provided data. + Defaults to None, in which case these are automatically determined based + on freq. encoder_mlp_dim_seq (List[int], optional): The dimensionalities of the MLP layers of the encoder for static features (default: [] if None) Defaults to None. @@ -243,6 +247,7 @@ def __init__( scaling: Optional[bool] = None, scaling_decoder_dynamic_feature: bool = False, joint_embedding_dimension: Optional[int] = None, + time_features: Optional[list] = None, encoder_mlp_dim_seq: Optional[List[int]] = None, decoder_mlp_dim_seq: Optional[List[int]] = None, decoder_hidden_dim: Optional[int] = None, diff --git a/test/torch/model/test_estimators.py b/test/torch/model/test_estimators.py index e894589911..cf40a87cb9 100644 --- a/test/torch/model/test_estimators.py +++ b/test/torch/model/test_estimators.py @@ -164,14 +164,14 @@ num_batches_per_epoch=3, trainer_kwargs=dict(max_epochs=2), ), - # lambda dataset: MQCNNEstimator( - # freq=dataset.metadata.freq, - # distr_output=QuantileOutput(quantiles=[0.1, 0.6, 0.85]), - # prediction_length=dataset.metadata.prediction_length, - # batch_size=4, - # num_batches_per_epoch=3, - # trainer_kwargs=dict(max_epochs=2), - # ), + lambda dataset: MQCNNEstimator( + freq=dataset.metadata.freq, + distr_output=QuantileOutput(quantiles=[0.1, 0.6, 0.85]), + prediction_length=dataset.metadata.prediction_length, + batch_size=4, + num_batches_per_epoch=3, + trainer_kwargs=dict(max_epochs=2), + ), ], ) @pytest.mark.parametrize("use_validation_data", [False, True]) @@ -343,17 +343,17 @@ def test_estimator_constant_dataset( cardinality=[2, 2], trainer_kwargs=dict(max_epochs=2), ), - # lambda freq, prediction_length: MQCNNEstimator( - # freq=freq, - # prediction_length=prediction_length, - # batch_size=4, - # num_batches_per_epoch=3, - # num_feat_dynamic_real=3, - # num_feat_static_real=1, - # num_feat_static_cat=2, - # cardinality=[2, 2], - # trainer_kwargs=dict(max_epochs=2), - # ), + lambda freq, prediction_length: MQCNNEstimator( + freq=freq, + prediction_length=prediction_length, + batch_size=4, + num_batches_per_epoch=3, + num_feat_dynamic_real=3, + num_feat_static_real=1, + num_feat_static_cat=2, + cardinality=[2, 2], + trainer_kwargs=dict(max_epochs=2), + ), ], ) def test_estimator_with_features(estimator_constructor): From cf1a7c3659fbc3ca6409ee3d3ce080c18f5dc004 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Mon, 27 May 2024 12:13:26 +0200 Subject: [PATCH 4/8] fix --- src/gluonts/torch/model/mq_cnn/estimator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/gluonts/torch/model/mq_cnn/estimator.py b/src/gluonts/torch/model/mq_cnn/estimator.py index d6583d72d5..9b2d8c2b8e 100644 --- a/src/gluonts/torch/model/mq_cnn/estimator.py +++ b/src/gluonts/torch/model/mq_cnn/estimator.py @@ -379,10 +379,12 @@ def __init__( ) assert ( - len(channels_seq) == len(dilation_seq) == len(kernel_size_seq) + len(self.channels_seq) + == len(self.dilation_seq) + == len(self.kernel_size_seq) ), ( - f"mismatch CNN configurations: {len(channels_seq)} vs. " - f"{len(dilation_seq)} vs. {len(kernel_size_seq)}" + f"mismatch CNN configurations: {len(self.channels_seq)} vs. " + f"{len(self.dilation_seq)} vs. {len(self.kernel_size_seq)}" ) self.use_residual = use_residual From d7b8f0fe04ec954a7c806b46cba2837b8046c972 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Mon, 27 May 2024 13:33:29 +0200 Subject: [PATCH 5/8] fixes --- src/gluonts/torch/model/mq_cnn/estimator.py | 19 ++++++++++++------- .../torch/model/mq_cnn/lightning_module.py | 4 ++-- test/torch/model/test_estimators.py | 8 ++++---- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/gluonts/torch/model/mq_cnn/estimator.py b/src/gluonts/torch/model/mq_cnn/estimator.py index 9b2d8c2b8e..553a0322bc 100644 --- a/src/gluonts/torch/model/mq_cnn/estimator.py +++ b/src/gluonts/torch/model/mq_cnn/estimator.py @@ -74,7 +74,7 @@ class MQCNNEstimator(PyTorchLightningEstimator): Length of the prediction, also known as 'horizon'. context_length (int, optional): Number of time units that condition the predictions, also known as 'lookback period'. - Defaults to None. + Defaults to `4 * prediction_length`. num_forking (int, optional): Decides how much forking to do in the decoder. (default: context_length if None) @@ -282,7 +282,11 @@ def __init__( self.freq = freq self.prediction_length = prediction_length - self.context_length = context_length + self.context_length = self.context_length = ( + context_length + if context_length is not None + else 4 * self.prediction_length + ) self.num_forking = ( min(num_forking, self.context_length) if num_forking is not None @@ -400,7 +404,8 @@ def __init__( self.enc_cnn_init_dim += sum(self.embedding_dimension_dynamic) self.dec_future_init_dim += sum(self.embedding_dimension_dynamic) else: - self.cardinality_dynamic = 0 + self.cardinality_dynamic = [0] + self.embedding_dimension_dynamic = [0] if self.use_past_feat_dynamic_real: assert ( @@ -431,14 +436,14 @@ def __init__( ) self.enc_mlp_init_dim += sum(self.embedding_dimension_static) else: - self.cardinality_static = 0 - self.embedding_dimension_static = 0 + self.cardinality_static = [0] + self.embedding_dimension_static = [0] self.joint_embedding_dimension = joint_embedding_dimension if self.joint_embedding_dimension is None: feat_static_dim = sum(self.embedding_dimension_static) self.joint_embedding_dimension = int( - channels_seq[-1] * max(np.sqrt(feat_static_dim), 1) + self.channels_seq[-1] * max(np.sqrt(feat_static_dim), 1) ) if self.use_feat_static_real: @@ -623,7 +628,7 @@ def _create_instance_splitter(self, mode: str) -> Chain: FieldName.FEAT_DYNAMIC, FieldName.FEAT_DYNAMIC_CAT, ] - + ([FieldName.OBSERVED_VALUES] if mode != "test" else []), + + ([FieldName.OBSERVED_VALUES] if mode != "test" else []) ) decoder_disabled_fields = ( diff --git a/src/gluonts/torch/model/mq_cnn/lightning_module.py b/src/gluonts/torch/model/mq_cnn/lightning_module.py index 687bc73e7a..f4602ee06e 100644 --- a/src/gluonts/torch/model/mq_cnn/lightning_module.py +++ b/src/gluonts/torch/model/mq_cnn/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -from pytorch_lightning import LightningModule +import lightning.pytorch as pl import torch from torch.optim.lr_scheduler import ReduceLROnPlateau from gluonts.torch.model.lightning_util import has_validation_loop @@ -21,7 +21,7 @@ from .module import MQCNNModel -class MQCNNLightningModule(LightningModule): +class MQCNNLightningModule(pl.LightningModule): """ A ``pl.LightningModule`` class that can be used to train a ``MQCNNModel`` with PyTorch Lightning. diff --git a/test/torch/model/test_estimators.py b/test/torch/model/test_estimators.py index cf40a87cb9..f2005f4957 100644 --- a/test/torch/model/test_estimators.py +++ b/test/torch/model/test_estimators.py @@ -348,10 +348,10 @@ def test_estimator_constant_dataset( prediction_length=prediction_length, batch_size=4, num_batches_per_epoch=3, - num_feat_dynamic_real=3, - num_feat_static_real=1, - num_feat_static_cat=2, - cardinality=[2, 2], + use_feat_dynamic_real=True, + use_feat_static_real=True, + use_feat_static_cat=True, + cardinality_static=[2, 2], trainer_kwargs=dict(max_epochs=2), ), ], From 9044883d8f82836d82d4cb81dec8fcf0b9a7ba00 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Mon, 27 May 2024 13:40:46 +0200 Subject: [PATCH 6/8] more fixes --- src/gluonts/torch/model/mq_cnn/estimator.py | 6 +----- src/gluonts/transform/field.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/gluonts/torch/model/mq_cnn/estimator.py b/src/gluonts/torch/model/mq_cnn/estimator.py index 553a0322bc..f678619dca 100644 --- a/src/gluonts/torch/model/mq_cnn/estimator.py +++ b/src/gluonts/torch/model/mq_cnn/estimator.py @@ -536,17 +536,13 @@ def create_transformation(self) -> Chain: # now we map all the dynamic input of length context_length + prediction_length onto FieldName.FEAT_DYNAMIC # we exclude past_feat_dynamic_real since its length is only context_length - if len(dynamic_feat_fields) > 1: + if len(dynamic_feat_fields) > 0: transforms.append( VstackFeatures( output_field=FieldName.FEAT_DYNAMIC, input_fields=dynamic_feat_fields, ) ) - elif len(dynamic_feat_fields) == 1: - transforms.append( - RenameFields({dynamic_feat_fields[0]: FieldName.FEAT_DYNAMIC}) - ) if not self.use_feat_dynamic_cat: transforms.append( diff --git a/src/gluonts/transform/field.py b/src/gluonts/transform/field.py index bfd519c65e..788969d196 100644 --- a/src/gluonts/transform/field.py +++ b/src/gluonts/transform/field.py @@ -41,7 +41,7 @@ def transform(self, data: DataEntry): for key, new_key in self.mapping.items(): if key in data: # no implicit overriding - assert new_key not in data + assert new_key not in data, f"Key {new_key} is already present" data[new_key] = data[key] del data[key] return data From dad0a02362250f65622eb10da9c059368bedbc41 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Mon, 27 May 2024 13:41:56 +0200 Subject: [PATCH 7/8] black --- src/gluonts/torch/model/mq_cnn/estimator.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/gluonts/torch/model/mq_cnn/estimator.py b/src/gluonts/torch/model/mq_cnn/estimator.py index f678619dca..b0d779a151 100644 --- a/src/gluonts/torch/model/mq_cnn/estimator.py +++ b/src/gluonts/torch/model/mq_cnn/estimator.py @@ -619,13 +619,10 @@ def _create_instance_splitter(self, mode: str) -> Chain: else [] ) - decoder_series_fields = ( - [ - FieldName.FEAT_DYNAMIC, - FieldName.FEAT_DYNAMIC_CAT, - ] - + ([FieldName.OBSERVED_VALUES] if mode != "test" else []) - ) + decoder_series_fields = [ + FieldName.FEAT_DYNAMIC, + FieldName.FEAT_DYNAMIC_CAT, + ] + ([FieldName.OBSERVED_VALUES] if mode != "test" else []) decoder_disabled_fields = ( [FieldName.FEAT_DYNAMIC, FieldName.FEAT_DYNAMIC_CAT] From 6aa577377eef0af6e8605f80bb6df1a49e6bf729 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Mon, 27 May 2024 13:47:19 +0200 Subject: [PATCH 8/8] docformatter --- src/gluonts/torch/model/mq_cnn/estimator.py | 15 ++++++++++----- src/gluonts/torch/model/mq_cnn/layers.py | 8 +++++--- src/gluonts/transform/split.py | 7 +++++-- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/gluonts/torch/model/mq_cnn/estimator.py b/src/gluonts/torch/model/mq_cnn/estimator.py index b0d779a151..1dfdf87d75 100644 --- a/src/gluonts/torch/model/mq_cnn/estimator.py +++ b/src/gluonts/torch/model/mq_cnn/estimator.py @@ -456,7 +456,8 @@ def __init__( self.feat_static_real_dim = 0 def create_transformation(self) -> Chain: - """Creates transformation to be applied to input dataset + """ + Creates transformation to be applied to input dataset. Returns: Chain: @@ -577,7 +578,8 @@ def create_transformation(self) -> Chain: return Chain(transforms) def _create_instance_splitter(self, mode: str) -> Chain: - """Creates instance splitter to be applied to the dataset + """ + Creates instance splitter to be applied to the dataset. Args: mode (str): `training`, `validation` or `test` @@ -676,7 +678,8 @@ def create_training_data_loader( shuffle_buffer_length: Optional[int] = None, **kwargs, ) -> Iterable: - """Creates data loader for the training dataset + """ + Creates data loader for the training dataset. Args: data (Dataset): training dataset @@ -706,7 +709,8 @@ def create_validation_data_loader( module: MQCNNLightningModule, **kwargs, ) -> Iterable: - """Creates data loader for the validation dataset + """ + Creates data loader for the validation dataset. Args: data (Dataset): validation dataset @@ -769,7 +773,8 @@ def create_predictor( transformation: Transformation, module: MQCNNLightningModule, ) -> PyTorchPredictor: - """Creates predictor for inference + """ + Creates predictor for inference. Args: transformation (Transformation): transformation to be applied to data input to predictor diff --git a/src/gluonts/torch/model/mq_cnn/layers.py b/src/gluonts/torch/model/mq_cnn/layers.py index d4351baa79..7c381cd95c 100644 --- a/src/gluonts/torch/model/mq_cnn/layers.py +++ b/src/gluonts/torch/model/mq_cnn/layers.py @@ -253,8 +253,9 @@ def forward( class Enc2Dec(nn.Module): """ - Integrates the encoder_output_static, encoder_output_dynamic and future_features_dynamic - and passes them through as the dynamic input to the decoder. + Integrates the encoder_output_static, encoder_output_dynamic and + future_features_dynamic and passes them through as the dynamic input to the + decoder. Parameters: ------------ @@ -439,7 +440,8 @@ def _get_local_mlp(self, init_dim, final_dim, hidden_dimension_seq): return local_mlp def forward(self, encoded_input: Tensor, future_input: Tensor) -> Tensor: - """Forward pass for MQCNN decoder + """ + Forward pass for MQCNN decoder. Args: encoded_input (Tensor): diff --git a/src/gluonts/transform/split.py b/src/gluonts/transform/split.py index 99b15b1f2e..b0ed1f302c 100644 --- a/src/gluonts/transform/split.py +++ b/src/gluonts/transform/split.py @@ -578,7 +578,9 @@ def flatmap_transform( class ForkingSequenceSplitter(FlatMapTransformation): - """Forking sequence splitter used by MQ-CNN Model""" + """ + Forking sequence splitter used by MQ-CNN Model. + """ @validated() def __init__( @@ -597,7 +599,8 @@ def __init__( start_input_field: str = FieldName.TARGET, lead_time: int = 0, ) -> None: - """Creates forking sequences + """ + Creates forking sequences. Args: instance_sampler ([type]):