Skip to content

Commit

Permalink
Merge pull request #369 from jrycw/organize-imports
Browse files Browse the repository at this point in the history
Move `pairwise`, `seq_groups`, and `is_equal` functions to `_utils.py`
  • Loading branch information
rich-iannone authored Jun 4, 2024
2 parents 02a04ee + fd19eda commit 5a67bee
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 96 deletions.
9 changes: 5 additions & 4 deletions great_tables/_data_color/palettes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from math import isinf, isnan
from typing import TypedDict

from great_tables._utils import pairwise

from .base import RGBColor, _hex_to_rgb, _html_color


Expand Down Expand Up @@ -155,11 +157,10 @@ def _linspace_to_one(n_steps: int) -> list[float]:
def _create_coefficients(self, cutoffs: list[float], channel: list[int]) -> CoeffSequence:
"""Return coefficients for interpolating between cutoffs on a color channel."""

p_cutoffs = list(zip(cutoffs[:-1], cutoffs[1:]))
p_colors = list(zip(channel[:-1], channel[1:]))

coeffs: list[GradientCoefficients] = []
for (prev_cutoff, crnt_cutoff), (prev_color, crnt_color) in zip(p_cutoffs, p_colors):
for (prev_cutoff, crnt_cutoff), (prev_color, crnt_color) in zip(
pairwise(cutoffs), pairwise(channel)
):
cutoff_diff = crnt_cutoff - prev_cutoff
color_scalar = (crnt_color - prev_color) / cutoff_diff

Expand Down
53 changes: 2 additions & 51 deletions great_tables/_spanners.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import itertools
from typing import TYPE_CHECKING, Any, Iterable
from collections.abc import Generator
from typing import TYPE_CHECKING

