Skip to content

Commit

Permalink
Add SegRNN model
Browse files Browse the repository at this point in the history
  • Loading branch information
lss-1138 committed Oct 10, 2024
1 parent a4f1a72 commit 366a584
Show file tree
Hide file tree
Showing 8 changed files with 508 additions and 0 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ The paper references and links are all listed at the bottom of this file.
| Neural Net | iTransformer🧑‍🔧[^24] || | | | | `2024 - ICLR` |
| Neural Net | ModernTCN[^38] || | | | | `2024 - ICLR` |
| Neural Net | ImputeFormer🧑‍🔧[^34] || | | | | `2024 - KDD` |
| Neural Net | SegRNN[^42] || | | | | `2023 - arXiv` |
| Neural Net | SAITS[^1] || | | | | `2023 - ESWA` |
| Neural Net | FreTS🧑‍🔧[^23] || | | | | `2023 - NeurIPS` |
| Neural Net | Koopa🧑‍🔧[^29] || | | | | `2023 - NeurIPS` |
Expand Down Expand Up @@ -509,3 +510,6 @@ Time-Series.AI</a>
[^41]: Xu, Z., Zeng, A., & Xu, Q. (2024).
[FITS: Modeling Time Series with 10k parameters](https://openreview.net/forum?id=bWcnvZ3qMb).
*ICLR 2024*.
[^42]: Lin, S., Lin, W., Wu, W., Zhao, F., Mo, R., & Zhang, H. (2023).
[Segrnn: Segment recurrent neural network for long-term time series forecasting](https://github.com/lss-1138/SegRNN)
*arXiv 2023*.
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .imputeformer import ImputeFormer
from .timemixer import TimeMixer
from .moderntcn import ModernTCN
from .segrnn import SegRNN

# naive imputation methods
from .locf import LOCF
Expand Down Expand Up @@ -87,4 +88,5 @@
"Lerp",
"TEFN",
"CSAI",
"SegRNN",
]
24 changes: 24 additions & 0 deletions pypots/imputation/segrnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
The package including the modules of SegRNN.
Refer to the paper
`Lin, Shengsheng and Lin, Weiwei and Wu, Wentai and Zhao, Feiyu and Mo, Ruichao and Zhang, Haotong.
Segrnn: Segment recurrent neural network for long-term time series forecasting.
arXiv preprint arXiv:2308.11200.
<https://arxiv.org/abs/2308.11200>`_
Notes
-----
This implementation is inspired by the official one https://github.com/lss-1138/SegRNN
"""

# Created by Shengsheng Lin



from .model import SegRNN

__all__ = [
"SegRNN",
]
59 changes: 59 additions & 0 deletions pypots/imputation/segrnn/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
The core wrapper assembles the submodules of SegRNN imputation model
and takes over the forward progress of the algorithm.
"""

# Created by Shengsheng Lin

from typing import Optional

from typing import Callable
import torch.nn as nn

from ...nn.modules.segrnn import BackboneSegRNN
from ...nn.modules.saits import SaitsLoss

class _SegRNN(nn.Module):
def __init__(
self,
n_steps: int,
n_features: int,
seg_len: int = 24,
d_model: int = 512,
dropout: float = 0.5,
ORT_weight: float = 1,
MIT_weight: float = 1,
):
super().__init__()

self.n_steps = n_steps
self.n_features = n_features
self.seg_len = seg_len
self.d_model = d_model
self.dropout = dropout

self.backbone = BackboneSegRNN(n_steps, n_features, seg_len, d_model, dropout)

# apply SAITS loss function to Transformer on the imputation task
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict, training: bool = True) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

reconstruction = self.backbone(X)

imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction
results = {
"imputed_data": imputed_data,
}

# if in training mode, return results with losses
if training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask)
results["ORT_loss"] = ORT_loss
results["MIT_loss"] = MIT_loss
# `loss` is always the item for backward propagating to update the model
results["loss"] = loss

return results
21 changes: 21 additions & 0 deletions pypots/imputation/segrnn/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
Dataset class for the imputation model SegRNN.
"""

# Created by Shengsheng lin

from typing import Union

from pypots.imputation.saits.data import DatasetForSAITS


class DatasetForSegRNN(DatasetForSAITS):
def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_y: bool,
file_type: str = "hdf5",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_y, file_type, rate)
Loading

0 comments on commit 366a584

Please sign in to comment.