Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

347 refactor validation category #368

Merged
merged 13 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/evaluation/calculate_base_metrics.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Calculate base metrics

::: eis_toolkit.evaluation.calculate_base_metrics
3 changes: 3 additions & 0 deletions docs/evaluation/classification_label_evaluation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Classification label evaluation

::: eis_toolkit.evaluation.classification_label_evaluation
3 changes: 3 additions & 0 deletions docs/evaluation/classification_probability_evaluation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Classification probability evaluation

::: eis_toolkit.evaluation.classification_probability_evaluation
3 changes: 3 additions & 0 deletions docs/evaluation/plot_confusion_matrix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Plot confusion matrix

::: eis_toolkit.evaluation.plot_confusion_matrix
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Plot neural network training performance (accuracy and loss)

::: eis_toolkit.validation.plot_nn_model_performance
::: eis_toolkit.evaluation.plot_nn_model_performance
3 changes: 3 additions & 0 deletions docs/evaluation/plot_prediction_area_curves.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Plot prediction-area (P-A) curves

::: eis_toolkit.evaluation.plot_prediction_area_curves
3 changes: 3 additions & 0 deletions docs/evaluation/plot_rate_curve.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Plot rate curve

::: eis_toolkit.evaluation.plot_rate_curve
3 changes: 0 additions & 3 deletions docs/validation/calculate_auc.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/validation/calculate_base_metrics.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/validation/get_pa_intersection.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/validation/plot_confusion_matrix.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/validation/plot_prediction_area_curves.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/validation/plot_rate_curve.md

This file was deleted.

