Skip to content

Commit b42bb0b

Browse files
committed
Sync pop
commit 4296e7964be0cccce99a18d8cda0c67cd4b97c41 Author: Dominik Jain <[email protected]> Date: Mon Sep 11 15:58:38 2023 +0200 Renamed classes to improve/simplify evaluation interfaces: - *EvaluationUtil -> *ModelEvaluation - Vector*ModelEvaluatorParams -> *EvaluatorParams (to make applications in high-level evaluation interfaces seem less random and more to the point, prioritising high-level interface) src/sensai/evaluation/__init__.py src/sensai/evaluation/crossval.py src/sensai/evaluation/eval_util.py src/sensai/evaluation/evaluator.py src/sensai/evaluation/metric_computation.py src/sensai/torch/torch_eval_util.py
1 parent c6f76f0 commit b42bb0b

File tree

6 files changed

+54
-43
lines changed

6 files changed

+54
-43
lines changed

src/sensai/evaluation/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from .crossval import VectorClassificationModelCrossValidator, VectorRegressionModelCrossValidator, \
22
VectorClassificationModelCrossValidationData, VectorRegressionModelCrossValidationData, \
33
VectorModelCrossValidatorParams
4-
from .eval_util import RegressionEvaluationUtil, ClassificationEvaluationUtil, MultiDataEvaluationUtil, \
4+
from .eval_util import RegressionModelEvaluation, ClassificationModelEvaluation, MultiDataModelEvaluation, \
55
eval_model_via_evaluator, create_evaluation_util, create_vector_model_evaluator, create_vector_model_cross_validator
66
from .evaluator import VectorClassificationModelEvaluator, VectorRegressionModelEvaluator, \
7-
VectorRegressionModelEvaluatorParams, VectorClassificationModelEvaluatorParams, \
7+
RegressionEvaluatorParams, ClassificationEvaluatorParams, \
88
VectorRegressionModelEvaluationData, VectorClassificationModelEvaluationData, \
99
RuleBasedVectorClassificationModelEvaluator, RuleBasedVectorRegressionModelEvaluator
1010

