From 6dad10574b7d1caa6b3fd7e15d089dd8bc43d046 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 6 Dec 2023 21:37:26 +0800 Subject: [PATCH 01/16] refactor: using only one GPU during testing if there are more than one available; --- tests/global_test_config.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/global_test_config.py b/tests/global_test_config.py index 3258ab6d..6d307427 100644 --- a/tests/global_test_config.py +++ b/tests/global_test_config.py @@ -7,6 +7,7 @@ import os +import numpy as np import torch from pypots.data.generating import gene_random_walk @@ -33,10 +34,11 @@ # set DEVICES to None if no cuda device is available, to avoid initialization failed while importing test classes -cuda_devices = [torch.device(i) for i in range(torch.cuda.device_count())] -if len(cuda_devices) > 2: +n_cuda_devices = torch.cuda.device_count() +cuda_devices = [torch.device(i) for i in range(n_cuda_devices)] +if n_cuda_devices > 1: logger.info("❗️Detected multiple cuda devices, using all of them to run testing.") - DEVICE = cuda_devices + DEVICE = cuda_devices[np.random.randint(n_cuda_devices)] else: # if having no multiple cuda devices, leave it as None to use the default device DEVICE = None From f403b3edc3b4845b021a8aed10778a54e9dfd096 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 6 Dec 2023 21:42:46 +0800 Subject: [PATCH 02/16] refactor: rewrite some comments; --- pypots/data/load_specific_datasets.py | 2 +- pypots/imputation/csdi/data.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pypots/data/load_specific_datasets.py b/pypots/data/load_specific_datasets.py index 174b148c..c43864cd 100644 --- a/pypots/data/load_specific_datasets.py +++ b/pypots/data/load_specific_datasets.py @@ -57,7 +57,7 @@ def load_specific_dataset(dataset_name: str, use_cache: bool = True) -> dict: """ logger.info( - f"Loading the dataset {dataset_name} with TSDB (https://github.com/WenjieDu/Time_Series_Database)..." + f"Loading the dataset {dataset_name} with TSDB (https://github.com/WenjieDu/Time_Series_Data_Beans)..." ) assert dataset_name in SUPPORTED_DATASETS, ( f"Dataset {dataset_name} is not supported. " diff --git a/pypots/imputation/csdi/data.py b/pypots/imputation/csdi/data.py index b2798ca0..e0cfc894 100644 --- a/pypots/imputation/csdi/data.py +++ b/pypots/imputation/csdi/data.py @@ -69,14 +69,14 @@ def _fetch_data_from_array(self, idx: int) -> Iterable: X = self.X[idx].to(torch.float32) X_intact, X, missing_mask, indicating_mask = mcar(X, p=self.rate) - observed_data = X_intact # i.e. originally observed data - observed_mask = missing_mask + indicating_mask # i.e. originally missing masks + observed_data = X_intact + observed_mask = missing_mask + indicating_mask observed_tp = ( torch.arange(0, self.n_steps, dtype=torch.float32) if self.time_points is None else self.time_points[idx].to(torch.float32) ) - gt_mask = missing_mask # missing mask with ground truth masked for validation + gt_mask = missing_mask for_pattern_mask = ( gt_mask if self.for_pattern_mask is None else self.for_pattern_mask[idx] ) From 6de9d838560ca3b83f59d2e3a62dca5d2aea16d3 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 7 Dec 2023 16:04:10 +0800 Subject: [PATCH 03/16] fix: return all imputation samples rather than averaged over sampling times; --- pypots/imputation/csdi/model.py | 121 +++++++++++++++++++++++++ pypots/imputation/csdi/modules/core.py | 12 ++- 2 files changed, 129 insertions(+), 4 deletions(-) diff --git a/pypots/imputation/csdi/model.py b/pypots/imputation/csdi/model.py index 70bf7734..71b76918 100644 --- a/pypots/imputation/csdi/model.py +++ b/pypots/imputation/csdi/model.py @@ -13,6 +13,7 @@ # Created by Wenjie Du # License: BSD-3-Clause +import os from typing import Union, Optional import h5py @@ -20,12 +21,18 @@ import torch from torch.utils.data import DataLoader +try: + import nni +except ImportError: + pass + from .data import DatasetForCSDI from .modules import _CSDI from ..base import BaseNNImputer from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger +from ...utils.metrics import cal_mse class CSDI(BaseNNImputer): @@ -205,6 +212,120 @@ def _assemble_input_for_validating(self, data) -> dict: def _assemble_input_for_testing(self, data) -> dict: return self._assemble_input_for_validating(data) + def _train_model( + self, + training_loader: DataLoader, + val_loader: DataLoader = None, + ) -> None: + # each training starts from the very beginning, so reset the loss and model dict here + self.best_loss = float("inf") + self.best_model_dict = None + + try: + training_step = 0 + for epoch in range(self.epochs): + self.model.train() + epoch_train_loss_collector = [] + for idx, data in enumerate(training_loader): + training_step += 1 + inputs = self._assemble_input_for_training(data) + self.optimizer.zero_grad() + results = self.model.forward(inputs) + # use sum() before backward() in case of multi-gpu training + results["loss"].sum().backward() + self.optimizer.step() + epoch_train_loss_collector.append(results["loss"].sum().item()) + + # save training loss logs into the tensorboard file for every step if in need + if self.summary_writer is not None: + self._save_log_into_tb_file(training_step, "training", results) + + # mean training loss of the current epoch + mean_train_loss = np.mean(epoch_train_loss_collector) + + if val_loader is not None: + self.model.eval() + imputation_collector = [] + with torch.no_grad(): + for idx, data in enumerate(val_loader): + inputs = self._assemble_input_for_validating(data) + results = self.model.forward(inputs, training=False) + imputed_data = results["imputed_data"].mean(axis=1) + imputation_collector.append(imputed_data) + + imputation_collector = torch.cat(imputation_collector) + imputation_collector = imputation_collector.cpu().detach().numpy() + + mean_val_loss = cal_mse( + imputation_collector, + val_loader.dataset.data["X_intact"], + val_loader.dataset.data["indicating_mask"], + # the above val_loader.dataset.data is a dict containing the validation dataset + ) + + # save validating loss logs into the tensorboard file for every epoch if in need + if self.summary_writer is not None: + val_loss_dict = { + "imputation_loss": mean_val_loss, + } + self._save_log_into_tb_file(epoch, "validating", val_loss_dict) + + logger.info( + f"Epoch {epoch} - " + f"training loss: {mean_train_loss:.4f}, " + f"validating loss: {mean_val_loss:.4f}" + ) + mean_loss = mean_val_loss + else: + logger.info(f"Epoch {epoch} - training loss: {mean_train_loss:.4f}") + mean_loss = mean_train_loss + + if np.isnan(mean_loss): + logger.warning( + f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors." + ) + + if mean_loss < self.best_loss: + self.best_loss = mean_loss + self.best_model_dict = self.model.state_dict() + self.patience = self.original_patience + # save the model if necessary + self._auto_save_model_if_necessary( + training_finished=False, + saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", + ) + else: + self.patience -= 1 + + if os.getenv("enable_tuning", False): + nni.report_intermediate_result(mean_loss) + if epoch == self.epochs - 1 or self.patience == 0: + nni.report_final_result(self.best_loss) + + if self.patience == 0: + logger.info( + "Exceeded the training patience. Terminating the training procedure..." + ) + break + + except Exception as e: + logger.error(f"Exception: {e}") + if self.best_model_dict is None: + raise RuntimeError( + "Training got interrupted. Model was not trained. Please investigate the error printed above." + ) + else: + RuntimeWarning( + "Training got interrupted. Please investigate the error printed above.\n" + "Model got trained and will load the best checkpoint so far for testing.\n" + "If you don't want it, please try fit() again." + ) + + if np.isnan(self.best_loss): + raise ValueError("Something is wrong. best_loss is Nan after training.") + + logger.info("Finished training.") + def fit( self, train_set: Union[dict, str], diff --git a/pypots/imputation/csdi/modules/core.py b/pypots/imputation/csdi/modules/core.py index 958fb65d..4f0ab608 100644 --- a/pypots/imputation/csdi/modules/core.py +++ b/pypots/imputation/csdi/modules/core.py @@ -262,9 +262,13 @@ def forward(self, inputs, training=True, n_sampling_times=1): if not training: samples = self.impute( observed_data, cond_mask, side_info, n_sampling_times - ) # (B,bz,K,L) - imputation = samples.mean(dim=1) # (B,K,L) - imputed_data = observed_data + imputation * (1 - gt_mask) - results["imputed_data"] = imputed_data.permute(0, 2, 1) # (B,L,K) + ) # (bz,n_sampling,K,L) + repeated_obs = observed_data.unsqueeze(1).repeat(1, n_sampling_times, 1, 1) + repeated_mask = gt_mask.unsqueeze(1).repeat(1, n_sampling_times, 1, 1) + imputed_data = repeated_obs + samples * (1 - repeated_mask) + + results["imputed_data"] = imputed_data.permute( + 0, 1, 3, 2 + ) # (bz,n_sampling,K,L) return results From 856706cf161ee57cd61d0068c98e23f4e4747ba8 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 7 Dec 2023 16:06:45 +0800 Subject: [PATCH 04/16] fix: update testing case; --- tests/global_test_config.py | 4 +++- tests/imputation/csdi.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/global_test_config.py b/tests/global_test_config.py index 6d307427..d9dd3b27 100644 --- a/tests/global_test_config.py +++ b/tests/global_test_config.py @@ -37,8 +37,10 @@ n_cuda_devices = torch.cuda.device_count() cuda_devices = [torch.device(i) for i in range(n_cuda_devices)] if n_cuda_devices > 1: - logger.info("❗️Detected multiple cuda devices, using all of them to run testing.") DEVICE = cuda_devices[np.random.randint(n_cuda_devices)] + logger.info( + f"❗️Detected multiple cuda devices, using one of them {DEVICE} to run testing." + ) else: # if having no multiple cuda devices, leave it as None to use the default device DEVICE = None diff --git a/tests/imputation/csdi.py b/tests/imputation/csdi.py index 3bfcd888..a5d3a73f 100644 --- a/tests/imputation/csdi.py +++ b/tests/imputation/csdi.py @@ -63,6 +63,7 @@ def test_0_fit(self): @pytest.mark.xdist_group(name="imputation-csdi") def test_1_impute(self): imputed_X = self.csdi.predict(TEST_SET)["imputation"] + imputed_X = imputed_X.mean(axis=1) # mean over sampling times assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." From 58a224a892feb954bd5bd687237550478b3a0ab2 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 7 Dec 2023 16:12:05 +0800 Subject: [PATCH 05/16] feat: use mse rather than mae as default to evaluate imputation models during validation; --- pypots/imputation/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pypots/imputation/base.py b/pypots/imputation/base.py index 87edd423..27e2c063 100644 --- a/pypots/imputation/base.py +++ b/pypots/imputation/base.py @@ -16,7 +16,7 @@ from ..base import BaseModel, BaseNNModel from ..utils.logging import logger -from ..utils.metrics import cal_mae +from ..utils.metrics import cal_mse try: import nni @@ -299,7 +299,7 @@ def _train_model( imputation_collector = torch.cat(imputation_collector) imputation_collector = imputation_collector.cpu().detach().numpy() - mean_val_loss = cal_mae( + mean_val_loss = cal_mse( imputation_collector, val_loader.dataset.data["X_intact"], val_loader.dataset.data["indicating_mask"], From dbfa00a045c4ef0f9201605c4adb428a0fc4fa2b Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 7 Dec 2023 16:13:17 +0800 Subject: [PATCH 06/16] fix: use updated tsdb dataset load func; --- pypots/data/load_specific_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pypots/data/load_specific_datasets.py b/pypots/data/load_specific_datasets.py index c43864cd..9db37175 100644 --- a/pypots/data/load_specific_datasets.py +++ b/pypots/data/load_specific_datasets.py @@ -66,6 +66,6 @@ def load_specific_dataset(dataset_name: str, use_cache: bool = True) -> dict: f"https://github.com/WenjieDu/PyPOTS/issues" ) logger.info(f"Starting preprocessing {dataset_name}...") - data = tsdb.load_dataset(dataset_name, use_cache) + data = tsdb.load(dataset_name, use_cache) data = PREPROCESSING_FUNC[dataset_name](data) return data From 36e333703adb776cf93e79fa4474b4f28221b556 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Fri, 8 Dec 2023 13:47:34 +0800 Subject: [PATCH 07/16] feat: change the metric of CSDI validation stage back to the same as the original... using validation denoising loss; --- pypots/imputation/csdi/model.py | 22 +++++++--------------- pypots/imputation/csdi/modules/core.py | 4 ++-- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/pypots/imputation/csdi/model.py b/pypots/imputation/csdi/model.py index 71b76918..17de692e 100644 --- a/pypots/imputation/csdi/model.py +++ b/pypots/imputation/csdi/model.py @@ -32,7 +32,6 @@ from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger -from ...utils.metrics import cal_mse class CSDI(BaseNNImputer): @@ -245,23 +244,16 @@ def _train_model( if val_loader is not None: self.model.eval() - imputation_collector = [] + val_loss_collector = [] with torch.no_grad(): for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) - results = self.model.forward(inputs, training=False) - imputed_data = results["imputed_data"].mean(axis=1) - imputation_collector.append(imputed_data) - - imputation_collector = torch.cat(imputation_collector) - imputation_collector = imputation_collector.cpu().detach().numpy() - - mean_val_loss = cal_mse( - imputation_collector, - val_loader.dataset.data["X_intact"], - val_loader.dataset.data["indicating_mask"], - # the above val_loader.dataset.data is a dict containing the validation dataset - ) + results = self.model.forward( + inputs, training=False, n_sampling_times=0 + ) + val_loss_collector.append(results["loss"].item()) + + mean_val_loss = np.asarray(val_loss_collector).mean() # save validating loss logs into the tensorboard file for every epoch if in need if self.summary_writer is not None: diff --git a/pypots/imputation/csdi/modules/core.py b/pypots/imputation/csdi/modules/core.py index 4f0ab608..7c811539 100644 --- a/pypots/imputation/csdi/modules/core.py +++ b/pypots/imputation/csdi/modules/core.py @@ -259,7 +259,7 @@ def forward(self, inputs, training=True, n_sampling_times=1): loss = loss_func(observed_data, cond_mask, observed_mask, side_info, training) results = {"loss": loss} - if not training: + if not training and n_sampling_times > 0: samples = self.impute( observed_data, cond_mask, side_info, n_sampling_times ) # (bz,n_sampling,K,L) @@ -269,6 +269,6 @@ def forward(self, inputs, training=True, n_sampling_times=1): results["imputed_data"] = imputed_data.permute( 0, 1, 3, 2 - ) # (bz,n_sampling,K,L) + ) # (bz,n_sampling,L,K) return results From dee9acc25ef5f090082188ade5f6503cfa6cf753 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Fri, 8 Dec 2023 22:16:15 +0800 Subject: [PATCH 08/16] feat: add get_random_seed(); --- pypots/utils/random.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/pypots/utils/random.py b/pypots/utils/random.py index b3654752..9ec3af57 100644 --- a/pypots/utils/random.py +++ b/pypots/utils/random.py @@ -1,5 +1,5 @@ """ -Transformer model for time-series imputation. +PyPOTS util module about random seed setting. """ # Created by Wenjie Du @@ -7,6 +7,7 @@ import numpy as np import torch + from .logging import logger RANDOM_SEED = 2204 @@ -21,7 +22,19 @@ def set_random_seed(random_seed: int = RANDOM_SEED) -> None: The seed to be set for generating random numbers in PyPOTS. """ - - np.random.seed(RANDOM_SEED) + globals()["RANDOM_SEED"] = random_seed + np.random.seed(random_seed) torch.manual_seed(random_seed) logger.info(f"Have set the random seed as {random_seed} for numpy and pytorch.") + + +def get_random_seed() -> int: + """Get the random seed used in PyPOTS. + + Returns + ------- + random_seed : + The random seed used in PyPOTS. + + """ + return RANDOM_SEED From 8cd435f98c7726d1aa3f3517819d11dd24f5991b Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 9 Dec 2023 00:33:57 +0800 Subject: [PATCH 09/16] feat: add calc_quantile_crps() and rename metric func into calc_*; --- pypots/clustering/crli/modules/core.py | 6 +- pypots/clustering/vader/modules/core.py | 4 +- pypots/imputation/base.py | 4 +- pypots/imputation/brits/modules/core.py | 8 +- pypots/imputation/csdi/model.py | 4 +- pypots/imputation/mrnn/modules/core.py | 4 +- pypots/imputation/saits/model.py | 4 +- pypots/imputation/saits/modules/core.py | 4 +- pypots/imputation/transformer/modules/core.py | 6 +- pypots/utils/metrics/__init__.py | 65 ++++++++-- pypots/utils/metrics/classification.py | 55 ++++++-- pypots/utils/metrics/clustering.py | 90 ++++++++++--- pypots/utils/metrics/error.py | 119 +++++++++++++++--- tests/classification/brits.py | 4 +- tests/classification/grud.py | 4 +- tests/classification/raindrop.py | 4 +- tests/clustering/crli.py | 12 +- tests/clustering/vader.py | 8 +- tests/forecasting/bttf.py | 4 +- tests/imputation/brits.py | 4 +- tests/imputation/csdi.py | 9 +- tests/imputation/gpvae.py | 4 +- tests/imputation/locf.py | 14 +-- tests/imputation/mrnn.py | 4 +- tests/imputation/saits.py | 4 +- tests/imputation/timesnet.py | 4 +- tests/imputation/transformer.py | 4 +- tests/imputation/usgan.py | 4 +- tests/optim/adadelta.py | 4 +- tests/optim/adagrad.py | 4 +- tests/optim/adam.py | 4 +- tests/optim/adamw.py | 4 +- tests/optim/lr_schedulers.py | 16 +-- tests/optim/rmsprop.py | 4 +- tests/optim/sgd.py | 4 +- tests/utils/random.py | 12 +- 36 files changed, 377 insertions(+), 135 deletions(-) diff --git a/pypots/clustering/crli/modules/core.py b/pypots/clustering/crli/modules/core.py index a4c16a2a..8cbd45b0 100644 --- a/pypots/clustering/crli/modules/core.py +++ b/pypots/clustering/crli/modules/core.py @@ -18,7 +18,7 @@ from sklearn.cluster import KMeans from .submodules import Generator, Decoder, Discriminator -from ....utils.metrics import cal_mse +from ....utils.metrics import calc_mse class _CRLI(nn.Module): @@ -89,8 +89,8 @@ def forward( l_G = F.binary_cross_entropy_with_logits( inputs["discrimination"], 1 - missing_mask, weight=1 - missing_mask ) - l_pre = cal_mse(inputs["imputation_latent"], X, missing_mask) - l_rec = cal_mse(inputs["reconstruction"], X, missing_mask) + l_pre = calc_mse(inputs["imputation_latent"], X, missing_mask) + l_rec = calc_mse(inputs["reconstruction"], X, missing_mask) HTH = torch.matmul(inputs["fcn_latent"], inputs["fcn_latent"].permute(1, 0)) if ( diff --git a/pypots/clustering/vader/modules/core.py b/pypots/clustering/vader/modules/core.py index 8ff2f4ac..41c5019f 100644 --- a/pypots/clustering/vader/modules/core.py +++ b/pypots/clustering/vader/modules/core.py @@ -21,7 +21,7 @@ PeepholeLSTMCell, ImplicitImputation, ) -from ....utils.metrics import cal_mse +from ....utils.metrics import calc_mse class _VaDER(nn.Module): @@ -184,7 +184,7 @@ def forward( } # calculate the reconstruction loss - unscaled_reconstruction_loss = cal_mse(X_reconstructed, X, missing_mask) + unscaled_reconstruction_loss = calc_mse(X_reconstructed, X, missing_mask) reconstruction_loss = ( unscaled_reconstruction_loss * self.n_steps diff --git a/pypots/imputation/base.py b/pypots/imputation/base.py index 27e2c063..488d5d6f 100644 --- a/pypots/imputation/base.py +++ b/pypots/imputation/base.py @@ -16,7 +16,7 @@ from ..base import BaseModel, BaseNNModel from ..utils.logging import logger -from ..utils.metrics import cal_mse +from ..utils.metrics import calc_mse try: import nni @@ -299,7 +299,7 @@ def _train_model( imputation_collector = torch.cat(imputation_collector) imputation_collector = imputation_collector.cpu().detach().numpy() - mean_val_loss = cal_mse( + mean_val_loss = calc_mse( imputation_collector, val_loader.dataset.data["X_intact"], val_loader.dataset.data["indicating_mask"], diff --git a/pypots/imputation/brits/modules/core.py b/pypots/imputation/brits/modules/core.py index 83b48f95..689d5582 100644 --- a/pypots/imputation/brits/modules/core.py +++ b/pypots/imputation/brits/modules/core.py @@ -21,7 +21,7 @@ from .submodules import FeatureRegression from ....modules.rnn import TemporalDecay -from ....utils.metrics import cal_mae +from ....utils.metrics import calc_mae class RITS(nn.Module): @@ -150,17 +150,17 @@ def impute( hidden_states = hidden_states * gamma_h # decay hidden states x_h = self.hist_reg(hidden_states) - reconstruction_loss += cal_mae(x_h, x, m) + reconstruction_loss += calc_mae(x_h, x, m) x_c = m * x + (1 - m) * x_h z_h = self.feat_reg(x_c) - reconstruction_loss += cal_mae(z_h, x, m) + reconstruction_loss += calc_mae(z_h, x, m) alpha = torch.sigmoid(self.combining_weight(torch.cat([gamma_x, m], dim=1))) c_h = alpha * z_h + (1 - alpha) * x_h - reconstruction_loss += cal_mae(c_h, x, m) + reconstruction_loss += calc_mae(c_h, x, m) c_c = m * x + (1 - m) * c_h estimations.append(c_h.unsqueeze(dim=1)) diff --git a/pypots/imputation/csdi/model.py b/pypots/imputation/csdi/model.py index 71b76918..25a455fb 100644 --- a/pypots/imputation/csdi/model.py +++ b/pypots/imputation/csdi/model.py @@ -32,7 +32,7 @@ from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger -from ...utils.metrics import cal_mse +from ...utils.metrics import calc_mse class CSDI(BaseNNImputer): @@ -256,7 +256,7 @@ def _train_model( imputation_collector = torch.cat(imputation_collector) imputation_collector = imputation_collector.cpu().detach().numpy() - mean_val_loss = cal_mse( + mean_val_loss = calc_mse( imputation_collector, val_loader.dataset.data["X_intact"], val_loader.dataset.data["indicating_mask"], diff --git a/pypots/imputation/mrnn/modules/core.py b/pypots/imputation/mrnn/modules/core.py index e4936ec8..ba0ba2cf 100644 --- a/pypots/imputation/mrnn/modules/core.py +++ b/pypots/imputation/mrnn/modules/core.py @@ -12,7 +12,7 @@ import torch.nn as nn from .submodules import FCN_Regression -from ....utils.metrics import cal_rmse +from ....utils.metrics import calc_rmse class _MRNN(nn.Module): @@ -74,7 +74,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: FCN_estimation = self.fcn_regression( x, m, RNN_imputed_data ) # FCN estimation is output estimation - reconstruction_loss += cal_rmse(FCN_estimation, x, m) + cal_rmse( + reconstruction_loss += calc_rmse(FCN_estimation, x, m) + calc_rmse( RNN_estimation, x, m ) estimations.append(FCN_estimation.unsqueeze(dim=1)) diff --git a/pypots/imputation/saits/model.py b/pypots/imputation/saits/model.py index 378062a4..99952645 100644 --- a/pypots/imputation/saits/model.py +++ b/pypots/imputation/saits/model.py @@ -27,7 +27,7 @@ from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger -from ...utils.metrics import cal_mae +from ...utils.metrics import calc_mae class SAITS(BaseNNImputer): @@ -145,7 +145,7 @@ def __init__( batch_size: int = 32, epochs: int = 100, patience: Optional[int] = None, - customized_loss_func: Callable = cal_mae, + customized_loss_func: Callable = calc_mae, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, diff --git a/pypots/imputation/saits/modules/core.py b/pypots/imputation/saits/modules/core.py index eb062709..51c2dfd5 100644 --- a/pypots/imputation/saits/modules/core.py +++ b/pypots/imputation/saits/modules/core.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from ....modules.transformer import EncoderLayer, PositionalEncoding -from ....utils.metrics import cal_mae +from ....utils.metrics import calc_mae class _SAITS(nn.Module): @@ -39,7 +39,7 @@ def __init__( diagonal_attention_mask: bool = True, ORT_weight: float = 1, MIT_weight: float = 1, - customized_loss_func: Callable = cal_mae, + customized_loss_func: Callable = calc_mae, ): super().__init__() self.n_layers = n_layers diff --git a/pypots/imputation/transformer/modules/core.py b/pypots/imputation/transformer/modules/core.py index 34750da8..7044358d 100644 --- a/pypots/imputation/transformer/modules/core.py +++ b/pypots/imputation/transformer/modules/core.py @@ -19,7 +19,7 @@ import torch.nn as nn from ....modules.transformer import EncoderLayer, PositionalEncoding -from ....utils.metrics import cal_mae +from ....utils.metrics import calc_mae class _TransformerEncoder(nn.Module): @@ -89,8 +89,8 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # if in training mode, return results with losses if training: - ORT_loss = cal_mae(learned_presentation, X, masks) - MIT_loss = cal_mae( + ORT_loss = calc_mae(learned_presentation, X, masks) + MIT_loss = calc_mae( learned_presentation, inputs["X_intact"], inputs["indicating_mask"] ) results["ORT_loss"] = ORT_loss diff --git a/pypots/utils/metrics/__init__.py b/pypots/utils/metrics/__init__.py index 9856a49b..22e155de 100644 --- a/pypots/utils/metrics/__init__.py +++ b/pypots/utils/metrics/__init__.py @@ -6,6 +6,12 @@ # License: BSD-3-Clause from .classification import ( + calc_binary_classification_metrics, + calc_precision_recall_f1, + calc_pr_auc, + calc_roc_auc, + calc_acc, + # deprecated cal_binary_classification_metrics, cal_precision_recall_f1, cal_pr_auc, @@ -13,6 +19,16 @@ cal_acc, ) from .clustering import ( + calc_rand_index, + calc_adjusted_rand_index, + calc_cluster_purity, + calc_nmi, + calc_chs, + calc_dbs, + calc_silhouette, + calc_internal_cluster_validation_metrics, + calc_external_cluster_validation_metrics, + # deprecated cal_rand_index, cal_adjusted_rand_index, cal_cluster_purity, @@ -23,21 +39,49 @@ cal_internal_cluster_validation_metrics, cal_external_cluster_validation_metrics, ) -from .error import cal_mae, cal_mse, cal_rmse, cal_mre +from .error import ( + calc_mae, + calc_mse, + calc_rmse, + calc_mre, + calc_quantile_crps, + calc_quantile_crps_sum, + # deprecated + cal_mae, + cal_mse, + cal_rmse, + cal_mre, +) __all__ = [ # error + "calc_mae", + "calc_mse", + "calc_rmse", + "calc_mre", + "calc_quantile_crps", + "calc_quantile_crps_sum", + # classification + "calc_binary_classification_metrics", + "calc_precision_recall_f1", + "calc_pr_auc", + "calc_roc_auc", + "calc_acc", + # clustering + "calc_rand_index", + "calc_adjusted_rand_index", + "calc_cluster_purity", + "calc_nmi", + "calc_chs", + "calc_dbs", + "calc_silhouette", + "calc_internal_cluster_validation_metrics", + "calc_external_cluster_validation_metrics", + # deprecated "cal_mae", "cal_mse", "cal_rmse", "cal_mre", - # classification - "cal_binary_classification_metrics", - "cal_precision_recall_f1", - "cal_pr_auc", - "cal_roc_auc", - "cal_acc", - # clustering "cal_rand_index", "cal_adjusted_rand_index", "cal_cluster_purity", @@ -47,4 +91,9 @@ "cal_silhouette", "cal_internal_cluster_validation_metrics", "cal_external_cluster_validation_metrics", + "cal_binary_classification_metrics", + "cal_precision_recall_f1", + "cal_pr_auc", + "cal_roc_auc", + "cal_acc", ] diff --git a/pypots/utils/metrics/classification.py b/pypots/utils/metrics/classification.py index 8d45e22f..d9867f6f 100644 --- a/pypots/utils/metrics/classification.py +++ b/pypots/utils/metrics/classification.py @@ -10,8 +10,10 @@ import numpy as np from sklearn import metrics +from ..logging import logger -def cal_binary_classification_metrics( + +def calc_binary_classification_metrics( prob_predictions: np.ndarray, targets: np.ndarray, pos_label: int = 1, @@ -85,7 +87,7 @@ def cal_binary_classification_metrics( raise f"predictions dimensions should be 1 or 2, but got predictions.shape: {prob_predictions.shape}" # accuracy score doesn't have to be of binary classification - acc_score = cal_acc(prediction_categories, targets) + acc_score = calc_acc(prediction_categories, targets) # turn targets into binary targets mask_val = -1 if pos_label == 0 else 0 @@ -93,13 +95,13 @@ def cal_binary_classification_metrics( binary_targets = np.copy(targets) binary_targets[~mask] = mask_val - precision, recall, f1 = cal_precision_recall_f1( + precision, recall, f1 = calc_precision_recall_f1( binary_prediction_categories, binary_targets, pos_label ) - pr_auc, precisions, recalls, _ = cal_pr_auc( + pr_auc, precisions, recalls, _ = calc_pr_auc( binary_predictions, binary_targets, pos_label ) - ROC_AUC, fprs, tprs, _ = cal_roc_auc(binary_predictions, binary_targets, pos_label) + ROC_AUC, fprs, tprs, _ = calc_roc_auc(binary_predictions, binary_targets, pos_label) PR_AUC = metrics.auc(recalls, precisions) classification_metrics = { "predictions": prediction_categories, @@ -117,7 +119,7 @@ def cal_binary_classification_metrics( return classification_metrics -def cal_precision_recall_f1( +def calc_precision_recall_f1( prob_predictions: np.ndarray, targets: np.ndarray, pos_label: int = 1, @@ -154,7 +156,7 @@ def cal_precision_recall_f1( return precision, recall, f1 -def cal_pr_auc( +def calc_pr_auc( prob_predictions: np.ndarray, targets: np.ndarray, pos_label: int = 1, @@ -195,7 +197,7 @@ def cal_pr_auc( return pr_auc, precisions, recalls, thresholds -def cal_roc_auc( +def calc_roc_auc( prob_predictions: np.ndarray, targets: np.ndarray, pos_label: int = 1, @@ -235,7 +237,7 @@ def cal_roc_auc( return roc_auc, fprs, tprs, thresholds -def cal_acc(class_predictions: np.ndarray, targets: np.ndarray) -> float: +def calc_acc(class_predictions: np.ndarray, targets: np.ndarray) -> float: """Calculate accuracy score of model predictions. Parameters @@ -254,3 +256,38 @@ def cal_acc(class_predictions: np.ndarray, targets: np.ndarray) -> float: """ acc_score = metrics.accuracy_score(targets, class_predictions) return acc_score + + +######################################################################################################################## +# Deprecated functions +######################################################################################################################## + + +def cal_binary_classification_metrics(**kwargs): + logger.warning( + "🚨 cal_binary_classification_metrics() is deprecated, " + "use calc_binary_classification_metrics() instead." + ) + return calc_binary_classification_metrics(**kwargs) + + +def cal_precision_recall_f1(**kwargs): + logger.warning( + "🚨 cal_precision_recall_f1() is deprecated, use calc_precision_recall_f1() instead." + ) + return calc_precision_recall_f1(**kwargs) + + +def cal_pr_auc(**kwargs): + logger.warning("🚨 cal_pr_auc() is deprecated, use calc_pr_auc() instead.") + return calc_pr_auc(**kwargs) + + +def cal_roc_auc(**kwargs): + logger.warning("🚨 cal_roc_auc() is deprecated, use calc_roc_auc() instead.") + return calc_roc_auc(**kwargs) + + +def cal_acc(**kwargs): + logger.warning("🚨 cal_acc() is deprecated, use calc_acc() instead.") + return calc_acc(**kwargs) diff --git a/pypots/utils/metrics/clustering.py b/pypots/utils/metrics/clustering.py index bf05195d..f417d295 100644 --- a/pypots/utils/metrics/clustering.py +++ b/pypots/utils/metrics/clustering.py @@ -8,8 +8,10 @@ import numpy as np from sklearn import metrics +from ..logging import logger -def cal_rand_index( + +def calc_rand_index( class_predictions: np.ndarray, targets: np.ndarray, ) -> float: @@ -60,7 +62,7 @@ def cal_rand_index( return RI -def cal_adjusted_rand_index( +def calc_adjusted_rand_index( class_predictions: np.ndarray, targets: np.ndarray, ) -> float: @@ -96,7 +98,7 @@ def cal_adjusted_rand_index( return aRI -def cal_nmi( +def calc_nmi( class_predictions: np.ndarray, targets: np.ndarray, ) -> float: @@ -121,7 +123,7 @@ def cal_nmi( return NMI -def cal_cluster_purity( +def calc_cluster_purity( class_predictions: np.ndarray, targets: np.ndarray, ) -> float: @@ -152,7 +154,7 @@ def cal_cluster_purity( return cluster_purity -def cal_external_cluster_validation_metrics(class_predictions, targets): +def calc_external_cluster_validation_metrics(class_predictions, targets): """Computer all external cluster validation metrics available in PyPOTS and return as a dictionary. Parameters @@ -169,10 +171,10 @@ def cal_external_cluster_validation_metrics(class_predictions, targets): A dictionary contains all external cluster validation metrics available in PyPOTS. """ - ri = cal_rand_index(class_predictions, targets) - ari = cal_adjusted_rand_index(class_predictions, targets) - nmi = cal_nmi(class_predictions, targets) - cp = cal_cluster_purity(class_predictions, targets) + ri = calc_rand_index(class_predictions, targets) + ari = calc_adjusted_rand_index(class_predictions, targets) + nmi = calc_nmi(class_predictions, targets) + cp = calc_cluster_purity(class_predictions, targets) external_cluster_validation_metrics = { "rand_index": ri, @@ -183,7 +185,7 @@ def cal_external_cluster_validation_metrics(class_predictions, targets): return external_cluster_validation_metrics -def cal_silhouette(X: np.ndarray, predicted_labels: np.ndarray) -> float: +def calc_silhouette(X: np.ndarray, predicted_labels: np.ndarray) -> float: """Compute the mean Silhouette Coefficient of all samples. Parameters @@ -214,7 +216,7 @@ def cal_silhouette(X: np.ndarray, predicted_labels: np.ndarray) -> float: return silhouette_score -def cal_chs(X: np.ndarray, predicted_labels: np.ndarray) -> float: +def calc_chs(X: np.ndarray, predicted_labels: np.ndarray) -> float: """Compute the Calinski and Harabasz score (also known as the Variance Ratio Criterion). X : array-like of shape (n_samples_a, n_features) @@ -239,7 +241,7 @@ def cal_chs(X: np.ndarray, predicted_labels: np.ndarray) -> float: return calinski_harabasz_score -def cal_dbs(X: np.ndarray, predicted_labels: np.ndarray) -> float: +def calc_dbs(X: np.ndarray, predicted_labels: np.ndarray) -> float: """Compute the Davies-Bouldin score. Parameters @@ -268,7 +270,7 @@ def cal_dbs(X: np.ndarray, predicted_labels: np.ndarray) -> float: return davies_bouldin_score -def cal_internal_cluster_validation_metrics(X, predicted_labels): +def calc_internal_cluster_validation_metrics(X, predicted_labels): """Computer all internal cluster validation metrics available in PyPOTS and return as a dictionary. Parameters @@ -285,9 +287,9 @@ def cal_internal_cluster_validation_metrics(X, predicted_labels): A dictionary contains all internal cluster validation metrics available in PyPOTS. """ - silhouette_score = cal_silhouette(X, predicted_labels) - calinski_harabasz_score = cal_chs(X, predicted_labels) - davies_bouldin_score = cal_dbs(X, predicted_labels) + silhouette_score = calc_silhouette(X, predicted_labels) + calinski_harabasz_score = calc_chs(X, predicted_labels) + davies_bouldin_score = calc_dbs(X, predicted_labels) internal_cluster_validation_metrics = { "silhouette_score": silhouette_score, @@ -295,3 +297,59 @@ def cal_internal_cluster_validation_metrics(X, predicted_labels): "davies_bouldin_score": davies_bouldin_score, } return internal_cluster_validation_metrics + + +######################################################################################################################## +# Deprecated functions +######################################################################################################################## + + +def cal_rand_index(**kwargs): + logger.warning("🚨 Deprecated function, please use `calc_rand_index` instead.") + return calc_rand_index(**kwargs) + + +def cal_adjusted_rand_index(**kwargs): + logger.warning( + "🚨 Deprecated function, please use `calc_adjusted_rand_index` instead." + ) + return calc_adjusted_rand_index(**kwargs) + + +def cal_nmi(**kwargs): + logger.warning("🚨 Deprecated function, please use `calc_nmi` instead.") + return calc_nmi(**kwargs) + + +def cal_cluster_purity(**kwargs): + logger.warning("🚨 Deprecated function, please use `calc_cluster_purity` instead.") + return calc_cluster_purity(**kwargs) + + +def cal_external_cluster_validation_metrics(**kwargs): + logger.warning( + "🚨 Deprecated function, please use `calc_external_cluster_validation_metrics` instead." + ) + return calc_external_cluster_validation_metrics(**kwargs) + + +def cal_silhouette(**kwargs): + logger.warning("🚨 Deprecated function, please use `calc_silhouette` instead.") + return calc_silhouette(**kwargs) + + +def cal_chs(**kwargs): + logger.warning("🚨 Deprecated function, please use `calc_chs` instead.") + return calc_chs(**kwargs) + + +def cal_dbs(**kwargs): + logger.warning("🚨 Deprecated function, please use `calc_dbs` instead.") + return calc_dbs(**kwargs) + + +def cal_internal_cluster_validation_metrics(**kwargs): + logger.warning( + "🚨 Deprecated function, please use `calc_internal_cluster_validation_metrics` instead." + ) + return calc_internal_cluster_validation_metrics(**kwargs) diff --git a/pypots/utils/metrics/error.py b/pypots/utils/metrics/error.py index fc8133d5..8cafcb61 100644 --- a/pypots/utils/metrics/error.py +++ b/pypots/utils/metrics/error.py @@ -10,8 +10,10 @@ import numpy as np import torch +from ..logging import logger -def cal_mae( + +def calc_mae( predictions: Union[np.ndarray, torch.Tensor, list], targets: Union[np.ndarray, torch.Tensor, list], masks: Optional[Union[np.ndarray, torch.Tensor, list]] = None, @@ -36,10 +38,10 @@ def cal_mae( -------- >>> import numpy as np - >>> from pypots.utils.metrics import cal_mae + >>> from pypots.utils.metrics import calc_mae >>> targets = np.array([1, 2, 3, 4, 5]) >>> predictions = np.array([1, 2, 1, 4, 6]) - >>> mae = cal_mae(predictions, targets) + >>> mae = calc_mae(predictions, targets) mae = 0.6 here, the error is from the 3rd and 5th elements and is :math:`|3-1|+|5-6|=3`, so the result is 3/5=0.6. @@ -47,7 +49,7 @@ def cal_mae( we can use ``masks`` to filter out them: >>> masks = np.array([0, 0, 0, 1, 1]) - >>> mae = cal_mae(predictions, targets, masks) + >>> mae = calc_mae(predictions, targets, masks) mae = 0.5 here, the first three elements are ignored, the error is from the 5th element and is :math:`|5-6|=1`, so the result is 1/2=0.5. @@ -66,7 +68,7 @@ def cal_mae( return lib.mean(lib.abs(predictions - targets)) -def cal_mse( +def calc_mse( predictions: Union[np.ndarray, torch.Tensor, list], targets: Union[np.ndarray, torch.Tensor, list], masks: Optional[Union[np.ndarray, torch.Tensor, list]] = None, @@ -91,10 +93,10 @@ def cal_mse( -------- >>> import numpy as np - >>> from pypots.utils.metrics import cal_mse + >>> from pypots.utils.metrics import calc_mse >>> targets = np.array([1, 2, 3, 4, 5]) >>> predictions = np.array([1, 2, 1, 4, 6]) - >>> mse = cal_mse(predictions, targets) + >>> mse = calc_mse(predictions, targets) mse = 1 here, the error is from the 3rd and 5th elements and is :math:`|3-1|^2+|5-6|^2=5`, so the result is 5/5=1. @@ -102,7 +104,7 @@ def cal_mse( we can use ``masks`` to filter out them: >>> masks = np.array([0, 0, 0, 1, 1]) - >>> mse = cal_mse(predictions, targets, masks) + >>> mse = calc_mse(predictions, targets, masks) mse = 0.5 here, the first three elements are ignored, the error is from the 5th element and is :math:`|5-6|^2=1`, so the result is 1/2=0.5. @@ -122,7 +124,7 @@ def cal_mse( return lib.mean(lib.square(predictions - targets)) -def cal_rmse( +def calc_rmse( predictions: Union[np.ndarray, torch.Tensor, list], targets: Union[np.ndarray, torch.Tensor, list], masks: Optional[Union[np.ndarray, torch.Tensor, list]] = None, @@ -147,10 +149,10 @@ def cal_rmse( -------- >>> import numpy as np - >>> from pypots.utils.metrics import cal_rmse + >>> from pypots.utils.metrics import calc_rmse >>> targets = np.array([1, 2, 3, 4, 5]) >>> predictions = np.array([1, 2, 1, 4, 6]) - >>> rmse = cal_rmse(predictions, targets) + >>> rmse = calc_rmse(predictions, targets) rmse = 1 here, the error is from the 3rd and 5th elements and is :math:`|3-1|^2+|5-6|^2=5`, so the result is :math:`\\sqrt{5/5}=1`. @@ -159,7 +161,7 @@ def cal_rmse( we can use ``masks`` to filter out them: >>> masks = np.array([0, 0, 0, 1, 1]) - >>> rmse = cal_rmse(predictions, targets, masks) + >>> rmse = calc_rmse(predictions, targets, masks) rmse = 0.707 here, the first three elements are ignored, the error is from the 5th element and is :math:`|5-6|^2=1`, so the result is :math:`\\sqrt{1/2}=0.5`. @@ -170,10 +172,10 @@ def cal_rmse( f"type(inputs)={type(predictions)}, type(target)={type(targets)}" ) lib = np if isinstance(predictions, np.ndarray) else torch - return lib.sqrt(cal_mse(predictions, targets, masks)) + return lib.sqrt(calc_mse(predictions, targets, masks)) -def cal_mre( +def calc_mre( predictions: Union[np.ndarray, torch.Tensor, list], targets: Union[np.ndarray, torch.Tensor, list], masks: Optional[Union[np.ndarray, torch.Tensor, list]] = None, @@ -198,10 +200,10 @@ def cal_mre( -------- >>> import numpy as np - >>> from pypots.utils.metrics import cal_mre + >>> from pypots.utils.metrics import calc_mre >>> targets = np.array([1, 2, 3, 4, 5]) >>> predictions = np.array([1, 2, 1, 4, 6]) - >>> mre = cal_mre(predictions, targets) + >>> mre = calc_mre(predictions, targets) mre = 0.2 here, the error is from the 3rd and 5th elements and is :math:`|3-1|+|5-6|=3`, so the result is :math:`\\sqrt{3/(1+2+3+4+5)}=1`. @@ -210,7 +212,7 @@ def cal_mre( we can use ``masks`` to filter out them: >>> masks = np.array([0, 0, 0, 1, 1]) - >>> mre = cal_mre(predictions, targets, masks) + >>> mre = calc_mre(predictions, targets, masks) mre = 0.111 here, the first three elements are ignored, the error is from the 5th element and is :math:`|5-6|^2=1`, so the result is :math:`\\sqrt{1/2}=0.5`. @@ -229,3 +231,86 @@ def cal_mre( return lib.sum(lib.abs(predictions - targets)) / ( lib.sum(lib.abs(targets)) + 1e-12 ) + + +def calc_quantile_loss(predictions, targets, q: float, eval_points) -> float: + quantile_loss = 2 * torch.sum( + torch.abs( + (predictions - targets) * eval_points * ((targets <= predictions) * 1.0 - q) + ) + ) + return quantile_loss + + +def calc_quantile_crps(predictions, targets, eval_points, mean_scaler=0, scaler=1): + """Continuous rank probability score for distributional predictions.""" + if isinstance(predictions, np.ndarray): + predictions = torch.from_numpy(predictions) + if isinstance(targets, np.ndarray): + targets = torch.from_numpy(targets) + if isinstance(eval_points, np.ndarray): + eval_points = torch.from_numpy(eval_points) + + targets = targets * scaler + mean_scaler + predictions = predictions * scaler + mean_scaler + + quantiles = np.arange(0.05, 1.0, 0.05) + denominator = torch.sum(torch.abs(targets * eval_points)) + CRPS = 0 + for i in range(len(quantiles)): + q_pred = [] + for j in range(len(predictions)): + q_pred.append(torch.quantile(predictions[j : j + 1], quantiles[i], dim=1)) + q_pred = torch.cat(q_pred, 0) + q_loss = calc_quantile_loss(targets, q_pred, quantiles[i], eval_points) + CRPS += q_loss / denominator + return CRPS.item() / len(quantiles) + + +def calc_quantile_crps_sum(predictions, targets, eval_points, mean_scaler=0, scaler=1): + """Continuous rank probability score for distributional predictions.""" + if isinstance(predictions, np.ndarray): + predictions = torch.from_numpy(predictions) + if isinstance(targets, np.ndarray): + targets = torch.from_numpy(targets) + if isinstance(eval_points, np.ndarray): + eval_points = torch.from_numpy(eval_points) + + eval_points = eval_points.mean(-1) + targets = targets * scaler + mean_scaler + targets = targets.sum(-1) + predictions = predictions * scaler + mean_scaler + + quantiles = np.arange(0.05, 1.0, 0.05) + denominator = torch.sum(torch.abs(targets * eval_points)) + CRPS = 0 + for i in range(len(quantiles)): + q_pred = torch.quantile(predictions.sum(-1), quantiles[i], dim=1) + q_loss = calc_quantile_loss(targets, q_pred, quantiles[i], eval_points) + CRPS += q_loss / denominator + return CRPS.item() / len(quantiles) + + +######################################################################################################################## +# Deprecated functions +######################################################################################################################## + + +def cal_mae(**kwargs): + logger.warning("🚨 cal_mae() is deprecated, use calc_mae() instead.") + return calc_mae(**kwargs) + + +def cal_rmse(**kwargs): + logger.warning("🚨 cal_rmse() is deprecated, use calc_rmse() instead.") + return calc_rmse(**kwargs) + + +def cal_mse(**kwargs): + logger.warning("🚨 cal_mse() is deprecated, use calc_mse() instead.") + return calc_mse(**kwargs) + + +def cal_mre(**kwargs): + logger.warning("🚨 cal_mre() is deprecated, use calc_mre() instead.") + return calc_mre(**kwargs) diff --git a/tests/classification/brits.py b/tests/classification/brits.py index 78e8c042..c7815f5c 100644 --- a/tests/classification/brits.py +++ b/tests/classification/brits.py @@ -13,7 +13,7 @@ from pypots.classification import BRITS from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_binary_classification_metrics +from pypots.utils.metrics import calc_binary_classification_metrics from tests.classification.config import ( EPOCHS, TRAIN_SET, @@ -58,7 +58,7 @@ def test_0_fit(self): @pytest.mark.xdist_group(name="classification-brits") def test_1_classify(self): predictions = self.brits.classify(TEST_SET) - metrics = cal_binary_classification_metrics(predictions, DATA["test_y"]) + metrics = calc_binary_classification_metrics(predictions, DATA["test_y"]) logger.info( f'ROC_AUC: {metrics["roc_auc"]}, \n' f'PR_AUC: {metrics["pr_auc"]},\n' diff --git a/tests/classification/grud.py b/tests/classification/grud.py index abc94f63..37bad931 100644 --- a/tests/classification/grud.py +++ b/tests/classification/grud.py @@ -13,7 +13,7 @@ from pypots.classification import GRUD from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_binary_classification_metrics +from pypots.utils.metrics import calc_binary_classification_metrics from tests.classification.config import ( EPOCHS, TRAIN_SET, @@ -57,7 +57,7 @@ def test_0_fit(self): @pytest.mark.xdist_group(name="classification-grud") def test_1_classify(self): predictions = self.grud.classify(TEST_SET) - metrics = cal_binary_classification_metrics(predictions, DATA["test_y"]) + metrics = calc_binary_classification_metrics(predictions, DATA["test_y"]) logger.info( f'ROC_AUC: {metrics["roc_auc"]}, \n' f'PR_AUC: {metrics["pr_auc"]},\n' diff --git a/tests/classification/raindrop.py b/tests/classification/raindrop.py index a7c42234..967f73ec 100644 --- a/tests/classification/raindrop.py +++ b/tests/classification/raindrop.py @@ -12,7 +12,7 @@ from pypots.classification import Raindrop from pypots.utils.logging import logger -from pypots.utils.metrics import cal_binary_classification_metrics +from pypots.utils.metrics import calc_binary_classification_metrics from tests.classification.config import ( EPOCHS, TRAIN_SET, @@ -60,7 +60,7 @@ def test_0_fit(self): @pytest.mark.xdist_group(name="classification-raindrop") def test_1_classify(self): predictions = self.raindrop.classify(TEST_SET) - metrics = cal_binary_classification_metrics(predictions, DATA["test_y"]) + metrics = calc_binary_classification_metrics(predictions, DATA["test_y"]) logger.info( f'ROC_AUC: {metrics["roc_auc"]}, \n' f'PR_AUC: {metrics["pr_auc"]},\n' diff --git a/tests/clustering/crli.py b/tests/clustering/crli.py index 2b849b79..63960619 100644 --- a/tests/clustering/crli.py +++ b/tests/clustering/crli.py @@ -15,8 +15,8 @@ from pypots.optim import Adam from pypots.utils.logging import logger from pypots.utils.metrics import ( - cal_external_cluster_validation_metrics, - cal_internal_cluster_validation_metrics, + calc_external_cluster_validation_metrics, + calc_internal_cluster_validation_metrics, ) from tests.clustering.config import ( EPOCHS, @@ -125,10 +125,10 @@ def test_1_parameters(self): def test_2_cluster(self): # GRU cell clustering_results = self.crli_gru.predict(TEST_SET, return_latent_vars=True) - external_metrics = cal_external_cluster_validation_metrics( + external_metrics = calc_external_cluster_validation_metrics( clustering_results["clustering"], DATA["test_y"] ) - internal_metrics = cal_internal_cluster_validation_metrics( + internal_metrics = calc_internal_cluster_validation_metrics( clustering_results["latent_vars"]["clustering_latent"], DATA["test_y"] ) logger.info(f"CRLI-GRU: {external_metrics}") @@ -136,10 +136,10 @@ def test_2_cluster(self): # LSTM cell clustering_results = self.crli_lstm.predict(TEST_SET, return_latent_vars=True) - external_metrics = cal_external_cluster_validation_metrics( + external_metrics = calc_external_cluster_validation_metrics( clustering_results["clustering"], DATA["test_y"] ) - internal_metrics = cal_internal_cluster_validation_metrics( + internal_metrics = calc_internal_cluster_validation_metrics( clustering_results["latent_vars"]["clustering_latent"], DATA["test_y"] ) logger.info(f"CRLI-LSTM: {external_metrics}") diff --git a/tests/clustering/vader.py b/tests/clustering/vader.py index cbdae092..d5143367 100644 --- a/tests/clustering/vader.py +++ b/tests/clustering/vader.py @@ -16,8 +16,8 @@ from pypots.optim import Adam from pypots.utils.logging import logger from pypots.utils.metrics import ( - cal_external_cluster_validation_metrics, - cal_internal_cluster_validation_metrics, + calc_external_cluster_validation_metrics, + calc_internal_cluster_validation_metrics, ) from tests.clustering.config import ( EPOCHS, @@ -65,10 +65,10 @@ def test_0_fit(self): def test_1_cluster(self): try: clustering_results = self.vader.predict(TEST_SET, return_latent_vars=True) - external_metrics = cal_external_cluster_validation_metrics( + external_metrics = calc_external_cluster_validation_metrics( clustering_results["clustering"], DATA["test_y"] ) - internal_metrics = cal_internal_cluster_validation_metrics( + internal_metrics = calc_internal_cluster_validation_metrics( clustering_results["latent_vars"]["z"], DATA["test_y"] ) logger.info(f"{external_metrics}") diff --git a/tests/forecasting/bttf.py b/tests/forecasting/bttf.py index 1ced03b0..1483e7d7 100644 --- a/tests/forecasting/bttf.py +++ b/tests/forecasting/bttf.py @@ -11,7 +11,7 @@ from pypots.forecasting import BTTF from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.forecasting.config import ( TEST_SET, TEST_SET_INTACT, @@ -38,7 +38,7 @@ class TestBTTF(unittest.TestCase): @pytest.mark.xdist_group(name="forecasting-bttf") def test_0_forecasting(self): predictions = self.bttf.forecast(TEST_SET) - mae = cal_mae(predictions, TEST_SET_INTACT["X"][:, -N_PRED_STEP:]) + mae = calc_mae(predictions, TEST_SET_INTACT["X"][:, -N_PRED_STEP:]) logger.info(f"prediction MAE: {mae}") diff --git a/tests/imputation/brits.py b/tests/imputation/brits.py index 69ea9613..e5eb2cb7 100644 --- a/tests/imputation/brits.py +++ b/tests/imputation/brits.py @@ -15,7 +15,7 @@ from pypots.imputation import BRITS from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import ( DATA, DEVICE, @@ -61,7 +61,7 @@ def test_1_impute(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"BRITS test_MAE: {test_MAE}") diff --git a/tests/imputation/csdi.py b/tests/imputation/csdi.py index a5d3a73f..0ccf1222 100644 --- a/tests/imputation/csdi.py +++ b/tests/imputation/csdi.py @@ -15,7 +15,7 @@ from pypots.imputation import CSDI from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae, calc_quantile_crps from tests.global_test_config import ( DATA, DEVICE, @@ -63,14 +63,17 @@ def test_0_fit(self): @pytest.mark.xdist_group(name="imputation-csdi") def test_1_impute(self): imputed_X = self.csdi.predict(TEST_SET)["imputation"] + test_CRPS = calc_quantile_crps( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) imputed_X = imputed_X.mean(axis=1) # mean over sampling times assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) - logger.info(f"CSDI test_MAE: {test_MAE}") + logger.info(f"CSDI test_MAE: {test_MAE}, test_CRPS: {test_CRPS}") @pytest.mark.xdist_group(name="imputation-csdi") def test_2_parameters(self): diff --git a/tests/imputation/gpvae.py b/tests/imputation/gpvae.py index d2e45f31..b94bff37 100644 --- a/tests/imputation/gpvae.py +++ b/tests/imputation/gpvae.py @@ -15,7 +15,7 @@ from pypots.imputation import GPVAE from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import ( DATA, DEVICE, @@ -61,7 +61,7 @@ def test_1_impute(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"GP-VAE test_MAE: {test_MAE}") diff --git a/tests/imputation/locf.py b/tests/imputation/locf.py index 18f7ed68..b43b7414 100644 --- a/tests/imputation/locf.py +++ b/tests/imputation/locf.py @@ -14,7 +14,7 @@ from pypots.imputation import LOCF from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import ( DATA, ) @@ -37,7 +37,7 @@ def test_0_impute(self): assert not np.isnan( test_X_imputed_zero ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( test_X_imputed_zero, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"LOCF (zero) test_MAE: {test_MAE}") @@ -46,7 +46,7 @@ def test_0_impute(self): assert not np.isnan( test_X_imputed_backward ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( test_X_imputed_backward, DATA["test_X_intact"], DATA["test_X_indicating_mask"], @@ -57,7 +57,7 @@ def test_0_impute(self): assert not np.isnan( test_X_imputed_mean ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( test_X_imputed_mean, DATA["test_X_intact"], DATA["test_X_indicating_mask"], @@ -80,14 +80,14 @@ def test_0_impute(self): assert not torch.isnan( test_X_imputed_zero ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae(test_X_imputed_zero, test_X_intact, test_X_indicating_mask) + test_MAE = calc_mae(test_X_imputed_zero, test_X_intact, test_X_indicating_mask) logger.info(f"LOCF (zero) test_MAE: {test_MAE}") test_X_imputed_backward = self.locf_backward.predict({"X": X})["imputation"] assert not torch.isnan( test_X_imputed_backward ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( test_X_imputed_backward, test_X_intact, test_X_indicating_mask, @@ -98,7 +98,7 @@ def test_0_impute(self): assert not torch.isnan( test_X_imputed_mean ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( test_X_imputed_mean, test_X_intact, test_X_indicating_mask, diff --git a/tests/imputation/mrnn.py b/tests/imputation/mrnn.py index ae28d0eb..b3074f09 100644 --- a/tests/imputation/mrnn.py +++ b/tests/imputation/mrnn.py @@ -15,7 +15,7 @@ from pypots.imputation import MRNN from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import ( DATA, DEVICE, @@ -61,7 +61,7 @@ def test_1_impute(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"MRNN test_MAE: {test_MAE}") diff --git a/tests/imputation/saits.py b/tests/imputation/saits.py index a5620569..d25f9361 100644 --- a/tests/imputation/saits.py +++ b/tests/imputation/saits.py @@ -15,7 +15,7 @@ from pypots.imputation import SAITS from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import ( DATA, DEVICE, @@ -71,7 +71,7 @@ def test_1_impute(self): "latent_vars" in imputation_results.keys() ), "Latent variables are not returned thought `return_latent_vars` is set as True." - test_MAE = cal_mae( + test_MAE = calc_mae( imputation_results["imputation"], DATA["test_X_intact"], DATA["test_X_indicating_mask"], diff --git a/tests/imputation/timesnet.py b/tests/imputation/timesnet.py index 52e33ae4..33bfae3e 100644 --- a/tests/imputation/timesnet.py +++ b/tests/imputation/timesnet.py @@ -15,7 +15,7 @@ from pypots.imputation import TimesNet from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import ( DATA, DEVICE, @@ -67,7 +67,7 @@ def test_1_impute(self): imputation_results["imputation"] ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputation_results["imputation"], DATA["test_X_intact"], DATA["test_X_indicating_mask"], diff --git a/tests/imputation/transformer.py b/tests/imputation/transformer.py index c145ecfa..15624dc4 100644 --- a/tests/imputation/transformer.py +++ b/tests/imputation/transformer.py @@ -15,7 +15,7 @@ from pypots.imputation import Transformer from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import ( DATA, DEVICE, @@ -67,7 +67,7 @@ def test_1_impute(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"Transformer test_MAE: {test_MAE}") diff --git a/tests/imputation/usgan.py b/tests/imputation/usgan.py index ea723238..0ff25ea3 100644 --- a/tests/imputation/usgan.py +++ b/tests/imputation/usgan.py @@ -15,7 +15,7 @@ from pypots.imputation import USGAN from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import ( DATA, DEVICE, @@ -63,7 +63,7 @@ def test_1_impute(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"US-GAN test_MAE: {test_MAE}") diff --git a/tests/optim/adadelta.py b/tests/optim/adadelta.py index 71c991f2..c7eb6e6d 100644 --- a/tests/optim/adadelta.py +++ b/tests/optim/adadelta.py @@ -13,7 +13,7 @@ from pypots.imputation import SAITS from pypots.optim import Adadelta from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import DATA from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET @@ -46,7 +46,7 @@ def test_0_fit(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") diff --git a/tests/optim/adagrad.py b/tests/optim/adagrad.py index 6b055f21..7cb2a988 100644 --- a/tests/optim/adagrad.py +++ b/tests/optim/adagrad.py @@ -13,7 +13,7 @@ from pypots.imputation import SAITS from pypots.optim import Adagrad from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import DATA from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET @@ -46,7 +46,7 @@ def test_0_fit(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") diff --git a/tests/optim/adam.py b/tests/optim/adam.py index be6cb89b..9f583aee 100644 --- a/tests/optim/adam.py +++ b/tests/optim/adam.py @@ -13,7 +13,7 @@ from pypots.imputation import SAITS from pypots.optim import Adam from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import DATA from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET @@ -46,7 +46,7 @@ def test_0_fit(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") diff --git a/tests/optim/adamw.py b/tests/optim/adamw.py index e7f89797..e785e9f6 100644 --- a/tests/optim/adamw.py +++ b/tests/optim/adamw.py @@ -13,7 +13,7 @@ from pypots.imputation import SAITS from pypots.optim import AdamW from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import DATA from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET @@ -46,7 +46,7 @@ def test_0_fit(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") diff --git a/tests/optim/lr_schedulers.py b/tests/optim/lr_schedulers.py index 3a88976a..2aa1c520 100644 --- a/tests/optim/lr_schedulers.py +++ b/tests/optim/lr_schedulers.py @@ -22,7 +22,7 @@ MultiplicativeLR, ) from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import DATA from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET @@ -74,7 +74,7 @@ def test_0_lambda_lrs(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") @@ -102,7 +102,7 @@ def test_1_multiplicative_lrs(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") @@ -130,7 +130,7 @@ def test_2_step_lrs(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") @@ -158,7 +158,7 @@ def test_3_multistep_lrs(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") @@ -187,7 +187,7 @@ def test_4_constant_lrs(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") @@ -215,7 +215,7 @@ def test_5_linear_lrs(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") @@ -243,7 +243,7 @@ def test_6_exponential_lrs(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") diff --git a/tests/optim/rmsprop.py b/tests/optim/rmsprop.py index 29087520..f4a3f53c 100644 --- a/tests/optim/rmsprop.py +++ b/tests/optim/rmsprop.py @@ -13,7 +13,7 @@ from pypots.imputation import SAITS from pypots.optim import RMSprop from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import DATA from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET @@ -46,7 +46,7 @@ def test_0_fit(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") diff --git a/tests/optim/sgd.py b/tests/optim/sgd.py index 569bb96d..7dec3bf3 100644 --- a/tests/optim/sgd.py +++ b/tests/optim/sgd.py @@ -13,7 +13,7 @@ from pypots.imputation import SAITS from pypots.optim import SGD from pypots.utils.logging import logger -from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import calc_mae from tests.global_test_config import DATA from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET @@ -46,7 +46,7 @@ def test_0_fit(self): assert not np.isnan( imputed_X ).any(), "Output still has missing values after running impute()." - test_MAE = cal_mae( + test_MAE = calc_mae( imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] ) logger.info(f"SAITS test_MAE: {test_MAE}") diff --git a/tests/utils/random.py b/tests/utils/random.py index 89d78d93..096b02d4 100644 --- a/tests/utils/random.py +++ b/tests/utils/random.py @@ -9,7 +9,7 @@ import torch -from pypots.utils.random import set_random_seed +from pypots.utils.random import set_random_seed, get_random_seed class TestRandom(unittest.TestCase): @@ -31,6 +31,16 @@ def test_set_random_seed(self): random_state1, random_state2 ), "The random seed has been set, two random states are not the same." + current_seed = get_random_seed() + assert ( + not current_seed == 32 + ), "The random seed has been set to 26, not equal to 32." + set_random_seed(32) + current_seed = get_random_seed() + assert ( + current_seed == 32 + ), "The random seed has been set to 32, should be equal." + if __name__ == "__main__": unittest.main() From e73f0973b59f1267683ffea498851814bcabc340 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 9 Dec 2023 00:43:16 +0800 Subject: [PATCH 10/16] fix: error in _fetch_data_from_file() of CSDI dataset; --- pypots/imputation/csdi/data.py | 44 +++++++++++++++++----------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/pypots/imputation/csdi/data.py b/pypots/imputation/csdi/data.py index e0cfc894..9f9f1cab 100644 --- a/pypots/imputation/csdi/data.py +++ b/pypots/imputation/csdi/data.py @@ -24,18 +24,6 @@ def __init__( rate: float = 0.1, ): super().__init__(data, return_labels, file_type) - self.time_points = ( - None if "time_points" not in data.keys() else data["time_points"] - ) - # _, self.time_points = self._check_input(self.X, time_points) - self.for_pattern_mask = ( - None if "for_pattern_mask" not in data.keys() else data["for_pattern_mask"] - ) - # _, self.for_pattern_mask = self._check_input(self.X, for_pattern_mask) - self.cut_length = ( - None if "cut_length" not in data.keys() else data["cut_length"] - ) - # _, self.cut_length = self._check_input(self.X, cut_length) self.rate = rate def _fetch_data_from_array(self, idx: int) -> Iterable: @@ -71,19 +59,21 @@ def _fetch_data_from_array(self, idx: int) -> Iterable: observed_data = X_intact observed_mask = missing_mask + indicating_mask + gt_mask = missing_mask observed_tp = ( torch.arange(0, self.n_steps, dtype=torch.float32) - if self.time_points is None - else self.time_points[idx].to(torch.float32) + if "time_points" not in self.data.keys() + else torch.from_numpy(self.data["time_points"][idx]).to(torch.float32) ) - gt_mask = missing_mask for_pattern_mask = ( - gt_mask if self.for_pattern_mask is None else self.for_pattern_mask[idx] + gt_mask + if "for_pattern_mask" not in self.data.keys() + else torch.from_numpy(self.data["for_pattern_mask"][idx]).to(torch.float32) ) cut_length = ( torch.zeros(len(observed_data)).long() - if self.cut_length is None - else self.cut_length[idx] + if "cut_length" not in self.data.keys() + else torch.from_numpy(self.data["cut_length"][idx]).to(torch.float32) ) sample = [ @@ -124,15 +114,25 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: observed_data = X_intact observed_mask = missing_mask + indicating_mask - observed_tp = self.time_points[idx].to(torch.float32) gt_mask = indicating_mask + observed_tp = ( + torch.arange(0, self.n_steps, dtype=torch.float32) + if "time_points" not in self.file_handle.keys() + else torch.from_numpy(self.file_handle["time_points"][idx]).to( + torch.float32 + ) + ) for_pattern_mask = ( - gt_mask if self.for_pattern_mask is None else self.for_pattern_mask[idx] + gt_mask + if "for_pattern_mask" not in self.file_handle.keys() + else torch.from_numpy(self.file_handle["for_pattern_mask"][idx]).to( + torch.float32 + ) ) cut_length = ( torch.zeros(len(observed_data)).long() - if self.cut_length is None - else self.cut_length[idx] + if "cut_length" not in self.file_handle.keys() + else torch.from_numpy(self.file_handle["cut_length"][idx]).to(torch.float32) ) sample = [ From 77f6794e1d12fdaa9eecb18a02f343c9a5aa65c5 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sat, 9 Dec 2023 01:00:44 +0800 Subject: [PATCH 11/16] fix: error args passing in deprecated functions; --- pypots/imputation/timesnet/modules/core.py | 4 +- pypots/utils/metrics/classification.py | 20 +++++----- pypots/utils/metrics/clustering.py | 45 ++++++++++++---------- pypots/utils/metrics/error.py | 16 ++++---- 4 files changed, 45 insertions(+), 40 deletions(-) diff --git a/pypots/imputation/timesnet/modules/core.py b/pypots/imputation/timesnet/modules/core.py index 9dd4bf5a..ff51fe86 100644 --- a/pypots/imputation/timesnet/modules/core.py +++ b/pypots/imputation/timesnet/modules/core.py @@ -11,7 +11,7 @@ from .embedding import DataEmbedding from .layer import TimesBlock -from ....utils.metrics import cal_mse +from ....utils.metrics import calc_mse class _TimesNet(nn.Module): @@ -88,7 +88,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: if training: # `loss` is always the item for backward propagating to update the model - loss = cal_mse(dec_out, inputs["X_intact"], inputs["indicating_mask"]) + loss = calc_mse(dec_out, inputs["X_intact"], inputs["indicating_mask"]) results["loss"] = loss return results diff --git a/pypots/utils/metrics/classification.py b/pypots/utils/metrics/classification.py index d9867f6f..ef74d218 100644 --- a/pypots/utils/metrics/classification.py +++ b/pypots/utils/metrics/classification.py @@ -263,31 +263,31 @@ def calc_acc(class_predictions: np.ndarray, targets: np.ndarray) -> float: ######################################################################################################################## -def cal_binary_classification_metrics(**kwargs): +def cal_binary_classification_metrics(*args): logger.warning( "🚨 cal_binary_classification_metrics() is deprecated, " "use calc_binary_classification_metrics() instead." ) - return calc_binary_classification_metrics(**kwargs) + return calc_binary_classification_metrics(*args) -def cal_precision_recall_f1(**kwargs): +def cal_precision_recall_f1(*args): logger.warning( "🚨 cal_precision_recall_f1() is deprecated, use calc_precision_recall_f1() instead." ) - return calc_precision_recall_f1(**kwargs) + return calc_precision_recall_f1(*args) -def cal_pr_auc(**kwargs): +def cal_pr_auc(*args): logger.warning("🚨 cal_pr_auc() is deprecated, use calc_pr_auc() instead.") - return calc_pr_auc(**kwargs) + return calc_pr_auc(*args) -def cal_roc_auc(**kwargs): +def cal_roc_auc(*args): logger.warning("🚨 cal_roc_auc() is deprecated, use calc_roc_auc() instead.") - return calc_roc_auc(**kwargs) + return calc_roc_auc(*args) -def cal_acc(**kwargs): +def cal_acc(*args): logger.warning("🚨 cal_acc() is deprecated, use calc_acc() instead.") - return calc_acc(**kwargs) + return calc_acc(*args) diff --git a/pypots/utils/metrics/clustering.py b/pypots/utils/metrics/clustering.py index f417d295..2e24edd7 100644 --- a/pypots/utils/metrics/clustering.py +++ b/pypots/utils/metrics/clustering.py @@ -154,7 +154,10 @@ def calc_cluster_purity( return cluster_purity -def calc_external_cluster_validation_metrics(class_predictions, targets): +def calc_external_cluster_validation_metrics( + class_predictions: np.ndarray, + targets: np.ndarray, +) -> dict: """Computer all external cluster validation metrics available in PyPOTS and return as a dictionary. Parameters @@ -270,7 +273,9 @@ def calc_dbs(X: np.ndarray, predicted_labels: np.ndarray) -> float: return davies_bouldin_score -def calc_internal_cluster_validation_metrics(X, predicted_labels): +def calc_internal_cluster_validation_metrics( + X: np.ndarray, predicted_labels: np.ndarray +) -> dict: """Computer all internal cluster validation metrics available in PyPOTS and return as a dictionary. Parameters @@ -304,52 +309,52 @@ def calc_internal_cluster_validation_metrics(X, predicted_labels): ######################################################################################################################## -def cal_rand_index(**kwargs): +def cal_rand_index(*args): logger.warning("🚨 Deprecated function, please use `calc_rand_index` instead.") - return calc_rand_index(**kwargs) + return calc_rand_index(*args) -def cal_adjusted_rand_index(**kwargs): +def cal_adjusted_rand_index(*args): logger.warning( "🚨 Deprecated function, please use `calc_adjusted_rand_index` instead." ) - return calc_adjusted_rand_index(**kwargs) + return calc_adjusted_rand_index(*args) -def cal_nmi(**kwargs): +def cal_nmi(*args): logger.warning("🚨 Deprecated function, please use `calc_nmi` instead.") - return calc_nmi(**kwargs) + return calc_nmi(*args) -def cal_cluster_purity(**kwargs): +def cal_cluster_purity(*args): logger.warning("🚨 Deprecated function, please use `calc_cluster_purity` instead.") - return calc_cluster_purity(**kwargs) + return calc_cluster_purity(*args) -def cal_external_cluster_validation_metrics(**kwargs): +def cal_external_cluster_validation_metrics(*args): logger.warning( "🚨 Deprecated function, please use `calc_external_cluster_validation_metrics` instead." ) - return calc_external_cluster_validation_metrics(**kwargs) + return calc_external_cluster_validation_metrics(*args) -def cal_silhouette(**kwargs): +def cal_silhouette(*args): logger.warning("🚨 Deprecated function, please use `calc_silhouette` instead.") - return calc_silhouette(**kwargs) + return calc_silhouette(*args) -def cal_chs(**kwargs): +def cal_chs(*args): logger.warning("🚨 Deprecated function, please use `calc_chs` instead.") - return calc_chs(**kwargs) + return calc_chs(*args) -def cal_dbs(**kwargs): +def cal_dbs(*args): logger.warning("🚨 Deprecated function, please use `calc_dbs` instead.") - return calc_dbs(**kwargs) + return calc_dbs(*args) -def cal_internal_cluster_validation_metrics(**kwargs): +def cal_internal_cluster_validation_metrics(*args): logger.warning( "🚨 Deprecated function, please use `calc_internal_cluster_validation_metrics` instead." ) - return calc_internal_cluster_validation_metrics(**kwargs) + return calc_internal_cluster_validation_metrics(*args) diff --git a/pypots/utils/metrics/error.py b/pypots/utils/metrics/error.py index 8cafcb61..b8d10428 100644 --- a/pypots/utils/metrics/error.py +++ b/pypots/utils/metrics/error.py @@ -296,21 +296,21 @@ def calc_quantile_crps_sum(predictions, targets, eval_points, mean_scaler=0, sca ######################################################################################################################## -def cal_mae(**kwargs): +def cal_mae(*args): logger.warning("🚨 cal_mae() is deprecated, use calc_mae() instead.") - return calc_mae(**kwargs) + return calc_mae(*args) -def cal_rmse(**kwargs): +def cal_rmse(*args): logger.warning("🚨 cal_rmse() is deprecated, use calc_rmse() instead.") - return calc_rmse(**kwargs) + return calc_rmse(*args) -def cal_mse(**kwargs): +def cal_mse(*args): logger.warning("🚨 cal_mse() is deprecated, use calc_mse() instead.") - return calc_mse(**kwargs) + return calc_mse(*args) -def cal_mre(**kwargs): +def cal_mre(*args): logger.warning("🚨 cal_mre() is deprecated, use calc_mre() instead.") - return calc_mre(**kwargs) + return calc_mre(*args) From b3ee912cfad92d22139fa163a63225b2e6fe1b9a Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sun, 10 Dec 2023 18:28:11 +0800 Subject: [PATCH 12/16] fix: missing argument attn_dropout; --- pypots/imputation/transformer/model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pypots/imputation/transformer/model.py b/pypots/imputation/transformer/model.py index fd8541e4..cfe0b009 100644 --- a/pypots/imputation/transformer/model.py +++ b/pypots/imputation/transformer/model.py @@ -20,7 +20,7 @@ import torch from torch.utils.data import DataLoader -from .data import DatasetForSAITS +from .data import DatasetForTransformer from .modules import _TransformerEncoder from ..base import BaseNNImputer from ...data.base import BaseDataset @@ -189,6 +189,7 @@ def __init__( self.d_k, self.d_v, self.dropout, + self.attn_dropout, self.ORT_weight, self.MIT_weight, ) @@ -237,7 +238,7 @@ def fit( file_type: str = "h5py", ) -> None: # Step 1: wrap the input data with classes Dataset and DataLoader - training_set = DatasetForSAITS( + training_set = DatasetForTransformer( train_set, return_labels=False, file_type=file_type ) training_loader = DataLoader( From 922902fe0831c4da8f03d22809e78a80571e7577 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 11 Dec 2023 21:45:00 +0800 Subject: [PATCH 13/16] refactor: move pypots.modules to pypots.nn.modules; --- pypots/classification/grud/modules/core.py | 2 +- pypots/imputation/brits/modules/core.py | 2 +- pypots/imputation/saits/modules/core.py | 2 +- pypots/imputation/timesnet/modules/embedding.py | 2 +- pypots/imputation/transformer/modules/core.py | 2 +- pypots/modules/__init__.py | 10 +++++++++- pypots/nn/__init__.py | 6 ++++++ pypots/nn/modules/__init__.py | 6 ++++++ pypots/{ => nn}/modules/rnn.py | 0 pypots/{ => nn}/modules/transformer/__init__.py | 0 pypots/{ => nn}/modules/transformer/attention.py | 0 pypots/{ => nn}/modules/transformer/auto_encoder.py | 0 pypots/{ => nn}/modules/transformer/layers.py | 0 pypots/{ => nn}/modules/transformer/pos_enc.py | 0 14 files changed, 26 insertions(+), 6 deletions(-) create mode 100644 pypots/nn/__init__.py create mode 100644 pypots/nn/modules/__init__.py rename pypots/{ => nn}/modules/rnn.py (100%) rename pypots/{ => nn}/modules/transformer/__init__.py (100%) rename pypots/{ => nn}/modules/transformer/attention.py (100%) rename pypots/{ => nn}/modules/transformer/auto_encoder.py (100%) rename pypots/{ => nn}/modules/transformer/layers.py (100%) rename pypots/{ => nn}/modules/transformer/pos_enc.py (100%) diff --git a/pypots/classification/grud/modules/core.py b/pypots/classification/grud/modules/core.py index c1b873f4..10c48fb4 100644 --- a/pypots/classification/grud/modules/core.py +++ b/pypots/classification/grud/modules/core.py @@ -16,7 +16,7 @@ import torch.nn as nn import torch.nn.functional as F -from ....modules.rnn import TemporalDecay +from ....nn.modules.rnn import TemporalDecay class _GRUD(nn.Module): diff --git a/pypots/imputation/brits/modules/core.py b/pypots/imputation/brits/modules/core.py index 689d5582..58413c7c 100644 --- a/pypots/imputation/brits/modules/core.py +++ b/pypots/imputation/brits/modules/core.py @@ -20,7 +20,7 @@ import torch.nn as nn from .submodules import FeatureRegression -from ....modules.rnn import TemporalDecay +from ....nn.modules.rnn import TemporalDecay from ....utils.metrics import calc_mae diff --git a/pypots/imputation/saits/modules/core.py b/pypots/imputation/saits/modules/core.py index 51c2dfd5..05dd0287 100644 --- a/pypots/imputation/saits/modules/core.py +++ b/pypots/imputation/saits/modules/core.py @@ -19,7 +19,7 @@ import torch.nn as nn import torch.nn.functional as F -from ....modules.transformer import EncoderLayer, PositionalEncoding +from ....nn.modules.transformer import EncoderLayer, PositionalEncoding from ....utils.metrics import calc_mae diff --git a/pypots/imputation/timesnet/modules/embedding.py b/pypots/imputation/timesnet/modules/embedding.py index 70bd739e..0823fad2 100644 --- a/pypots/imputation/timesnet/modules/embedding.py +++ b/pypots/imputation/timesnet/modules/embedding.py @@ -11,7 +11,7 @@ import torch.fft import torch.nn as nn -from ....modules.transformer import PositionalEncoding +from ....nn.modules.transformer import PositionalEncoding class TokenEmbedding(nn.Module): diff --git a/pypots/imputation/transformer/modules/core.py b/pypots/imputation/transformer/modules/core.py index 7044358d..682d5bd2 100644 --- a/pypots/imputation/transformer/modules/core.py +++ b/pypots/imputation/transformer/modules/core.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn -from ....modules.transformer import EncoderLayer, PositionalEncoding +from ....nn.modules.transformer import EncoderLayer, PositionalEncoding from ....utils.metrics import calc_mae diff --git a/pypots/modules/__init__.py b/pypots/modules/__init__.py index 8121af3c..638464fe 100644 --- a/pypots/modules/__init__.py +++ b/pypots/modules/__init__.py @@ -1,6 +1,14 @@ """ -Frequently-used modules like self-attention modules of vanilla Transformer are put in this package. +Everything used to be in this package has been moved to pypots.nn.modules. +This package is kept for backward compatibility and will be removed in the future. """ # Created by Wenjie Du # License: BSD-3-Clause + +from ..utils.logging import logger + +logger.warning( + "🚨 pypots.modules package has been moved to pypots.nn.modules. " + "Please import everything from pypots.nn.modules instead." +) diff --git a/pypots/nn/__init__.py b/pypots/nn/__init__.py new file mode 100644 index 00000000..ceaa7ee3 --- /dev/null +++ b/pypots/nn/__init__.py @@ -0,0 +1,6 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause diff --git a/pypots/nn/modules/__init__.py b/pypots/nn/modules/__init__.py new file mode 100644 index 00000000..8121af3c --- /dev/null +++ b/pypots/nn/modules/__init__.py @@ -0,0 +1,6 @@ +""" +Frequently-used modules like self-attention modules of vanilla Transformer are put in this package. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause diff --git a/pypots/modules/rnn.py b/pypots/nn/modules/rnn.py similarity index 100% rename from pypots/modules/rnn.py rename to pypots/nn/modules/rnn.py diff --git a/pypots/modules/transformer/__init__.py b/pypots/nn/modules/transformer/__init__.py similarity index 100% rename from pypots/modules/transformer/__init__.py rename to pypots/nn/modules/transformer/__init__.py diff --git a/pypots/modules/transformer/attention.py b/pypots/nn/modules/transformer/attention.py similarity index 100% rename from pypots/modules/transformer/attention.py rename to pypots/nn/modules/transformer/attention.py diff --git a/pypots/modules/transformer/auto_encoder.py b/pypots/nn/modules/transformer/auto_encoder.py similarity index 100% rename from pypots/modules/transformer/auto_encoder.py rename to pypots/nn/modules/transformer/auto_encoder.py diff --git a/pypots/modules/transformer/layers.py b/pypots/nn/modules/transformer/layers.py similarity index 100% rename from pypots/modules/transformer/layers.py rename to pypots/nn/modules/transformer/layers.py diff --git a/pypots/modules/transformer/pos_enc.py b/pypots/nn/modules/transformer/pos_enc.py similarity index 100% rename from pypots/modules/transformer/pos_enc.py rename to pypots/nn/modules/transformer/pos_enc.py From 42d7c43ea2d6056fb050457a20b3c8484d8df9c9 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 11 Dec 2023 22:08:05 +0800 Subject: [PATCH 14/16] feat: add stable nonstationary norm and denorm functions; --- docs/references.bib | 7 ++ pypots/nn/functional/__init__.py | 14 ++++ pypots/nn/functional/normalization.py | 98 +++++++++++++++++++++++++++ 3 files changed, 119 insertions(+) create mode 100644 pypots/nn/functional/__init__.py create mode 100644 pypots/nn/functional/normalization.py diff --git a/docs/references.bib b/docs/references.bib index 28719e1d..c2ea2e35 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -466,3 +466,10 @@ @inproceedings{wu2023timesnet year={2023}, url={https://openreview.net/forum?id=ju_Uqw384Oq} } + +@inproceedings{liu2022nonstationary, + title={Non-stationary Transformers: Exploring the Stationarity in Time Series Forecasting}, + author={Liu, Yong and Wu, Haixu and Wang, Jianmin and Long, Mingsheng}, + booktitle={Advances in Neural Information Processing Systems}, + year={2022} +} diff --git a/pypots/nn/functional/__init__.py b/pypots/nn/functional/__init__.py new file mode 100644 index 00000000..36df2bc6 --- /dev/null +++ b/pypots/nn/functional/__init__.py @@ -0,0 +1,14 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from .normalization import nonstationary_norm, nonstationary_denorm + +__all__ = [ + # normalization functions + "nonstationary_norm", + "nonstationary_denorm", +] diff --git a/pypots/nn/functional/normalization.py b/pypots/nn/functional/normalization.py new file mode 100644 index 00000000..474bb637 --- /dev/null +++ b/pypots/nn/functional/normalization.py @@ -0,0 +1,98 @@ +""" +Store normalization functions for neural networks. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from typing import Tuple, Optional + +import torch + + +def nonstationary_norm( + X: torch.Tensor, + missing_mask: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Normalization from Non-stationary Transformer. Please refer to :cite:`liu2022nonstationary` for more details. + + Parameters + ---------- + X : torch.Tensor + Input data to be normalized. Shape: (n_samples, n_steps (seq_len), n_features). + + missing_mask : torch.Tensor, optional + Missing mask has the same shape as X. 1 indicates observed and 0 indicates missing. + + Returns + ------- + X_enc : torch.Tensor + Normalized data. Shape: (n_samples, n_steps (seq_len), n_features). + + means : torch.Tensor + Means values for de-normalization. Shape: (n_samples, n_features) or (n_samples, 1, n_features). + + stdev : torch.Tensor + Standard deviation values for de-normalization. Shape: (n_samples, n_features) or (n_samples, 1, n_features). + + """ + if torch.isnan(X).any(): + if missing_mask is None: + missing_mask = torch.isnan(X) + else: + raise ValueError("missing_mask is given but X still contains nan values.") + + if missing_mask is None: + means = X.mean(1, keepdim=True).detach() + X_enc = X - means + variance = torch.var(X_enc, dim=1, keepdim=True, unbiased=False) + 1e-9 + stdev = torch.sqrt(variance).detach() + else: + # for data contain missing values, add a small number to avoid dividing by 0 + missing_sum = torch.sum(missing_mask == 1, dim=1, keepdim=True) + 1e-9 + means = torch.sum(X, dim=1, keepdim=True) / missing_sum + X_enc = X - means + X_enc = X_enc.masked_fill(missing_mask == 0, 0) + variance = torch.sum(X_enc * X_enc, dim=1, keepdim=True) + 1e-9 + stdev = torch.sqrt(variance / missing_sum) + + X_enc /= stdev + return X_enc, means, stdev + + +def nonstationary_denorm( + X: torch.Tensor, + means: torch.Tensor, + stdev: torch.Tensor, +) -> torch.Tensor: + """De-Normalization from Non-stationary Transformer. Please refer to :cite:`liu2022nonstationary` for more details. + + Parameters + ---------- + X : torch.Tensor + Input data to be de-normalized. Shape: (n_samples, n_steps (seq_len), n_features). + + means : torch.Tensor + Means values for de-normalization . Shape: (n_samples, n_features) or (n_samples, 1, n_features). + + stdev : torch.Tensor + Standard deviation values for de-normalization. Shape: (n_samples, n_features) or (n_samples, 1, n_features). + + Returns + ------- + X_denorm : torch.Tensor + De-normalized data. Shape: (n_samples, n_steps (seq_len), n_features). + + """ + assert ( + len(X) == len(means) == len(stdev) + ), "Input data and normalization parameters should have the same number of samples." + if len(means.shape) == 2: + means = means.unsqueeze(1) + if len(stdev.shape) == 2: + stdev = stdev.unsqueeze(1) + + X = X * stdev # (stdev.repeat(1, n_steps, 1)) + X = X + means # (means.repeat(1, n_steps, 1)) + return X From bb817e647e2c1e4bd7fd0b50213e5f9e0708ab16 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 11 Dec 2023 22:10:12 +0800 Subject: [PATCH 15/16] fix: remove the TimesNet official implementation of nonstationary norm to fix #265; --- pypots/imputation/timesnet/model.py | 10 ++++++- pypots/imputation/timesnet/modules/core.py | 31 +++++++--------------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/pypots/imputation/timesnet/model.py b/pypots/imputation/timesnet/model.py index ec00835a..408f6cfd 100644 --- a/pypots/imputation/timesnet/model.py +++ b/pypots/imputation/timesnet/model.py @@ -21,12 +21,12 @@ from torch.utils.data import DataLoader from .data import DatasetForTimesNet -from ...utils.logging import logger from .modules.core import _TimesNet from ..base import BaseNNImputer from ...data.base import BaseDataset from ...optim.adam import Adam from ...optim.base import Optimizer +from ...utils.logging import logger class TimesNet(BaseNNImputer): @@ -59,6 +59,11 @@ class TimesNet(BaseNNImputer): dropout : The dropout rate for the model. + apply_nonstationary_norm : + Whether to apply non-stationary normalization to the input data for TimesNet. + Please refer to :cite:`liu2022nonstationary` for details about non-stationary normalization, + which is not the idea of the original TimesNet paper. Hence, we make it optional and default not to use here. + batch_size : The batch size for training and evaluating the model. @@ -117,6 +122,7 @@ def __init__( d_ffn: int, n_kernels: int, dropout: float = 0, + apply_nonstationary_norm: bool = False, batch_size: int = 32, epochs: int = 100, patience: int = None, @@ -145,6 +151,7 @@ def __init__( self.d_ffn = d_ffn self.n_kernels = n_kernels self.dropout = dropout + self.apply_nonstationary_norm = apply_nonstationary_norm # set up the model self.model = _TimesNet( @@ -156,6 +163,7 @@ def __init__( self.d_ffn, self.n_kernels, self.dropout, + self.apply_nonstationary_norm, ) self._send_model_to_given_device() self._print_model_size() diff --git a/pypots/imputation/timesnet/modules/core.py b/pypots/imputation/timesnet/modules/core.py index ff51fe86..87dea28c 100644 --- a/pypots/imputation/timesnet/modules/core.py +++ b/pypots/imputation/timesnet/modules/core.py @@ -5,12 +5,11 @@ # Created by Wenjie Du # License: BSD-3-Clause -import torch -import torch.fft import torch.nn as nn from .embedding import DataEmbedding from .layer import TimesBlock +from ....nn.functional import nonstationary_norm, nonstationary_denorm from ....utils.metrics import calc_mse @@ -25,11 +24,13 @@ def __init__( d_ffn, n_kernels, dropout, + apply_nonstationary_norm, ): super().__init__() self.seq_len = n_steps self.n_layers = n_layers + self.apply_nonstationary_norm = apply_nonstationary_norm self.pred_len = 0 # for the imputation task, the pred_len is always 0 self.model = nn.ModuleList( @@ -52,19 +53,12 @@ def __init__( def forward(self, inputs: dict, training: bool = True) -> dict: X, masks = inputs["X"], inputs["missing_mask"] - # Normalization from Non-stationary Transformer - means = torch.sum(X, dim=1) / torch.sum(masks == 1, dim=1) - means = means.unsqueeze(1).detach() - x_enc = X - means - x_enc = x_enc.masked_fill(masks == 0, 0) - stdev = torch.sqrt( - torch.sum(x_enc * x_enc, dim=1) / torch.sum(masks == 1, dim=1) + 1e-5 - ) - stdev = stdev.unsqueeze(1).detach() - x_enc /= stdev + if self.apply_nonstationary_norm: + # Normalization from Non-stationary Transformer + X, means, stdev = nonstationary_norm(X, masks) # embedding - enc_out = self.enc_embedding(x_enc) # [B,T,C] + enc_out = self.enc_embedding(X) # [B,T,C] # TimesNet for i in range(self.n_layers): enc_out = self.layer_norm(self.model[i](enc_out)) @@ -72,16 +66,11 @@ def forward(self, inputs: dict, training: bool = True) -> dict: # project back the original data space dec_out = self.projection(enc_out) - # De-Normalization from Non-stationary Transformer - dec_out = dec_out * ( - stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len + self.seq_len, 1) - ) - dec_out = dec_out + ( - means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len + self.seq_len, 1) - ) + if self.apply_nonstationary_norm: + # De-Normalization from Non-stationary Transformer + dec_out = nonstationary_denorm(dec_out, means, stdev) imputed_data = masks * X + (1 - masks) * dec_out - results = { "imputed_data": imputed_data, } From dd70ff215f145ae2db0a7add57a664fc3ab77a45 Mon Sep 17 00:00:00 2001 From: Jun Wang <140130585+AugustJW@users.noreply.github.com> Date: Tue, 12 Dec 2023 11:48:27 +0800 Subject: [PATCH 16/16] Add files via upload add two visualisation functions in utils --- pypots/utils/__init__.py | 1 + pypots/utils/visual.py | 102 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+) create mode 100644 pypots/utils/visual.py diff --git a/pypots/utils/__init__.py b/pypots/utils/__init__.py index 6c5074ee..fdf5413c 100644 --- a/pypots/utils/__init__.py +++ b/pypots/utils/__init__.py @@ -12,4 +12,5 @@ "logging", "metrics", "random", + "visual", ] diff --git a/pypots/utils/visual.py b/pypots/utils/visual.py new file mode 100644 index 00000000..835523ba --- /dev/null +++ b/pypots/utils/visual.py @@ -0,0 +1,102 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +import matplotlib.pyplot as plt +import pandas as pd + + +def plot_data(vals_obs, vals_eval, vals_imputed, dataidx = None, nrows = 10, ncols = 4, figsize=[24.0, 36.0]): + """ Plot the imputed values, the observed values, and the evaluated values of one multivariate timeseries. The observed values are marked with red 'x', the evaluated values are marked with blue 'o', and the imputed values are marked with solid green line. + Parameters + ---------- + vals_obs : ndarray, + The observed values + + vals_eval : ndarray, + The evaluated values + + vals_imputed : ndarray, + The imputed values + + dataidx : int, + The index of the sample to be plotted + + nrows : int, + The number of rows in the plot + + ncols : int, + The number of columns in the plot + + figsize : list, + The size of the figure + """ + n_s, n_l, n_c = vals_obs.shape + if dataidx == None: + dataidx = np.random.randint(low=0, high=n_s) + + n_k = nrows*ncols + K = np.min([n_c, n_k]) + L = n_l + plt.rcParams["font.size"] = 16 + fig, axes = plt.subplots(nrows=nrows, ncols=ncols,figsize=(figsize[0], figsize[1])) + # fig.delaxes(axes[-1][-1]) + + for k in range(K): + df = pd.DataFrame({"x":np.arange(0,L), "val":vals_imputed[dataidx,:,k]}) + df1 = pd.DataFrame({"x":np.arange(0,L), "val":vals_obs[dataidx,:,k]}) + df2 = pd.DataFrame({"x":np.arange(0,L), "val":vals_eval[dataidx,:,k]}) + row = k // ncols + col = k % ncols + axes[row][col].plot(df1.x,df1.val, color = 'r', marker = 'x', linestyle='None') + axes[row][col].plot(df2.x,df2.val, color = 'b', marker = 'o', linestyle='None') + axes[row][col].plot(df.x,df.val, color = 'g', linestyle='solid') + if col == 0: + plt.setp(axes[row, 0], ylabel='value') + if row == -1: + plt.setp(axes[-1, col], xlabel='time') + +def plot_missingness(mask, t_max = 1, t_min = 0, dataidx = None): + """ Plot the missingness pattern of one multivariate timeseries. For each feature, the observed timestamp is marked with blue '|'. The distribution of sequence lengths is also plotted. Hereby, the sequence length is defined as the number of observed timestamps in one feature. + Parameters + ---------- + mask : ndarray, + The mask matrix of one multivariate timeseries + + t_max : int, + The maximum time + + t_min : int, + The minimum time + + dataidx : int, + The index of the sample to be plotted + """ + n_s,n_l,n_c = mask.shape + time = np.repeat(np.repeat(np.linspace(0,t_max, n_l).reshape(1, n_l, 1), axis=2, repeats=n_c), axis=0, repeats=n_s) + if dataidx == None: + dataidx = np.random.randint(low=0, high=n_s) + fig, axes = plt.subplots(figsize=[12,3.5], dpi = 200, nrows=1, ncols=2) + plt.subplots_adjust(hspace=0.1) + seq_len = [] + sample = np.transpose(time[dataidx], (1, 0)) + mask_s = np.transpose(mask[dataidx], (1, 0)) + for feature_idx in range(n_c): + t = sample[feature_idx][mask_s[feature_idx]==1] + axes[0].scatter(t, np.ones_like(t)*(feature_idx), alpha=1, c='C0', marker="|") + seq_len.append(len(t)) + axes[0].set_title("Visualization of arrival times", fontsize=9) + axes[0].set_xlabel("Time", fontsize=7) + axes[0].set_ylabel("Features #", fontsize=7) + axes[0].set_xlim(-1, n_l) + # axes[0].set_ylim(0, n_c-1) + axes[0].tick_params(axis="both", labelsize=7) + + axes[1].set_title("Distribution of sequence lengths", fontsize=9) + axes[1].hist(seq_len, bins = n_l, color="C1", range=(t_min, t_max),) + axes[1].set_xlabel(r"Sequence length", fontsize=7) + axes[1].set_ylabel("Frequency", fontsize=7) + axes[1].tick_params(axis="both", labelsize=7) + plt.show()