Skip to content

Commit

Permalink
Merge pull request #467 from WenjieDu/(feat)add_tcn
Browse files Browse the repository at this point in the history
Add TCN as an imputation model
  • Loading branch information
WenjieDu authored Jul 20, 2024
2 parents 38fa5a6 + e6393ef commit cc539e5
Show file tree
Hide file tree
Showing 10 changed files with 723 additions and 0 deletions.
7 changes: 7 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -756,3 +756,10 @@ @article{nie2024imputeformer
doi = {10.1145/3637528.3671751},
url = {https://doi.org/10.1145/3637528.3671751},
}

@article{bai2018tcn,
title={An empirical evaluation of generic convolutional and recurrent networks for sequence modeling},
author={Bai, Shaojie and Kolter, J Zico and Koltun, Vladlen},
journal={arXiv preprint arXiv:1803.01271},
year={2018}
}
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .crossformer import Crossformer
from .informer import Informer
from .autoformer import Autoformer
from .tcn import TCN
from .reformer import Reformer
from .dlinear import DLinear
from .patchtst import PatchTST
Expand Down Expand Up @@ -57,6 +58,7 @@
"DLinear",
"Informer",
"Autoformer",
"TCN",
"Reformer",
"NonstationaryTransformer",
"Pyraformer",
Expand Down
24 changes: 24 additions & 0 deletions pypots/imputation/tcn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
The package of the partially-observed time-series imputation model TCN.
Refer to the paper
`Shaojie Bai, J. Zico Kolter, and Vladlen Koltun.
"An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling".
arXiv preprint arXiv:1803.01271.
<https://arxiv.org/pdf/1803.01271>`_
Notes
-----
This implementation is inspired by the official one https://github.com/locuslab/TCN
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause


from .model import TCN

__all__ = [
"TCN",
]
81 changes: 81 additions & 0 deletions pypots/imputation/tcn/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
The core wrapper assembles the submodules of TCN imputation model
and takes over the forward progress of the algorithm.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

import torch.nn as nn

from ...nn.modules.saits import SaitsLoss, SaitsEmbedding
from ...nn.modules.tcn import BackboneTCN


class _TCN(nn.Module):
def __init__(
self,
n_steps: int,
n_features: int,
n_levels: int,
d_hidden: int,
kernel_size: int,
dropout: float = 0,
ORT_weight: float = 1,
MIT_weight: float = 1,
):
super().__init__()
self.n_steps = n_steps
channel_sizes = [d_hidden] * n_levels

self.saits_embedding = SaitsEmbedding(
n_features * 2,
n_features,
with_pos=False,
dropout=dropout,
)
self.backbone = BackboneTCN(
n_features,
channel_sizes,
kernel_size,
dropout,
)

# for the imputation task, the output dim is the same as input dim
self.output_projection = nn.Linear(channel_sizes[-1], n_features)
self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight)

def forward(self, inputs: dict, training: bool = True) -> dict:
X, missing_mask = inputs["X"], inputs["missing_mask"]

# WDU: the original TCN paper isn't proposed for imputation task. Hence the model doesn't take
# the missing mask into account, which means, in the process, the model doesn't know which part of
# the input data is missing, and this may hurt the model's imputation performance. Therefore, I apply the
# SAITS embedding method to project the concatenation of features and masks into a hidden space, as well as
# the output layers to project back from the hidden space to the original space.
enc_out = self.saits_embedding(X, missing_mask)
enc_out = enc_out.permute(0, 2, 1)

# TCN encoder processing
enc_out = self.backbone(enc_out)
enc_out = enc_out.permute(0, 2, 1)
# project back the original data space
reconstruction = self.output_projection(enc_out)

imputed_data = missing_mask * X + (1 - missing_mask) * reconstruction
results = {
"imputed_data": imputed_data,
}

# if in training mode, return results with losses
if training:
X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"]
loss, ORT_loss, MIT_loss = self.saits_loss_func(
reconstruction, X_ori, missing_mask, indicating_mask
)
results["ORT_loss"] = ORT_loss
results["MIT_loss"] = MIT_loss
# `loss` is always the item for backward propagating to update the model
results["loss"] = loss

return results
24 changes: 24 additions & 0 deletions pypots/imputation/tcn/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Dataset class for TCN.
"""

# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForTCN(DatasetForSAITS):
"""Actually TCN uses the same data strategy as SAITS, needs MIT for training."""

def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_y: bool,
file_type: str = "hdf5",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_y, file_type, rate)
Loading

0 comments on commit cc539e5

Please sign in to comment.