src/sensai/evaluation/crossval.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from .eval_stats.eval_stats_regression import RegressionEvalStats, RegressionEvalStatsCollection
1212
from .evaluator import VectorRegressionModelEvaluationData, VectorClassificationModelEvaluationData, \
1313
VectorModelEvaluationData, VectorClassificationModelEvaluator, VectorRegressionModelEvaluator, \
14-
MetricsDictProvider, VectorModelEvaluator, VectorClassificationModelEvaluatorParams, \
15-
VectorRegressionModelEvaluatorParams, MetricsDictProviderFromFunction
14+
MetricsDictProvider, VectorModelEvaluator, ClassificationEvaluatorParams, \
15+
RegressionEvaluatorParams, MetricsDictProviderFromFunction
1616
from ..data import InputOutputData, DataSplitterFractional
1717
from ..tracking.tracking_base import TrackingContext
1818
from ..util.typing import PandasNamedTuple
@@ -128,7 +128,7 @@ def __init__(self,
128128
folds: int = 5,
129129
splitter: CrossValidationSplitter = None,
130130
return_trained_models=False,
131-
evaluator_params: Union[VectorRegressionModelEvaluatorParams, VectorClassificationModelEvaluatorParams] = None,
131+
evaluator_params: Union[RegressionEvaluatorParams, ClassificationEvaluatorParams] = None,
132132
default_splitter_random_seed=42,
133133
default_splitter_shuffle=True):
134134
"""
@@ -233,7 +233,7 @@ def _create_eval_stats_collection(self, l: List[RegressionEvalStats]) -> Regress
233233

234234
class VectorRegressionModelCrossValidator(VectorModelCrossValidator[VectorRegressionModelCrossValidationData]):
235235
def _create_model_evaluator(self, training_data: InputOutputData, test_data: InputOutputData) -> VectorRegressionModelEvaluator:
236-
evaluator_params = VectorRegressionModelEvaluatorParams.from_dict_or_instance(self.params.evaluatorParams)
236+
evaluator_params = RegressionEvaluatorParams.from_dict_or_instance(self.params.evaluatorParams)
237237
return VectorRegressionModelEvaluator(training_data, test_data=test_data, params=evaluator_params)
238238

239239
def _create_result_data(self, trained_models, eval_data_list, test_indices_list, predicted_var_names) \
@@ -249,7 +249,7 @@ def _create_eval_stats_collection(self, l: List[ClassificationEvalStats]) -> Cla
249249

250250
class VectorClassificationModelCrossValidator(VectorModelCrossValidator[VectorClassificationModelCrossValidationData]):
251251
def _create_model_evaluator(self, training_data: InputOutputData, test_data: InputOutputData):
252-
evaluator_params = VectorClassificationModelEvaluatorParams.from_dict_or_instance(self.params.evaluatorParams)
252+
evaluator_params = ClassificationEvaluatorParams.from_dict_or_instance(self.params.evaluatorParams)
253253
return VectorClassificationModelEvaluator(training_data, test_data=test_data, params=evaluator_params)
254254

255255
def _create_result_data(self, trained_models, eval_data_list, test_indices_list, predicted_var_names) \

src/sensai/evaluation/eval_util.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from .eval_stats.eval_stats_regression import RegressionEvalStats
3030
from .evaluator import VectorModelEvaluator, VectorModelEvaluationData, VectorRegressionModelEvaluator, \
3131
VectorRegressionModelEvaluationData, VectorClassificationModelEvaluator, VectorClassificationModelEvaluationData, \
32-
VectorRegressionModelEvaluatorParams, VectorClassificationModelEvaluatorParams
32+
RegressionEvaluatorParams, ClassificationEvaluatorParams
3333
from ..data import InputOutputData
3434
from ..feature_importance import AggregatedFeatureImportance, FeatureImportanceProvider, plot_feature_importance, FeatureImportance
3535
from ..tracking import TrackedExperiment
@@ -62,14 +62,14 @@ def _is_regression(model: Optional[VectorModel], is_regression: Optional[bool])
6262

6363

6464
def create_vector_model_evaluator(data: InputOutputData, model: VectorModel = None,
65-
is_regression: bool = None, params: Union[VectorRegressionModelEvaluatorParams, VectorClassificationModelEvaluatorParams] = None) \
65+
is_regression: bool = None, params: Union[RegressionEvaluatorParams, ClassificationEvaluatorParams] = None) \
6666
-> Union[VectorRegressionModelEvaluator, VectorClassificationModelEvaluator]:
6767
is_regression = _is_regression(model, is_regression)
6868
if params is None:
6969
if is_regression:
70-
params = VectorRegressionModelEvaluatorParams(fractional_split_test_fraction=0.2)
70+
params = RegressionEvaluatorParams(fractional_split_test_fraction=0.2)
7171
else:
72-
params = VectorClassificationModelEvaluatorParams(fractional_split_test_fraction=0.2)
72+
params = ClassificationEvaluatorParams(fractional_split_test_fraction=0.2)
7373
log.debug(f"No evaluator parameters specified, using default: {params}")
7474
if is_regression:
7575
return VectorRegressionModelEvaluator(data, params=params)
@@ -89,13 +89,13 @@ def create_vector_model_cross_validator(data: InputOutputData,
8989

9090

9191
def create_evaluation_util(data: InputOutputData, model: VectorModel = None, is_regression: bool = None,
92-
evaluator_params: Optional[Union[VectorRegressionModelEvaluatorParams, VectorClassificationModelEvaluatorParams]] = None,
92+
evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams]] = None,
9393
cross_validator_params: Optional[Dict[str, Any]] = None) \
94-
-> Union["ClassificationEvaluationUtil", "RegressionEvaluationUtil"]:
94+
-> Union["ClassificationModelEvaluation", "RegressionModelEvaluation"]:
9595
if _is_regression(model, is_regression):
96-
return RegressionEvaluationUtil(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params)
96+
return RegressionModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params)
9797
else:
98-
return ClassificationEvaluationUtil(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params)
98+
return ClassificationModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params)
9999

100100

101101
def eval_model_via_evaluator(model: TModel, io_data: InputOutputData, test_fraction=0.2,
@@ -131,10 +131,10 @@ def eval_model_via_evaluator(model: TModel, io_data: InputOutputData, test_fract
131131
fig.show()
132132

133133
if model.is_regression_model():
134-
evaluator_params = VectorRegressionModelEvaluatorParams(fractional_split_test_fraction=test_fraction,
134+
evaluator_params = RegressionEvaluatorParams(fractional_split_test_fraction=test_fraction,
135135
fractional_split_random_seed=random_seed)
136136
else:
137-
evaluator_params = VectorClassificationModelEvaluatorParams(fractional_split_test_fraction=test_fraction,
137+
evaluator_params = ClassificationEvaluatorParams(fractional_split_test_fraction=test_fraction,
138138
compute_probabilities=compute_probabilities, fractional_split_random_seed=random_seed)
139139
ev = create_evaluation_util(io_data, model=model, evaluator_params=evaluator_params)
140140
return ev.perform_simple_evaluation(model, show_plots=True, log_results=True)
@@ -209,14 +209,13 @@ def __init__(self):
209209
self.add_plot("threshold-counts", ClassificationEvalStatsPlotProbabilityThresholdCounts())
210210

211211

212-
# TODO conceive of better class name
213-
class EvaluationUtil(ABC, Generic[TModel, TEvaluator, TEvalData, TCrossValidator, TCrossValData, TEvalStats]):
212+
class ModelEvaluation(ABC, Generic[TModel, TEvaluator, TEvalData, TCrossValidator, TCrossValData, TEvalStats]):
214213
"""
215214
Utility class for the evaluation of models based on a dataset
216215
"""
217216
def __init__(self, io_data: InputOutputData,
218217
eval_stats_plot_collector: Union[RegressionEvalStatsPlotCollector, ClassificationEvalStatsPlotCollector],
219-
evaluator_params: Optional[Union[VectorRegressionModelEvaluatorParams, VectorClassificationModelEvaluatorParams,
218+
evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams,
220219
Dict[str, Any]]] = None,
221220
cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None):
222221
"""
@@ -291,7 +290,7 @@ def gather_results(result_data: VectorModelEvaluationData, res_writer, subtitle_
291290
self.create_plots(result_data, show_plots=show_plots, result_writer=res_writer,
292291
subtitle_prefix=subtitle_prefix, tracking_context=trackingContext)
293292

294-
eval_result_data = evaluator.eval_model(model, fit=True)
293+
eval_result_data = evaluator.eval_model(model, fit=fit_model)
295294
gather_results(eval_result_data, result_writer)
296295
if additional_evaluation_on_training_data:
297296
eval_result_data_train = evaluator.eval_model(model, on_training_data=True, track=False)
@@ -526,10 +525,10 @@ def _create_eval_stats_plots(self, eval_stats: TEvalStats, result_collector: Eva
526525
self.eval_stats_plot_collector.create_plots(eval_stats, subtitle, result_collector)
527526

528527

529-
class RegressionEvaluationUtil(EvaluationUtil[VectorRegressionModel, VectorRegressionModelEvaluator, VectorRegressionModelEvaluationData,
528+
class RegressionModelEvaluation(ModelEvaluation[VectorRegressionModel, VectorRegressionModelEvaluator, VectorRegressionModelEvaluationData,
530529
VectorRegressionModelCrossValidator, VectorRegressionModelCrossValidationData, RegressionEvalStats]):
531530
def __init__(self, io_data: InputOutputData,
532-
evaluator_params: Optional[Union[VectorRegressionModelEvaluatorParams, Dict[str, Any]]] = None,
531+
evaluator_params: Optional[Union[RegressionEvaluatorParams, Dict[str, Any]]] = None,
533532
cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None):
534533
"""
535534
:param io_data: the data set to use for evaluation
@@ -540,11 +539,11 @@ def __init__(self, io_data: InputOutputData,
540539
cross_validator_params=cross_validator_params)
541540

542541

543-
class ClassificationEvaluationUtil(EvaluationUtil[VectorClassificationModel, VectorClassificationModelEvaluator,
542+
class ClassificationModelEvaluation(ModelEvaluation[VectorClassificationModel, VectorClassificationModelEvaluator,
544543
VectorClassificationModelEvaluationData, VectorClassificationModelCrossValidator, VectorClassificationModelCrossValidationData,
545544
ClassificationEvalStats]):
546545
def __init__(self, io_data: InputOutputData,
547-
evaluator_params: Optional[Union[VectorClassificationModelEvaluatorParams, Dict[str, Any]]] = None,
546+
evaluator_params: Optional[Union[ClassificationEvaluatorParams, Dict[str, Any]]] = None,
548547
cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None):
549548
"""
550549
:param io_data: the data set to use for evaluation
@@ -555,10 +554,10 @@ def __init__(self, io_data: InputOutputData,
555554
cross_validator_params=cross_validator_params)
556555

557556

558-
class MultiDataEvaluationUtil:
557+
class MultiDataModelEvaluation:
559558
def __init__(self, io_data_dict: Dict[str, InputOutputData], key_name: str = "dataset",
560559
meta_data_dict: Optional[Dict[str, Dict[str, Any]]] = None,
561-
evaluator_params: Optional[Union[VectorRegressionModelEvaluatorParams, VectorClassificationModelEvaluatorParams, Dict[str, Any]]] = None,
560+
evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams, Dict[str, Any]]] = None,
562561
cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None):
563562
"""
564563
:param io_data_dict: a dictionary mapping from names to the data sets with which to evaluate models

src/sensai/evaluation/evaluator.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ..data import DataSplitter, DataSplitterFractional, InputOutputData
1313
from ..data_transformation import DataFrameTransformer
1414
from ..tracking import TrackingMixin, TrackedExperiment
15-
from ..tracking.tracking_base import TrackingContext
15+
from ..util.deprecation import deprecated
1616
from ..util.string import ToStringMixin
1717
from ..util.typing import PandasNamedTuple
1818
from ..vector_model import VectorClassificationModel, VectorModel, VectorModelBase, VectorModelFittableBase, VectorRegressionModel
@@ -122,7 +122,7 @@ def get_eval_stats_collection(self):
122122
TEvalData = TypeVar("TEvalData", bound=VectorModelEvaluationData)
123123

124124

125-
class VectorModelEvaluatorParams(ToStringMixin, ABC):
125+
class EvaluatorParams(ToStringMixin, ABC):
126126
def __init__(self, data_splitter: DataSplitter = None, fractional_split_test_fraction: float = None, fractional_split_random_seed=42,
127127
fractional_split_shuffle=True):
128128
"""
@@ -166,7 +166,7 @@ def set_data_splitter(self, splitter: DataSplitter):
166166

167167

168168
class VectorModelEvaluator(MetricsDictProvider, Generic[TEvalData], ABC):
169-
def __init__(self, data: Optional[InputOutputData], test_data: InputOutputData = None, params: VectorModelEvaluatorParams = None):
169+
def __init__(self, data: Optional[InputOutputData], test_data: InputOutputData = None, params: EvaluatorParams = None):
170170
"""
171171
Constructs an evaluator with test and training data.
172172
@@ -248,7 +248,7 @@ def fit_model(self, model: VectorModelFittableBase):
248248
model.fit(self.training_data.inputs, self.training_data.outputs)
249249

250250

251-
class VectorRegressionModelEvaluatorParams(VectorModelEvaluatorParams):
251+
class RegressionEvaluatorParams(EvaluatorParams):
252252
def __init__(self,
253253
data_splitter: DataSplitter = None,
254254
fractional_split_test_fraction: float = None,
@@ -281,9 +281,9 @@ def __init__(self,
281281

282282
@classmethod
283283
def from_dict_or_instance(cls,
284-
params: Optional[Union[Dict[str, Any], "VectorRegressionModelEvaluatorParams"]]) -> "VectorRegressionModelEvaluatorParams":
284+
params: Optional[Union[Dict[str, Any], "RegressionEvaluatorParams"]]) -> "RegressionEvaluatorParams":
285285
if params is None:
286-
return VectorRegressionModelEvaluatorParams()
286+
return RegressionEvaluatorParams()
287287
elif type(params) == dict:
288288
raise Exception("Old-style dictionary parametrisation is no longer supported")
289289
elif isinstance(params, cls):
@@ -292,9 +292,15 @@ def from_dict_or_instance(cls,
292292
raise ValueError(f"Must provide dictionary or {cls} instance, got {params}, type {type(params)}")
293293

294294

295+
class VectorRegressionModelEvaluatorParams(RegressionEvaluatorParams):
296+
@deprecated("Use RegressionEvaluatorParams instead")
297+
def __init__(self, *args, **kwargs):
298+
super().__init__(*args, **kwargs)
299+
300+
295301
class VectorRegressionModelEvaluator(VectorModelEvaluator[VectorRegressionModelEvaluationData]):
296302
def __init__(self, data: Optional[InputOutputData], test_data: InputOutputData = None,
297-
params: VectorRegressionModelEvaluatorParams = None):
303+
params: RegressionEvaluatorParams = None):
298304
"""
299305
Constructs an evaluator with test and training data.
300306
@@ -358,7 +364,7 @@ def get_misclassified_triples_pred_true_input(self) -> List[Tuple[Any, Any, pd.S
358364
return [(eval_stats.y_predicted[i], eval_stats.y_true[i], self.input_data.iloc[i]) for i in indices]
359365

360366

361-
class VectorClassificationModelEvaluatorParams(VectorModelEvaluatorParams):
367+
class ClassificationEvaluatorParams(EvaluatorParams):
362368
def __init__(self, data_splitter: DataSplitter = None, fractional_split_test_fraction: float = None, fractional_split_random_seed=42,
363369
fractional_split_shuffle=True, additional_metrics: Sequence[ClassificationMetric] = None,
364370
compute_probabilities: bool = False, binary_positive_label=GUESS):
@@ -387,18 +393,24 @@ def __init__(self, data_splitter: DataSplitter = None, fractional_split_test_fra
387393

388394
@classmethod
389395
def from_dict_or_instance(cls,
390-
params: Optional[Union[Dict[str, Any], "VectorClassificationModelEvaluatorParams"]]) \
391-
-> "VectorClassificationModelEvaluatorParams":
396+
params: Optional[Union[Dict[str, Any], "ClassificationEvaluatorParams"]]) \
397+
-> "ClassificationEvaluatorParams":
392398
if params is None:
393-
return VectorClassificationModelEvaluatorParams()
399+
return ClassificationEvaluatorParams()
394400
elif type(params) == dict:
395401
raise ValueError("Old-style dictionary parametrisation is no longer supported")
396-
elif isinstance(params, VectorClassificationModelEvaluatorParams):
402+
elif isinstance(params, ClassificationEvaluatorParams):
397403
return params
398404
else:
399405
raise ValueError(f"Must provide dictionary or instance, got {params}")
400406

401407

408+
class VectorClassificationModelEvaluatorParams(ClassificationEvaluatorParams):
409+
@deprecated("Use ClassificationEvaluatorParams instead")
410+
def __init__(self, *args, **kwargs):
411+
super().__init__(*args, **kwargs)
412+
413+
402414
class VectorClassificationModelEvaluator(VectorModelEvaluator[VectorClassificationModelEvaluationData]):
403415
def __init__(self,
404416
data: Optional[InputOutputData],

src/sensai/evaluation/metric_computation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Union, List, Callable
44

55
from sensai import VectorRegressionModel, VectorClassificationModel, VectorModelBase
6-
from sensai.evaluation import MultiDataEvaluationUtil
6+
from sensai.evaluation import MultiDataModelEvaluation
77
from sensai.evaluation.eval_stats import RegressionMetric, ClassificationMetric
88

99
TMetric = Union[RegressionMetric, ClassificationMetric]
@@ -26,7 +26,7 @@ def compute_metric_value(self, model_factory: Callable[[], TModel]) -> MetricCom
2626

2727

2828
class MetricComputationMultiData(MetricComputation):
29-
def __init__(self, ev_util: MultiDataEvaluationUtil, use_cross_validation: bool, metric: TMetric,
29+
def __init__(self, ev_util: MultiDataModelEvaluation, use_cross_validation: bool, metric: TMetric,
3030
use_combined_eval_stats: bool):
3131
super().__init__(metric)
3232
self.use_combined_eval_stats = use_combined_eval_stats

src/sensai/torch/torch_eval_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from typing import Union
22

33
from . import TorchVectorRegressionModel
4-
from ..evaluation import RegressionEvaluationUtil
4+
from ..evaluation import RegressionModelEvaluation
55
from ..evaluation.crossval import VectorModelCrossValidationData, VectorRegressionModelCrossValidationData
66
from ..evaluation.eval_util import EvaluationResultCollector
77
from ..evaluation.evaluator import VectorModelEvaluationData, VectorRegressionModelEvaluationData
88

99

10-
class TorchVectorRegressionModelEvaluationUtil(RegressionEvaluationUtil):
10+
class TorchVectorRegressionModelEvaluationUtil(RegressionModelEvaluation):
1111

1212
def _create_plots(self,
1313
data: Union[VectorRegressionModelEvaluationData, VectorRegressionModelCrossValidationData],

0 commit comments

Comments
 (0)