diff --git a/bamt/nodes/conditional_gaussian_node.py b/bamt/nodes/conditional_gaussian_node.py index 58801b6..97a5949 100644 --- a/bamt/nodes/conditional_gaussian_node.py +++ b/bamt/nodes/conditional_gaussian_node.py @@ -7,7 +7,7 @@ from pandas import DataFrame from sklearn import linear_model from sklearn.base import clone -from sklearn.metrics import mean_squared_error as mse +from sklearn.metrics import root_mean_squared_error as rmse from .base import BaseNode from .schema import CondGaussParams @@ -51,8 +51,8 @@ def fit_parameters(self, data: DataFrame) -> Dict[str, Dict[str, CondGaussParams new_data[self.cont_parents].values, new_data[self.name].values ) predicted_value = model.predict(new_data[self.cont_parents].values) - variance = mse( - new_data[self.name].values, predicted_value, squared=False + variance = rmse( + new_data[self.name].values, predicted_value ) hycprob[str(key_comb)] = { "variance": variance, diff --git a/bamt/nodes/gaussian_node.py b/bamt/nodes/gaussian_node.py index 15144a8..5274ae6 100644 --- a/bamt/nodes/gaussian_node.py +++ b/bamt/nodes/gaussian_node.py @@ -5,7 +5,7 @@ import numpy as np from pandas import DataFrame from sklearn import linear_model -from sklearn.metrics import mean_squared_error as mse +from sklearn.metrics import root_mean_squared_error as rmse from .base import BaseNode from .schema import GaussianParams @@ -30,7 +30,7 @@ def fit_parameters(self, data: DataFrame, **kwargs) -> GaussianParams: if parents: self.regressor.fit(data[parents].values, data[self.name].values, **kwargs) predicted_value = self.regressor.predict(data[parents].values) - variance = mse(data[self.name].values, predicted_value, squared=False) + variance = rmse(data[self.name].values, predicted_value) return { "mean": np.nan, "regressor_obj": self.regressor, diff --git a/bamt/utils/composite_utils/CompositeGeneticOperators.py b/bamt/utils/composite_utils/CompositeGeneticOperators.py index 8de9d8d..6fc53ae 100644 --- a/bamt/utils/composite_utils/CompositeGeneticOperators.py +++ b/bamt/utils/composite_utils/CompositeGeneticOperators.py @@ -5,7 +5,7 @@ from golem.core.dag.graph_utils import ordered_subnodes_hierarchy from numpy import std, mean, log from scipy.stats import norm -from sklearn.metrics import mean_squared_error +from sklearn.metrics import root_mean_squared_error from sklearn.model_selection import train_test_split import numpy as np from .CompositeModel import CompositeModel @@ -168,8 +168,9 @@ def composite_metric(graph: CompositeModel, data: pd.DataFrame): if node_type == "cont": predictions = fitted_model.predict(features_test) - mse = mean_squared_error(target_test, predictions, squared=False) + 1e-7 - score += norm.logpdf(target_test, loc=predictions, scale=mse).sum() + rmse = root_mean_squared_error(target_test, predictions, squared=False) + 1e-7 + score += norm.logpdf(target_test, loc=predictions, scale=rmse).sum() + else: predict_proba = fitted_model.predict_proba(features_test) probas = np.maximum(predict_proba[range(len(target_test)), [index_test_dict[x] for x in target_test]], 1e-7) diff --git a/docs/source/examples/learn_sampling_predict.rst b/docs/source/examples/learn_sampling_predict.rst index d6b2fd7..938eb0e 100644 --- a/docs/source/examples/learn_sampling_predict.rst +++ b/docs/source/examples/learn_sampling_predict.rst @@ -13,7 +13,7 @@ Used imports: import matplotlib.pyplot as plt from sklearn import preprocessing - from sklearn.metrics import accuracy_score, mean_squared_error + from sklearn.metrics import accuracy_score from sklearn.ensemble import RandomForestClassifier from sklearn.neighbors import KNeighborsClassifier from sklearn.tree import DecisionTreeClassifier diff --git a/tests/MainTest.py b/tests/MainTest.py index 4796fe6..2b229d3 100644 --- a/tests/MainTest.py +++ b/tests/MainTest.py @@ -82,7 +82,7 @@ # print(cont_bn.weights) # print(cont_bn2.weights) # print('RMSE on predicted values with continuous data: ' + -# f'{mse(cont_target, cont_predicted_values, squared=False)}') +# f'{rmse(cont_target, cont_predicted_values)}') # print(cont_bn.get_info()) # print(cont_bn2.get_info()) # print(synth_cont_data) @@ -116,7 +116,7 @@ # print(hybrid_bn2.weights) # print(hybrid_bn3.weights) # print('RMSE on predicted values with hybrid data: ' + -# f'{mse(hybrid_target, hybrid_predicted_values, squared=False)}') +# f'{rmse(hybrid_target, hybrid_predicted_values)}') # print(hybrid_bn.get_info()) # print(hybrid_bn2.get_info()) # print(hybrid_bn3.get_info())