Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu authored Dec 12, 2023
2 parents 35c2aa0 + be8ee4d commit 62def0b
Show file tree
Hide file tree
Showing 60 changed files with 825 additions and 204 deletions.
7 changes: 7 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,10 @@ @inproceedings{wu2023timesnet
year={2023},
url={https://openreview.net/forum?id=ju_Uqw384Oq}
}

@inproceedings{liu2022nonstationary,
title={Non-stationary Transformers: Exploring the Stationarity in Time Series Forecasting},
author={Liu, Yong and Wu, Haixu and Wang, Jianmin and Long, Mingsheng},
booktitle={Advances in Neural Information Processing Systems},
year={2022}
}
2 changes: 1 addition & 1 deletion pypots/classification/grud/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch.nn as nn
import torch.nn.functional as F

from ....modules.rnn import TemporalDecay
from ....nn.modules.rnn import TemporalDecay


class _GRUD(nn.Module):
Expand Down
6 changes: 3 additions & 3 deletions pypots/clustering/crli/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from sklearn.cluster import KMeans

from .submodules import Generator, Decoder, Discriminator
from ....utils.metrics import cal_mse
from ....utils.metrics import calc_mse


class _CRLI(nn.Module):
Expand Down Expand Up @@ -89,8 +89,8 @@ def forward(
l_G = F.binary_cross_entropy_with_logits(
inputs["discrimination"], 1 - missing_mask, weight=1 - missing_mask
)
l_pre = cal_mse(inputs["imputation_latent"], X, missing_mask)
l_rec = cal_mse(inputs["reconstruction"], X, missing_mask)
l_pre = calc_mse(inputs["imputation_latent"], X, missing_mask)
l_rec = calc_mse(inputs["reconstruction"], X, missing_mask)
HTH = torch.matmul(inputs["fcn_latent"], inputs["fcn_latent"].permute(1, 0))

if (
Expand Down
4 changes: 2 additions & 2 deletions pypots/clustering/vader/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
PeepholeLSTMCell,
ImplicitImputation,
)
from ....utils.metrics import cal_mse
from ....utils.metrics import calc_mse


class _VaDER(nn.Module):
Expand Down Expand Up @@ -184,7 +184,7 @@ def forward(
}

# calculate the reconstruction loss
unscaled_reconstruction_loss = cal_mse(X_reconstructed, X, missing_mask)
unscaled_reconstruction_loss = calc_mse(X_reconstructed, X, missing_mask)
reconstruction_loss = (
unscaled_reconstruction_loss
* self.n_steps
Expand Down
4 changes: 2 additions & 2 deletions pypots/data/load_specific_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def load_specific_dataset(dataset_name: str, use_cache: bool = True) -> dict:
"""
logger.info(
f"Loading the dataset {dataset_name} with TSDB (https://github.com/WenjieDu/Time_Series_Database)..."
f"Loading the dataset {dataset_name} with TSDB (https://github.com/WenjieDu/Time_Series_Data_Beans)..."
)
assert dataset_name in SUPPORTED_DATASETS, (
f"Dataset {dataset_name} is not supported. "
Expand All @@ -66,6 +66,6 @@ def load_specific_dataset(dataset_name: str, use_cache: bool = True) -> dict:
f"https://github.com/WenjieDu/PyPOTS/issues"
)
logger.info(f"Starting preprocessing {dataset_name}...")
data = tsdb.load_dataset(dataset_name, use_cache)
data = tsdb.load(dataset_name, use_cache)
data = PREPROCESSING_FUNC[dataset_name](data)
return data
4 changes: 2 additions & 2 deletions pypots/imputation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from ..base import BaseModel, BaseNNModel
from ..utils.logging import logger
from ..utils.metrics import cal_mae
from ..utils.metrics import calc_mse

try:
import nni
Expand Down Expand Up @@ -299,7 +299,7 @@ def _train_model(
imputation_collector = torch.cat(imputation_collector)
imputation_collector = imputation_collector.cpu().detach().numpy()

mean_val_loss = cal_mae(
mean_val_loss = calc_mse(
imputation_collector,
val_loader.dataset.data["X_intact"],
val_loader.dataset.data["indicating_mask"],
Expand Down
10 changes: 5 additions & 5 deletions pypots/imputation/brits/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import torch.nn as nn

from .submodules import FeatureRegression
from ....modules.rnn import TemporalDecay
from ....utils.metrics import cal_mae
from ....nn.modules.rnn import TemporalDecay
from ....utils.metrics import calc_mae


class RITS(nn.Module):
Expand Down Expand Up @@ -150,17 +150,17 @@ def impute(

hidden_states = hidden_states * gamma_h # decay hidden states
x_h = self.hist_reg(hidden_states)
reconstruction_loss += cal_mae(x_h, x, m)
reconstruction_loss += calc_mae(x_h, x, m)

x_c = m * x + (1 - m) * x_h

z_h = self.feat_reg(x_c)
reconstruction_loss += cal_mae(z_h, x, m)
reconstruction_loss += calc_mae(z_h, x, m)

alpha = torch.sigmoid(self.combining_weight(torch.cat([gamma_x, m], dim=1)))

c_h = alpha * z_h + (1 - alpha) * x_h
reconstruction_loss += cal_mae(c_h, x, m)
reconstruction_loss += calc_mae(c_h, x, m)

c_c = m * x + (1 - m) * c_h
estimations.append(c_h.unsqueeze(dim=1))
Expand Down
48 changes: 24 additions & 24 deletions pypots/imputation/csdi/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,6 @@ def __init__(
rate: float = 0.1,
):
super().__init__(data, return_labels, file_type)
self.time_points = (
None if "time_points" not in data.keys() else data["time_points"]
)
# _, self.time_points = self._check_input(self.X, time_points)
self.for_pattern_mask = (
None if "for_pattern_mask" not in data.keys() else data["for_pattern_mask"]
)
# _, self.for_pattern_mask = self._check_input(self.X, for_pattern_mask)
self.cut_length = (
None if "cut_length" not in data.keys() else data["cut_length"]
)
# _, self.cut_length = self._check_input(self.X, cut_length)
self.rate = rate

def _fetch_data_from_array(self, idx: int) -> Iterable:
Expand Down Expand Up @@ -69,21 +57,23 @@ def _fetch_data_from_array(self, idx: int) -> Iterable:
X = self.X[idx].to(torch.float32)
X_intact, X, missing_mask, indicating_mask = mcar(X, p=self.rate)

observed_data = X_intact # i.e. originally observed data
observed_mask = missing_mask + indicating_mask # i.e. originally missing masks
observed_data = X_intact
observed_mask = missing_mask + indicating_mask
gt_mask = missing_mask
observed_tp = (
torch.arange(0, self.n_steps, dtype=torch.float32)
if self.time_points is None
else self.time_points[idx].to(torch.float32)
if "time_points" not in self.data.keys()
else torch.from_numpy(self.data["time_points"][idx]).to(torch.float32)
)
gt_mask = missing_mask # missing mask with ground truth masked for validation
for_pattern_mask = (
gt_mask if self.for_pattern_mask is None else self.for_pattern_mask[idx]
gt_mask
if "for_pattern_mask" not in self.data.keys()
else torch.from_numpy(self.data["for_pattern_mask"][idx]).to(torch.float32)
)
cut_length = (
torch.zeros(len(observed_data)).long()
if self.cut_length is None
else self.cut_length[idx]
if "cut_length" not in self.data.keys()
else torch.from_numpy(self.data["cut_length"][idx]).to(torch.float32)
)

sample = [
Expand Down Expand Up @@ -124,15 +114,25 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:

observed_data = X_intact
observed_mask = missing_mask + indicating_mask
observed_tp = self.time_points[idx].to(torch.float32)
gt_mask = indicating_mask
observed_tp = (
torch.arange(0, self.n_steps, dtype=torch.float32)
if "time_points" not in self.file_handle.keys()
else torch.from_numpy(self.file_handle["time_points"][idx]).to(
torch.float32
)
)
for_pattern_mask = (
gt_mask if self.for_pattern_mask is None else self.for_pattern_mask[idx]
gt_mask
if "for_pattern_mask" not in self.file_handle.keys()
else torch.from_numpy(self.file_handle["for_pattern_mask"][idx]).to(
torch.float32
)
)
cut_length = (
torch.zeros(len(observed_data)).long()
if self.cut_length is None
else self.cut_length[idx]
if "cut_length" not in self.file_handle.keys()
else torch.from_numpy(self.file_handle["cut_length"][idx]).to(torch.float32)
)

sample = [
Expand Down
113 changes: 113 additions & 0 deletions pypots/imputation/csdi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,19 @@
# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

import os
from typing import Union, Optional

import h5py
import numpy as np
import torch
from torch.utils.data import DataLoader

try:
import nni
except ImportError:
pass

from .data import DatasetForCSDI
from .modules import _CSDI
from ..base import BaseNNImputer
Expand Down Expand Up @@ -205,6 +211,113 @@ def _assemble_input_for_validating(self, data) -> dict:
def _assemble_input_for_testing(self, data) -> dict:
return self._assemble_input_for_validating(data)

def _train_model(
self,
training_loader: DataLoader,
val_loader: DataLoader = None,
) -> None:
# each training starts from the very beginning, so reset the loss and model dict here
self.best_loss = float("inf")
self.best_model_dict = None

try:
training_step = 0
for epoch in range(self.epochs):
self.model.train()
epoch_train_loss_collector = []
for idx, data in enumerate(training_loader):
training_step += 1
inputs = self._assemble_input_for_training(data)
self.optimizer.zero_grad()
results = self.model.forward(inputs)
# use sum() before backward() in case of multi-gpu training
results["loss"].sum().backward()
self.optimizer.step()
epoch_train_loss_collector.append(results["loss"].sum().item())

# save training loss logs into the tensorboard file for every step if in need
if self.summary_writer is not None:
self._save_log_into_tb_file(training_step, "training", results)

# mean training loss of the current epoch
mean_train_loss = np.mean(epoch_train_loss_collector)

if val_loader is not None:
self.model.eval()
val_loss_collector = []
with torch.no_grad():
for idx, data in enumerate(val_loader):
inputs = self._assemble_input_for_validating(data)
results = self.model.forward(
inputs, training=False, n_sampling_times=0
)
val_loss_collector.append(results["loss"].item())

mean_val_loss = np.asarray(val_loss_collector).mean()

# save validating loss logs into the tensorboard file for every epoch if in need
if self.summary_writer is not None:
val_loss_dict = {
"imputation_loss": mean_val_loss,
}
self._save_log_into_tb_file(epoch, "validating", val_loss_dict)

logger.info(
f"Epoch {epoch} - "
f"training loss: {mean_train_loss:.4f}, "
f"validating loss: {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
logger.info(f"Epoch {epoch} - training loss: {mean_train_loss:.4f}")
mean_loss = mean_train_loss

if np.isnan(mean_loss):
logger.warning(
f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors."
)

if mean_loss < self.best_loss:
self.best_loss = mean_loss
self.best_model_dict = self.model.state_dict()
self.patience = self.original_patience
# save the model if necessary
self._auto_save_model_if_necessary(
training_finished=False,
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}",
)
else:
self.patience -= 1

if os.getenv("enable_tuning", False):
nni.report_intermediate_result(mean_loss)
if epoch == self.epochs - 1 or self.patience == 0:
nni.report_final_result(self.best_loss)

if self.patience == 0:
logger.info(
"Exceeded the training patience. Terminating the training procedure..."
)
break

except Exception as e:
logger.error(f"Exception: {e}")
if self.best_model_dict is None:
raise RuntimeError(
"Training got interrupted. Model was not trained. Please investigate the error printed above."
)
else:
RuntimeWarning(
"Training got interrupted. Please investigate the error printed above.\n"
"Model got trained and will load the best checkpoint so far for testing.\n"
"If you don't want it, please try fit() again."
)

if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")

logger.info("Finished training.")

def fit(
self,
train_set: Union[dict, str],
Expand Down
14 changes: 9 additions & 5 deletions pypots/imputation/csdi/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,16 @@ def forward(self, inputs, training=True, n_sampling_times=1):
loss = loss_func(observed_data, cond_mask, observed_mask, side_info, training)
results = {"loss": loss}

if not training:
if not training and n_sampling_times > 0:
samples = self.impute(
observed_data, cond_mask, side_info, n_sampling_times
) # (B,bz,K,L)
imputation = samples.mean(dim=1) # (B,K,L)
imputed_data = observed_data + imputation * (1 - gt_mask)
results["imputed_data"] = imputed_data.permute(0, 2, 1) # (B,L,K)
) # (bz,n_sampling,K,L)
repeated_obs = observed_data.unsqueeze(1).repeat(1, n_sampling_times, 1, 1)
repeated_mask = gt_mask.unsqueeze(1).repeat(1, n_sampling_times, 1, 1)
imputed_data = repeated_obs + samples * (1 - repeated_mask)

results["imputed_data"] = imputed_data.permute(
0, 1, 3, 2
) # (bz,n_sampling,L,K)

return results
4 changes: 2 additions & 2 deletions pypots/imputation/mrnn/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.nn as nn

from .submodules import FCN_Regression
from ....utils.metrics import cal_rmse
from ....utils.metrics import calc_rmse


class _MRNN(nn.Module):
Expand Down Expand Up @@ -74,7 +74,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
FCN_estimation = self.fcn_regression(
x, m, RNN_imputed_data
) # FCN estimation is output estimation
reconstruction_loss += cal_rmse(FCN_estimation, x, m) + cal_rmse(
reconstruction_loss += calc_rmse(FCN_estimation, x, m) + calc_rmse(
RNN_estimation, x, m
)
estimations.append(FCN_estimation.unsqueeze(dim=1))
Expand Down
4 changes: 2 additions & 2 deletions pypots/imputation/saits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ...optim.adam import Adam
from ...optim.base import Optimizer
from ...utils.logging import logger
from ...utils.metrics import cal_mae
from ...utils.metrics import calc_mae


class SAITS(BaseNNImputer):
Expand Down Expand Up @@ -145,7 +145,7 @@ def __init__(
batch_size: int = 32,
epochs: int = 100,
patience: Optional[int] = None,
customized_loss_func: Callable = cal_mae,
customized_loss_func: Callable = calc_mae,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand Down
Loading

0 comments on commit 62def0b

Please sign in to comment.