2 changes: 1 addition & 1 deletion eis_toolkit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2879,7 +2879,7 @@ def winsorize_transform_cli(
typer.echo(f"Winsorize transform completed, writing raster to {output_raster}.")


# ---VALIDATION ---
# ---EVALUATION ---
# TODO


Expand Down
File renamed without changes.
37 changes: 37 additions & 0 deletions eis_toolkit/evaluation/classification_label_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from numbers import Number
from typing import Dict

import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix, precision_recall_fscore_support


def summarize_label_metrics_binary(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, Number]:
"""
Generate a comprehensive report of various evaluation metrics for binary classification results.

The output includes accuracy, precision, recall, F1 scores and confusion matrix elements
(true negatives, false positives, false negatives, true positives).

Args:
y_true: True labels.
y_pred: Predicted labels. The array should come from a binary classifier.

Returns:
A dictionary containing the evaluated metrics.
"""
metrics = {}

metrics["Accuracy"] = accuracy_score(y_true, y_pred)

precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary")
metrics["Precision"] = precision
metrics["Recall"] = recall
metrics["F1_score"] = f1

tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
metrics["True_negatives"] = tn
metrics["False_positives"] = fp
metrics["False_negatives"] = fn
metrics["True_positives"] = tp

return metrics
192 changes: 192 additions & 0 deletions eis_toolkit/evaluation/classification_probability_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from typing import Dict

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from beartype.typing import Optional
from sklearn.calibration import CalibrationDisplay
from sklearn.metrics import (
DetCurveDisplay,
PrecisionRecallDisplay,
RocCurveDisplay,
average_precision_score,
brier_score_loss,
log_loss,
roc_auc_score,
)


def summarize_probability_metrics(y_true: np.ndarray, y_prob: np.ndarray) -> Dict[str, float]:
"""
Generate a comprehensive report of various evaluation metrics for classification probabilities.

The output includes ROC AUC, log loss, average precision and Brier score loss.

Args:
y_true: True labels.
y_prob: Predicted probabilities for the positive class. The array should come from
a binary classifier.

Returns:
A dictionary containing the evaluated metrics.
"""
metrics = {}

metrics["roc_auc"] = roc_auc_score(y_true, y_prob)
metrics["log_loss"] = log_loss(y_true, y_prob)
metrics["average_precision"] = average_precision_score(y_true, y_prob)
metrics["brier_score_loss"] = brier_score_loss(y_true, y_prob)

return metrics


def plot_roc_curve(
y_true: np.ndarray,
y_prob: np.ndarray,
plot_title: Optional[str] = "ROC curve",
ax: Optional[plt.Axes] = None,
**kwargs
) -> plt.Axes:
"""
Plot ROC (receiver operating characteristic) curve.

ROC curve is a binary classification multi-threshold metric. The ideal performance corner of the plot
is top-left. AUC of the ROC curve summarizes model performance across different classification thresholds.

Args:
y_true: True labels.
y_prob: Predicted probabilities for the positive class. The array should come from
a binary classifier.
plot_title: Title for the plot. Defaults to "ROC curve".
ax: An existing Axes in which to draw the plot. Defaults to None.
**kwargs: Additional keyword arguments passed to matplotlib.pyplot.plot.

Returns:
Matplotlib axes containing the plot.
"""
display = RocCurveDisplay.from_predictions(y_true, y_prob, plot_chance_level=True, ax=ax, **kwargs)
out_ax = display.ax_
out_ax.set(xlabel="False positive rate", ylabel="True positive rate", title=plot_title)
return out_ax


def plot_det_curve(
y_true: np.ndarray,
y_prob: np.ndarray,
plot_title: Optional[str] = "DET curve",
ax: Optional[plt.Axes] = None,
**kwargs
) -> plt.Axes:
"""
Plot DET (detection error tradeoff) curve.

DET curve is a binary classification multi-threshold metric. DET curves are a variation of ROC curves where
False Negative Rate is plotted on the y-axis instead of True Positive Rate. The ideal performance corner of
the plot is bottom-left. When comparing the performance of different models, DET curves can be
slightly easier to assess visually than ROC curves.

Args:
y_true: True labels.
y_prob: Predicted probabilities for the positive class. The array should come from
a binary classifier.
plot_title: Title for the plot. Defaults to "DET curve".
ax: An existing Axes in which to draw the plot. Defaults to None.
**kwargs: Additional keyword arguments passed to matplotlib.pyplot.plot.

Returns:
Matplotlib axes containing the plot.
"""
display = DetCurveDisplay.from_predictions(y_true, y_prob, ax=ax, **kwargs)
out_ax = display.ax_
out_ax.set(xlabel="False positive rate", ylabel="False negative rate", title=plot_title)
return out_ax


def plot_precision_recall_curve(
y_true: np.ndarray,
y_prob: np.ndarray,
plot_title: Optional[str] = "Precision-Recall curve",
ax: Optional[plt.Axes] = None,
**kwargs
) -> plt.Axes:
"""
Plot precision-recall curve.

Precision-recall curve is a binary classification multi-threshold metric. Precision-recall curve shows
the tradeoff between precision and recall for different classification thresholds.
It can be a useful measure of success when classes are imbalanced.

Args:
y_true: True labels.
y_prob: Predicted probabilities for the positive class. The array should come from
a binary classifier.
plot_title: Title for the plot. Defaults to "Precision-Recall curve".
ax: An existing Axes in which to draw the plot. Defaults to None.
**kwargs: Additional keyword arguments passed to matplotlib.pyplot.plot.

Returns:
Matplotlib axes containing the plot.
"""
display = PrecisionRecallDisplay.from_predictions(y_true, y_prob, plot_chance_level=True, ax=ax, **kwargs)
out_ax = display.ax_
out_ax.set(xlabel="Recall", ylabel="Precision", title=plot_title)
return out_ax


def plot_calibration_curve(
y_true: np.ndarray,
y_prob: np.ndarray,
n_bins: int = 5,
plot_title: Optional[str] = "Calibration curve",
ax: Optional[plt.Axes] = None,
**kwargs
) -> plt.Axes:
"""
Plot calibration curve (aka realibity diagram).

Calibration curve has the frequency of the positive labels on the y-axis and the predicted probability on
the x-axis. Generally, the close the calibration curve is to line x=y, the better the model is calibrated.

Args:
y_true: True labels.
y_prob: Predicted probabilities for the positive class. The array should come from
a binary classifier.
plot_title: Title for the plot. Defaults to "Precision-Recall curve".
ax: An existing Axes in which to draw the plot. Defaults to None.
**kwargs: Additional keyword arguments passed to matplotlib.pyplot.plot.

Returns:
Matplotlib axes containing the plot.
"""
display = CalibrationDisplay.from_predictions(y_true, y_prob, n_bins=n_bins, ax=ax, **kwargs)
out_ax = display.ax_
out_ax.set(xlabel="Mean predicted probability", ylabel="Fraction of positives", title=plot_title)
return out_ax


def plot_predicted_probability_distribution(
y_prob: np.ndarray,
n_bins: int = 5,
plot_title: Optional[str] = "Distribution of predicted probabilities",
ax: Optional[plt.Axes] = None,
**kwargs
) -> plt.Axes:
"""
Plot a histogram of the predicted probabilities.

Args:
y_prob: Predicted probabilities for the positive class. The array should come from
a binary classifier.
n_bins: Number of bins used for the histogram. Defaults to 5.
plot_title: Title for the plot. Defaults to "Distribution of predicted probabilities".
ax: An existing Axes in which to draw the plot. Defaults to None.
**kwargs: Additional keyword arguments passed to sns.histplot and matplotlib.

Returns:
Matplolib axes containing the plot.
"""
sns.set_theme(style="white")
plt.figure()
out_ax = sns.histplot(y_prob, bins=n_bins, ax=ax, **kwargs)
out_ax.set(xlabel="Predicted probability", ylabel="Count", title=plot_title)
return out_ax
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

@beartype
def plot_confusion_matrix(
confusion_matrix: np.ndarray, cmap: Optional[Union[str, Colormap, Sequence]] = None
confusion_matrix: np.ndarray,
cmap: Optional[Union[str, Colormap, Sequence]] = None,
plot_title: str = "Confusion matrix",
ax: Optional[plt.Axes] = None,
**kwargs,
) -> plt.Axes:
"""Plot confusion matrix to visualize classification results.

Expand All @@ -19,6 +23,9 @@ def plot_confusion_matrix(
(upper-left corner) to have True negatives.
cmap: Colormap name, matploltib colormap objects or list of colors for coloring the plot.
Optional parameter.
plot_title: Title for the plot. Defaults to "Confusion matrix".
ax: An existing Axes in which to draw the plot. Defaults to None.
**kwargs: Additional keyword arguments passed to sns.heatmap.

Returns:
Matplotlib axes containing the plot.
Expand All @@ -40,7 +47,7 @@ def plot_confusion_matrix(
else:
labels = np.asarray([f"{v1}\n{v2}" for v1, v2 in zip(counts, percentages)]).reshape(shape)

ax = sns.heatmap(confusion_matrix, annot=labels, fmt="", cmap=cmap)
ax.set(xlabel="Predicted label", ylabel="True label")
out_ax = sns.heatmap(confusion_matrix, annot=labels, fmt="", cmap=cmap, ax=ax, **kwargs)
out_ax.set(xlabel="Predicted label", ylabel="True label", title=plot_title)

return ax
return out_ax
Loading
Loading