Skip to content

Commit

Permalink
Add separate render for metrics and checks.
Browse files Browse the repository at this point in the history
  • Loading branch information
Liraim committed Dec 16, 2024
1 parent ccab7fc commit e35ef8e
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 49 deletions.
57 changes: 37 additions & 20 deletions examples/metric_workbench.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/evidently/model/widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class WidgetType(Enum):
COUNTER = "counter"
TABLE = "table"
BIG_TABLE = "big_table"
GROUP = "group"
BIG_GRAPH = "big_graph"
RICH_DATA = "rich_data"
TABBED_GRAPH = "tabbed_graph"
Expand Down
13 changes: 13 additions & 0 deletions src/evidently/renderers/html_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,3 +913,16 @@ def get_class_separation_plot_data_agg(

additional_plots.append((str(label), plotly_figure(title="", figure=fig)))
return additional_plots


def group_widget(
*,
title: str,
widgets: List[BaseWidgetInfo],
) -> BaseWidgetInfo:
return BaseWidgetInfo(
title=title,
type=WidgetType.GROUP.value,
widgets=widgets,
size=2,
)
15 changes: 9 additions & 6 deletions src/evidently/v2/checks/numerical_checks.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
from typing import Union

from ..metrics import Metric
from ..metrics.base import CheckResult
from ..metrics.base import SingleValue
from ..metrics.base import SingleValueCheck
from ..metrics.base import TestStatus


def le(threshold: Union[int, float]) -> SingleValueCheck:
def func(value: SingleValue):
def func(metric: Metric, value: SingleValue):
return CheckResult(
f"Less or Equal {threshold}",
"",
"le",
f"{metric.display_name()}: Less or Equal {threshold}",
f"Actual value {value.value} {'<' if value.value < threshold else '>='} {threshold}",
TestStatus.SUCCESS if value.value <= threshold else TestStatus.FAIL,
)

return func


def ge(threshold: Union[int, float]) -> SingleValueCheck:
def func(value: SingleValue):
def func(metric: Metric, value: SingleValue):
return CheckResult(
f"Greater or Equal {threshold}",
"",
"ge",
f"{metric.display_name()}: Greater or Equal {threshold}",
f"Actual value {value.value} {'<' if value.value < threshold else '>='} {threshold}",
TestStatus.SUCCESS if value.value >= threshold else TestStatus.FAIL,
)

Expand Down
36 changes: 25 additions & 11 deletions src/evidently/v2/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing
import uuid
from abc import abstractmethod
from copy import copy
from typing import Generic
from typing import List
from typing import Optional
Expand All @@ -30,12 +31,18 @@


class MetricResult:
_metric: Optional["Metric"] = None
_widget: Optional[List[BaseWidgetInfo]] = None
_checks: Optional[List["CheckResult"]] = None

def set_widget(self, widget: List[BaseWidgetInfo]):
self._widget = widget
def set_checks(self, checks: List["CheckResult"]):
self._checks = checks

def _repr_html_(self):
assert self._widget
widget = copy(self._widget)
if self._checks:
widget.append(checks_widget(self))
return render_results(self, html=False)

def is_widget_set(self) -> bool:
Expand All @@ -49,6 +56,10 @@ def widget(self) -> List[BaseWidgetInfo]:
def widget(self, value: List[BaseWidgetInfo]):
self._widget = value

@property
def checks(self) -> List["CheckResult"]:
return self._checks


def render_widgets(widgets: List[BaseWidgetInfo]):
dashboard_id, dashboard_info, graphs = (
Expand Down Expand Up @@ -81,9 +92,12 @@ def render_results(results: Union[MetricResult, List[MetricResult]], html=True):

MetricReturnValue = Tuple[TResult, BaseWidgetInfo]

CheckId = str


@dataclasses.dataclass
class CheckResult:
id: CheckId
name: str
description: str
status: TestStatus
Expand All @@ -95,17 +109,17 @@ class SingleValue(MetricResult):


class Check(Protocol[TResult]):
def __call__(self, value: TResult) -> CheckResult: ...
def __call__(self, metric: "Metric", value: TResult) -> CheckResult: ...


class SingleValueCheck(Check[TResult], Protocol):
def __call__(self, value: SingleValue) -> CheckResult: ...
def __call__(self, metric: "Metric", value: SingleValue) -> CheckResult: ...


MetricId = str


def checks_widget(metric: "Metric", result: TResult) -> BaseWidgetInfo:
def checks_widget(result: TResult) -> BaseWidgetInfo:
return BaseWidgetInfo(
title="",
size=2,
Expand All @@ -118,17 +132,17 @@ def checks_widget(metric: "Metric", result: TResult) -> BaseWidgetInfo:
state=check.status.value.lower(),
groups=[],
)
for idx, check in enumerate([check(result) for check in metric.checks()])
for idx, check in enumerate(result.checks)
],
},
)


