Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance complete_analysis with model saving and add model loading functionality #61

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bibmon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from ._sklearn_regressor import sklearnRegressor
from ._preprocess import PreProcess
from ._load_data import load_tennessee_eastman, load_real_data
from ._bibmon_tools import train_val_test_split, complete_analysis, comparative_table, spearmanr_dendrogram, create_df_with_dates, create_df_with_noise, align_dfs_by_rows
from ._bibmon_tools import train_val_test_split, complete_analysis, comparative_table, spearmanr_dendrogram, create_df_with_dates, create_df_with_noise, align_dfs_by_rows, load_model

__all__ = ['Autoencoder','PCA','ESN','SBM',
'sklearnRegressor', 'PreProcess',
'load_tennessee_eastman', 'load_real_data',
'train_val_test_split', 'complete_analysis', 'comparative_table',
'spearmanr_dendrogram', 'create_df_with_dates',
'create_df_with_noise', 'align_dfs_by_rows']
'create_df_with_noise', 'align_dfs_by_rows', 'load_model']
46 changes: 44 additions & 2 deletions bibmon/_bibmon_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pandas as pd
from datetime import datetime
import matplotlib.pyplot as plt
import pickle

###############################################################################

Expand Down Expand Up @@ -207,7 +208,9 @@ def complete_analysis (model, X_train, X_validation, X_test,
count_limit = 1,
count_window_size = 0,
fault_start = None,
fault_end = None):
fault_end = None,
save_model=False,
model_filename=None):
"""
Performs a complete monitoring analysis, with train, validation, and test.

Expand Down Expand Up @@ -262,6 +265,10 @@ def complete_analysis (model, X_train, X_validation, X_test,
Start timestamp of the fault.
fault_end: string, optional
End timestamp of the fault.
save_model: bool, optional
If True, saves the trained model to a file.
model_filename: string, optional
Name of the file to save the model. If None, uses a default name.
"""
fig, ax = plt.subplots(3,2, figsize = (15,12))

Expand Down Expand Up @@ -341,10 +348,45 @@ def complete_analysis (model, X_train, X_validation, X_test,
ax[2,1].axvline(datetime.strptime(str(fault_end),
'%Y-%m-%d %H:%M:%S'), ls = '--')

fig.tight_layout();
fig.tight_layout()

######## Saving the model ########
if save_model:
if model_filename is None:
model_filename = f'model_{datetime.now().strftime("%Y%m%d_%H%M%S")}.pkl'
with open(model_filename, 'wb') as f:
pickle.dump(model, f)
print(f"Model saved as {model_filename}")

return model

##############################################################################

def load_model(model_filename):
"""
Loads a trained model from a .pkl file and returns the model.

Parameters
----------
model_filename: string
The name of the .pkl file that contains the saved model.

Returns
-------
model: The trained model that was saved.
"""
try:
with open(model_filename, 'rb') as f:
model = pickle.load(f)
print(f"Model loaded from {model_filename}")
return model
except FileNotFoundError:
print(f"Error: The file {model_filename} was not found.")
except Exception as e:
print(f"An error occurred while loading the model: {e}")

##############################################################################

def comparative_table (models, X_train, X_validation, X_test,
Y_train = None , Y_validation = None, Y_test = None,
lim_conf = 0.99,
Expand Down
21 changes: 19 additions & 2 deletions test/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import bibmon
import pandas as pd
import os
import pickle

def test_complete_analysis():

Expand Down Expand Up @@ -48,6 +50,10 @@ def test_complete_analysis():
from sklearn.metrics import mean_absolute_error

mtr = [r2_score, mean_absolute_error]

# Filename for saving the model

model_filename = 'test_model.pkl'

# complete analysis!

Expand All @@ -58,6 +64,17 @@ def test_complete_analysis():
metrics = mtr,
count_window_size = 3, count_limit = 2,
fault_start = '2018-01-02 06:00:00',
fault_end = '2018-01-02 09:00:00')
fault_end = '2018-01-02 09:00:00',
save_model=True,
model_filename=model_filename)

model.plot_importances()
model.plot_importances()

assert os.path.exists(model_filename), "Model file was not saved."

loaded_model = bibmon.load_model(model_filename)

assert isinstance(loaded_model, type(model)), "Loaded model is not the same type as the original model."

if os.path.exists(model_filename):
os.remove(model_filename)