Skip to content

Commit

Permalink
Sync pop
Browse files Browse the repository at this point in the history
  commit 4f31bb5714cff5d0fff7879c7d70752fcbbfdbe1
  Author: Dominik Jain <[email protected]>
  Date:   Thu Aug 3 17:49:19 2023 +0200

      VectorClassificationModelEvaluatorParams: Improve docstrings

  src/sensai/evaluation/evaluator.py

  commit 7f6dd0136b6597c7238d4f22d37918e33a6aca53
  Author: Dominik Jain <[email protected]>
  Date:   Thu Aug 3 17:44:00 2023 +0200

      EvalStatsClassification:
        Adjust naming of attributes to use snake_case,
        making is_probabilities_available a public attributes

      Fix: Do not attempt to create precision/recall plots if
      class probabilities are unavailable

  src/sensai/evaluation/eval_stats/eval_stats_classification.py
  • Loading branch information
opcode81 committed Aug 3, 2023
1 parent d6e4c36 commit dd237eb
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 28 deletions.
52 changes: 26 additions & 26 deletions src/sensai/evaluation/eval_stats/eval_stats_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, name=None, bounds: Tuple[float, float] = (0, 1), requires_pro
else self.__class__.requires_probabilities

def compute_value_for_eval_stats(self, eval_stats: "ClassificationEvalStats"):
return self.compute_value(eval_stats.y_true, eval_stats.y_predicted, eval_stats.y_predictedClassProbabilities)
return self.compute_value(eval_stats.y_true, eval_stats.y_predicted, eval_stats.y_predicted_class_probabilities)

