Skip to content

Commit

Permalink
linter fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sibre28 committed Aug 19, 2024
1 parent 78c7180 commit 3a4233a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
18 changes: 9 additions & 9 deletions src/safeds/ml/nn/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,49 +1124,49 @@ def _get_best_cnn_model_column(
match optimization_metric:
case "accuracy":
best_metric_value = ClassificationMetrics.accuracy(
predicted=fitted_model.predict(input_data).get_output(), # type: ignore[arg-type]
predicted=fitted_model.predict(input_data).get_output(), # type: ignore
expected=expected) # type: ignore[arg-type]
case "precision":
best_metric_value = ClassificationMetrics.precision(
predicted=fitted_model.predict(input_data).get_output(), # type: ignore[arg-type]
predicted=fitted_model.predict(input_data).get_output(), # type: ignore
expected=expected,
positive_class=positive_class) # type: ignore[arg-type]
case "recall":
best_metric_value = ClassificationMetrics.recall(predicted=fitted_model.predict(input_data).get_output(), # type: ignore[arg-type]
best_metric_value = ClassificationMetrics.recall(predicted=fitted_model.predict(input_data).get_output(), # type: ignore
expected=expected,
positive_class=positive_class) # type: ignore[arg-type]
case "f1_score":
best_metric_value = ClassificationMetrics.f1_score(
predicted=fitted_model.predict(input_data).get_output(), # type: ignore[arg-type]
predicted=fitted_model.predict(input_data).get_output(), # type: ignore
expected=expected,
positive_class=positive_class) # type: ignore[arg-type]
else:
match optimization_metric:
case "accuracy":
error_of_fitted_model = ClassificationMetrics.accuracy(
predicted=fitted_model.predict(input_data).get_output(), # type: ignore[arg-type]
predicted=fitted_model.predict(input_data).get_output(), # type: ignore
expected=expected) # type: ignore[arg-type]
if error_of_fitted_model > best_metric_value:
best_model = fitted_model # pragma: no cover
best_metric_value = error_of_fitted_model # pragma: no cover
case "precision":
error_of_fitted_model = ClassificationMetrics.precision(
predicted=fitted_model.predict(input_data).get_output(), expected=expected, # type: ignore[arg-type]
predicted=fitted_model.predict(input_data).get_output(), expected=expected, # type: ignore
positive_class=positive_class) # type: ignore[arg-type]
if error_of_fitted_model > best_metric_value:
best_model = fitted_model # pragma: no cover
best_metric_value = error_of_fitted_model # pragma: no cover
case "recall":
error_of_fitted_model = ClassificationMetrics.recall(
predicted=fitted_model.predict(input_data).get_output(), # type: ignore[arg-type]
predicted=fitted_model.predict(input_data).get_output(), # type: ignore
expected=expected,
positive_class=positive_class) # type: ignore[arg-type]
if error_of_fitted_model > best_metric_value:
best_model = fitted_model # pragma: no cover
best_metric_value = error_of_fitted_model # pragma: no cover
case "f1_score":
error_of_fitted_model = ClassificationMetrics.f1_score(
predicted=fitted_model.predict(input_data).get_output(), expected=expected, # type: ignore[arg-type]
predicted=fitted_model.predict(input_data).get_output(), expected=expected, # type: ignore
positive_class=positive_class) # type: ignore[arg-type]
if error_of_fitted_model > best_metric_value:
best_model = fitted_model # pragma: no cover
Expand Down Expand Up @@ -1196,7 +1196,7 @@ def _get_best_cnn_model_table(
best_model = None
best_metric_value = None
for fitted_model in list_of_fitted_models:
prediction = self._inverse_one_hot_encode_by_index_of_column(fitted_model.predict(input_data).get_output()) # type: ignore[arg-type] #type: ignore[attr-defined]
prediction = self._inverse_one_hot_encode_by_index_of_column(fitted_model.predict(input_data).get_output()) # type: ignore
if best_model is None:
best_model = fitted_model
match optimization_metric:
Expand Down
4 changes: 2 additions & 2 deletions tests/safeds/ml/nn/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def test_should_assert_that_is_fitted_is_set_correctly_and_check_return_type_for
],
ids=["accuracy", "precision", "recall", "f1_score"],
)
def test_should_assert_that_is_fitted_is_set_correctly_and_check_return_type_for_cnns_Column_Output(
def test_should_assert_that_is_fitted_is_set_correctly_and_check_return_type_for_cnns_column_output(
self,
metric: Literal["accuracy", "precision", "recall", "f1_score"],
positive_class: Any,
Expand Down Expand Up @@ -483,7 +483,7 @@ def test_should_assert_that_is_fitted_is_set_correctly_and_check_return_type_for
],
ids=["accuracy", "precision", "recall", "f1_score"],
)
def test_should_assert_that_is_fitted_is_set_correctly_and_check_return_type_for_cnns_Table_Output(
def test_should_assert_that_is_fitted_is_set_correctly_and_check_return_type_for_cnns_table_output(
self,
metric: Literal["accuracy", "precision", "recall", "f1_score"],
positive_class: Any,
Expand Down

0 comments on commit 3a4233a

Please sign in to comment.