29
29
from .eval_stats .eval_stats_regression import RegressionEvalStats
30
30
from .evaluator import VectorModelEvaluator , VectorModelEvaluationData , VectorRegressionModelEvaluator , \
31
31
VectorRegressionModelEvaluationData , VectorClassificationModelEvaluator , VectorClassificationModelEvaluationData , \
32
- VectorRegressionModelEvaluatorParams , VectorClassificationModelEvaluatorParams
32
+ RegressionEvaluatorParams , ClassificationEvaluatorParams
33
33
from ..data import InputOutputData
34
34
from ..feature_importance import AggregatedFeatureImportance , FeatureImportanceProvider , plot_feature_importance , FeatureImportance
35
35
from ..tracking import TrackedExperiment
@@ -62,14 +62,14 @@ def _is_regression(model: Optional[VectorModel], is_regression: Optional[bool])
62
62
63
63
64
64
def create_vector_model_evaluator (data : InputOutputData , model : VectorModel = None ,
65
- is_regression : bool = None , params : Union [VectorRegressionModelEvaluatorParams , VectorClassificationModelEvaluatorParams ] = None ) \
65
+ is_regression : bool = None , params : Union [RegressionEvaluatorParams , ClassificationEvaluatorParams ] = None ) \
66
66
-> Union [VectorRegressionModelEvaluator , VectorClassificationModelEvaluator ]:
67
67
is_regression = _is_regression (model , is_regression )
68
68
if params is None :
69
69
if is_regression :
70
- params = VectorRegressionModelEvaluatorParams (fractional_split_test_fraction = 0.2 )
70
+ params = RegressionEvaluatorParams (fractional_split_test_fraction = 0.2 )
71
71
else :
72
- params = VectorClassificationModelEvaluatorParams (fractional_split_test_fraction = 0.2 )
72
+ params = ClassificationEvaluatorParams (fractional_split_test_fraction = 0.2 )
73
73
log .debug (f"No evaluator parameters specified, using default: { params } " )
74
74
if is_regression :
75
75
return VectorRegressionModelEvaluator (data , params = params )
@@ -89,13 +89,13 @@ def create_vector_model_cross_validator(data: InputOutputData,
89
89
90
90
91
91
def create_evaluation_util (data : InputOutputData , model : VectorModel = None , is_regression : bool = None ,
92
- evaluator_params : Optional [Union [VectorRegressionModelEvaluatorParams , VectorClassificationModelEvaluatorParams ]] = None ,
92
+ evaluator_params : Optional [Union [RegressionEvaluatorParams , ClassificationEvaluatorParams ]] = None ,
93
93
cross_validator_params : Optional [Dict [str , Any ]] = None ) \
94
- -> Union ["ClassificationEvaluationUtil " , "RegressionEvaluationUtil " ]:
94
+ -> Union ["ClassificationModelEvaluation " , "RegressionModelEvaluation " ]:
95
95
if _is_regression (model , is_regression ):
96
- return RegressionEvaluationUtil (data , evaluator_params = evaluator_params , cross_validator_params = cross_validator_params )
96
+ return RegressionModelEvaluation (data , evaluator_params = evaluator_params , cross_validator_params = cross_validator_params )
97
97
else :
98
- return ClassificationEvaluationUtil (data , evaluator_params = evaluator_params , cross_validator_params = cross_validator_params )
98
+ return ClassificationModelEvaluation (data , evaluator_params = evaluator_params , cross_validator_params = cross_validator_params )
99
99
100
100
101
101
def eval_model_via_evaluator (model : TModel , io_data : InputOutputData , test_fraction = 0.2 ,
@@ -131,10 +131,10 @@ def eval_model_via_evaluator(model: TModel, io_data: InputOutputData, test_fract
131
131
fig .show ()
132
132
133
133
if model .is_regression_model ():
134
- evaluator_params = VectorRegressionModelEvaluatorParams (fractional_split_test_fraction = test_fraction ,
134
+ evaluator_params = RegressionEvaluatorParams (fractional_split_test_fraction = test_fraction ,
135
135
fractional_split_random_seed = random_seed )
136
136
else :
137
- evaluator_params = VectorClassificationModelEvaluatorParams (fractional_split_test_fraction = test_fraction ,
137
+ evaluator_params = ClassificationEvaluatorParams (fractional_split_test_fraction = test_fraction ,
138
138
compute_probabilities = compute_probabilities , fractional_split_random_seed = random_seed )
139
139
ev = create_evaluation_util (io_data , model = model , evaluator_params = evaluator_params )
140
140
return ev .perform_simple_evaluation (model , show_plots = True , log_results = True )
@@ -209,14 +209,13 @@ def __init__(self):
209
209
self .add_plot ("threshold-counts" , ClassificationEvalStatsPlotProbabilityThresholdCounts ())
210
210
211
211
212
- # TODO conceive of better class name
213
- class EvaluationUtil (ABC , Generic [TModel , TEvaluator , TEvalData , TCrossValidator , TCrossValData , TEvalStats ]):
212
+ class ModelEvaluation (ABC , Generic [TModel , TEvaluator , TEvalData , TCrossValidator , TCrossValData , TEvalStats ]):
214
213
"""
215
214
Utility class for the evaluation of models based on a dataset
216
215
"""
217
216
def __init__ (self , io_data : InputOutputData ,
218
217
eval_stats_plot_collector : Union [RegressionEvalStatsPlotCollector , ClassificationEvalStatsPlotCollector ],
219
- evaluator_params : Optional [Union [VectorRegressionModelEvaluatorParams , VectorClassificationModelEvaluatorParams ,
218
+ evaluator_params : Optional [Union [RegressionEvaluatorParams , ClassificationEvaluatorParams ,
220
219
Dict [str , Any ]]] = None ,
221
220
cross_validator_params : Optional [Union [VectorModelCrossValidatorParams , Dict [str , Any ]]] = None ):
222
221
"""
@@ -291,7 +290,7 @@ def gather_results(result_data: VectorModelEvaluationData, res_writer, subtitle_
291
290
self .create_plots (result_data , show_plots = show_plots , result_writer = res_writer ,
292
291
subtitle_prefix = subtitle_prefix , tracking_context = trackingContext )
293
292
294
- eval_result_data = evaluator .eval_model (model , fit = True )
293
+ eval_result_data = evaluator .eval_model (model , fit = fit_model )
295
294
gather_results (eval_result_data , result_writer )
296
295
if additional_evaluation_on_training_data :
297
296
eval_result_data_train = evaluator .eval_model (model , on_training_data = True , track = False )
@@ -526,10 +525,10 @@ def _create_eval_stats_plots(self, eval_stats: TEvalStats, result_collector: Eva
526
525
self .eval_stats_plot_collector .create_plots (eval_stats , subtitle , result_collector )
527
526
528
527
529
- class RegressionEvaluationUtil ( EvaluationUtil [VectorRegressionModel , VectorRegressionModelEvaluator , VectorRegressionModelEvaluationData ,
528
+ class RegressionModelEvaluation ( ModelEvaluation [VectorRegressionModel , VectorRegressionModelEvaluator , VectorRegressionModelEvaluationData ,
530
529
VectorRegressionModelCrossValidator , VectorRegressionModelCrossValidationData , RegressionEvalStats ]):
531
530
def __init__ (self , io_data : InputOutputData ,
532
- evaluator_params : Optional [Union [VectorRegressionModelEvaluatorParams , Dict [str , Any ]]] = None ,
531
+ evaluator_params : Optional [Union [RegressionEvaluatorParams , Dict [str , Any ]]] = None ,
533
532
cross_validator_params : Optional [Union [VectorModelCrossValidatorParams , Dict [str , Any ]]] = None ):
534
533
"""
535
534
:param io_data: the data set to use for evaluation
@@ -540,11 +539,11 @@ def __init__(self, io_data: InputOutputData,
540
539
cross_validator_params = cross_validator_params )
541
540
542
541
543
- class ClassificationEvaluationUtil ( EvaluationUtil [VectorClassificationModel , VectorClassificationModelEvaluator ,
542
+ class ClassificationModelEvaluation ( ModelEvaluation [VectorClassificationModel , VectorClassificationModelEvaluator ,
544
543
VectorClassificationModelEvaluationData , VectorClassificationModelCrossValidator , VectorClassificationModelCrossValidationData ,
545
544
ClassificationEvalStats ]):
546
545
def __init__ (self , io_data : InputOutputData ,
547
- evaluator_params : Optional [Union [VectorClassificationModelEvaluatorParams , Dict [str , Any ]]] = None ,
546
+ evaluator_params : Optional [Union [ClassificationEvaluatorParams , Dict [str , Any ]]] = None ,
548
547
cross_validator_params : Optional [Union [VectorModelCrossValidatorParams , Dict [str , Any ]]] = None ):
549
548
"""
550
549
:param io_data: the data set to use for evaluation
@@ -555,10 +554,10 @@ def __init__(self, io_data: InputOutputData,
555
554
cross_validator_params = cross_validator_params )
556
555
557
556
558
- class MultiDataEvaluationUtil :
557
+ class MultiDataModelEvaluation :
559
558
def __init__ (self , io_data_dict : Dict [str , InputOutputData ], key_name : str = "dataset" ,
560
559
meta_data_dict : Optional [Dict [str , Dict [str , Any ]]] = None ,
561
- evaluator_params : Optional [Union [VectorRegressionModelEvaluatorParams , VectorClassificationModelEvaluatorParams , Dict [str , Any ]]] = None ,
560
+ evaluator_params : Optional [Union [RegressionEvaluatorParams , ClassificationEvaluatorParams , Dict [str , Any ]]] = None ,
562
561
cross_validator_params : Optional [Union [VectorModelCrossValidatorParams , Dict [str , Any ]]] = None ):
563
562
"""
564
563
:param io_data_dict: a dictionary mapping from names to the data sets with which to evaluate models
0 commit comments