From c24109bfd0da28d2e071df4e502bb5c9303ad1d0 Mon Sep 17 00:00:00 2001 From: jrycw Date: Sun, 26 May 2024 02:33:37 +0800 Subject: [PATCH 1/4] Support specifying a subset of rows in `GT.data_color()` --- great_tables/_data_color/base.py | 56 ++++++++++------- .../__snapshots__/test_data_color.ambr | 60 +++++++++++++++++++ tests/data_color/test_data_color.py | 22 +++++++ 3 files changed, 116 insertions(+), 22 deletions(-) diff --git a/great_tables/_data_color/base.py b/great_tables/_data_color/base.py index 85ca8e7c7..20c85a959 100644 --- a/great_tables/_data_color/base.py +++ b/great_tables/_data_color/base.py @@ -3,7 +3,8 @@ from typing import TYPE_CHECKING import numpy as np -from great_tables._tbl_data import DataFrameLike, is_na +from great_tables._locations import resolve_cols_c, resolve_rows_i, RowSelectExpr +from great_tables._tbl_data import DataFrameLike, is_na, SelectExpr from great_tables.loc import body from great_tables.style import fill, text from typing_extensions import TypeAlias @@ -19,7 +20,8 @@ def data_color( self: GTSelf, - columns: str | list[str] | None = None, + columns: SelectExpr = None, + rows: RowSelectExpr = None, palette: str | list[str] | None = None, domain: list[str] | list[int] | list[float] | None = None, na_color: str | None = None, @@ -47,6 +49,10 @@ def data_color( columns The columns to target. Can either be a single column name or a series of column names provided in a list. + rows + In conjunction with `columns=`, we can specify which rows should be colored. By default, + all rows in the targeted columns will be colored. Alternatively, we can provide a list + of row indices. palette The color palette to use. This should be a list of colors (e.g., `["#FF0000", "#00FF00", "#0000FF"]`). A ColorBrewer palette could also be used, just supply the name (reference @@ -202,18 +208,20 @@ def data_color( # get a list of all columns in the table body columns_resolved: list[str] - if isinstance(columns, str): - columns_resolved = [columns] - elif columns is None: + if columns is None: columns_resolved = data_table.columns else: - columns_resolved = columns + columns_resolved = resolve_cols_c(data=self, expr=columns) + + row_res = resolve_rows_i(self, rows) + row_pos = [name_pos[1] for name_pos in row_res] gt_obj = self # For each column targeted, get the data values as a new list object for col in columns_resolved: - column_vals = data_table[col].to_list() + # This line handles both pandas and polars dataframes + column_vals = data_table[col][row_pos].to_list() # Filter out NA values from `column_vals` filtered_column_vals = [x for x in column_vals if not is_na(data_table, x)] @@ -258,21 +266,25 @@ def data_color( # Replace 'None' and 'np.nan' values in `color_vals` with the `na_color=` color color_vals = [na_color if is_na(data_table, x) else x for x in color_vals] - # for every color value in color_vals, apply a fill to the corresponding cell - # by using `tab_style()` - for i, color_val in enumerate(color_vals): - if autocolor_text: - fgnd_color = _ideal_fgnd_color(bgnd_color=color_val) - - gt_obj = gt_obj.tab_style( - style=[text(color=fgnd_color), fill(color=color_val)], - locations=body(columns=col, rows=[i]), - ) - - else: - gt_obj = gt_obj.tab_style( - style=fill(color=color_val), locations=body(columns=col, rows=[i]) - ) + # For each row, we check if the row index is selected, and then apply a fill to the + # corresponding cell using `tab_style()` + n_rows = len(resolve_rows_i(self)) + iter_color_vals = iter(color_vals) + for i in range(n_rows): + if i in row_pos: + color_val = next(iter_color_vals) + if autocolor_text: + fgnd_color = _ideal_fgnd_color(bgnd_color=color_val) + + gt_obj = gt_obj.tab_style( + style=[text(color=fgnd_color), fill(color=color_val)], + locations=body(columns=col, rows=[i]), + ) + + else: + gt_obj = gt_obj.tab_style( + style=fill(color=color_val), locations=body(columns=col, rows=[i]) + ) return gt_obj diff --git a/tests/data_color/__snapshots__/test_data_color.ambr b/tests/data_color/__snapshots__/test_data_color.ambr index df6e233ab..620ada1a0 100644 --- a/tests/data_color/__snapshots__/test_data_color.ambr +++ b/tests/data_color/__snapshots__/test_data_color.ambr @@ -362,6 +362,66 @@ ''' # --- +# name: test_data_color_pd_cols_rows_snap + ''' + + + 1 + 51 + + + 2 + 52 + + + 3 + 53 + + + 4 + 54 + + + 5 + 55 + + + 200 + 200 + + + ''' +# --- +# name: test_data_color_pl_cols_rows_snap + ''' + + + 1 + 51 + + + 2 + 52 + + + 3 + 53 + + + 4 + 54 + + + 5 + 55 + + + 200 + 200 + + + ''' +# --- # name: test_data_color_simple_df_snap ''' diff --git a/tests/data_color/test_data_color.py b/tests/data_color/test_data_color.py index 186deaa5c..eb01334ce 100644 --- a/tests/data_color/test_data_color.py +++ b/tests/data_color/test_data_color.py @@ -55,6 +55,28 @@ def test_data_color_simple_exibble_snap(snapshot: str, df: DataFrameLike): assert_rendered_body(snapshot, gt) +def test_data_color_pd_cols_rows_snap(snapshot: str): + df = pd.DataFrame({"a": [1, 2, 3, 4, 5, 200], "b": [51, 52, 53, 54, 55, 200]}) + new_gt = GT(df).data_color(columns=["a"], rows=[0, 1, 2, 3, 4]) + assert_rendered_body(snapshot, new_gt) + new_gt2 = GT(df).data_color(columns=["a"], rows=lambda df_: df_["a"].lt(60)) + assert create_body_component_h(new_gt._build_data("html")) == create_body_component_h( + new_gt2._build_data("html") + ) + + +def test_data_color_pl_cols_rows_snap(snapshot: str): + import polars.selectors as cs + + df = pl.DataFrame({"a": [1, 2, 3, 4, 5, 200], "b": [51, 52, 53, 54, 55, 200]}) + new_gt = GT(df).data_color(columns=["b"], rows=[0, 1, 2, 3, 4]) + assert_rendered_body(snapshot, new_gt) + new_gt2 = GT(df).data_color(columns=cs.starts_with("b"), rows=pl.col("b").lt(60)) + assert create_body_component_h(new_gt._build_data("html")) == create_body_component_h( + new_gt2._build_data("html") + ) + + @pytest.mark.parametrize("none_val", [None, np.nan, float("nan"), pd.NA]) @pytest.mark.parametrize("df_cls", [pd.DataFrame, pl.DataFrame]) def test_data_color_missing_value(df_cls, none_val): From 43b0ce5c6e5edc1280bcf13b6be8e37a615fc37f Mon Sep 17 00:00:00 2001 From: jrycw Date: Thu, 30 May 2024 06:45:20 +0800 Subject: [PATCH 2/4] Code review change --- great_tables/_data_color/base.py | 34 ++++++++++++++------------------ 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/great_tables/_data_color/base.py b/great_tables/_data_color/base.py index 20c85a959..ffc25e45a 100644 --- a/great_tables/_data_color/base.py +++ b/great_tables/_data_color/base.py @@ -266,25 +266,21 @@ def data_color( # Replace 'None' and 'np.nan' values in `color_vals` with the `na_color=` color color_vals = [na_color if is_na(data_table, x) else x for x in color_vals] - # For each row, we check if the row index is selected, and then apply a fill to the - # corresponding cell using `tab_style()` - n_rows = len(resolve_rows_i(self)) - iter_color_vals = iter(color_vals) - for i in range(n_rows): - if i in row_pos: - color_val = next(iter_color_vals) - if autocolor_text: - fgnd_color = _ideal_fgnd_color(bgnd_color=color_val) - - gt_obj = gt_obj.tab_style( - style=[text(color=fgnd_color), fill(color=color_val)], - locations=body(columns=col, rows=[i]), - ) - - else: - gt_obj = gt_obj.tab_style( - style=fill(color=color_val), locations=body(columns=col, rows=[i]) - ) + # for every color value in color_vals, apply a fill to the corresponding cell + # by using `tab_style()` + for i, color_val in zip(row_pos, color_vals): + if autocolor_text: + fgnd_color = _ideal_fgnd_color(bgnd_color=color_val) + + gt_obj = gt_obj.tab_style( + style=[text(color=fgnd_color), fill(color=color_val)], + locations=body(columns=col, rows=[i]), + ) + + else: + gt_obj = gt_obj.tab_style( + style=fill(color=color_val), locations=body(columns=col, rows=[i]) + ) return gt_obj From a683850ff9ab19d96739287db76c9d7e4b49445a Mon Sep 17 00:00:00 2001 From: jrycw Date: Thu, 30 May 2024 16:33:13 +0800 Subject: [PATCH 3/4] Alternative logic for `GT.data_color()` --- great_tables/_data_color/base.py | 18 +++-- .../__snapshots__/test_data_color.ambr | 72 +++++++++++++++++-- tests/data_color/test_data_color.py | 32 +++++---- 3 files changed, 95 insertions(+), 27 deletions(-) diff --git a/great_tables/_data_color/base.py b/great_tables/_data_color/base.py index ffc25e45a..c52895317 100644 --- a/great_tables/_data_color/base.py +++ b/great_tables/_data_color/base.py @@ -21,7 +21,7 @@ def data_color( self: GTSelf, columns: SelectExpr = None, - rows: RowSelectExpr = None, + include_last_row: bool = True, palette: str | list[str] | None = None, domain: list[str] | list[int] | list[float] | None = None, na_color: str | None = None, @@ -49,10 +49,10 @@ def data_color( columns The columns to target. Can either be a single column name or a series of column names provided in a list. - rows - In conjunction with `columns=`, we can specify which rows should be colored. By default, - all rows in the targeted columns will be colored. Alternatively, we can provide a list - of row indices. + include_last_row + Whether to include the last row for color styling. The default is set to `True`, but it + can be useful to set it to `False` if the last row is a summary or subtotal row of the + table. palette The color palette to use. This should be a list of colors (e.g., `["#FF0000", "#00FF00", "#0000FF"]`). A ColorBrewer palette could also be used, just supply the name (reference @@ -213,8 +213,12 @@ def data_color( else: columns_resolved = resolve_cols_c(data=self, expr=columns) - row_res = resolve_rows_i(self, rows) + row_res = resolve_rows_i(self) # get all rows row_pos = [name_pos[1] for name_pos in row_res] + if not include_last_row: + # Excluding the last row for now; + # adjustments may be needed after implementing `summary_rows()`. + row_pos = row_pos[:-1] gt_obj = self @@ -268,7 +272,7 @@ def data_color( # for every color value in color_vals, apply a fill to the corresponding cell # by using `tab_style()` - for i, color_val in zip(row_pos, color_vals): + for i, color_val in enumerate(color_vals): if autocolor_text: fgnd_color = _ideal_fgnd_color(bgnd_color=color_val) diff --git a/tests/data_color/__snapshots__/test_data_color.ambr b/tests/data_color/__snapshots__/test_data_color.ambr index 620ada1a0..026c76d3a 100644 --- a/tests/data_color/__snapshots__/test_data_color.ambr +++ b/tests/data_color/__snapshots__/test_data_color.ambr @@ -362,7 +362,7 @@ ''' # --- -# name: test_data_color_pd_cols_rows_snap +# name: test_data_color_pd_cols_rows_snap[False] ''' @@ -386,13 +386,43 @@ 55 - 200 - 200 + 15 + 265 ''' # --- -# name: test_data_color_pl_cols_rows_snap +# name: test_data_color_pd_cols_rows_snap[True] + ''' + + + 1 + 51 + + + 2 + 52 + + + 3 + 53 + + + 4 + 54 + + + 5 + 55 + + + 15 + 265 + + + ''' +# --- +# name: test_data_color_pl_cols_rows_snap[False] ''' @@ -416,8 +446,38 @@ 55 - 200 - 200 + 15 + 265 + + + ''' +# --- +# name: test_data_color_pl_cols_rows_snap[True] + ''' + + + 1 + 51 + + + 2 + 52 + + + 3 + 53 + + + 4 + 54 + + + 5 + 55 + + + 15 + 265 ''' diff --git a/tests/data_color/test_data_color.py b/tests/data_color/test_data_color.py index eb01334ce..237e8aa40 100644 --- a/tests/data_color/test_data_color.py +++ b/tests/data_color/test_data_color.py @@ -55,26 +55,30 @@ def test_data_color_simple_exibble_snap(snapshot: str, df: DataFrameLike): assert_rendered_body(snapshot, gt) -def test_data_color_pd_cols_rows_snap(snapshot: str): - df = pd.DataFrame({"a": [1, 2, 3, 4, 5, 200], "b": [51, 52, 53, 54, 55, 200]}) - new_gt = GT(df).data_color(columns=["a"], rows=[0, 1, 2, 3, 4]) - assert_rendered_body(snapshot, new_gt) - new_gt2 = GT(df).data_color(columns=["a"], rows=lambda df_: df_["a"].lt(60)) - assert create_body_component_h(new_gt._build_data("html")) == create_body_component_h( - new_gt2._build_data("html") +@pytest.mark.parametrize("include_last_row", [True, False]) +def test_data_color_pd_cols_rows_snap(snapshot: str, include_last_row: bool): + df = pd.DataFrame( + { + "a": [1, 2, 3, 4, 5, sum([1, 2, 3, 4, 5])], + "b": [51, 52, 53, 54, 55, sum([51, 52, 53, 54, 55])], + } ) + new_gt = GT(df).data_color(columns=["a"], include_last_row=include_last_row) + assert_rendered_body(snapshot, new_gt) -def test_data_color_pl_cols_rows_snap(snapshot: str): +@pytest.mark.parametrize("include_last_row", [True, False]) +def test_data_color_pl_cols_rows_snap(snapshot: str, include_last_row: bool): import polars.selectors as cs - df = pl.DataFrame({"a": [1, 2, 3, 4, 5, 200], "b": [51, 52, 53, 54, 55, 200]}) - new_gt = GT(df).data_color(columns=["b"], rows=[0, 1, 2, 3, 4]) - assert_rendered_body(snapshot, new_gt) - new_gt2 = GT(df).data_color(columns=cs.starts_with("b"), rows=pl.col("b").lt(60)) - assert create_body_component_h(new_gt._build_data("html")) == create_body_component_h( - new_gt2._build_data("html") + df = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5, sum([1, 2, 3, 4, 5])], + "b": [51, 52, 53, 54, 55, sum([51, 52, 53, 54, 55])], + } ) + new_gt = GT(df).data_color(columns=cs.by_name("b"), include_last_row=include_last_row) + assert_rendered_body(snapshot, new_gt) @pytest.mark.parametrize("none_val", [None, np.nan, float("nan"), pd.NA]) From 610542e884db1d1b71b352aa1371af64ebf62100 Mon Sep 17 00:00:00 2001 From: jrycw Date: Sat, 1 Jun 2024 00:12:05 +0800 Subject: [PATCH 4/4] Revert "Alternative logic for `GT.data_color()`" This reverts commit a683850ff9ab19d96739287db76c9d7e4b49445a. --- great_tables/_data_color/base.py | 18 ++--- .../__snapshots__/test_data_color.ambr | 72 ++----------------- tests/data_color/test_data_color.py | 32 ++++----- 3 files changed, 27 insertions(+), 95 deletions(-) diff --git a/great_tables/_data_color/base.py b/great_tables/_data_color/base.py index c52895317..ffc25e45a 100644 --- a/great_tables/_data_color/base.py +++ b/great_tables/_data_color/base.py @@ -21,7 +21,7 @@ def data_color( self: GTSelf, columns: SelectExpr = None, - include_last_row: bool = True, + rows: RowSelectExpr = None, palette: str | list[str] | None = None, domain: list[str] | list[int] | list[float] | None = None, na_color: str | None = None, @@ -49,10 +49,10 @@ def data_color( columns The columns to target. Can either be a single column name or a series of column names provided in a list. - include_last_row - Whether to include the last row for color styling. The default is set to `True`, but it - can be useful to set it to `False` if the last row is a summary or subtotal row of the - table. + rows + In conjunction with `columns=`, we can specify which rows should be colored. By default, + all rows in the targeted columns will be colored. Alternatively, we can provide a list + of row indices. palette The color palette to use. This should be a list of colors (e.g., `["#FF0000", "#00FF00", "#0000FF"]`). A ColorBrewer palette could also be used, just supply the name (reference @@ -213,12 +213,8 @@ def data_color( else: columns_resolved = resolve_cols_c(data=self, expr=columns) - row_res = resolve_rows_i(self) # get all rows + row_res = resolve_rows_i(self, rows) row_pos = [name_pos[1] for name_pos in row_res] - if not include_last_row: - # Excluding the last row for now; - # adjustments may be needed after implementing `summary_rows()`. - row_pos = row_pos[:-1] gt_obj = self @@ -272,7 +268,7 @@ def data_color( # for every color value in color_vals, apply a fill to the corresponding cell # by using `tab_style()` - for i, color_val in enumerate(color_vals): + for i, color_val in zip(row_pos, color_vals): if autocolor_text: fgnd_color = _ideal_fgnd_color(bgnd_color=color_val) diff --git a/tests/data_color/__snapshots__/test_data_color.ambr b/tests/data_color/__snapshots__/test_data_color.ambr index 026c76d3a..620ada1a0 100644 --- a/tests/data_color/__snapshots__/test_data_color.ambr +++ b/tests/data_color/__snapshots__/test_data_color.ambr @@ -362,7 +362,7 @@ ''' # --- -# name: test_data_color_pd_cols_rows_snap[False] +# name: test_data_color_pd_cols_rows_snap ''' @@ -386,43 +386,13 @@ 55 - 15 - 265 + 200 + 200 ''' # --- -# name: test_data_color_pd_cols_rows_snap[True] - ''' - - - 1 - 51 - - - 2 - 52 - - - 3 - 53 - - - 4 - 54 - - - 5 - 55 - - - 15 - 265 - - - ''' -# --- -# name: test_data_color_pl_cols_rows_snap[False] +# name: test_data_color_pl_cols_rows_snap ''' @@ -446,38 +416,8 @@ 55 - 15 - 265 - - - ''' -# --- -# name: test_data_color_pl_cols_rows_snap[True] - ''' - - - 1 - 51 - - - 2 - 52 - - - 3 - 53 - - - 4 - 54 - - - 5 - 55 - - - 15 - 265 + 200 + 200 ''' diff --git a/tests/data_color/test_data_color.py b/tests/data_color/test_data_color.py index 237e8aa40..eb01334ce 100644 --- a/tests/data_color/test_data_color.py +++ b/tests/data_color/test_data_color.py @@ -55,30 +55,26 @@ def test_data_color_simple_exibble_snap(snapshot: str, df: DataFrameLike): assert_rendered_body(snapshot, gt) -@pytest.mark.parametrize("include_last_row", [True, False]) -def test_data_color_pd_cols_rows_snap(snapshot: str, include_last_row: bool): - df = pd.DataFrame( - { - "a": [1, 2, 3, 4, 5, sum([1, 2, 3, 4, 5])], - "b": [51, 52, 53, 54, 55, sum([51, 52, 53, 54, 55])], - } - ) - new_gt = GT(df).data_color(columns=["a"], include_last_row=include_last_row) +def test_data_color_pd_cols_rows_snap(snapshot: str): + df = pd.DataFrame({"a": [1, 2, 3, 4, 5, 200], "b": [51, 52, 53, 54, 55, 200]}) + new_gt = GT(df).data_color(columns=["a"], rows=[0, 1, 2, 3, 4]) assert_rendered_body(snapshot, new_gt) + new_gt2 = GT(df).data_color(columns=["a"], rows=lambda df_: df_["a"].lt(60)) + assert create_body_component_h(new_gt._build_data("html")) == create_body_component_h( + new_gt2._build_data("html") + ) -@pytest.mark.parametrize("include_last_row", [True, False]) -def test_data_color_pl_cols_rows_snap(snapshot: str, include_last_row: bool): +def test_data_color_pl_cols_rows_snap(snapshot: str): import polars.selectors as cs - df = pl.DataFrame( - { - "a": [1, 2, 3, 4, 5, sum([1, 2, 3, 4, 5])], - "b": [51, 52, 53, 54, 55, sum([51, 52, 53, 54, 55])], - } - ) - new_gt = GT(df).data_color(columns=cs.by_name("b"), include_last_row=include_last_row) + df = pl.DataFrame({"a": [1, 2, 3, 4, 5, 200], "b": [51, 52, 53, 54, 55, 200]}) + new_gt = GT(df).data_color(columns=["b"], rows=[0, 1, 2, 3, 4]) assert_rendered_body(snapshot, new_gt) + new_gt2 = GT(df).data_color(columns=cs.starts_with("b"), rows=pl.col("b").lt(60)) + assert create_body_component_h(new_gt._build_data("html")) == create_body_component_h( + new_gt2._build_data("html") + ) @pytest.mark.parametrize("none_val", [None, np.nan, float("nan"), pd.NA])