Skip to content

Commit

Permalink
Merge pull request #433 from WenjieDu/(feat)add_reformer
Browse files Browse the repository at this point in the history
Add Reformer as an imputation model
  • Loading branch information
WenjieDu authored Jun 18, 2024
2 parents 487f8f8 + 4a9163a commit ba116eb
Show file tree
Hide file tree
Showing 11 changed files with 1,756 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .crossformer import Crossformer
from .informer import Informer
from .autoformer import Autoformer
from .reformer import Reformer
from .dlinear import DLinear
from .patchtst import PatchTST
from .usgan import USGAN
Expand Down Expand Up @@ -54,6 +55,7 @@
"DLinear",
"Informer",
"Autoformer",
"Reformer",
"NonstationaryTransformer",
"Pyraformer",
"BRITS",
Expand Down
25 changes: 25 additions & 0 deletions pypots/imputation/reformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
The package of the partially-observed time-series imputation model Reformer.
Refer to the paper
`Kitaev, Nikita, Łukasz Kaiser, and Anselm Levskaya.
Reformer: The Efficient Transformer.
International Conference on Learning Representations, 2020.
<https://openreview.net/pdf?id=rkgNKkHtvB>`_
Notes
-----
This implementation is inspired by the official one https://github.com/google/trax/tree/master/trax/models/reformer and
https://github.com/lucidrains/reformer-pytorch
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause


from .model import Reformer

__all__ = [
"Reformer",
]
88 changes: 88 additions & 0 deletions pypots/imputation/reformer/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
The core wrapper assembles the submodules of Reformer imputation model
and takes over the forward progress of the algorithm.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

import torch.nn as nn

from ...nn.modules.reformer import ReformerEncoder
from ...nn.modules.saits import SaitsLoss, SaitsEmbedding


class _Reformer(nn.Module):
def __init__(
self,
n_steps,
n_features,
n_layers,
d_model,
n_heads,
bucket_size,
n_hashes,
causal,
d_ffn,
dropout,
ORT_weight: float = 1,
MIT_weight: float = 1,
):
super().__init__()

self.n_steps = n_steps

self.saits_embedding = SaitsEmbedding(
n_features * 2,
d_model,
with_pos=False,
dropout=dropout,
)
self.encoder = ReformerEncoder(
n_steps,
n_layers,
d_model,
n_heads,
bucket_size,
n_hashes,
causal,
d_ffn,
dropout,
)

# for the imputation task, the output dim is the same as input dim
self.output_projection = nn.Linear(d_model, n_features)
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict, training: bool = True) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

# WDU: the original Reformer paper isn't proposed for imputation task. Hence the model doesn't take
# the missing mask into account, which means, in the process, the model doesn't know which part of
# the input data is missing, and this may hurt the model's imputation performance. Therefore, I apply the
# SAITS embedding method to project the concatenation of features and masks into a hidden space, as well as
# the output layers to project back from the hidden space to the original space.
enc_out = self.saits_embedding(X, missing_mask)

# Reformer encoder processing
enc_out = self.encoder(enc_out)
# project back the original data space
reconstruction = self.output_projection(enc_out)

imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction
results = {
"imputed_data": imputed_data,
}

# if in training mode, return results with losses
if training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(
reconstruction, X_ori, missing_mask, indicating_mask
)
results["ORT_loss"] = ORT_loss
results["MIT_loss"] = MIT_loss
# `loss` is always the item for backward propagating to update the model
results["loss"] = loss

return results
24 changes: 24 additions & 0 deletions pypots/imputation/reformer/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Dataset class for Reformer.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForReformer(DatasetForSAITS):
"""Actually Reformer uses the same data strategy as SAITS, needs MIT for training."""

def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_y: bool,
file_type: str = "hdf5",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_y, file_type, rate)
Loading

0 comments on commit ba116eb

Please sign in to comment.