def get_default_render(metric: "Metric", result: TResult) -> List[BaseWidgetInfo]:
def get_default_render(title: str, result: TResult) -> List[BaseWidgetInfo]:
if isinstance(result, SingleValue):
return [
counter(
title=metric.display_name(),
title=title,
size=WidgetSize.FULL,
counters=[CounterData(label="", value=result.value)],
),
Expand Down Expand Up @@ -160,9 +174,9 @@ def call(self, context: "Context") -> TResult:
"""
result = self.calculate(*context._input_data)
if not result.is_widget_set():
result.widget = get_default_render(self, result)
result.widget = get_default_render(self.display_name(), result)
if self._checks and len(self._checks) > 0:
result.widget.append(checks_widget(self, result))
result.set_checks([check(self, result) for check in self._checks])
return result

@abc.abstractmethod
Expand Down Expand Up @@ -195,7 +209,7 @@ def display_name(self) -> str:
raise NotImplementedError()

def checks(self) -> List[Check]:
return self._checks
return self._checks or []

def group_by(self, group_by: Optional[str]) -> Union["Metric", List["Metric"]]:
if group_by is None:
Expand Down
69 changes: 57 additions & 12 deletions src/evidently/v2/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TypeVar
from typing import Union
Expand All @@ -13,6 +14,7 @@
from .metrics import MetricPreset
from .metrics import MetricResult
from .metrics.base import MetricId
from .metrics.base import checks_widget
from .metrics.base import render_widgets

TResultType = TypeVar("TResultType", bound="MetricResult")
Expand Down Expand Up @@ -42,6 +44,7 @@ class Context:

def __init__(self):
self._metrics = {}
self._metric_defs = {}
self._configuration = None
self._data_columns = {}
self._metrics_graph = {}
Expand All @@ -59,14 +62,22 @@ def column(self, column_name: str) -> ContextColumnData:

def calculate_metric(self, metric: Metric[TResultType]) -> TResultType:
if metric.id not in self._current_graph_level:
self._current_graph_level[metric.id] = {}
self._current_graph_level[metric.id] = {"_self": metric}
prev_level = self._current_graph_level
self._current_graph_level = prev_level[metric.id]
if metric.id not in self._metrics:
self._metrics[metric.id] = metric.call(self)
self._current_graph_level = prev_level
return self._metrics[metric.id]

def get_metric_result(self, metric: Union[MetricId, Metric[TResultType]]) -> TResultType:
if isinstance(metric, MetricId):
return self._metrics[metric]
return self.calculate_metric(metric)

def get_metric(self, metric: MetricId) -> Metric[TResultType]:
return self._metrics_graph[metric]["_self"]


class Snapshot:
_report: "Report"
Expand All @@ -76,29 +87,63 @@ def __init__(self, report: "Report"):
self._report = report
self._context = Context()

@property
def context(self) -> Context:
return self._context

@property
def report(self) -> "Report":
return self._report

def run(self, current_data: Dataset, reference_data: Optional[Dataset]):
self._context.init_dataset(current_data, reference_data)
for metric in self._report._metrics:
if isinstance(metric, (MetricPreset,)):
for metric in metric.metrics():
self._context.calculate_metric(metric)
elif isinstance(metric, (MetricContainer,)):
for metric in metric.metrics(self._context):
self._context.calculate_metric(metric)
self.context.init_dataset(current_data, reference_data)
for item in self.report.items():
if isinstance(item, (MetricPreset,)):
for metric in item.metrics():
self.context.calculate_metric(metric)
elif isinstance(item, (MetricContainer,)):
for metric in item.metrics(self.context):
self.context.calculate_metric(metric)
else:
self._context.calculate_metric(metric)
self.context.calculate_metric(item)

def _repr_html_(self):
from evidently.renderers.html_widgets import TabData
from evidently.renderers.html_widgets import group_widget
from evidently.renderers.html_widgets import widget_tabs

results = [
(
metric,
self._context.get_metric_result(metric).widget,
checks_widget(self.context.get_metric_result(metric))
if self.context.get_metric_result(metric).checks
else None,
)
for metric in self.context._metrics_graph.keys()
]
tabs = widget_tabs(
title="tabs",
tabs=[
TabData("Metrics", group_widget(title="", widgets=list(chain(*[result[1] for result in results])))),
TabData(
"Checks", group_widget(title="", widgets=[result[2] for result in results if result[2] is not None])
),
],
)
return render_widgets(
list(chain(*[self._context._metrics[metric].widget for metric in self._context._metrics_graph.keys()]))
[tabs],
)


class Report:
def __init__(self, metrics: List[Union[Metric, MetricPreset]]):
def __init__(self, metrics: List[Union[Metric, MetricPreset, MetricContainer]]):
self._metrics = metrics

def run(self, current_data: Dataset, reference_data: Optional[Dataset]) -> Snapshot:
snapshot = Snapshot(self)
snapshot.run(current_data, reference_data)
return snapshot

def items(self) -> Sequence[Union[Metric, MetricPreset, MetricContainer]]:
return self._metrics

0 comments on commit e35ef8e

Please sign in to comment.