diff --git a/src/safeds/ml/nn/_model.py b/src/safeds/ml/nn/_model.py index 7947eb22c..9473f9314 100644 --- a/src/safeds/ml/nn/_model.py +++ b/src/safeds/ml/nn/_model.py @@ -1124,41 +1124,41 @@ def _get_best_cnn_model_column( match optimization_metric: case "accuracy": best_metric_value = ClassificationMetrics.accuracy( - predicted=fitted_model.predict(input_data).get_output(), # type: ignore[arg-type] + predicted=fitted_model.predict(input_data).get_output(), # type: ignore expected=expected) # type: ignore[arg-type] case "precision": best_metric_value = ClassificationMetrics.precision( - predicted=fitted_model.predict(input_data).get_output(), # type: ignore[arg-type] + predicted=fitted_model.predict(input_data).get_output(), # type: ignore expected=expected, positive_class=positive_class) # type: ignore[arg-type] case "recall": - best_metric_value = ClassificationMetrics.recall(predicted=fitted_model.predict(input_data).get_output(), # type: ignore[arg-type] + best_metric_value = ClassificationMetrics.recall(predicted=fitted_model.predict(input_data).get_output(), # type: ignore expected=expected, positive_class=positive_class) # type: ignore[arg-type] case "f1_score": best_metric_value = ClassificationMetrics.f1_score( - predicted=fitted_model.predict(input_data).get_output(), # type: ignore[arg-type] + predicted=fitted_model.predict(input_data).get_output(), # type: ignore expected=expected, positive_class=positive_class) # type: ignore[arg-type] else: match optimization_metric: case "accuracy": error_of_fitted_model = ClassificationMetrics.accuracy( - predicted=fitted_model.predict(input_data).get_output(), # type: ignore[arg-type] + predicted=fitted_model.predict(input_data).get_output(), # type: ignore expected=expected) # type: ignore[arg-type] if error_of_fitted_model > best_metric_value: best_model = fitted_model # pragma: no cover best_metric_value = error_of_fitted_model # pragma: no cover case "precision": error_of_fitted_model = ClassificationMetrics.precision( - predicted=fitted_model.predict(input_data).get_output(), expected=expected, # type: ignore[arg-type] + predicted=fitted_model.predict(input_data).get_output(), expected=expected, # type: ignore positive_class=positive_class) # type: ignore[arg-type] if error_of_fitted_model > best_metric_value: best_model = fitted_model # pragma: no cover best_metric_value = error_of_fitted_model # pragma: no cover case "recall": error_of_fitted_model = ClassificationMetrics.recall( - predicted=fitted_model.predict(input_data).get_output(), # type: ignore[arg-type] + predicted=fitted_model.predict(input_data).get_output(), # type: ignore expected=expected, positive_class=positive_class) # type: ignore[arg-type] if error_of_fitted_model > best_metric_value: @@ -1166,7 +1166,7 @@ def _get_best_cnn_model_column( best_metric_value = error_of_fitted_model # pragma: no cover case "f1_score": error_of_fitted_model = ClassificationMetrics.f1_score( - predicted=fitted_model.predict(input_data).get_output(), expected=expected, # type: ignore[arg-type] + predicted=fitted_model.predict(input_data).get_output(), expected=expected, # type: ignore positive_class=positive_class) # type: ignore[arg-type] if error_of_fitted_model > best_metric_value: best_model = fitted_model # pragma: no cover @@ -1196,7 +1196,7 @@ def _get_best_cnn_model_table( best_model = None best_metric_value = None for fitted_model in list_of_fitted_models: - prediction = self._inverse_one_hot_encode_by_index_of_column(fitted_model.predict(input_data).get_output()) # type: ignore[arg-type] #type: ignore[attr-defined] + prediction = self._inverse_one_hot_encode_by_index_of_column(fitted_model.predict(input_data).get_output()) # type: ignore if best_model is None: best_model = fitted_model match optimization_metric: diff --git a/tests/safeds/ml/nn/test_model.py b/tests/safeds/ml/nn/test_model.py index a3644088c..7fef2ea74 100644 --- a/tests/safeds/ml/nn/test_model.py +++ b/tests/safeds/ml/nn/test_model.py @@ -428,7 +428,7 @@ def test_should_assert_that_is_fitted_is_set_correctly_and_check_return_type_for ], ids=["accuracy", "precision", "recall", "f1_score"], ) - def test_should_assert_that_is_fitted_is_set_correctly_and_check_return_type_for_cnns_Column_Output( + def test_should_assert_that_is_fitted_is_set_correctly_and_check_return_type_for_cnns_column_output( self, metric: Literal["accuracy", "precision", "recall", "f1_score"], positive_class: Any, @@ -483,7 +483,7 @@ def test_should_assert_that_is_fitted_is_set_correctly_and_check_return_type_for ], ids=["accuracy", "precision", "recall", "f1_score"], ) - def test_should_assert_that_is_fitted_is_set_correctly_and_check_return_type_for_cnns_Table_Output( + def test_should_assert_that_is_fitted_is_set_correctly_and_check_return_type_for_cnns_table_output( self, metric: Literal["accuracy", "precision", "recall", "f1_score"], positive_class: Any,