Skip to content

Commit

Permalink
Add metrics ROC AUC, f1, precision, recall.
Browse files Browse the repository at this point in the history
  • Loading branch information
Liraim committed Dec 22, 2024
1 parent cca2ad3 commit 863f12a
Show file tree
Hide file tree
Showing 4 changed files with 829 additions and 40 deletions.
752 changes: 739 additions & 13 deletions examples/list_metrics.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def calculate(self, data: InputData) -> ClassificationQualityByClassResult:
current_roc_aucs = sklearn.metrics.roc_auc_score(
binaraized_target, prediction.prediction_probas, average=None
).tolist()
for idx, item in enumerate(list(prediction.prediction_probas.columns)):
metrics_matrix[item].roc_auc = current_roc_aucs[idx]
reference_roc_aucs = None

reference = None
Expand All @@ -115,6 +117,8 @@ def calculate(self, data: InputData) -> ClassificationQualityByClassResult:
reference_roc_aucs = sklearn.metrics.roc_auc_score(
binaraized_target, ref_prediction.prediction_probas, average=None
).tolist()
for idx, item in enumerate(list(prediction.prediction_probas.columns)):
ref_metrics[item].roc_auc = reference_roc_aucs[idx]
reference = ClassificationQuality(metrics=ref_metrics, roc_aucs=reference_roc_aucs)
return ClassificationQualityByClassResult(
columns=columns,
Expand Down
86 changes: 86 additions & 0 deletions src/evidently/v2/metrics/data_quality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Optional

from evidently.metrics import ClassificationQualityByClass
from evidently.v2.datasets import Dataset
from evidently.v2.metrics import ByLabelValue
from evidently.v2.metrics import Metric
from evidently.v2.metrics.base import TResult
from evidently.v2.report import Context


class F1Metric(Metric[ByLabelValue]):
def __init__(self, probas_threshold: Optional[float] = None, k: Optional[int] = None):
super().__init__("f1")
self.probas_threshold = probas_threshold
self.k = k

def calculate(self, current_data: Dataset, reference_data: Optional[Dataset]) -> TResult:
raise ValueError()

def _call(self, context: Context) -> ByLabelValue:
result = context.get_legacy_metric(ClassificationQualityByClass(self.probas_threshold, self.k))
return ByLabelValue(
{k: v.f1 for k, v in result.current.metrics.items()},
)

def display_name(self) -> str:
return "F1 metric"


class PrecisionMetric(Metric[ByLabelValue]):
def __init__(self, probas_threshold: Optional[float] = None, k: Optional[int] = None):
super().__init__("precision")
self.probas_threshold = probas_threshold
self.k = k

def calculate(self, current_data: Dataset, reference_data: Optional[Dataset]) -> TResult:
raise ValueError()

def _call(self, context: Context) -> ByLabelValue:
result = context.get_legacy_metric(ClassificationQualityByClass(self.probas_threshold, self.k))
return ByLabelValue(
{k: v.precision for k, v in result.current.metrics.items()},
)

def display_name(self) -> str:
return "Precision metric"


class RecallMetric(Metric[ByLabelValue]):
def __init__(self, probas_threshold: Optional[float] = None, k: Optional[int] = None):
super().__init__("recall")
self.probas_threshold = probas_threshold
self.k = k

def calculate(self, current_data: Dataset, reference_data: Optional[Dataset]) -> TResult:
raise ValueError()

def _call(self, context: Context) -> ByLabelValue:
result = context.get_legacy_metric(ClassificationQualityByClass(self.probas_threshold, self.k))

return ByLabelValue(
{k: v.recall for k, v in result.current.metrics.items()},
)

def display_name(self) -> str:
return "Recall metric"


class RocAucMetric(Metric[ByLabelValue]):
def __init__(self, probas_threshold: Optional[float] = None, k: Optional[int] = None):
super().__init__("roc_auc")
self.probas_threshold = probas_threshold
self.k = k

def calculate(self, current_data: Dataset, reference_data: Optional[Dataset]) -> TResult:
raise ValueError()

def _call(self, context: Context) -> ByLabelValue:
result = context.get_legacy_metric(ClassificationQualityByClass(self.probas_threshold, self.k))

return ByLabelValue(
{k: v.roc_auc for k, v in result.current.metrics.items()},
)

def display_name(self) -> str:
return "ROC AUC metric"
27 changes: 0 additions & 27 deletions src/evidently/v2/metrics/f1.py

This file was deleted.

0 comments on commit 863f12a

Please sign in to comment.