From 8cf7696b785647a9359fbaebd0b9d379edd547f4 Mon Sep 17 00:00:00 2001 From: Lin Zhang Date: Mon, 18 Mar 2024 13:39:17 +0100 Subject: [PATCH] fix importError of obsolete scipy.interp with scipy.interpolate --- scikitplot/metrics.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scikitplot/metrics.py b/scikitplot/metrics.py index 08ec693..df8357f 100644 --- a/scikitplot/metrics.py +++ b/scikitplot/metrics.py @@ -24,7 +24,7 @@ from sklearn.calibration import calibration_curve from sklearn.utils import deprecated -from scipy import interp +from scipy import interpolate from scikitplot.helpers import binary_ks_curve, validate_labels from scikitplot.helpers import cumulative_gain_curve @@ -281,7 +281,7 @@ def plot_roc_curve(y_true, y_probas, title='ROC Curves', # Then interpolate all ROC curves at this points mean_tpr = np.zeros_like(all_fpr) for i in range(len(classes)): - mean_tpr += interp(all_fpr, fpr[i], tpr[i]) + mean_tpr += interpolate(all_fpr, fpr[i], tpr[i]) # Finally average it and compute AUC mean_tpr /= len(classes) @@ -440,7 +440,7 @@ def plot_roc(y_true, y_probas, title='ROC Curves', # Then interpolate all ROC curves at this points mean_tpr = np.zeros_like(all_fpr) for i in range(len(classes)): - mean_tpr += interp(all_fpr, fpr_dict[i], tpr_dict[i]) + mean_tpr += interpolate(all_fpr, fpr_dict[i], tpr_dict[i]) # Finally average it and compute AUC mean_tpr /= len(classes)