-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #368 from GispoCoding/347-refactor-validation-cate…
…gory 347 refactor validation category, now called evaluation
- Loading branch information
Showing
40 changed files
with
1,082 additions
and
654 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Calculate base metrics | ||
|
||
::: eis_toolkit.evaluation.calculate_base_metrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Classification label evaluation | ||
|
||
::: eis_toolkit.evaluation.classification_label_evaluation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Classification probability evaluation | ||
|
||
::: eis_toolkit.evaluation.classification_probability_evaluation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Plot confusion matrix | ||
|
||
::: eis_toolkit.evaluation.plot_confusion_matrix |
2 changes: 1 addition & 1 deletion
2
docs/validation/plot_nn_model_performance.md → docs/evaluation/plot_nn_model_performance.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
# Plot neural network training performance (accuracy and loss) | ||
|
||
::: eis_toolkit.validation.plot_nn_model_performance | ||
::: eis_toolkit.evaluation.plot_nn_model_performance |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Plot prediction-area (P-A) curves | ||
|
||
::: eis_toolkit.evaluation.plot_prediction_area_curves |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Plot rate curve | ||
|
||
::: eis_toolkit.evaluation.plot_rate_curve |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from numbers import Number | ||
from typing import Dict | ||
|
||
import numpy as np | ||
from sklearn.metrics import accuracy_score, confusion_matrix, precision_recall_fscore_support | ||
|
||
|
||
def summarize_label_metrics_binary(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, Number]: | ||
""" | ||
Generate a comprehensive report of various evaluation metrics for binary classification results. | ||
The output includes accuracy, precision, recall, F1 scores and confusion matrix elements | ||
(true negatives, false positives, false negatives, true positives). | ||
Args: | ||
y_true: True labels. | ||
y_pred: Predicted labels. The array should come from a binary classifier. | ||
Returns: | ||
A dictionary containing the evaluated metrics. | ||
""" | ||
metrics = {} | ||
|
||
metrics["Accuracy"] = accuracy_score(y_true, y_pred) | ||
|
||
precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary") | ||
metrics["Precision"] = precision | ||
metrics["Recall"] = recall | ||
metrics["F1_score"] = f1 | ||
|
||
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel() | ||
metrics["True_negatives"] = tn | ||
metrics["False_positives"] = fp | ||
metrics["False_negatives"] = fn | ||
metrics["True_positives"] = tp | ||
|
||
return metrics |
192 changes: 192 additions & 0 deletions
192
eis_toolkit/evaluation/classification_probability_evaluation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
from typing import Dict | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import seaborn as sns | ||
from beartype.typing import Optional | ||
from sklearn.calibration import CalibrationDisplay | ||
from sklearn.metrics import ( | ||
DetCurveDisplay, | ||
PrecisionRecallDisplay, | ||
RocCurveDisplay, | ||
average_precision_score, | ||
brier_score_loss, | ||
log_loss, | ||
roc_auc_score, | ||
) | ||
|
||
|
||
def summarize_probability_metrics(y_true: np.ndarray, y_prob: np.ndarray) -> Dict[str, float]: | ||
""" | ||
Generate a comprehensive report of various evaluation metrics for classification probabilities. | ||
The output includes ROC AUC, log loss, average precision and Brier score loss. | ||
Args: | ||
y_true: True labels. | ||
y_prob: Predicted probabilities for the positive class. The array should come from | ||
a binary classifier. | ||
Returns: | ||
A dictionary containing the evaluated metrics. | ||
""" | ||
metrics = {} | ||
|
||
metrics["roc_auc"] = roc_auc_score(y_true, y_prob) | ||
metrics["log_loss"] = log_loss(y_true, y_prob) | ||
metrics["average_precision"] = average_precision_score(y_true, y_prob) | ||
metrics["brier_score_loss"] = brier_score_loss(y_true, y_prob) | ||
|
||
return metrics | ||
|
||
|
||
def plot_roc_curve( | ||
y_true: np.ndarray, | ||
y_prob: np.ndarray, | ||
plot_title: Optional[str] = "ROC curve", | ||
ax: Optional[plt.Axes] = None, | ||
**kwargs | ||
) -> plt.Axes: | ||
""" | ||
Plot ROC (receiver operating characteristic) curve. | ||
ROC curve is a binary classification multi-threshold metric. The ideal performance corner of the plot | ||
is top-left. AUC of the ROC curve summarizes model performance across different classification thresholds. | ||
Args: | ||
y_true: True labels. | ||
y_prob: Predicted probabilities for the positive class. The array should come from | ||
a binary classifier. | ||
plot_title: Title for the plot. Defaults to "ROC curve". | ||
ax: An existing Axes in which to draw the plot. Defaults to None. | ||
**kwargs: Additional keyword arguments passed to matplotlib.pyplot.plot. | ||
Returns: | ||
Matplotlib axes containing the plot. | ||
""" | ||
display = RocCurveDisplay.from_predictions(y_true, y_prob, plot_chance_level=True, ax=ax, **kwargs) | ||
out_ax = display.ax_ | ||
out_ax.set(xlabel="False positive rate", ylabel="True positive rate", title=plot_title) | ||
return out_ax | ||
|
||
|
||
def plot_det_curve( | ||
y_true: np.ndarray, | ||
y_prob: np.ndarray, | ||
plot_title: Optional[str] = "DET curve", | ||
ax: Optional[plt.Axes] = None, | ||
**kwargs | ||
) -> plt.Axes: | ||
""" | ||
Plot DET (detection error tradeoff) curve. | ||
DET curve is a binary classification multi-threshold metric. DET curves are a variation of ROC curves where | ||
False Negative Rate is plotted on the y-axis instead of True Positive Rate. The ideal performance corner of | ||
the plot is bottom-left. When comparing the performance of different models, DET curves can be | ||
slightly easier to assess visually than ROC curves. | ||
Args: | ||
y_true: True labels. | ||
y_prob: Predicted probabilities for the positive class. The array should come from | ||
a binary classifier. | ||
plot_title: Title for the plot. Defaults to "DET curve". | ||
ax: An existing Axes in which to draw the plot. Defaults to None. | ||
**kwargs: Additional keyword arguments passed to matplotlib.pyplot.plot. | ||
Returns: | ||
Matplotlib axes containing the plot. | ||
""" | ||
display = DetCurveDisplay.from_predictions(y_true, y_prob, ax=ax, **kwargs) | ||
out_ax = display.ax_ | ||
out_ax.set(xlabel="False positive rate", ylabel="False negative rate", title=plot_title) | ||
return out_ax | ||
|
||
|
||
def plot_precision_recall_curve( | ||
y_true: np.ndarray, | ||
y_prob: np.ndarray, | ||
plot_title: Optional[str] = "Precision-Recall curve", | ||
ax: Optional[plt.Axes] = None, | ||
**kwargs | ||
) -> plt.Axes: | ||
""" | ||
Plot precision-recall curve. | ||
Precision-recall curve is a binary classification multi-threshold metric. Precision-recall curve shows | ||
the tradeoff between precision and recall for different classification thresholds. | ||
It can be a useful measure of success when classes are imbalanced. | ||
Args: | ||
y_true: True labels. | ||
y_prob: Predicted probabilities for the positive class. The array should come from | ||
a binary classifier. | ||
plot_title: Title for the plot. Defaults to "Precision-Recall curve". | ||
ax: An existing Axes in which to draw the plot. Defaults to None. | ||
**kwargs: Additional keyword arguments passed to matplotlib.pyplot.plot. | ||
Returns: | ||
Matplotlib axes containing the plot. | ||
""" | ||
display = PrecisionRecallDisplay.from_predictions(y_true, y_prob, plot_chance_level=True, ax=ax, **kwargs) | ||
out_ax = display.ax_ | ||
out_ax.set(xlabel="Recall", ylabel="Precision", title=plot_title) | ||
return out_ax | ||
|
||
|
||
def plot_calibration_curve( | ||
y_true: np.ndarray, | ||
y_prob: np.ndarray, | ||
n_bins: int = 5, | ||
plot_title: Optional[str] = "Calibration curve", | ||
ax: Optional[plt.Axes] = None, | ||
**kwargs | ||
) -> plt.Axes: | ||
""" | ||
Plot calibration curve (aka realibity diagram). | ||
Calibration curve has the frequency of the positive labels on the y-axis and the predicted probability on | ||
the x-axis. Generally, the close the calibration curve is to line x=y, the better the model is calibrated. | ||
Args: | ||
y_true: True labels. | ||
y_prob: Predicted probabilities for the positive class. The array should come from | ||
a binary classifier. | ||
plot_title: Title for the plot. Defaults to "Precision-Recall curve". | ||
ax: An existing Axes in which to draw the plot. Defaults to None. | ||
**kwargs: Additional keyword arguments passed to matplotlib.pyplot.plot. | ||
Returns: | ||
Matplotlib axes containing the plot. | ||
""" | ||
display = CalibrationDisplay.from_predictions(y_true, y_prob, n_bins=n_bins, ax=ax, **kwargs) | ||
out_ax = display.ax_ | ||
out_ax.set(xlabel="Mean predicted probability", ylabel="Fraction of positives", title=plot_title) | ||
return out_ax | ||
|
||
|
||
def plot_predicted_probability_distribution( | ||
y_prob: np.ndarray, | ||
n_bins: int = 5, | ||
plot_title: Optional[str] = "Distribution of predicted probabilities", | ||
ax: Optional[plt.Axes] = None, | ||
**kwargs | ||
) -> plt.Axes: | ||
""" | ||
Plot a histogram of the predicted probabilities. | ||
Args: | ||
y_prob: Predicted probabilities for the positive class. The array should come from | ||
a binary classifier. | ||
n_bins: Number of bins used for the histogram. Defaults to 5. | ||
plot_title: Title for the plot. Defaults to "Distribution of predicted probabilities". | ||
ax: An existing Axes in which to draw the plot. Defaults to None. | ||
**kwargs: Additional keyword arguments passed to sns.histplot and matplotlib. | ||
Returns: | ||
Matplolib axes containing the plot. | ||
""" | ||
sns.set_theme(style="white") | ||
plt.figure() | ||
out_ax = sns.histplot(y_prob, bins=n_bins, ax=ax, **kwargs) | ||
out_ax.set(xlabel="Predicted probability", ylabel="Count", title=plot_title) | ||
return out_ax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
Oops, something went wrong.