from ._gt_data import SpannerInfo, Spanners
from ._locations import resolve_cols_c
from ._tbl_data import SelectExpr
Expand Down Expand Up @@ -560,55 +560,6 @@ def empty_spanner_matrix(
return [{var: var for var in vars}], vars


def pairwise(iterable: Iterable[Any]) -> Generator[tuple[Any, Any], None, None]:
"""
https://docs.python.org/3/library/itertools.html#itertools.pairwise
pairwise('ABCDEFG') → AB BC CD DE EF FG
"""
# This function can be replaced by `itertools.pairwise` if we only plan to support
# Python 3.10+ in the future.
iterator = iter(iterable)
a = next(iterator, None)
for b in iterator:
yield a, b
a = b


def seq_groups(seq: Iterable[str]) -> Generator[tuple[str, int], None, None]:
iterator = iter(seq)

# TODO: 0-length sequence
a = next(iterator) # will raise StopIteration if `seq` is empty

try:
b = next(iterator)
except StopIteration:
yield a, 1
return

# We can confirm that we have two elements and both are not `None`,
# so we can chain them back together as the original seq.
seq = itertools.chain([a, b], iterator)

crnt_ttl = 1
for crnt_el, next_el in pairwise(seq):
if is_equal(crnt_el, next_el):
crnt_ttl += 1
else:
yield crnt_el, crnt_ttl
crnt_ttl = 1

# final step has same elements, so we need to yield one last time
if is_equal(crnt_el, next_el):
yield crnt_el, crnt_ttl
else:
yield next_el, 1


def is_equal(x: Any, y: Any) -> bool:
return x is not None and x == y


def cols_width(data: GTSelf, cases: dict[str, str]) -> GTSelf:
"""Set the widths of columns.
Expand Down
53 changes: 52 additions & 1 deletion great_tables/_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import importlib
import itertools
import json
import re
from collections.abc import Generator
from types import ModuleType
from typing import Any
from typing import Any, Iterable

from ._tbl_data import PdDataFrame

Expand Down Expand Up @@ -160,3 +162,52 @@ def _str_replace(string: str, pattern: str, replace: str) -> str:

def _str_detect(string: str, pattern: str) -> bool:
return bool(re.match(pattern, string))


def pairwise(iterable: Iterable[Any]) -> Generator[tuple[Any, Any], None, None]:
"""
https://docs.python.org/3/library/itertools.html#itertools.pairwise
pairwise('ABCDEFG') → AB BC CD DE EF FG
"""
# This function can be replaced by `itertools.pairwise` if we only plan to support
# Python 3.10+ in the future.
iterator = iter(iterable)
a = next(iterator, None)
for b in iterator:
yield a, b
a = b


def seq_groups(seq: Iterable[str]) -> Generator[tuple[str, int], None, None]:
iterator = iter(seq)

# TODO: 0-length sequence
a = next(iterator) # will raise StopIteration if `seq` is empty

try:
b = next(iterator)
except StopIteration:
yield a, 1
return

# We can confirm that we have two elements and both are not `None`,
# so we can chain them back together as the original seq.
seq = itertools.chain([a, b], iterator)

crnt_ttl = 1
for crnt_el, next_el in pairwise(seq):
if is_equal(crnt_el, next_el):
crnt_ttl += 1
else:
yield crnt_el, crnt_ttl
crnt_ttl = 1

# final step has same elements, so we need to yield one last time
if is_equal(crnt_el, next_el):
yield crnt_el, crnt_ttl
else:
yield next_el, 1


def is_equal(x: Any, y: Any) -> bool:
return x is not None and x == y
6 changes: 3 additions & 3 deletions great_tables/_utils_render_html.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from __future__ import annotations

from itertools import chain, groupby
from itertools import chain
from typing import Any, cast

from great_tables._spanners import seq_groups, spanners_print_matrix
from great_tables._spanners import spanners_print_matrix
from htmltools import HTML, TagList, css, tags

from ._gt_data import GTData
from ._tbl_data import _get_cell, cast_frame_to_string, n_rows, replace_null_frame
from ._text import StringBuilder, _process_text, _process_text_id
from ._utils import heading_has_subtitle, heading_has_title
from ._utils import heading_has_subtitle, heading_has_title, seq_groups


def create_heading_component_h(data: GTData) -> StringBuilder:
Expand Down
37 changes: 0 additions & 37 deletions tests/test_spanners.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from collections.abc import Generator

import pandas as pd
import polars as pl
import polars.selectors as cs
Expand All @@ -14,45 +12,10 @@
cols_move_to_start,
empty_spanner_matrix,
spanners_print_matrix,
seq_groups,
tab_spanner,
)


@pytest.mark.parametrize(
"seq, grouped",
[
("a", [("a", 1)]),
("abc", [("a", 1), ("b", 1), ("c", 1)]),
("aabbcc", [("a", 2), ("b", 2), ("c", 2)]),
("aabbccd", [("a", 2), ("b", 2), ("c", 2), ("d", 1)]),
(("a", "b", "c"), [("a", 1), ("b", 1), ("c", 1)]),
(("aa", "bb", "cc"), [("aa", 1), ("bb", 1), ("cc", 1)]),
(iter("xyyzzz"), [("x", 1), ("y", 2), ("z", 3)]),
((i for i in "333221"), [("3", 3), ("2", 2), ("1", 1)]),
(["a", "a", "b", None, "c"], [("a", 2), ("b", 1), (None, 1), ("c", 1)]),
(["a", "a", "b", None, None, "c"], [("a", 2), ("b", 1), (None, 1), (None, 1), ("c", 1)]),
([None, "a", "a", "b"], [(None, 1), ("a", 2), ("b", 1)]),
([None, None, "a", "a", "b"], [(None, 1), (None, 1), ("a", 2), ("b", 1)]),
([None, None, None, "a", "a", "b"], [(None, 1), (None, 1), (None, 1), ("a", 2), ("b", 1)]),
([None, None, None], [(None, 1), (None, 1), (None, 1)]),
],
)
def test_seq_groups(seq, grouped):
g = seq_groups(seq)
assert isinstance(g, Generator)
assert list(g) == grouped


def test_seq_groups_raises():
"""
https://stackoverflow.com/questions/66566960/pytest-raises-does-not-catch-stopiteration-error
"""
with pytest.raises(RuntimeError) as exc_info:
next(seq_groups([]))
assert "StopIteration" in str(exc_info.value)


@pytest.fixture
def spanners() -> Spanners:
return Spanners(
Expand Down
36 changes: 36 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Generator
import pytest
from great_tables._utils import (
_assert_list_is_subset,
Expand All @@ -11,6 +12,7 @@
_unique_set,
heading_has_subtitle,
heading_has_title,
seq_groups,
)


Expand Down Expand Up @@ -125,3 +127,37 @@ def test_collapse_list_elements():
def test_insert_into_list():
lst = ["b", "c"]
assert _insert_into_list(lst, "a") == ["a", "b", "c"]


@pytest.mark.parametrize(
"seq, grouped",
[
("a", [("a", 1)]),
("abc", [("a", 1), ("b", 1), ("c", 1)]),
("aabbcc", [("a", 2), ("b", 2), ("c", 2)]),
("aabbccd", [("a", 2), ("b", 2), ("c", 2), ("d", 1)]),
(("a", "b", "c"), [("a", 1), ("b", 1), ("c", 1)]),
(("aa", "bb", "cc"), [("aa", 1), ("bb", 1), ("cc", 1)]),
(iter("xyyzzz"), [("x", 1), ("y", 2), ("z", 3)]),
((i for i in "333221"), [("3", 3), ("2", 2), ("1", 1)]),
(["a", "a", "b", None, "c"], [("a", 2), ("b", 1), (None, 1), ("c", 1)]),
(["a", "a", "b", None, None, "c"], [("a", 2), ("b", 1), (None, 1), (None, 1), ("c", 1)]),
([None, "a", "a", "b"], [(None, 1), ("a", 2), ("b", 1)]),
([None, None, "a", "a", "b"], [(None, 1), (None, 1), ("a", 2), ("b", 1)]),
([None, None, None, "a", "a", "b"], [(None, 1), (None, 1), (None, 1), ("a", 2), ("b", 1)]),
([None, None, None], [(None, 1), (None, 1), (None, 1)]),
],
)
def test_seq_groups(seq, grouped):
g = seq_groups(seq)
assert isinstance(g, Generator)
assert list(g) == grouped


def test_seq_groups_raises():
"""
https://stackoverflow.com/questions/66566960/pytest-raises-does-not-catch-stopiteration-error
"""
with pytest.raises(RuntimeError) as exc_info:
next(seq_groups([]))
assert "StopIteration" in str(exc_info.value)

0 comments on commit 5a67bee

Please sign in to comment.