diff --git a/src/sensai/evaluation/eval_util.py b/src/sensai/evaluation/eval_util.py index 78623472..e92084e3 100644 --- a/src/sensai/evaluation/eval_util.py +++ b/src/sensai/evaluation/eval_util.py @@ -91,12 +91,12 @@ def create_vector_model_cross_validator(data: InputOutputData, def create_evaluation_util(data: InputOutputData, model: VectorModel = None, is_regression: bool = None, evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams]] = None, - cross_validator_params: Optional[Dict[str, Any]] = None) \ + cross_validator_params: Optional[Dict[str, Any]] = None, test_io_data: Optional[InputOutputData] = None) \ -> Union["ClassificationModelEvaluation", "RegressionModelEvaluation"]: if _is_regression(model, is_regression): - return RegressionModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params) + return RegressionModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params, test_io_data=test_io_data) else: - return ClassificationModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params) + return ClassificationModelEvaluation(data, evaluator_params=evaluator_params, cross_validator_params=cross_validator_params, test_io_data=test_io_data) def eval_model_via_evaluator(model: TModel, io_data: InputOutputData, test_fraction=0.2, @@ -576,7 +576,8 @@ class MultiDataModelEvaluation: def __init__(self, io_data_dict: Dict[str, InputOutputData], key_name: str = "dataset", meta_data_dict: Optional[Dict[str, Dict[str, Any]]] = None, evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams, Dict[str, Any]]] = None, - cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None): + cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None, + test_io_data_dict: Optional[Dict[str, InputOutputData]] = None): """ :param io_data_dict: a dictionary mapping from names to the data sets with which to evaluate models :param key_name: a name for the key value used in inputOutputDataDict, which will be used as a column name in result data frames @@ -584,8 +585,12 @@ def __init__(self, io_data_dict: Dict[str, InputOutputData], key_name: str = "da from a column name to a value and which is to be used to extend the result data frames containing per-dataset results :param evaluator_params: parameters to use for the instantiation of evaluators (relevant if useCrossValidation==False) :param cross_validator_params: parameters to use for the instantiation of cross-validators (relevant if useCrossValidation==True) + :param test_io_data_dict: a dictionary mapping from names to the test data sets to use for evaluation. Datasets under the same + keys as in io_data_dict will be used for evaluation of the models that were trained on the respective io_data_dict. + The keys don't need to be the same as io_data_dict, unused keys are ignored, and for missing keys the test_io_data is None. """ self.io_data_dict = io_data_dict + self.test_io_data_dict = test_io_data_dict or {} self.key_name = key_name self.evaluator_params = evaluator_params self.cross_validator_params = cross_validator_params @@ -659,8 +664,9 @@ def compare_models(self, else: raise ValueError("The models have to be either all regression models or all classification, not a mixture") + test_io_data = self.test_io_data_dict.get(key) ev = create_evaluation_util(inputOutputData, is_regression=is_regression, evaluator_params=self.evaluator_params, - cross_validator_params=self.cross_validator_params) + cross_validator_params=self.cross_validator_params, test_io_data=test_io_data) if plot_collector is None: plot_collector = ev.eval_stats_plot_collector