diff --git a/bibmon/__init__.py b/bibmon/__init__.py index 3dd3ee7..affb6c0 100644 --- a/bibmon/__init__.py +++ b/bibmon/__init__.py @@ -5,11 +5,11 @@ from ._sklearn_regressor import sklearnRegressor from ._preprocess import PreProcess from ._load_data import load_tennessee_eastman, load_real_data -from ._bibmon_tools import train_val_test_split, complete_analysis, comparative_table, spearmanr_dendrogram, create_df_with_dates, create_df_with_noise, align_dfs_by_rows +from ._bibmon_tools import train_val_test_split, complete_analysis, comparative_table, spearmanr_dendrogram, create_df_with_dates, create_df_with_noise, align_dfs_by_rows, load_model __all__ = ['Autoencoder','PCA','ESN','SBM', 'sklearnRegressor', 'PreProcess', 'load_tennessee_eastman', 'load_real_data', 'train_val_test_split', 'complete_analysis', 'comparative_table', 'spearmanr_dendrogram', 'create_df_with_dates', - 'create_df_with_noise', 'align_dfs_by_rows'] + 'create_df_with_noise', 'align_dfs_by_rows', 'load_model'] diff --git a/bibmon/_bibmon_tools.py b/bibmon/_bibmon_tools.py index fd598d9..8e75bfb 100644 --- a/bibmon/_bibmon_tools.py +++ b/bibmon/_bibmon_tools.py @@ -2,6 +2,7 @@ import pandas as pd from datetime import datetime import matplotlib.pyplot as plt +import pickle ############################################################################### @@ -207,7 +208,9 @@ def complete_analysis (model, X_train, X_validation, X_test, count_limit = 1, count_window_size = 0, fault_start = None, - fault_end = None): + fault_end = None, + save_model=False, + model_filename=None): """ Performs a complete monitoring analysis, with train, validation, and test. @@ -262,6 +265,10 @@ def complete_analysis (model, X_train, X_validation, X_test, Start timestamp of the fault. fault_end: string, optional End timestamp of the fault. + save_model: bool, optional + If True, saves the trained model to a file. + model_filename: string, optional + Name of the file to save the model. If None, uses a default name. """ fig, ax = plt.subplots(3,2, figsize = (15,12)) @@ -341,10 +348,45 @@ def complete_analysis (model, X_train, X_validation, X_test, ax[2,1].axvline(datetime.strptime(str(fault_end), '%Y-%m-%d %H:%M:%S'), ls = '--') - fig.tight_layout(); + fig.tight_layout() + + ######## Saving the model ######## + if save_model: + if model_filename is None: + model_filename = f'model_{datetime.now().strftime("%Y%m%d_%H%M%S")}.pkl' + with open(model_filename, 'wb') as f: + pickle.dump(model, f) + print(f"Model saved as {model_filename}") + + return model ############################################################################## +def load_model(model_filename): + """ + Loads a trained model from a .pkl file and returns the model. + + Parameters + ---------- + model_filename: string + The name of the .pkl file that contains the saved model. + + Returns + ------- + model: The trained model that was saved. + """ + try: + with open(model_filename, 'rb') as f: + model = pickle.load(f) + print(f"Model loaded from {model_filename}") + return model + except FileNotFoundError: + print(f"Error: The file {model_filename} was not found.") + except Exception as e: + print(f"An error occurred while loading the model: {e}") + +############################################################################## + def comparative_table (models, X_train, X_validation, X_test, Y_train = None , Y_validation = None, Y_test = None, lim_conf = 0.99, diff --git a/test/test_tools.py b/test/test_tools.py index 03f0d5d..a15ca2e 100644 --- a/test/test_tools.py +++ b/test/test_tools.py @@ -8,6 +8,8 @@ import bibmon import pandas as pd +import os +import pickle def test_complete_analysis(): @@ -48,6 +50,10 @@ def test_complete_analysis(): from sklearn.metrics import mean_absolute_error mtr = [r2_score, mean_absolute_error] + + # Filename for saving the model + + model_filename = 'test_model.pkl' # complete analysis! @@ -58,6 +64,17 @@ def test_complete_analysis(): metrics = mtr, count_window_size = 3, count_limit = 2, fault_start = '2018-01-02 06:00:00', - fault_end = '2018-01-02 09:00:00') + fault_end = '2018-01-02 09:00:00', + save_model=True, + model_filename=model_filename) - model.plot_importances() \ No newline at end of file + model.plot_importances() + + assert os.path.exists(model_filename), "Model file was not saved." + + loaded_model = bibmon.load_model(model_filename) + + assert isinstance(loaded_model, type(model)), "Loaded model is not the same type as the original model." + + if os.path.exists(model_filename): + os.remove(model_filename) \ No newline at end of file