From 89208216679aad2ea6a55fbab826fca6761e4c69 Mon Sep 17 00:00:00 2001 From: jrycw Date: Fri, 31 May 2024 19:13:34 +0800 Subject: [PATCH 1/2] Move `pairwise`, `seq_groups`, and `is_equal` functions to `_utils.py`. --- great_tables/_spanners.py | 53 ++---------------------------- great_tables/_utils.py | 53 +++++++++++++++++++++++++++++- great_tables/_utils_render_html.py | 6 ++-- tests/test_spanners.py | 37 --------------------- tests/test_utils.py | 36 ++++++++++++++++++++ 5 files changed, 93 insertions(+), 92 deletions(-) diff --git a/great_tables/_spanners.py b/great_tables/_spanners.py index 67fc433bb..6d0651b6f 100644 --- a/great_tables/_spanners.py +++ b/great_tables/_spanners.py @@ -1,8 +1,8 @@ from __future__ import annotations import itertools -from typing import TYPE_CHECKING, Any, Iterable -from collections.abc import Generator +from typing import TYPE_CHECKING + from ._gt_data import SpannerInfo, Spanners from ._locations import resolve_cols_c from ._tbl_data import SelectExpr @@ -560,55 +560,6 @@ def empty_spanner_matrix( return [{var: var for var in vars}], vars -def pairwise(iterable: Iterable[Any]) -> Generator[tuple[Any, Any], None, None]: - """ - https://docs.python.org/3/library/itertools.html#itertools.pairwise - pairwise('ABCDEFG') → AB BC CD DE EF FG - """ - # This function can be replaced by `itertools.pairwise` if we only plan to support - # Python 3.10+ in the future. - iterator = iter(iterable) - a = next(iterator, None) - for b in iterator: - yield a, b - a = b - - -def seq_groups(seq: Iterable[str]) -> Generator[tuple[str, int], None, None]: - iterator = iter(seq) - - # TODO: 0-length sequence - a = next(iterator) # will raise StopIteration if `seq` is empty - - try: - b = next(iterator) - except StopIteration: - yield a, 1 - return - - # We can confirm that we have two elements and both are not `None`, - # so we can chain them back together as the original seq. - seq = itertools.chain([a, b], iterator) - - crnt_ttl = 1 - for crnt_el, next_el in pairwise(seq): - if is_equal(crnt_el, next_el): - crnt_ttl += 1 - else: - yield crnt_el, crnt_ttl - crnt_ttl = 1 - - # final step has same elements, so we need to yield one last time - if is_equal(crnt_el, next_el): - yield crnt_el, crnt_ttl - else: - yield next_el, 1 - - -def is_equal(x: Any, y: Any) -> bool: - return x is not None and x == y - - def cols_width(data: GTSelf, cases: dict[str, str]) -> GTSelf: """Set the widths of columns. diff --git a/great_tables/_utils.py b/great_tables/_utils.py index 008302170..ad364445c 100644 --- a/great_tables/_utils.py +++ b/great_tables/_utils.py @@ -1,10 +1,12 @@ from __future__ import annotations import importlib +import itertools import json import re +from collections.abc import Generator from types import ModuleType -from typing import Any +from typing import Any, Iterable from ._tbl_data import PdDataFrame @@ -160,3 +162,52 @@ def _str_replace(string: str, pattern: str, replace: str) -> str: def _str_detect(string: str, pattern: str) -> bool: return bool(re.match(pattern, string)) + + +def pairwise(iterable: Iterable[Any]) -> Generator[tuple[Any, Any], None, None]: + """ + https://docs.python.org/3/library/itertools.html#itertools.pairwise + pairwise('ABCDEFG') → AB BC CD DE EF FG + """ + # This function can be replaced by `itertools.pairwise` if we only plan to support + # Python 3.10+ in the future. + iterator = iter(iterable) + a = next(iterator, None) + for b in iterator: + yield a, b + a = b + + +def seq_groups(seq: Iterable[str]) -> Generator[tuple[str, int], None, None]: + iterator = iter(seq) + + # TODO: 0-length sequence + a = next(iterator) # will raise StopIteration if `seq` is empty + + try: + b = next(iterator) + except StopIteration: + yield a, 1 + return + + # We can confirm that we have two elements and both are not `None`, + # so we can chain them back together as the original seq. + seq = itertools.chain([a, b], iterator) + + crnt_ttl = 1 + for crnt_el, next_el in pairwise(seq): + if is_equal(crnt_el, next_el): + crnt_ttl += 1 + else: + yield crnt_el, crnt_ttl + crnt_ttl = 1 + + # final step has same elements, so we need to yield one last time + if is_equal(crnt_el, next_el): + yield crnt_el, crnt_ttl + else: + yield next_el, 1 + + +def is_equal(x: Any, y: Any) -> bool: + return x is not None and x == y diff --git a/great_tables/_utils_render_html.py b/great_tables/_utils_render_html.py index 06fb54220..856d1658a 100644 --- a/great_tables/_utils_render_html.py +++ b/great_tables/_utils_render_html.py @@ -1,15 +1,15 @@ from __future__ import annotations -from itertools import chain, groupby +from itertools import chain from typing import Any, cast -from great_tables._spanners import seq_groups, spanners_print_matrix +from great_tables._spanners import spanners_print_matrix from htmltools import HTML, TagList, css, tags from ._gt_data import GTData from ._tbl_data import _get_cell, cast_frame_to_string, n_rows, replace_null_frame from ._text import StringBuilder, _process_text, _process_text_id -from ._utils import heading_has_subtitle, heading_has_title +from ._utils import heading_has_subtitle, heading_has_title, seq_groups def create_heading_component_h(data: GTData) -> StringBuilder: diff --git a/tests/test_spanners.py b/tests/test_spanners.py index b05784bf9..c33720278 100644 --- a/tests/test_spanners.py +++ b/tests/test_spanners.py @@ -1,5 +1,3 @@ -from collections.abc import Generator - import pandas as pd import polars as pl import polars.selectors as cs @@ -14,45 +12,10 @@ cols_move_to_start, empty_spanner_matrix, spanners_print_matrix, - seq_groups, tab_spanner, ) -@pytest.mark.parametrize( - "seq, grouped", - [ - ("a", [("a", 1)]), - ("abc", [("a", 1), ("b", 1), ("c", 1)]), - ("aabbcc", [("a", 2), ("b", 2), ("c", 2)]), - ("aabbccd", [("a", 2), ("b", 2), ("c", 2), ("d", 1)]), - (("a", "b", "c"), [("a", 1), ("b", 1), ("c", 1)]), - (("aa", "bb", "cc"), [("aa", 1), ("bb", 1), ("cc", 1)]), - (iter("xyyzzz"), [("x", 1), ("y", 2), ("z", 3)]), - ((i for i in "333221"), [("3", 3), ("2", 2), ("1", 1)]), - (["a", "a", "b", None, "c"], [("a", 2), ("b", 1), (None, 1), ("c", 1)]), - (["a", "a", "b", None, None, "c"], [("a", 2), ("b", 1), (None, 1), (None, 1), ("c", 1)]), - ([None, "a", "a", "b"], [(None, 1), ("a", 2), ("b", 1)]), - ([None, None, "a", "a", "b"], [(None, 1), (None, 1), ("a", 2), ("b", 1)]), - ([None, None, None, "a", "a", "b"], [(None, 1), (None, 1), (None, 1), ("a", 2), ("b", 1)]), - ([None, None, None], [(None, 1), (None, 1), (None, 1)]), - ], -) -def test_seq_groups(seq, grouped): - g = seq_groups(seq) - assert isinstance(g, Generator) - assert list(g) == grouped - - -def test_seq_groups_raises(): - """ - https://stackoverflow.com/questions/66566960/pytest-raises-does-not-catch-stopiteration-error - """ - with pytest.raises(RuntimeError) as exc_info: - next(seq_groups([])) - assert "StopIteration" in str(exc_info.value) - - @pytest.fixture def spanners() -> Spanners: return Spanners( diff --git a/tests/test_utils.py b/tests/test_utils.py index 0b6edd8d4..0c3ea4111 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,4 @@ +from collections.abc import Generator import pytest from great_tables._utils import ( _assert_list_is_subset, @@ -11,6 +12,7 @@ _unique_set, heading_has_subtitle, heading_has_title, + seq_groups, ) @@ -125,3 +127,37 @@ def test_collapse_list_elements(): def test_insert_into_list(): lst = ["b", "c"] assert _insert_into_list(lst, "a") == ["a", "b", "c"] + + +@pytest.mark.parametrize( + "seq, grouped", + [ + ("a", [("a", 1)]), + ("abc", [("a", 1), ("b", 1), ("c", 1)]), + ("aabbcc", [("a", 2), ("b", 2), ("c", 2)]), + ("aabbccd", [("a", 2), ("b", 2), ("c", 2), ("d", 1)]), + (("a", "b", "c"), [("a", 1), ("b", 1), ("c", 1)]), + (("aa", "bb", "cc"), [("aa", 1), ("bb", 1), ("cc", 1)]), + (iter("xyyzzz"), [("x", 1), ("y", 2), ("z", 3)]), + ((i for i in "333221"), [("3", 3), ("2", 2), ("1", 1)]), + (["a", "a", "b", None, "c"], [("a", 2), ("b", 1), (None, 1), ("c", 1)]), + (["a", "a", "b", None, None, "c"], [("a", 2), ("b", 1), (None, 1), (None, 1), ("c", 1)]), + ([None, "a", "a", "b"], [(None, 1), ("a", 2), ("b", 1)]), + ([None, None, "a", "a", "b"], [(None, 1), (None, 1), ("a", 2), ("b", 1)]), + ([None, None, None, "a", "a", "b"], [(None, 1), (None, 1), (None, 1), ("a", 2), ("b", 1)]), + ([None, None, None], [(None, 1), (None, 1), (None, 1)]), + ], +) +def test_seq_groups(seq, grouped): + g = seq_groups(seq) + assert isinstance(g, Generator) + assert list(g) == grouped + + +def test_seq_groups_raises(): + """ + https://stackoverflow.com/questions/66566960/pytest-raises-does-not-catch-stopiteration-error + """ + with pytest.raises(RuntimeError) as exc_info: + next(seq_groups([])) + assert "StopIteration" in str(exc_info.value) From fd19edaa895625b9cfe134c852173187b0fc1f15 Mon Sep 17 00:00:00 2001 From: jrycw Date: Fri, 31 May 2024 19:23:31 +0800 Subject: [PATCH 2/2] Update `GradientPalette._create_coefficients` to utilize `pairwise` --- great_tables/_data_color/palettes.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/great_tables/_data_color/palettes.py b/great_tables/_data_color/palettes.py index c49286f7a..4dcf2e31f 100644 --- a/great_tables/_data_color/palettes.py +++ b/great_tables/_data_color/palettes.py @@ -9,6 +9,8 @@ from math import isinf, isnan from typing import TypedDict +from great_tables._utils import pairwise + from .base import RGBColor, _hex_to_rgb, _html_color @@ -155,11 +157,10 @@ def _linspace_to_one(n_steps: int) -> list[float]: def _create_coefficients(self, cutoffs: list[float], channel: list[int]) -> CoeffSequence: """Return coefficients for interpolating between cutoffs on a color channel.""" - p_cutoffs = list(zip(cutoffs[:-1], cutoffs[1:])) - p_colors = list(zip(channel[:-1], channel[1:])) - coeffs: list[GradientCoefficients] = [] - for (prev_cutoff, crnt_cutoff), (prev_color, crnt_color) in zip(p_cutoffs, p_colors): + for (prev_cutoff, crnt_cutoff), (prev_color, crnt_color) in zip( + pairwise(cutoffs), pairwise(channel) + ): cutoff_diff = crnt_cutoff - prev_cutoff color_scalar = (crnt_color - prev_color) / cutoff_diff