diff --git a/README.md b/README.md
index 9ab1e108..7559a304 100644
--- a/README.md
+++ b/README.md
@@ -26,7 +26,7 @@
-
+
@@ -86,7 +86,7 @@ mae = cal_mae(imputation, X_intact, indicating_mask) # calculate mean absolute
| Task | Type | Algorithm | Year | Reference |
|-------------------------------|----------------|--------------------------------------------------------------------------|------|-----------|
| Imputation | Neural Network | SAITS (Self-Attention-based Imputation for Time Series) | 2023 | [^1] |
-| Imputation | Neural Network | Transformer | 2017 | [^2] [^1] |
+| Imputation | Neural Network | Transformer | 2017 | [^1] [^2] |
| Imputation,
Classification | Neural Network | BRITS (Bidirectional Recurrent Imputation for Time Series) | 2018 | [^3] |
| Imputation | Naive | LOCF (Last Observation Carried Forward) | - | - |
| Classification | Neural Network | GRU-D | 2018 | [^4] |
@@ -120,14 +120,14 @@ The documentation and tutorials are under construction. And a short paper introd
Thank you all for your attention! 😃
-[^1]: Du, W., Cote, D., & Liu, Y. (2023). [SAITS: Self-Attention-based Imputation for Time Series](https://doi.org/10.1016/j.eswa.2023.119619). Expert systems with applications.
-[^2]: Vaswani, A., Shazeer, N.M., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L., & Polosukhin, I. (2017). [Attention is All you Need](https://papers.nips.cc/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html). NeurIPS 2017.
-[^3]: Cao, W., Wang, D., Li, J., Zhou, H., Li, L., & Li, Y. (2018). [BRITS: Bidirectional Recurrent Imputation for Time Series](https://papers.nips.cc/paper/2018/hash/734e6bfcd358e25ac1db0a4241b95651-Abstract.html). NeurIPS 2018.
-[^4]: Che, Z., Purushotham, S., Cho, K., Sontag, D.A., & Liu, Y. (2018). [Recurrent Neural Networks for Multivariate Time Series with Missing Values](https://www.nature.com/articles/s41598-018-24271-9). Scientific Reports, 8.
-[^5]: Zhang, X., Zeman, M., Tsiligkaridis, T., & Zitnik, M. (2022). [Graph-Guided Network for Irregularly Sampled Multivariate Time Series](https://arxiv.org/abs/2110.05357). ICLR 2022.
-[^6]: Ma, Q., Chen, C., Li, S., & Cottrell, G. W. (2021). [Learning Representations for Incomplete Time Series Clustering](https://ojs.aaai.org/index.php/AAAI/article/view/17070). AAAI 2021.
-[^7]: Jong, J.D., Emon, M.A., Wu, P., Karki, R., Sood, M., Godard, P., Ahmad, A., Vrooman, H.A., Hofmann-Apitius, M., & Fröhlich, H. (2019). [Deep learning for clustering of multivariate clinical patient trajectories with missing values](https://academic.oup.com/gigascience/article/8/11/giz134/5626377). GigaScience, 8.
-[^8]: Chen, X., & Sun, L. (2021). [Bayesian Temporal Factorization for Multidimensional Time Series Prediction](https://arxiv.org/abs/1910.06366). IEEE transactions on pattern analysis and machine intelligence, PP.
+[^1]: Du, W., Cote, D., & Liu, Y. (2023). [SAITS: Self-Attention-based Imputation for Time Series](https://doi.org/10.1016/j.eswa.2023.119619). *Expert systems with applications*.
+[^2]: Vaswani, A., Shazeer, N.M., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L., & Polosukhin, I. (2017). [Attention is All you Need](https://papers.nips.cc/paper/2017/hash/3f5ee243547dee91fbd053c1c4a845aa-Abstract.html). *NeurIPS* 2017.
+[^3]: Cao, W., Wang, D., Li, J., Zhou, H., Li, L., & Li, Y. (2018). [BRITS: Bidirectional Recurrent Imputation for Time Series](https://papers.nips.cc/paper/2018/hash/734e6bfcd358e25ac1db0a4241b95651-Abstract.html). *NeurIPS* 2018.
+[^4]: Che, Z., Purushotham, S., Cho, K., Sontag, D.A., & Liu, Y. (2018). [Recurrent Neural Networks for Multivariate Time Series with Missing Values](https://www.nature.com/articles/s41598-018-24271-9). *Scientific Reports*.
+[^5]: Zhang, X., Zeman, M., Tsiligkaridis, T., & Zitnik, M. (2022). [Graph-Guided Network for Irregularly Sampled Multivariate Time Series](https://arxiv.org/abs/2110.05357). *ICLR* 2022.
+[^6]: Ma, Q., Chen, C., Li, S., & Cottrell, G. W. (2021). [Learning Representations for Incomplete Time Series Clustering](https://ojs.aaai.org/index.php/AAAI/article/view/17070). *AAAI* 2021.
+[^7]: Jong, J.D., Emon, M.A., Wu, P., Karki, R., Sood, M., Godard, P., Ahmad, A., Vrooman, H.A., Hofmann-Apitius, M., & Fröhlich, H. (2019). [Deep learning for clustering of multivariate clinical patient trajectories with missing values](https://academic.oup.com/gigascience/article/8/11/giz134/5626377). *GigaScience*.
+[^8]: Chen, X., & Sun, L. (2021). [Bayesian Temporal Factorization for Multidimensional Time Series Prediction](https://arxiv.org/abs/1910.06366). *IEEE transactions on pattern analysis and machine intelligence*.
🏠 Visits
diff --git a/pypots/__init__.py b/pypots/__init__.py
index 694fee76..51e916a6 100644
--- a/pypots/__init__.py
+++ b/pypots/__init__.py
@@ -5,5 +5,13 @@
# Created by Wenjie Du
# License: GPL-v3
-
from .__version__ import version as __version__
+
+__all__ = [
+ "data",
+ "imputation",
+ "classification",
+ "clustering",
+ "forecasting",
+ "utils",
+]
diff --git a/pypots/base.py b/pypots/base.py
index 0669285d..106f981d 100644
--- a/pypots/base.py
+++ b/pypots/base.py
@@ -4,11 +4,15 @@
# Created by Wenjie Du
# License: GLP-v3
+
+import os
from abc import ABC
import numpy as np
import torch
-from torch.utils.tensorboard import SummaryWriter
+
+from pypots.utils.logging import logger
+from pypots.utils.files import create_dir_if_not_exist
class BaseModel(ABC):
@@ -24,7 +28,7 @@ def __init__(self, device):
if torch.cuda.is_available() and torch.cuda.device_count() > 0
else "cpu"
)
- print("No given device, using default device:", self.device)
+ logger.info(f"No given device, using default device: {self.device}")
else:
self.device = device
@@ -136,21 +140,41 @@ def save_logs_to_tensorboard(self, saving_path):
# tb_summary_writer = SummaryWriter(saving_path)
# tb_summary_writer.add_custom_scalars(self.logger)
# tb_summary_writer.close()
- # print(f'Log saved successfully to {saving_path}.')
+ # logger.info(f'Log saved successfully to {saving_path}.')
- def save_model(self, saving_path):
+ def save_model(self, saving_dir, name, overwrite=False):
"""Save the model to a disk file.
+ A .pypots extension will be appended to the filename if it does not already have one.
+ Please note that such an extension is not necessary, but to indicate the saved model is from PyPOTS framework so people can distinguish.
+
Parameters
----------
- saving_path : str,
- The given path to save the model.
+ saving_dir : str,
+ The given directory to save the model.
+
+ name : str,
+ The file name of the model to be saved.
+
+ overwrite : bool,
+
"""
+ name = name + ".pypots" if name.split(".")[-1] != "pypots" else name
+ saving_path = os.path.join(saving_dir, name)
+ if os.path.exists(saving_path):
+ if overwrite:
+ logger.warning(
+ f"File {saving_path} exists. Argument `overwrite` is True. Overwriting now..."
+ )
+ else:
+ logger.error(f"File {saving_path} exists. Saving operation aborted.")
+ return
try:
+ create_dir_if_not_exist(saving_dir)
torch.save(self.model, saving_path)
+ logger.info(f"Saved successfully to {saving_path}.")
except Exception as e:
- print(e)
- print(f"Saved successfully to {saving_path}.")
+ raise RuntimeError(f'{e} Failed to save the model to "{saving_path}"!')
def load_model(self, model_path):
"""Load the saved model from a disk file.
@@ -174,7 +198,7 @@ def load_model(self, model_path):
self.model = loaded_model.model
except Exception as e:
raise e
- print(f"Model loaded successfully from {model_path}.")
+ logger.info(f"Model loaded successfully from {model_path}.")
class BaseNNModel(BaseModel):
@@ -202,6 +226,6 @@ def __init__(
def _print_model_size(self):
"""Print the number of trainable parameters in the initialized NN model."""
num_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
- print(
+ logger.info(
f"Model initialized successfully. Number of the trainable parameters: {num_params}"
)
diff --git a/pypots/classification/base.py b/pypots/classification/base.py
index 4ee02af0..54d40889 100644
--- a/pypots/classification/base.py
+++ b/pypots/classification/base.py
@@ -12,6 +12,7 @@
import torch
from pypots.base import BaseModel, BaseNNModel
+from pypots.utils.logging import logger
class BaseClassifier(BaseModel):
@@ -116,12 +117,12 @@ def _train_model(self, training_loader, val_loader=None):
mean_val_loss = np.mean(epoch_val_loss_collector)
self.logger["validating_loss"].append(mean_val_loss)
- print(
+ logger.info(
f"epoch {epoch}: training loss {mean_train_loss:.4f}, validating loss {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
- print(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
+ logger.info(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
mean_loss = mean_train_loss
if mean_loss < self.best_loss:
@@ -131,12 +132,12 @@ def _train_model(self, training_loader, val_loader=None):
else:
self.patience -= 1
if self.patience == 0:
- print(
+ logger.info(
"Exceeded the training patience. Terminating the training procedure..."
)
break
except Exception as e:
- print(f"Exception: {e}")
+ logger.info(f"Exception: {e}")
if self.best_model_dict is None:
raise RuntimeError(
"Training got interrupted. Model was not get trained. Please try fit() again."
@@ -151,4 +152,4 @@ def _train_model(self, training_loader, val_loader=None):
if np.equal(self.best_loss, float("inf")):
raise ValueError("Something is wrong. best_loss is Nan after training.")
- print("Finished training.")
+ logger.info("Finished training.")
diff --git a/pypots/classification/raindrop.py b/pypots/classification/raindrop.py
index 1772c165..d63f0560 100644
--- a/pypots/classification/raindrop.py
+++ b/pypots/classification/raindrop.py
@@ -27,6 +27,10 @@
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader
+from pypots.classification.base import BaseNNClassifier
+from pypots.data.dataset_for_grud import DatasetForGRUD
+from pypots.utils.logging import logger
+
try:
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import glorot
@@ -35,15 +39,12 @@
from torch_scatter import scatter
from torch_sparse import SparseTensor
except ImportError as e:
- print(
+ logger.error(
f"{e}\n"
"torch_geometric is missing, "
"please install it with 'pip install torch_geometric' or 'conda install -c pyg pyg'"
)
-from pypots.classification.base import BaseNNClassifier
-from pypots.data.dataset_for_grud import DatasetForGRUD
-
class PositionalEncodingTF(nn.Module):
"""Generate positional encoding according to time information."""
@@ -96,7 +97,7 @@ def __init__(
edge_dim: Optional[int] = None,
bias: bool = True,
root_weight: bool = True,
- **kwargs
+ **kwargs,
):
kwargs.setdefault("aggr", "add")
super().__init__(node_dim=0, **kwargs)
diff --git a/pypots/clustering/__init__.py b/pypots/clustering/__init__.py
index b8fb3b8e..0ad39300 100644
--- a/pypots/clustering/__init__.py
+++ b/pypots/clustering/__init__.py
@@ -8,4 +8,7 @@
from pypots.clustering.crli import CRLI
from pypots.clustering.vader import VaDER
-__all__ = ["CRLI", "VaDER"]
+__all__ = [
+ "CRLI",
+ "VaDER",
+]
diff --git a/pypots/clustering/base.py b/pypots/clustering/base.py
index e7420521..30f69f49 100644
--- a/pypots/clustering/base.py
+++ b/pypots/clustering/base.py
@@ -12,6 +12,7 @@
import torch
from pypots.base import BaseModel, BaseNNModel
+from pypots.utils.logging import logger
class BaseClusterer(BaseModel):
@@ -110,12 +111,12 @@ def _train_model(self, training_loader, val_loader=None):
mean_val_loss = np.mean(epoch_val_loss_collector)
self.logger["validating_loss"].append(mean_val_loss)
- print(
+ logger.info(
f"epoch {epoch}: training loss {mean_train_loss:.4f}, validating loss {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
- print(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
+ logger.info(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
mean_loss = mean_train_loss
if mean_loss < self.best_loss:
@@ -125,12 +126,12 @@ def _train_model(self, training_loader, val_loader=None):
else:
self.patience -= 1
if self.patience == 0:
- print(
+ logger.info(
"Exceeded the training patience. Terminating the training procedure..."
)
break
except Exception as e:
- print(f"Exception: {e}")
+ logger.info(f"Exception: {e}")
if self.best_model_dict is None:
raise RuntimeError(
"Training got interrupted. Model was not get trained. Please try fit() again."
@@ -145,4 +146,4 @@ def _train_model(self, training_loader, val_loader=None):
if np.equal(self.best_loss, float("inf")):
raise ValueError("Something is wrong. best_loss is Nan after training.")
- print("Finished training.")
+ logger.info("Finished training.")
diff --git a/pypots/clustering/crli.py b/pypots/clustering/crli.py
index 1903b515..f6305a41 100644
--- a/pypots/clustering/crli.py
+++ b/pypots/clustering/crli.py
@@ -16,6 +16,7 @@
from pypots.clustering.base import BaseNNClusterer
from pypots.data.dataset_for_grud import DatasetForGRUD
+from pypots.utils.logging import logger
from pypots.utils.metrics import cal_mse
RNN_CELL = {
@@ -437,7 +438,7 @@ def _train_model(self, training_loader, val_loader=None):
) # mean training loss of the current epoch
self.logger["training_loss_generator"].append(mean_train_G_loss)
self.logger["training_loss_discriminator"].append(mean_train_D_loss)
- print(
+ logger.info(
f"epoch {epoch}: "
f"training loss_generator {mean_train_G_loss:.4f}, "
f"train loss_discriminator {mean_train_D_loss:.4f}"
@@ -451,12 +452,12 @@ def _train_model(self, training_loader, val_loader=None):
else:
self.patience -= 1
if self.patience == 0:
- print(
+ logger.info(
"Exceeded the training patience. Terminating the training procedure..."
)
break
except Exception as e:
- print(f"Exception: {e}")
+ logger.info(f"Exception: {e}")
if self.best_model_dict is None:
raise RuntimeError(
"Training got interrupted. Model was not get trained. Please try fit() again."
@@ -471,7 +472,7 @@ def _train_model(self, training_loader, val_loader=None):
if np.equal(self.best_loss, float("inf")):
raise ValueError("Something is wrong. best_loss is Nan after training.")
- print("Finished training.")
+ logger.info("Finished training.")
def cluster(self, X):
X = self.check_input(self.n_steps, self.n_features, X)
diff --git a/pypots/clustering/vader.py b/pypots/clustering/vader.py
index 49d15d5e..7922daac 100644
--- a/pypots/clustering/vader.py
+++ b/pypots/clustering/vader.py
@@ -21,6 +21,7 @@
from pypots.clustering.base import BaseNNClusterer
from pypots.data.dataset_for_grud import DatasetForGRUD
+from pypots.utils.logging import logger
from pypots.utils.metrics import cal_mse
@@ -478,12 +479,12 @@ def _train_model(self, training_loader, val_loader=None):
mean_val_loss = np.mean(epoch_val_loss_collector)
self.logger["validating_loss"].append(mean_val_loss)
- print(
+ logger.info(
f"epoch {epoch}: training loss {mean_train_loss:.4f}, validating loss {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
- print(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
+ logger.info(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
mean_loss = mean_train_loss
if mean_loss < self.best_loss:
@@ -493,12 +494,12 @@ def _train_model(self, training_loader, val_loader=None):
else:
self.patience -= 1
if self.patience == 0:
- print(
+ logger.info(
"Exceeded the training patience. Terminating the training procedure..."
)
break
except Exception as e:
- print(f"Exception: {e}")
+ logger.info(f"Exception: {e}")
if self.best_model_dict is None:
raise RuntimeError(
"Training got interrupted. Model was not get trained. Please try fit() again."
@@ -513,7 +514,7 @@ def _train_model(self, training_loader, val_loader=None):
if np.equal(self.best_loss, float("inf")):
raise ValueError("Something is wrong. best_loss is Nan after training.")
- print("Finished training.")
+ logger.info("Finished training.")
def cluster(self, X):
X = self.check_input(self.n_steps, self.n_features, X)
diff --git a/pypots/data/load_specific_datasets.py b/pypots/data/load_specific_datasets.py
index 9000ad92..7a4cf366 100644
--- a/pypots/data/load_specific_datasets.py
+++ b/pypots/data/load_specific_datasets.py
@@ -7,6 +7,7 @@
import pandas as pd
import tsdb
+from pypots.utils.logging import logger
SUPPORTED_DATASETS = [
"physionet_2012",
@@ -80,7 +81,7 @@ def load_specific_dataset(dataset_name, use_cache=True):
e.g. standardizing and splitting.
"""
- print(
+ logger.info(
f"Loading the dataset {dataset_name} with TSDB (https://github.com/WenjieDu/Time_Series_Database)..."
)
assert dataset_name in SUPPORTED_DATASETS, (
@@ -89,7 +90,7 @@ def load_specific_dataset(dataset_name, use_cache=True):
f"please create an issue on GitHub "
f"https://github.com/WenjieDu/PyPOTS/issues"
)
- print(f"Starting preprocessing {dataset_name}...")
+ logger.info(f"Starting preprocessing {dataset_name}...")
data = tsdb.load_dataset(dataset_name, use_cache)
data = PREPROCESSING[dataset_name](data)
return data
diff --git a/pypots/forecasting/__init__.py b/pypots/forecasting/__init__.py
index c28fe0fd..a43138f2 100644
--- a/pypots/forecasting/__init__.py
+++ b/pypots/forecasting/__init__.py
@@ -7,4 +7,6 @@
from pypots.forecasting.bttf import BTTF
-__all__ = ["BTTF"]
+__all__ = [
+ "BTTF",
+]
diff --git a/pypots/forecasting/base.py b/pypots/forecasting/base.py
index ce930fed..64beadde 100644
--- a/pypots/forecasting/base.py
+++ b/pypots/forecasting/base.py
@@ -12,7 +12,7 @@
import torch
from pypots.base import BaseModel, BaseNNModel
-
+from pypots.utils.logging import logger
class BaseForecaster(BaseModel):
"""Abstract class for all forecasting models."""
@@ -102,12 +102,12 @@ def _train_model(self, training_loader, val_loader=None):
mean_val_loss = np.mean(epoch_val_loss_collector)
self.logger["validating_loss"].append(mean_val_loss)
- print(
+ logger.info(
f"epoch {epoch}: training loss {mean_train_loss:.4f}, validating loss {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
- print(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
+ logger.info(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
mean_loss = mean_train_loss
if mean_loss < self.best_loss:
@@ -117,12 +117,12 @@ def _train_model(self, training_loader, val_loader=None):
else:
self.patience -= 1
if self.patience == 0:
- print(
+ logger.info(
"Exceeded the training patience. Terminating the training procedure..."
)
break
except Exception as e:
- print(f"Exception: {e}")
+ logger.info(f"Exception: {e}")
if self.best_model_dict is None:
raise RuntimeError(
"Training got interrupted. Model was not get trained. Please try fit() again."
@@ -137,4 +137,4 @@ def _train_model(self, training_loader, val_loader=None):
if np.equal(self.best_loss, float("inf")):
raise ValueError("Something is wrong. best_loss is Nan after training.")
- print("Finished training.")
+ logger.info("Finished training.")
diff --git a/pypots/forecasting/bttf.py b/pypots/forecasting/bttf.py
index e119783c..03711d5f 100644
--- a/pypots/forecasting/bttf.py
+++ b/pypots/forecasting/bttf.py
@@ -22,6 +22,7 @@
from scipy.stats import wishart
from pypots.forecasting.base import BaseForecaster
+from pypots.utils.logging import logger
def mvnrnd_pre(mu, Lambda):
@@ -259,9 +260,9 @@ def _BTTF(
temp_hat += tensor_hat[pos_test]
if (it + 1) % show_iter == 0 and it < burn_iter:
# temp_hat = temp_hat / show_iter
- # print('Iter: {}'.format(it + 1))
- # print('MAPE: {:.6}'.format(compute_mape(dense_test, temp_hat)))
- # print('RMSE: {:.6}'.format(compute_rmse(dense_test, temp_hat)))
+ # logger.info('Iter: {}'.format(it + 1))
+ # logger.info('MAPE: {:.6}'.format(compute_mape(dense_test, temp_hat)))
+ # logger.info('RMSE: {:.6}'.format(compute_rmse(dense_test, temp_hat)))
temp_hat = np.zeros(len(pos_test[0]))
if it + 1 > burn_iter:
U_plus[:, :, it - burn_iter] = U
@@ -274,8 +275,8 @@ def _BTTF(
X_plus[:, :, it - burn_iter] = X0
tensor_new_plus += np.einsum("is, js, ts -> ijt", U, V, X0[-multi_step:, :])
tensor_hat = tensor_hat_plus / gibbs_iter
- # print('Imputation MAPE: {:.6}'.format(compute_mape(dense_test, tensor_hat[:, :, : dim3][pos_test])))
- # print('Imputation RMSE: {:.6}'.format(compute_rmse(dense_test, tensor_hat[:, :, : dim3][pos_test])))
+ # logger.info('Imputation MAPE: {:.6}'.format(compute_mape(dense_test, tensor_hat[:, :, : dim3][pos_test])))
+ # logger.info('Imputation RMSE: {:.6}'.format(compute_rmse(dense_test, tensor_hat[:, :, : dim3][pos_test])))
tensor_hat = np.append(tensor_hat, tensor_new_plus / gibbs_iter, axis=2)
tensor_hat[tensor_hat < 0] = 0
diff --git a/pypots/imputation/base.py b/pypots/imputation/base.py
index c40c09d0..a7290e14 100644
--- a/pypots/imputation/base.py
+++ b/pypots/imputation/base.py
@@ -13,6 +13,7 @@
from pypots.base import BaseModel, BaseNNModel
from pypots.utils.metrics import cal_mae
+from pypots.utils.logging import logger
try:
import nni
@@ -121,12 +122,12 @@ def _train_model(
imputation_collector, val_X_intact, val_indicating_mask
)
self.logger["validating_loss"].append(mean_val_loss)
- print(
+ logger.info(
f"epoch {epoch}: training loss {mean_train_loss:.4f}, validating loss {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
- print(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
+ logger.info(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
mean_loss = mean_train_loss
if mean_loss < self.best_loss:
@@ -142,13 +143,13 @@ def _train_model(
nni.report_final_result(self.best_loss)
if self.patience == 0:
- print(
+ logger.info(
"Exceeded the training patience. Terminating the training procedure..."
)
break
except Exception as e:
- print(f"Exception: {e}")
+ logger.info(f"Exception: {e}")
if self.best_model_dict is None:
raise RuntimeError(
"Training got interrupted. Model was not get trained. Please try fit() again."
@@ -163,4 +164,4 @@ def _train_model(
if np.equal(self.best_loss.item(), float("inf")):
raise ValueError("Something is wrong. best_loss is Nan after training.")
- print("Finished training.")
+ logger.info("Finished training.")
diff --git a/pypots/tests/test_classification.py b/pypots/tests/test_classification.py
index 9f283ab2..8148980b 100644
--- a/pypots/tests/test_classification.py
+++ b/pypots/tests/test_classification.py
@@ -10,6 +10,7 @@
from pypots.classification import BRITS, GRUD, Raindrop
from pypots.tests.unified_data_for_test import DATA
from pypots.utils.metrics import cal_binary_classification_metrics
+from pypots.utils.logging import logger
EPOCHS = 5
@@ -22,7 +23,7 @@ def setUp(self) -> None:
self.val_y = DATA["val_y"]
self.test_X = DATA["test_X"]
self.test_y = DATA["test_y"]
- print("Running test cases for BRITS...")
+ logger.info("Running test cases for BRITS...")
self.brits = BRITS(
DATA["n_steps"],
DATA["n_features"],
@@ -48,7 +49,7 @@ def test_parameters(self):
def test_classify(self):
predictions = self.brits.classify(self.test_X)
metrics = cal_binary_classification_metrics(predictions, self.test_y)
- print(
+ logger.info(
f'ROC_AUC: {metrics["roc_auc"]}, \n'
f'PR_AUC: {metrics["pr_auc"]},\n'
f'F1: {metrics["f1"]},\n'
@@ -66,7 +67,7 @@ def setUp(self) -> None:
self.val_y = DATA["val_y"]
self.test_X = DATA["test_X"]
self.test_y = DATA["test_y"]
- print("Running test cases for GRUD...")
+ logger.info("Running test cases for GRUD...")
self.grud = GRUD(
DATA["n_steps"],
DATA["n_features"],
@@ -92,7 +93,7 @@ def test_parameters(self):
def test_classify(self):
predictions = self.grud.classify(self.test_X)
metrics = cal_binary_classification_metrics(predictions, self.test_y)
- print(
+ logger.info(
f'ROC_AUC: {metrics["roc_auc"]}, \n'
f'PR_AUC: {metrics["pr_auc"]},\n'
f'F1: {metrics["f1"]},\n'
@@ -110,7 +111,7 @@ def setUp(self) -> None:
self.val_y = DATA["val_y"]
self.test_X = DATA["test_X"]
self.test_y = DATA["test_y"]
- print("Running test cases for Raindrop...")
+ logger.info("Running test cases for Raindrop...")
self.raindrop = Raindrop(
DATA["n_features"],
2,
@@ -146,7 +147,7 @@ def test_parameters(self):
def test_classify(self):
predictions = self.raindrop.classify(self.test_X)
metrics = cal_binary_classification_metrics(predictions, self.test_y)
- print(
+ logger.info(
f'ROC_AUC: {metrics["roc_auc"]}, \n'
f'PR_AUC: {metrics["pr_auc"]},\n'
f'F1: {metrics["f1"]},\n'
diff --git a/pypots/tests/test_clustering.py b/pypots/tests/test_clustering.py
index 65f59d68..ce22c64a 100644
--- a/pypots/tests/test_clustering.py
+++ b/pypots/tests/test_clustering.py
@@ -11,6 +11,7 @@
import numpy as np
from pypots.clustering import VaDER, CRLI
+from pypots.utils.logging import logger
from pypots.tests.unified_data_for_test import DATA
from pypots.utils.metrics import cal_rand_index, cal_cluster_purity
@@ -21,7 +22,7 @@ class TestCRLI(unittest.TestCase):
def setUp(self) -> None:
self.train_X = DATA["train_X"]
self.train_y = DATA["train_y"]
- print("Running test cases for CRLI...")
+ logger.info("Running test cases for CRLI...")
self.crli = CRLI(
n_steps=DATA["n_steps"],
n_features=DATA["n_features"],
@@ -50,14 +51,14 @@ def test_cluster(self):
clustering = self.crli.cluster(self.train_X)
RI = cal_rand_index(clustering, self.train_y)
CP = cal_cluster_purity(clustering, self.train_y)
- print(f"RI: {RI}\nCP: {CP}")
+ logger.info(f"RI: {RI}\nCP: {CP}")
class TestVaDER(unittest.TestCase):
def setUp(self) -> None:
self.train_X = DATA["train_X"]
self.train_y = DATA["train_y"]
- print("Running test cases for VaDER...")
+ logger.info("Running test cases for VaDER...")
self.vader = VaDER(
n_steps=DATA["n_steps"],
n_features=DATA["n_features"],
@@ -87,9 +88,9 @@ def test_cluster(self):
clustering = self.vader.cluster(self.train_X)
RI = cal_rand_index(clustering, self.train_y)
CP = cal_cluster_purity(clustering, self.train_y)
- print(f"RI: {RI}\nCP: {CP}")
+ logger.info(f"RI: {RI}\nCP: {CP}")
except np.linalg.LinAlgError as e:
- print(
+ logger.info(
f"{e}\n"
"Got singular matrix, please try to retrain the model to fix this"
)
diff --git a/pypots/tests/test_forecasting.py b/pypots/tests/test_forecasting.py
index 74a9fb60..27734a68 100644
--- a/pypots/tests/test_forecasting.py
+++ b/pypots/tests/test_forecasting.py
@@ -11,6 +11,7 @@
from pypots.forecasting import BTTF
from pypots.tests.unified_data_for_test import gene_random_walk_data
+from pypots.utils.logging import logger
from pypots.utils.metrics import cal_mae
EPOCHS = 5
@@ -22,7 +23,7 @@ def setUp(self) -> None:
self.test_X = DATA["test_X"]
self.test_X_intact = DATA["test_X_intact"]
self.test_X_for_input = self.test_X[:, :100]
- print("Running test cases for BTTF...")
+ logger.info("Running test cases for BTTF...")
self.bttf = BTTF(
100,
10,
@@ -37,7 +38,8 @@ def setUp(self) -> None:
def test_forecasting(self):
predictions = self.bttf.forecast(self.test_X_for_input)
mae = cal_mae(predictions, self.test_X_intact[:, 100:])
- print(f"prediction MAE: {mae}")
+ logger.info(f"prediction MAE: {mae}")
- if __name__ == "__main__":
- unittest.main()
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/pypots/tests/test_imputation.py b/pypots/tests/test_imputation.py
index 4219aa4a..957a4d34 100644
--- a/pypots/tests/test_imputation.py
+++ b/pypots/tests/test_imputation.py
@@ -18,6 +18,7 @@
)
from pypots.tests.unified_data_for_test import DATA
from pypots.utils.metrics import cal_mae
+from pypots.utils.logging import logger
EPOCH = 5
@@ -29,7 +30,7 @@ def setUp(self) -> None:
self.test_X = DATA["test_X"]
self.test_X_intact = DATA["test_X_intact"]
self.test_X_indicating_mask = DATA["test_X_indicating_mask"]
- print("Running test cases for SAITS...")
+ logger.info("Running test cases for SAITS...")
self.saits = SAITS(
DATA["n_steps"],
DATA["n_features"],
@@ -63,7 +64,7 @@ def test_impute(self):
imputed_X
).any(), "Output still has missing values after running impute()."
test_MAE = cal_mae(imputed_X, self.test_X_intact, self.test_X_indicating_mask)
- print(f"SAITS test_MAE: {test_MAE}")
+ logger.info(f"SAITS test_MAE: {test_MAE}")
class TestTransformer(unittest.TestCase):
@@ -73,7 +74,7 @@ def setUp(self) -> None:
self.test_X = DATA["test_X"]
self.test_X_intact = DATA["test_X_intact"]
self.test_X_indicating_mask = DATA["test_X_indicating_mask"]
- print("Running test cases for Transformer...")
+ logger.info("Running test cases for Transformer...")
self.transformer = Transformer(
DATA["n_steps"],
DATA["n_features"],
@@ -110,7 +111,7 @@ def test_impute(self):
imputed_X
).any(), "Output still has missing values after running impute()."
test_MAE = cal_mae(imputed_X, self.test_X_intact, self.test_X_indicating_mask)
- print(f"Transformer test_MAE: {test_MAE}")
+ logger.info(f"Transformer test_MAE: {test_MAE}")
class TestBRITS(unittest.TestCase):
@@ -120,7 +121,7 @@ def setUp(self) -> None:
self.test_X = DATA["test_X"]
self.test_X_intact = DATA["test_X_intact"]
self.test_X_indicating_mask = DATA["test_X_indicating_mask"]
- print("Running test cases for BRITS...")
+ logger.info("Running test cases for BRITS...")
self.brits = BRITS(DATA["n_steps"], DATA["n_features"], 256, epochs=EPOCH)
self.brits.fit(self.train_X, self.val_X)
@@ -143,7 +144,7 @@ def test_impute(self):
imputed_X
).any(), "Output still has missing values after running impute()."
test_MAE = cal_mae(imputed_X, self.test_X_intact, self.test_X_indicating_mask)
- print(f"BRITS test_MAE: {test_MAE}")
+ logger.info(f"BRITS test_MAE: {test_MAE}")
class TestLOCF(unittest.TestCase):
@@ -153,7 +154,7 @@ def setUp(self) -> None:
self.test_X = DATA["test_X"]
self.test_X_intact = DATA["test_X_intact"]
self.test_X_indicating_mask = DATA["test_X_indicating_mask"]
- print("Running test cases for LOCF...")
+ logger.info("Running test cases for LOCF...")
self.locf = LOCF(nan=0)
def test_parameters(self):
@@ -167,7 +168,7 @@ def test_impute(self):
test_MAE = cal_mae(
test_X_imputed, self.test_X_intact, self.test_X_indicating_mask
)
- print(f"LOCF test_MAE: {test_MAE}")
+ logger.info(f"LOCF test_MAE: {test_MAE}")
if __name__ == "__main__":
diff --git a/pypots/tests/test_logging.py b/pypots/tests/test_logging.py
new file mode 100644
index 00000000..3ebc3fca
--- /dev/null
+++ b/pypots/tests/test_logging.py
@@ -0,0 +1,43 @@
+"""
+Test cases of logging.
+"""
+import os
+import shutil
+import unittest
+
+from pypots.utils.logging import Logger
+
+
+# Created by Wenjie Du
+# License: GPL-v3
+
+
+class TestLogger(unittest.TestCase):
+ def setUp(self) -> None:
+ self.logger_creator = Logger(name="PyPOTS testing log", logging_level="debug")
+ self.logger = self.logger_creator.logger
+
+ def test_different_level_logging(self):
+ self.logger.debug('debug')
+ self.logger.info('info')
+ self.logger.warning('warning')
+ self.logger.error('error')
+
+ def test_changing_level(self):
+ self.logger_creator.set_level('info')
+ assert self.logger.level == 20, f'the level of logger is {self.logger.level}, not INFO'
+ self.logger_creator.set_level('warning')
+ assert self.logger.level == 30, f'the level of logger is {self.logger.level}, not WARNING'
+ self.logger_creator.set_level('error')
+ assert self.logger.level == 40, f'the level of logger is {self.logger.level}, not ERROR'
+ self.logger_creator.set_level('debug')
+ assert self.logger.level == 10, f'the level of logger is {self.logger.level}, not DEBUG'
+
+ def test_saving_log_into_file(self):
+ self.logger_creator.set_saving_path('test_log', 'testing.log')
+ assert os.path.exists('test_log/testing.log')
+ shutil.rmtree('test_log', ignore_errors=True)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/pypots/utils/__init__.py b/pypots/utils/__init__.py
index 2d6f3394..8fc1eb72 100644
--- a/pypots/utils/__init__.py
+++ b/pypots/utils/__init__.py
@@ -4,3 +4,9 @@
# Created by Wenjie Du
# License: GPL-v3
+
+__all__ = [
+ "files",
+ "logging",
+ "metrics",
+]
diff --git a/pypots/utils/files.py b/pypots/utils/files.py
new file mode 100644
index 00000000..6b2344f3
--- /dev/null
+++ b/pypots/utils/files.py
@@ -0,0 +1,49 @@
+"""
+Utilities for checking things.
+"""
+
+# Created by Wenjie Du
+# License: GLP-v3
+
+import os
+
+from pypots.utils.logging import logger
+
+
+def extract_parent_dir(path):
+ """Extract the given path's parent directory.
+
+ Parameters
+ ----------
+ path : str,
+ The path for extracting.
+
+ Returns
+ -------
+ parent_dir : str
+ The path to the parent dir of the given path.
+
+ """
+ parent_dir = os.path.abspath(os.path.join(path, ".."))
+ return parent_dir
+
+
+def create_dir_if_not_exist(path, is_dir=True):
+ """Create the given directory if it doesn't exist.
+
+ Parameters
+ ----------
+ path : str,
+ The path for check.
+
+ is_dir : bool,
+ Whether the given path is to a directory. If `is_dir` is False, the given path is to a file or an object,
+ then this file's parent directory will be checked.
+
+ """
+ path = extract_parent_dir(path) if not is_dir else path
+ if os.path.exists(path):
+ logger.info(f'The given directory "{path}" exists.')
+ else:
+ os.makedirs(path, exist_ok=True)
+ logger.info(f'Successfully created "{path}".')
diff --git a/pypots/utils/logging.py b/pypots/utils/logging.py
new file mode 100644
index 00000000..b97285ac
--- /dev/null
+++ b/pypots/utils/logging.py
@@ -0,0 +1,108 @@
+"""
+Configure logging here.
+"""
+
+# Created by Wenjie Du
+# License: GPL-v3
+
+import logging
+import os
+
+LEVELS = {
+ "debug": logging.DEBUG,
+ "info": logging.INFO,
+ "warning": logging.WARNING,
+ "error": logging.ERROR,
+}
+
+
+class Logger:
+ """Logging class for PyPOTS logger initialization and control."""
+
+ def __init__(
+ self,
+ name="PyPOTS running log",
+ logging_level="debug",
+ logging_format="%(asctime)s [%(levelname)s]: %(message)s",
+ ):
+ """
+ Parameters
+ ----------
+ name : str,
+ The name for the logger to be initialized.
+
+ logging_level : str,
+ The logging level of the logger, should be debug/info/warning/error.
+
+ logging_format : str,
+ Logging format of the logger.
+ """
+
+ assert (
+ logging_level in LEVELS.keys()
+ ), f"logging_level should be {list(LEVELS.keys())}, but got {logging_level}"
+
+ self.logger = logging.getLogger(name)
+ self.logging_level = LEVELS[logging_level]
+
+ self.stream_handler = logging.StreamHandler()
+ self.formatter = None
+ self.file_handler = None
+
+ self.set_level(logging_level)
+ self.set_logging_format(logging_format)
+ self.logger.propagate = False
+
+ def set_logging_format(self, logging_format):
+ self.formatter = logging.Formatter(logging_format, datefmt="%Y-%m-%d %H:%M:%S")
+ self.stream_handler.setFormatter(self.formatter)
+ self.logger.addHandler(self.stream_handler)
+ if self.file_handler is not None:
+ self.file_handler.setFormatter(self.formatter)
+ self.logger.addHandler(self.file_handler)
+
+ def set_saving_path(self, saving_dir, name, mode="a"):
+ """Set the logger's saving path. This function will enable saving logs to the specified path.
+
+ Parameters
+ ----------
+ saving_dir : str,
+ The path to the directory for logging file saving.
+
+ name : str,
+ The name of the logging file to be saved.
+
+ mode : str, default = 'a'
+ Logging file writing mode.
+
+ """
+ if not os.path.exists(saving_dir):
+ self.logger.warning(f'{saving_dir} does not exist. Creating it now...')
+ os.makedirs(saving_dir)
+ path = os.path.join(saving_dir, name)
+ self.file_handler = logging.FileHandler(path, mode=mode)
+ self.file_handler.setLevel(self.logging_level)
+ self.file_handler.setFormatter(self.formatter)
+ self.logger.addHandler(self.file_handler)
+ self.logger.info(f"Log will be saved to {path}")
+
+ def set_level(self, level):
+ """Set the logger's logging level.
+
+ Parameters
+ ----------
+ level : str,
+ The logging level of the logger, should be debug/info/warning/error.
+
+ """
+ self.logging_level = LEVELS[level]
+ self.logger.setLevel(self.logging_level)
+ if self.stream_handler is not None:
+ self.stream_handler.setLevel(self.logging_level)
+ if self.file_handler is not None:
+ self.file_handler.setLevel(self.logging_level)
+
+
+# initialize a logger for PyPOTS logging
+logger_creator = Logger()
+logger = logger_creator.logger