-
Notifications
You must be signed in to change notification settings - Fork 613
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add metrics ROC AUC, f1, precision, recall.
- Loading branch information
Showing
4 changed files
with
829 additions
and
40 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from typing import Optional | ||
|
||
from evidently.metrics import ClassificationQualityByClass | ||
from evidently.v2.datasets import Dataset | ||
from evidently.v2.metrics import ByLabelValue | ||
from evidently.v2.metrics import Metric | ||
from evidently.v2.metrics.base import TResult | ||
from evidently.v2.report import Context | ||
|
||
|
||
class F1Metric(Metric[ByLabelValue]): | ||
def __init__(self, probas_threshold: Optional[float] = None, k: Optional[int] = None): | ||
super().__init__("f1") | ||
self.probas_threshold = probas_threshold | ||
self.k = k | ||
|
||
def calculate(self, current_data: Dataset, reference_data: Optional[Dataset]) -> TResult: | ||
raise ValueError() | ||
|
||
def _call(self, context: Context) -> ByLabelValue: | ||
result = context.get_legacy_metric(ClassificationQualityByClass(self.probas_threshold, self.k)) | ||
return ByLabelValue( | ||
{k: v.f1 for k, v in result.current.metrics.items()}, | ||
) | ||
|
||
def display_name(self) -> str: | ||
return "F1 metric" | ||
|
||
|
||
class PrecisionMetric(Metric[ByLabelValue]): | ||
def __init__(self, probas_threshold: Optional[float] = None, k: Optional[int] = None): | ||
super().__init__("precision") | ||
self.probas_threshold = probas_threshold | ||
self.k = k | ||
|
||
def calculate(self, current_data: Dataset, reference_data: Optional[Dataset]) -> TResult: | ||
raise ValueError() | ||
|
||
def _call(self, context: Context) -> ByLabelValue: | ||
result = context.get_legacy_metric(ClassificationQualityByClass(self.probas_threshold, self.k)) | ||
return ByLabelValue( | ||
{k: v.precision for k, v in result.current.metrics.items()}, | ||
) | ||
|
||
def display_name(self) -> str: | ||
return "Precision metric" | ||
|
||
|
||
class RecallMetric(Metric[ByLabelValue]): | ||
def __init__(self, probas_threshold: Optional[float] = None, k: Optional[int] = None): | ||
super().__init__("recall") | ||
self.probas_threshold = probas_threshold | ||
self.k = k | ||
|
||
def calculate(self, current_data: Dataset, reference_data: Optional[Dataset]) -> TResult: | ||
raise ValueError() | ||
|
||
def _call(self, context: Context) -> ByLabelValue: | ||
result = context.get_legacy_metric(ClassificationQualityByClass(self.probas_threshold, self.k)) | ||
|
||
return ByLabelValue( | ||
{k: v.recall for k, v in result.current.metrics.items()}, | ||
) | ||
|
||
def display_name(self) -> str: | ||
return "Recall metric" | ||
|
||
|
||
class RocAucMetric(Metric[ByLabelValue]): | ||
def __init__(self, probas_threshold: Optional[float] = None, k: Optional[int] = None): | ||
super().__init__("roc_auc") | ||
self.probas_threshold = probas_threshold | ||
self.k = k | ||
|
||
def calculate(self, current_data: Dataset, reference_data: Optional[Dataset]) -> TResult: | ||
raise ValueError() | ||
|
||
def _call(self, context: Context) -> ByLabelValue: | ||
result = context.get_legacy_metric(ClassificationQualityByClass(self.probas_threshold, self.k)) | ||
|
||
return ByLabelValue( | ||
{k: v.roc_auc for k, v in result.current.metrics.items()}, | ||
) | ||
|
||
def display_name(self) -> str: | ||
return "ROC AUC metric" |
This file was deleted.
Oops, something went wrong.