def compute_value(self, y_true, y_predicted, y_predicted_class_probabilities=None):
if self.requires_probabilities and y_predicted_class_probabilities is None:
Expand Down Expand Up @@ -348,9 +348,9 @@ def __init__(self, y_predicted: PredictionArray = None,
if None, treat the problem as non-binary, regardless of the labels being used.
"""
self.labels = labels
self.y_predictedClassProbabilities = y_predicted_class_probabilities
self._probabilitiesAvailable = y_predicted_class_probabilities is not None
if self._probabilitiesAvailable:
self.y_predicted_class_probabilities = y_predicted_class_probabilities
self.is_probabilities_available = y_predicted_class_probabilities is not None
if self.is_probabilities_available:
col_set = set(y_predicted_class_probabilities.columns)
if col_set != set(labels):
raise ValueError(f"Columns in class probabilities data frame ({y_predicted_class_probabilities.columns}) do not "
Expand All @@ -377,22 +377,22 @@ def __init__(self, y_predicted: PredictionArray = None,
if num_labels == 2 and binary_positive_label is None:
log.warning(f"Binary classification (labels={labels}) without specification of positive class label; "
f"binary classification metrics will not be considered")
self.binaryPositiveLabel = binary_positive_label
self.isBinary = binary_positive_label is not None
self.binary_positive_label = binary_positive_label
self.is_binary = binary_positive_label is not None

if metrics is None:
metrics = [ClassificationMetricAccuracy(), ClassificationMetricBalancedAccuracy(),
ClassificationMetricGeometricMeanOfTrueClassProbability()]
if self.isBinary:
if self.is_binary:
metrics.extend([
BinaryClassificationMetricPrecision(self.binaryPositiveLabel),
BinaryClassificationMetricRecall(self.binaryPositiveLabel),
BinaryClassificationMetricF1Score(self.binaryPositiveLabel)])
BinaryClassificationMetricPrecision(self.binary_positive_label),
BinaryClassificationMetricRecall(self.binary_positive_label),
BinaryClassificationMetricF1Score(self.binary_positive_label)])

metrics = list(metrics)
if additional_metrics is not None:
for m in additional_metrics:
if not self._probabilitiesAvailable and m.requires_probabilities:
if not self.is_probabilities_available and m.requires_probabilities:
raise ValueError(f"Additional metric {m} not supported, as class probabilities were not provided")

super().__init__(y_predicted, y_true, metrics, additional_metrics=additional_metrics)
Expand All @@ -417,7 +417,7 @@ def get_accuracy(self):
def metrics_dict(self) -> Dict[str, float]:
d = {}
for metric in self.metrics:
if not metric.requires_probabilities or self._probabilitiesAvailable:
if not metric.requires_probabilities or self.is_probabilities_available:
d[metric.name] = self.compute_metric_value(metric)
return d

Expand All @@ -431,13 +431,13 @@ def plot_confusion_matrix(self, normalize=True, title_add: str = None):

def plot_precision_recall_curve(self, title_add: str = None):
from sklearn.metrics import PrecisionRecallDisplay # only supported by newer versions of sklearn
if not self._probabilitiesAvailable:
if not self.is_probabilities_available:
raise Exception("Precision-recall curve requires probabilities")
if not self.isBinary:
if not self.is_binary:
raise Exception("Precision-recall curve is not applicable to non-binary classification")
probabilities = self.y_predictedClassProbabilities[self.binaryPositiveLabel]
probabilities = self.y_predicted_class_probabilities[self.binary_positive_label]
precision, recall, thresholds = precision_recall_curve(y_true=self.y_true, probas_pred=probabilities,
pos_label=self.binaryPositiveLabel)
pos_label=self.binary_positive_label)
disp = PrecisionRecallDisplay(precision, recall)
disp.plot()
ax: plt.Axes = disp.ax_
Expand Down Expand Up @@ -468,14 +468,14 @@ def get_combined_eval_stats(self) -> ClassificationEvalStats:
y_true = np.concatenate([evalStats.y_true for evalStats in self.statsList])
y_predicted = np.concatenate([evalStats.y_predicted for evalStats in self.statsList])
es0 = self.statsList[0]
if es0.y_predictedClassProbabilities is not None:
y_probs = pd.concat([evalStats.y_predictedClassProbabilities for evalStats in self.statsList])
if es0.y_predicted_class_probabilities is not None:
y_probs = pd.concat([evalStats.y_predicted_class_probabilities for evalStats in self.statsList])
labels = list(y_probs.columns)
else:
y_probs = None
labels = es0.labels
self.globalStats = ClassificationEvalStats(y_predicted=y_predicted, y_true=y_true, y_predicted_class_probabilities=y_probs,
labels=labels, binary_positive_label=es0.binaryPositiveLabel, metrics=es0.metrics)
labels=labels, binary_positive_label=es0.binary_positive_label, metrics=es0.metrics)
return self.globalStats


Expand Down Expand Up @@ -522,12 +522,12 @@ def from_probability_threshold(cls, probabilities: Sequence[float], threshold: f

@classmethod
def from_eval_stats(cls, eval_stats: ClassificationEvalStats, threshold=0.5) -> "BinaryClassificationCounts":
if not eval_stats.isBinary:
if not eval_stats.is_binary:
raise ValueError("Probability threshold variation data can only be computed for binary classification problems")
if eval_stats.y_predictedClassProbabilities is None:
if eval_stats.y_predicted_class_probabilities is None:
raise ValueError("No probability data")
pos_class_label = eval_stats.binaryPositiveLabel
probs = eval_stats.y_predictedClassProbabilities[pos_class_label]
pos_class_label = eval_stats.binary_positive_label
probs = eval_stats.y_predicted_class_probabilities[pos_class_label]
is_positive_gt = [gtLabel == pos_class_label for gtLabel in eval_stats.y_true]
return cls.from_probability_threshold(probabilities=probs, threshold=threshold, is_positive_ground_truth=is_positive_gt)

Expand Down Expand Up @@ -605,20 +605,20 @@ def create_figure(self, eval_stats: ClassificationEvalStats, subtitle: str) -> p

class ClassificationEvalStatsPlotPrecisionRecall(ClassificationEvalStatsPlot):
def create_figure(self, eval_stats: ClassificationEvalStats, subtitle: str) -> Optional[plt.Figure]:
if not eval_stats.isBinary:
if not eval_stats.is_binary or not eval_stats.is_probabilities_available:
return None
return eval_stats.plot_precision_recall_curve(title_add=subtitle)


class ClassificationEvalStatsPlotProbabilityThresholdPrecisionRecall(ClassificationEvalStatsPlot):
def create_figure(self, eval_stats: ClassificationEvalStats, subtitle: str) -> Optional[plt.Figure]:
if not eval_stats.isBinary:
if not eval_stats.is_binary or not eval_stats.is_probabilities_available:
return None
return eval_stats.get_binary_classification_probability_threshold_variation_data().plot_precision_recall(subtitle=subtitle)


class ClassificationEvalStatsPlotProbabilityThresholdCounts(ClassificationEvalStatsPlot):
def create_figure(self, eval_stats: ClassificationEvalStats, subtitle: str) -> Optional[plt.Figure]:
if not eval_stats.isBinary:
if not eval_stats.is_binary or not eval_stats.is_probabilities_available:
return None
return eval_stats.get_binary_classification_probability_threshold_variation_data().plot_counts(subtitle=subtitle)
5 changes: 3 additions & 2 deletions src/sensai/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,10 @@ def __init__(self, data_splitter: DataSplitter = None, fractional_split_test_fra
:param fractional_split_shuffle: [if dataSplitter is None, test data must be obtained via split] whether to randomly (based on
randomSeed) shuffle the dataset before splitting it
:param additional_metrics: additional metrics to apply
:param compute_probabilities: whether to compute class probabilities
:param compute_probabilities: whether to compute class probabilities. Enabling this will enable many downstream computations
and visualisations (e.g. precision-recall plots) but requires the model to support probability computation in general.
:param binary_positive_label: the positive class label for binary classification; if GUESS, try to detect from labels;
if None, no detection (non-binary classification)
if None, no detection (assume non-binary classification)
"""
super().__init__(data_splitter,
fractional_split_test_fraction=fractional_split_test_fraction,
Expand Down

0 comments on commit dd237eb

Please sign in to comment.