Skip to content

Commit

Permalink
Support test_io_data for MultiDataEval and eval util
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Panchenko committed Feb 27, 2024
1 parent d560fe0 commit a7d1005
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions src/sensai/evaluation/eval_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ def create_vector_model_cross_validator(data: InputOutputData,

def create_evaluation_util(data: InputOutputData, model: VectorModel = None, is_regression: bool = None,
evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams]] = None,
cross_validator_params: Optional[Dict[str, Any]] = None) \
cross_validator_params: Optional[Dict[str, Any]] = None, test_io_data: Optional[InputOutputData] = None) \
-> Union["ClassificationModelEvaluation", "RegressionModelEvaluation"]:
if _is_regression(model, is_regression):
return RegressionModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params)
return RegressionModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params, test_io_data=test_io_data)
else:
return ClassificationModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params)
return ClassificationModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params, test_io_data=test_io_data)


def eval_model_via_evaluator(model: TModel, io_data: InputOutputData, test_fraction=0.2,
Expand Down Expand Up @@ -576,16 +576,21 @@ class MultiDataModelEvaluation:
def __init__(self, io_data_dict: Dict[str, InputOutputData], key_name: str = "dataset",
meta_data_dict: Optional[Dict[str, Dict[str, Any]]] = None,
evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams, Dict[str, Any]]] = None,
cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None):
cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None,
test_io_data_dict: Optional[Dict[str, InputOutputData]] = None):
"""
:param io_data_dict: a dictionary mapping from names to the data sets with which to evaluate models
:param key_name: a name for the key value used in inputOutputDataDict, which will be used as a column name in result data frames
:param meta_data_dict: a dictionary which maps from a name (same keys as in inputOutputDataDict) to a dictionary, which maps
from a column name to a value and which is to be used to extend the result data frames containing per-dataset results
:param evaluator_params: parameters to use for the instantiation of evaluators (relevant if useCrossValidation==False)
:param cross_validator_params: parameters to use for the instantiation of cross-validators (relevant if useCrossValidation==True)
:param test_io_data_dict: a dictionary mapping from names to the test data sets to use for evaluation. Datasets under the same
keys as in io_data_dict will be used for evaluation of the models that were trained on the respective io_data_dict.
The keys don't need to be the same as io_data_dict, unused keys are ignored, and for missing keys the test_io_data is None.
"""
self.io_data_dict = io_data_dict
self.test_io_data_dict = test_io_data_dict or {}
self.key_name = key_name
self.evaluator_params = evaluator_params
self.cross_validator_params = cross_validator_params
Expand Down Expand Up @@ -659,8 +664,9 @@ def compare_models(self,
else:
raise ValueError("The models have to be either all regression models or all classification, not a mixture")

test_io_data = self.test_io_data_dict.get(key)
ev = create_evaluation_util(inputOutputData, is_regression=is_regression, evaluator_params=self.evaluator_params,
cross_validator_params=self.cross_validator_params)
cross_validator_params=self.cross_validator_params, test_io_data=test_io_data)

if plot_collector is None:
plot_collector = ev.eval_stats_plot_collector
Expand Down

0 comments on commit a7d1005

Please sign in to comment.