Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/transformer refactorisation #1915

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 86 additions & 36 deletions darts/models/forecasting/transformer_model.py
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Transformer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe import just generate_square_subsequent_mask

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reading up a little and checking out the implementation, it turns out that generate_square_subsequent_mask is a static method of Transformer. While it is possible to import just it (https://stackoverflow.com/questions/48178011/import-static-method-of-a-class-without-importing-the-whole-class) I don't think it's worth it. That said, I definitely agree that this import is a little intuitive and I think that a nice middle ground would be adding an implementation of generate_square_subsequent_mask to darts/utils as it's a very small function. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT @dennisbader ?


from darts.logging import get_logger, raise_if, raise_if_not, raise_log
from darts.models.components import glu_variants, layer_norm_variants
Expand Down Expand Up @@ -188,6 +190,7 @@ def __init__(
self.target_size = output_size
self.nr_params = nr_params
self.target_length = self.output_chunk_length
self.d_model = d_model

self.encoder = nn.Linear(input_size, d_model)
self.positional_encoding = _PositionalEncoding(
Expand Down Expand Up @@ -281,48 +284,105 @@ def __init__(
custom_decoder=custom_decoder,
)

self.decoder = nn.Linear(
d_model, self.target_length * self.target_size * self.nr_params
)

def _create_transformer_inputs(self, data):
# '_TimeSeriesSequentialDataset' stores time series in the
# (batch_size, input_chunk_length, input_size) format. PyTorch's nn.Transformer
# module needs it the (input_chunk_length, batch_size, input_size) format.
# Therefore, the first two dimensions need to be swapped.
src = data.permute(1, 0, 2)
tgt = src[-1:, :, :]

return src, tgt
self.decoder = nn.Linear(d_model, self.target_size * self.nr_params)

@io_processor
def forward(self, x_in: Tuple):
data, _ = x_in
# Here we create 'src' and 'tgt', the inputs for the encoder and decoder
# side of the Transformer architecture
src, tgt = self._create_transformer_inputs(data)
"""
During training (teacher forcing) x_in = tuple(past_target + past_covariates, static_covariates, future_targets)
During inference x_in = tuple(past_target + past_covariates, static_covariates)

'_TimeSeriesSequentialDataset' stores time series in the
(batch_size, input_chunk_length, input_size) format. PyTorch's nn.Transformer
module needs it the (input_chunk_length, batch_size, input_size) format.
Therefore, the first two dimensions need to be swapped.
"""
src = x_in[0].permute(1, 0, 2)
pad_size = (0, self.input_size - self.target_size)

# start token consists only of target series, past covariates are substituted with 0 padding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would you not include the past covariates in the start token?

start_token = src[-1:, :, : self.target_size]
start_token_padded = F.pad(start_token, pad_size)

if len(x_in) == 3:
tgt = x_in[-1].permute(1, 0, 2)
tgt = F.pad(tgt, pad_size)
tgt = torch.cat([start_token_padded, tgt], dim=0)
return self._prediction_step(src, tgt)[:, :-1, :, :]

tgt = start_token_padded

predictions = []
for _ in range(self.target_length):
pred = self._prediction_step(src, tgt)[:, -1, :, :]
predictions.append(pred)
tgt = torch.cat(
[tgt, F.pad(pred.mean(dim=2).unsqueeze(dim=0), pad_size)],
dim=0,
) # take average of samples
return torch.stack(predictions, dim=1)

def _prediction_step(self, src: torch.Tensor, tgt: torch.Tensor):
target_length = tgt.shape[0]
device, tensor_type = src.device, src.dtype
# "math.sqrt(self.input_size)" is a normalization factor
# see section 3.2.1 in 'Attention is All you Need' by Vaswani et al. (2017)
src = self.encoder(src) * math.sqrt(self.input_size)
src = self.positional_encoding(src)
src = self.encoder(src) * math.sqrt(self.d_model)
tgt = self.encoder(tgt) * math.sqrt(self.d_model)

tgt = self.encoder(tgt) * math.sqrt(self.input_size)
src = self.positional_encoding(src)
tgt = self.positional_encoding(tgt)

x = self.transformer(src=src, tgt=tgt)
tgt_mask = Transformer.generate_square_subsequent_mask(
target_length, device
).to(dtype=tensor_type)

x = self.transformer(src=src, tgt=tgt, tgt_mask=tgt_mask)
out = self.decoder(x)

# Here we change the data format
# from (1, batch_size, output_chunk_length * output_size)
# to (batch_size, output_chunk_length, output_size, nr_params)
predictions = out[0, :, :]
predictions = out.permute(1, 0, 2)
predictions = predictions.view(
-1, self.target_length, self.target_size, self.nr_params
-1, target_length, self.target_size, self.nr_params
)

return predictions

def training_step(self, train_batch, batch_idx) -> torch.Tensor:
"""performs the training step"""
train_batch = list(train_batch)
future_targets = train_batch[-1]
train_batch.append(future_targets)
return super().training_step(train_batch, batch_idx)

def _produce_train_output(self, input_batch: Tuple):
"""
Feeds PastCovariatesTorchModel with input and output chunks of a PastCovariatesSequentialDataset for
training.

Parameters:
input_batch
``(past_target, past_covariates, static_covariates, future_target)`` during training

``(past_target, past_covariates, static_covariates)`` during validation (not teacher forced)
"""

past_target, past_covariates, static_covariates = input_batch[:3]
# Currently all our PastCovariates models require past target and covariates concatenated
inpt = [
torch.cat([past_target, past_covariates], dim=2)
if past_covariates is not None
else past_target,
static_covariates,
]

# add future targets when training (teacher forcing)
if len(input_batch) == 4:
inpt.append(input_batch[-1])
return self(inpt)


class TransformerModel(PastCovariatesTorchModel):
def __init__(
Expand Down Expand Up @@ -351,7 +411,7 @@ def __init__(
The multi-head attention mechanism is highly parallelizable, which makes the transformer architecture
very suitable to be trained with GPUs.

The transformer architecture implemented here is based on [1]_.
The transformer architecture implemented here is based on [1]_ and uses teacher forcing [4]_.

This model supports past covariates (known for `input_chunk_length` points before prediction time).

Expand Down Expand Up @@ -388,7 +448,7 @@ def __init__(
Fraction of neurons affected by Dropout (default=0.1).
activation
The activation function of encoder/decoder intermediate layer, (default='relu').
can be one of the glu variant's FeedForward Network (FFN)[2]. A feedforward network is a
can be one of the glu variant's FeedForward Network (FFN)[3]. A feedforward network is a
fully-connected layer with an activation. The glu variant's FeedForward Network are a series
of FFNs designed to work better with Transformer based models. ["GLU", "Bilinear", "ReGLU", "GEGLU",
"SwiGLU", "ReLU", "GELU"] or one the pytorch internal activations ["relu", "gelu"]
Expand Down Expand Up @@ -541,17 +601,7 @@ def encode_year(idx):
.. [2] Shazeer, Noam, "GLU Variants Improve Transformer", 2020. arVix https://arxiv.org/abs/2002.05202.
.. [3] T. Kim et al. "Reversible Instance Normalization for Accurate Time-Series Forecasting against
Distribution Shift", https://openreview.net/forum?id=cGDAkQo1C0p

Notes
-----
Disclaimer:
This current implementation is fully functional and can already produce some good predictions. However,
it is still limited in how it uses the Transformer architecture because the `tgt` input of
`torch.nn.Transformer` is not utilized to its full extent. Currently, we simply pass the last value of the
`src` input to `tgt`. To get closer to the way the Transformer is usually used in language models, we
should allow the model to consume its own output as part of the `tgt` argument, such that when predicting
sequences of values, the input to the `tgt` argument would grow as outputs of the transformer model would be
added to it. Of course, the training of the model would have to be adapted accordingly.
.. [4] Teacher Forcing PyTorch tutorial: https://github.com/pytorch/examples/tree/main/word_language_model

Examples
--------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@
"output_chunk_length": 5,
"n_epochs": 20,
"random_state": 0,
"norm_type": "LayerNorm",
"likelihood": GaussianLikelihood(),
**tfm_kwargs,
},
Expand Down
Loading