-
-
Notifications
You must be signed in to change notification settings - Fork 107
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test: add testing cases for TimeMixer;
- Loading branch information
Showing
1 changed file
with
128 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
""" | ||
Test cases for TimeMixer imputation model. | ||
""" | ||
|
||
# Created by Wenjie Du <[email protected]> | ||
# License: BSD-3-Clause | ||
|
||
|
||
import os.path | ||
import unittest | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from pypots.imputation import TimeMixer | ||
from pypots.optim import Adam | ||
from pypots.utils.logging import logger | ||
from pypots.utils.metrics import calc_mse | ||
from tests.global_test_config import ( | ||
DATA, | ||
EPOCHS, | ||
DEVICE, | ||
TRAIN_SET, | ||
VAL_SET, | ||
TEST_SET, | ||
GENERAL_H5_TRAIN_SET_PATH, | ||
GENERAL_H5_VAL_SET_PATH, | ||
GENERAL_H5_TEST_SET_PATH, | ||
RESULT_SAVING_DIR_FOR_IMPUTATION, | ||
check_tb_and_model_checkpoints_existence, | ||
) | ||
|
||
|
||
class TestTimeMixer(unittest.TestCase): | ||
logger.info("Running tests for an imputation model TimeMixer...") | ||
|
||
# set the log and model saving path | ||
saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "TimeMixer") | ||
model_save_name = "saved_timemixer_model.pypots" | ||
|
||
# initialize an Adam optimizer | ||
optimizer = Adam(lr=0.001, weight_decay=1e-5) | ||
|
||
# initialize a TimeMixer model | ||
timemixer = TimeMixer( | ||
DATA["n_steps"], | ||
DATA["n_features"], | ||
n_layers=2, | ||
top_k=5, | ||
d_model=512, | ||
d_ffn=512, | ||
dropout=0.1, | ||
epochs=EPOCHS, | ||
saving_path=saving_path, | ||
optimizer=optimizer, | ||
device=DEVICE, | ||
) | ||
|
||
@pytest.mark.xdist_group(name="imputation-timemixer") | ||
def test_0_fit(self): | ||
self.timemixer.fit(TRAIN_SET, VAL_SET) | ||
|
||
@pytest.mark.xdist_group(name="imputation-timemixer") | ||
def test_1_impute(self): | ||
imputation_results = self.timemixer.predict(TEST_SET) | ||
assert not np.isnan( | ||
imputation_results["imputation"] | ||
).any(), "Output still has missing values after running impute()." | ||
|
||
test_MSE = calc_mse( | ||
imputation_results["imputation"], | ||
DATA["test_X_ori"], | ||
DATA["test_X_indicating_mask"], | ||
) | ||
logger.info(f"TimeMixer test_MSE: {test_MSE}") | ||
|
||
@pytest.mark.xdist_group(name="imputation-timemixer") | ||
def test_2_parameters(self): | ||
assert hasattr(self.timemixer, "model") and self.timemixer.model is not None | ||
|
||
assert ( | ||
hasattr(self.timemixer, "optimizer") | ||
and self.timemixer.optimizer is not None | ||
) | ||
|
||
assert hasattr(self.timemixer, "best_loss") | ||
self.assertNotEqual(self.timemixer.best_loss, float("inf")) | ||
|
||
assert ( | ||
hasattr(self.timemixer, "best_model_dict") | ||
and self.timemixer.best_model_dict is not None | ||
) | ||
|
||
@pytest.mark.xdist_group(name="imputation-timemixer") | ||
def test_3_saving_path(self): | ||
# whether the root saving dir exists, which should be created by save_log_into_tb_file | ||
assert os.path.exists( | ||
self.saving_path | ||
), f"file {self.saving_path} does not exist" | ||
|
||
# check if the tensorboard file and model checkpoints exist | ||
check_tb_and_model_checkpoints_existence(self.timemixer) | ||
|
||
# save the trained model into file, and check if the path exists | ||
saved_model_path = os.path.join(self.saving_path, self.model_save_name) | ||
self.timemixer.save(saved_model_path) | ||
|
||
# test loading the saved model, not necessary, but need to test | ||
self.timemixer.load(saved_model_path) | ||
|
||
@pytest.mark.xdist_group(name="imputation-timemixer") | ||
def test_4_lazy_loading(self): | ||
self.timemixer.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH) | ||
imputation_results = self.timemixer.predict(GENERAL_H5_TEST_SET_PATH) | ||
assert not np.isnan( | ||
imputation_results["imputation"] | ||
).any(), "Output still has missing values after running impute()." | ||
|
||
test_MSE = calc_mse( | ||
imputation_results["imputation"], | ||
DATA["test_X_ori"], | ||
DATA["test_X_indicating_mask"], | ||
) | ||
logger.info(f"Lazy-loading TimeMixer test_MSE: {test_MSE}") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |