diff --git a/tests/performance_estimation/CBPE/test_cbpe_metrics.py b/tests/performance_estimation/CBPE/test_cbpe_metrics.py index d1e42c2e..c2ae06cb 100644 --- a/tests/performance_estimation/CBPE/test_cbpe_metrics.py +++ b/tests/performance_estimation/CBPE/test_cbpe_metrics.py @@ -3473,6 +3473,12 @@ def test_method_logs_warning_when_lower_threshold_is_overridden_by_metric_limits pd.DataFrame( { 'key': ['[0:19999]', '[20000:39999]', '[40000:59999]'], + 'realized_roc_auc': [0.909805, 0.840071, np.nan], + 'realized_f1': [0.759170, 0.658896, np.nan], + 'realized_precision': [0.759265, 0.660188, np.nan], + 'realized_recall': [0.759149, 0.658760, np.nan], + 'realized_specificity': [0.879632, 0.829581, np.nan], + 'realized_accuracy': [0.75925, 0.65950, np.nan], 'realized_true_highstreet_card_pred_highstreet_card': [ 4912.0, 4702.0, @@ -3523,7 +3529,7 @@ def test_method_logs_warning_when_lower_threshold_is_overridden_by_metric_limits ), ] ) -def test_cbpe_for_multiclass_classification_cm_with_nans(calculator_opts, estimated, realized): # noqa: D103 +def test_cbpe_for_multiclass_classification_cm_with_nans(calculator_opts, realized): # noqa: D103 """Test Nan Handling of CM MC metric.""" reference, analysis, targets = load_synthetic_multiclass_classification_dataset() analysis = analysis.merge(targets, left_index=True, right_index=True) @@ -3537,11 +3543,13 @@ def test_cbpe_for_multiclass_classification_cm_with_nans(calculator_opts, estima y_pred='y_pred', y_true='y_true', problem_type='classification_multiclass', - metrics=['confusion_matrix'], + metrics=['roc_auc', 'f1', 'precision', 'recall', 'specificity', 'accuracy', 'confusion_matrix'], **calculator_opts, ).fit(reference) result = cbpe.estimate(analysis) - column_names = [ + column_names = [(m.name, 'realized') for m in result.metrics] + column_names = [c for c in column_names if c[0] != 'confusion_matrix'] + column_names += [ ('true_highstreet_card_pred_highstreet_card', 'realized'), ('true_highstreet_card_pred_prepaid_card', 'realized'), ('true_highstreet_card_pred_upmarket_card', 'realized'), @@ -3555,6 +3563,12 @@ def test_cbpe_for_multiclass_classification_cm_with_nans(calculator_opts, estima sut = result.filter(period='analysis').to_df()[[('chunk', 'key')] + column_names] sut.columns = [ 'key', + 'realized_roc_auc', + 'realized_f1', + 'realized_precision', + 'realized_recall', + 'realized_specificity', + 'realized_accuracy', 'realized_true_highstreet_card_pred_highstreet_card', 'realized_true_highstreet_card_pred_prepaid_card', 'realized_true_highstreet_card_pred_upmarket_card',