Skip to content

Modernize more type definitions related to text support #52

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions eli5/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


@attrs
class Explanation(object):
class Explanation:
""" An explanation for classifier or regressor,
it can either explain weights or a single prediction.
"""
Expand Down Expand Up @@ -49,7 +49,7 @@ def _repr_html_(self):


@attrs
class FeatureImportances(object):
class FeatureImportances:
""" Feature importances with number of remaining non-zero features.
"""
def __init__(self, importances, remaining):
Expand All @@ -64,7 +64,7 @@ def from_names_values(cls, names, values, std=None, **kwargs):


@attrs
class TargetExplanation(object):
class TargetExplanation:
""" Explanation for a single target or class.
Feature weights are stored in the :feature_weights: attribute,
and features highlighted in text in the :weighted_spans: attribute.
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(self,


@attrs
class FeatureWeights(object):
class FeatureWeights:
""" Weights for top features, :pos: for positive and :neg: for negative,
sorted by descending absolute value.
Number of remaining positive and negative features are stored in
Expand All @@ -111,7 +111,7 @@ def __init__(self,


@attrs
class FeatureWeight(object):
class FeatureWeight:
def __init__(self, feature: Feature, weight: float, std: Optional[float] = None, value=None):
self.feature = feature
self.weight = weight
Expand All @@ -120,7 +120,7 @@ def __init__(self, feature: Feature, weight: float, std: Optional[float] = None,


@attrs
class WeightedSpans(object):
class WeightedSpans:
""" Holds highlighted spans for parts of document - a DocWeightedSpans
object for each vectorizer, and other features not highlighted anywhere.
"""
Expand All @@ -140,7 +140,7 @@ def __init__(self,


@attrs
class DocWeightedSpans(object):
class DocWeightedSpans:
""" Features highlighted in text. :document: is a pre-processed document
before applying the analyzer. :weighted_spans: holds a list of spans
for features found in text (span indices correspond to
Expand All @@ -161,15 +161,15 @@ def __init__(self,


@attrs
class TransitionFeatureWeights(object):
class TransitionFeatureWeights:
""" Weights matrix for transition features. """
def __init__(self, class_names: list[str], coef):
self.class_names = class_names
self.coef = coef


@attrs
class TreeInfo(object):
class TreeInfo:
""" Information about the decision tree. :criterion: is the name of
the function to measure the quality of a split, :tree: holds all nodes
of the tree, and :graphviz: is the tree rendered in graphviz .dot format.
Expand All @@ -182,7 +182,7 @@ def __init__(self, criterion: str, tree: 'NodeInfo', graphviz: str, is_classific


@attrs
class NodeInfo(object):
class NodeInfo:
""" A node in a binary tree.
Pointers to left and right children are in :left: and :right: attributes.
"""
Expand Down
98 changes: 37 additions & 61 deletions eli5/formatters/html.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from itertools import groupby
from typing import List, Optional, Tuple
from html import escape
from typing import Optional

import numpy as np
from jinja2 import Environment, PackageLoader
Expand Down Expand Up @@ -32,16 +31,15 @@
))


def format_as_html(explanation, # type: Explanation
include_styles=True, # type: bool
force_weights=True, # type: bool
def format_as_html(explanation: Explanation,
include_styles=True,
force_weights=True,
show=fields.ALL,
preserve_density=None, # type: Optional[bool]
highlight_spaces=None, # type: Optional[bool]
horizontal_layout=True, # type: bool
show_feature_values=False # type: bool
):
# type: (...) -> str
preserve_density: Optional[bool] = None,
highlight_spaces: Optional[bool] = None,
horizontal_layout=True,
show_feature_values=False,
) -> str:
""" Format explanation as html.
Most styles are inline, but some are included separately in <style> tag,
you can omit them by passing ``include_styles=False`` and call
Expand Down Expand Up @@ -130,42 +128,37 @@ def format_as_html(explanation, # type: Explanation
'''.replace('\n', ' ')


def format_html_styles():
# type: () -> str
def format_html_styles() -> str:
""" Format just the styles,
use with ``format_as_html(explanation, include_styles=False)``.
"""
return template_env.get_template('styles.html').render()


def render_targets_weighted_spans(
targets, # type: List[TargetExplanation]
preserve_density, # type: Optional[bool]
):
# type: (...) -> List[Optional[str]]
targets: list[TargetExplanation],
preserve_density: Optional[bool],
) -> list[Optional[str]]:
""" Return a list of rendered weighted spans for targets.
Function must accept a list in order to select consistent weight
ranges across all targets.
"""
prepared_weighted_spans = prepare_weighted_spans(
targets, preserve_density)

def _fmt_pws(pws):
# type: (PreparedWeightedSpans) -> str
def _fmt_pws(pws: PreparedWeightedSpans) -> str:
name = ('<b>{}:</b> '.format(pws.doc_weighted_spans.vec_name)
if pws.doc_weighted_spans.vec_name else '')
return '{}{}'.format(name, render_weighted_spans(pws))

def _fmt_pws_list(pws_lst):
# type: (List[PreparedWeightedSpans]) -> str
def _fmt_pws_list(pws_lst: list[PreparedWeightedSpans]) -> str:
return '<br/>'.join(_fmt_pws(pws) for pws in pws_lst)

return [_fmt_pws_list(pws_lst) if pws_lst else None
for pws_lst in prepared_weighted_spans]


def render_weighted_spans(pws):
# type: (PreparedWeightedSpans) -> str
def render_weighted_spans(pws: PreparedWeightedSpans) -> str:
# TODO - for longer documents, an option to remove text
# without active features
return ''.join(
Expand All @@ -177,11 +170,10 @@ def render_weighted_spans(pws):
key=lambda x: x[1]))


def _colorize(token, # type: str
weight, # type: float
weight_range, # type: float
):
# type: (...) -> str
def _colorize(token: str,
weight: float,
weight_range: float,
) -> str:
""" Return token wrapped in a span with some styles
(calculated from weight and weight_range) applied.
"""
Expand All @@ -208,8 +200,7 @@ def _colorize(token, # type: str
)


def _weight_opacity(weight, weight_range):
# type: (float, float) -> str
def _weight_opacity(weight: float, weight_range: float) -> str:
""" Return opacity value for given weight as a string.
"""
min_opacity = 0.8
Expand All @@ -220,11 +211,10 @@ def _weight_opacity(weight, weight_range):
return '{:.2f}'.format(min_opacity + (1 - min_opacity) * rel_weight)


_HSL_COLOR = Tuple[float, float, float]
_HSL_COLOR = tuple[float, float, float]


def weight_color_hsl(weight, weight_range, min_lightness=0.8):
# type: (float, float, float) -> _HSL_COLOR
def weight_color_hsl(weight: float, weight_range: float, min_lightness=0.8) -> _HSL_COLOR:
""" Return HSL color components for given weight,
where the max absolute weight is given by weight_range.
"""
Expand All @@ -235,21 +225,18 @@ def weight_color_hsl(weight, weight_range, min_lightness=0.8):
return hue, saturation, lightness


def format_hsl(hsl_color):
# type: (_HSL_COLOR) -> str
def format_hsl(hsl_color: _HSL_COLOR) -> str:
""" Format hsl color as css color string.
"""
hue, saturation, lightness = hsl_color
return 'hsl({}, {:.2%}, {:.2%})'.format(hue, saturation, lightness)


def _hue(weight):
# type: (float) -> float
def _hue(weight: float) -> float:
return 120 if weight > 0 else 0


def get_weight_range(weights):
# type: (FeatureWeights) -> float
def get_weight_range(weights: FeatureWeights) -> float:
""" Max absolute feature for pos and neg weights.
"""
return max_or_0(abs(fw.weight)
Expand All @@ -258,11 +245,10 @@ def get_weight_range(weights):


def remaining_weight_color_hsl(
ws, # type: List[FeatureWeight]
weight_range, # type: float
pos_neg, # type: str
):
# type: (...) -> _HSL_COLOR
ws: list[FeatureWeight],
weight_range: float,
pos_neg: str,
) -> _HSL_COLOR:
""" Color for "remaining" row.
Handles a number of edge cases: if there are no weights in ws or weight_range
is zero, assume the worst (most intensive positive or negative color).
Expand All @@ -278,8 +264,7 @@ def remaining_weight_color_hsl(
return weight_color_hsl(weight, weight_range)


def _format_unhashed_feature(feature, weight, hl_spaces):
# type: (...) -> str
def _format_unhashed_feature(feature, weight, hl_spaces) -> str:
""" Format unhashed feature: show first (most probable) candidate,
display other candidates in title attribute.
"""
Expand All @@ -295,8 +280,7 @@ def _format_unhashed_feature(feature, weight, hl_spaces):
return html


def _format_feature(feature, weight, hl_spaces):
# type: (...) -> str
def _format_feature(feature, weight, hl_spaces) -> str:
""" Format any feature.
"""
if isinstance(feature, FormattedFeatureName):
Expand All @@ -308,14 +292,12 @@ def _format_feature(feature, weight, hl_spaces):
return _format_single_feature(feature, weight, hl_spaces=hl_spaces)


def _format_single_feature(feature, weight, hl_spaces):
# type: (str, float, bool) -> str
def _format_single_feature(feature: str, weight: float, hl_spaces: bool) -> str:
feature = html_escape(feature)
if not hl_spaces:
return feature

def replacer(n_spaces, side):
# type: (int, str) -> str
def replacer(n_spaces: int, side: str) -> str:
m = '0.1em'
margins = {'left': (m, 0), 'right': (0, m), 'center': (m, m)}[side]
style = '; '.join([
Expand All @@ -331,18 +313,12 @@ def replacer(n_spaces, side):
return replace_spaces(feature, replacer)


def _format_decision_tree(treedict):
# type: (...) -> str
def _format_decision_tree(treedict) -> str:
if treedict.graphviz and _graphviz.is_supported():
return _graphviz.dot2svg(treedict.graphviz)
else:
return tree2text(treedict)


def html_escape(text):
# type: (str) -> str
try:
from html import escape
except ImportError:
from cgi import escape # type: ignore
def html_escape(text) -> str:
return escape(text, quote=True)
31 changes: 15 additions & 16 deletions eli5/formatters/text_helpers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from collections import Counter
from typing import List, Optional
from typing import Optional

import numpy as np

from eli5.base import TargetExplanation, WeightedSpans, DocWeightedSpans
from eli5.base import TargetExplanation, DocWeightedSpans
from eli5.base_utils import attrs
from eli5.utils import max_or_0


def get_char_weights(doc_weighted_spans, preserve_density=None):
# type: (DocWeightedSpans, Optional[bool]) -> np.ndarray
def get_char_weights(
doc_weighted_spans: DocWeightedSpans, preserve_density: Optional[bool] = None,
) -> np.ndarray:
""" Return character weights for a text document with highlighted features.
If preserve_density is True, then color for longer fragments will be
less intensive than for shorter fragments, so that "sum" of intensities
Expand All @@ -35,11 +36,10 @@ def get_char_weights(doc_weighted_spans, preserve_density=None):
@attrs
class PreparedWeightedSpans(object):
def __init__(self,
doc_weighted_spans, # type: DocWeightedSpans
char_weights, # type: np.ndarray
weight_range, # type: float
doc_weighted_spans: DocWeightedSpans,
char_weights: np.ndarray,
weight_range: float,
):
# type: (...) -> None
self.doc_weighted_spans = doc_weighted_spans
self.char_weights = char_weights
self.weight_range = weight_range
Expand All @@ -55,25 +55,24 @@ def __eq__(self, other):
return False


def prepare_weighted_spans(targets, # type: List[TargetExplanation]
preserve_density=None, # type: Optional[bool]
):
# type: (...) -> List[Optional[List[PreparedWeightedSpans]]]
def prepare_weighted_spans(targets: list[TargetExplanation],
preserve_density: Optional[bool] = None,
) -> list[Optional[list[PreparedWeightedSpans]]]:
""" Return weighted spans prepared for rendering.
Calculate a separate weight range for each different weighted
span (for each different index): each target has the same number
of weighted spans.
"""
targets_char_weights = [
targets_char_weights: list[Optional[list[np.ndarray]]] = [
[get_char_weights(ws, preserve_density=preserve_density)
for ws in t.weighted_spans.docs_weighted_spans]
if t.weighted_spans else None
for t in targets] # type: List[Optional[List[np.ndarray]]]
for t in targets]
max_idx = max_or_0(len(ch_w or []) for ch_w in targets_char_weights)

targets_char_weights_not_None = [
targets_char_weights_not_None: list[list[np.ndarray]] = [
cw for cw in targets_char_weights
if cw is not None] # type: List[List[np.ndarray]]
if cw is not None]

spans_weight_ranges = [
max_or_0(
Expand Down
Loading