Skip to content

Commit

Permalink
Fix sklearn deprecation warning for parameter squared of function mea…
Browse files Browse the repository at this point in the history
…n_squared_error (#114)

Replace mean_squared_error(..., squared=False) with root_mean_squared_error

Co-authored-by: Jerzy Kamiński <[email protected]>
  • Loading branch information
anton-golubkov and jrzkaminski authored Oct 11, 2024
1 parent f521f95 commit 938753c
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 11 deletions.
6 changes: 3 additions & 3 deletions bamt/nodes/conditional_gaussian_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pandas import DataFrame
from sklearn import linear_model
from sklearn.base import clone
from sklearn.metrics import mean_squared_error as mse
from sklearn.metrics import root_mean_squared_error as rmse

from .base import BaseNode
from .schema import CondGaussParams
Expand Down Expand Up @@ -51,8 +51,8 @@ def fit_parameters(self, data: DataFrame) -> Dict[str, Dict[str, CondGaussParams
new_data[self.cont_parents].values, new_data[self.name].values
)
predicted_value = model.predict(new_data[self.cont_parents].values)
variance = mse(
new_data[self.name].values, predicted_value, squared=False
variance = rmse(
new_data[self.name].values, predicted_value
)
hycprob[str(key_comb)] = {
"variance": variance,
Expand Down
4 changes: 2 additions & 2 deletions bamt/nodes/gaussian_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from pandas import DataFrame
from sklearn import linear_model
from sklearn.metrics import mean_squared_error as mse
from sklearn.metrics import root_mean_squared_error as rmse

from .base import BaseNode
from .schema import GaussianParams
Expand All @@ -30,7 +30,7 @@ def fit_parameters(self, data: DataFrame, **kwargs) -> GaussianParams:
if parents:
self.regressor.fit(data[parents].values, data[self.name].values, **kwargs)
predicted_value = self.regressor.predict(data[parents].values)
variance = mse(data[self.name].values, predicted_value, squared=False)
variance = rmse(data[self.name].values, predicted_value)
return {
"mean": np.nan,
"regressor_obj": self.regressor,
Expand Down
7 changes: 4 additions & 3 deletions bamt/utils/composite_utils/CompositeGeneticOperators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from golem.core.dag.graph_utils import ordered_subnodes_hierarchy
from numpy import std, mean, log
from scipy.stats import norm
from sklearn.metrics import mean_squared_error
from sklearn.metrics import root_mean_squared_error
from sklearn.model_selection import train_test_split
import numpy as np
from .CompositeModel import CompositeModel
Expand Down Expand Up @@ -168,8 +168,9 @@ def composite_metric(graph: CompositeModel, data: pd.DataFrame):

if node_type == "cont":
predictions = fitted_model.predict(features_test)
mse = mean_squared_error(target_test, predictions, squared=False) + 1e-7
score += norm.logpdf(target_test, loc=predictions, scale=mse).sum()
rmse = root_mean_squared_error(target_test, predictions, squared=False) + 1e-7
score += norm.logpdf(target_test, loc=predictions, scale=rmse).sum()

else:
predict_proba = fitted_model.predict_proba(features_test)
probas = np.maximum(predict_proba[range(len(target_test)), [index_test_dict[x] for x in target_test]], 1e-7)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/learn_sampling_predict.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Used imports:
import matplotlib.pyplot as plt
from sklearn import preprocessing
from sklearn.metrics import accuracy_score, mean_squared_error
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
Expand Down
4 changes: 2 additions & 2 deletions tests/MainTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
# print(cont_bn.weights)
# print(cont_bn2.weights)
# print('RMSE on predicted values with continuous data: ' +
# f'{mse(cont_target, cont_predicted_values, squared=False)}')
# f'{rmse(cont_target, cont_predicted_values)}')
# print(cont_bn.get_info())
# print(cont_bn2.get_info())
# print(synth_cont_data)
Expand Down Expand Up @@ -116,7 +116,7 @@
# print(hybrid_bn2.weights)
# print(hybrid_bn3.weights)
# print('RMSE on predicted values with hybrid data: ' +
# f'{mse(hybrid_target, hybrid_predicted_values, squared=False)}')
# f'{rmse(hybrid_target, hybrid_predicted_values)}')
# print(hybrid_bn.get_info())
# print(hybrid_bn2.get_info())
# print(hybrid_bn3.get_info())
Expand Down

0 comments on commit 938753c

Please sign in to comment.