Skip to content

Commit

Permalink
Merge pull request #505 from ztxtech/main
Browse files Browse the repository at this point in the history
Add TEFN model
  • Loading branch information
WenjieDu authored Sep 6, 2024
2 parents 66da59c + a8ad2df commit b9cba18
Show file tree
Hide file tree
Showing 12 changed files with 993 additions and 181 deletions.
260 changes: 185 additions & 75 deletions README.md

Large diffs are not rendered by default.

310 changes: 205 additions & 105 deletions README_zh.md

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -763,3 +763,10 @@ @article{bai2018tcn
journal={arXiv preprint arXiv:1803.01271},
year={2018}
}

@article{zhan2024tefn,
title={Time Evidence Fusion Network: Multi-source View in Long-Term Time Series Forecasting},
author={Zhan, Tianxiang and He, Yuanpeng and Li, Zhen and Deng, Yong},
journal={arXiv preprint arXiv:2405.06419},
year={2024}
}
3 changes: 2 additions & 1 deletion pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

# neural network imputation methods
from .brits import BRITS
from .csdi import CSDI
from .gpvae import GPVAE
Expand Down Expand Up @@ -44,6 +43,7 @@
from .mean import Mean
from .median import Median
from .lerp import Lerp
from .tefn import TEFN

__all__ = [
# neural network imputation methods
Expand Down Expand Up @@ -84,4 +84,5 @@
"Mean",
"Median",
"Lerp",
"TEFN"
]
24 changes: 24 additions & 0 deletions pypots/imputation/tefn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
The package of the forecasting model TEFN.
Refer to the paper
`Tianxiang Zhan, Yuanpeng He, Yong Deng, and Zhen Li.
Time Evidence Fusion Network: Multi-source View in Long-Term Time Series Forecasting.
In Arxiv, 2024.
<https://arxiv.org/abs/2405.06419>`_
Notes
-----
This implementation is transfered from the official one https://github.com/ztxtech/Time-Evidence-Fusion-Network
"""

# Created by Tianxiang Zhan <[email protected]>
# License: BSD-3-Clause


from .model import TEFN

__all__ = [
"TEFN",
]
59 changes: 59 additions & 0 deletions pypots/imputation/tefn/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
"""

# Created by Tianxiang Zhan <[email protected]>
# License: BSD-3-Clause

import torch.nn as nn

from ...nn.functional import nonstationary_norm, nonstationary_denorm
from ...nn.modules.tefn import BackboneTEFN
from ...utils.metrics import calc_mse


class _TEFN(nn.Module):
def __init__(
self,
n_steps,
n_features,
n_fod,
apply_nonstationary_norm,
):
super().__init__()

self.seq_len = n_steps
self.n_fod = n_fod
self.apply_nonstationary_norm = apply_nonstationary_norm

self.model = BackboneTEFN(
n_steps,
n_features,
n_fod,
)

def forward(self, inputs: dict, training: bool = True) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

if self.apply_nonstationary_norm:
# Normalization from Non-stationary Transformer
X, means, stdev = nonstationary_norm(X, missing_mask)

# TEFN processing
out = self.model(X)

if self.apply_nonstationary_norm:
# De-Normalization from Non-stationary Transformer
out = nonstationary_denorm(out, means, stdev)

imputed_data = missing_mask * X + (1 - missing_mask) * out
results = {
"imputed_data": imputed_data,
}

if training:
# `loss` is always the item for backward propagating to update the model
loss = calc_mse(out, inputs["X_ori"], inputs["indicating_mask"])
results["loss"] = loss

return results
24 changes: 24 additions & 0 deletions pypots/imputation/tefn/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Dataset class for the imputation model TEFN.
"""

# Created by Tianxiang Zhan <[email protected]>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForTEFN(DatasetForSAITS):
"""Actually TEFN uses the same data strategy as SAITS, needs MIT for training."""

def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_y: bool,
file_type: str = "hdf5",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_y, file_type, rate)
Loading

0 comments on commit b9cba18

Please sign in to comment.