Skip to content

Commit

Permalink
Merge pull request #364 from jrycw/data-color-row-sel
Browse files Browse the repository at this point in the history
Support specifying a subset of rows in `GT.data_color()`
  • Loading branch information
machow authored Jun 6, 2024
2 parents 82f66ad + 610542e commit b6aa4f8
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 8 deletions.
24 changes: 16 additions & 8 deletions great_tables/_data_color/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from typing import TYPE_CHECKING

import numpy as np
from great_tables._tbl_data import DataFrameLike, is_na
from great_tables._locations import resolve_cols_c, resolve_rows_i, RowSelectExpr
from great_tables._tbl_data import DataFrameLike, is_na, SelectExpr
from great_tables.loc import body
from great_tables.style import fill, text
from typing_extensions import TypeAlias
Expand All @@ -19,7 +20,8 @@

def data_color(
self: GTSelf,
columns: str | list[str] | None = None,
columns: SelectExpr = None,
rows: RowSelectExpr = None,
palette: str | list[str] | None = None,
domain: list[str] | list[int] | list[float] | None = None,
na_color: str | None = None,
Expand Down Expand Up @@ -47,6 +49,10 @@ def data_color(
columns
The columns to target. Can either be a single column name or a series of column names
provided in a list.
rows
In conjunction with `columns=`, we can specify which rows should be colored. By default,
all rows in the targeted columns will be colored. Alternatively, we can provide a list
of row indices.
palette
The color palette to use. This should be a list of colors (e.g., `["#FF0000", "#00FF00",
"#0000FF"]`). A ColorBrewer palette could also be used, just supply the name (reference
Expand Down Expand Up @@ -202,18 +208,20 @@ def data_color(
# get a list of all columns in the table body
columns_resolved: list[str]

if isinstance(columns, str):
columns_resolved = [columns]
elif columns is None:
if columns is None:
columns_resolved = data_table.columns
else:
columns_resolved = columns
columns_resolved = resolve_cols_c(data=self, expr=columns)

row_res = resolve_rows_i(self, rows)
row_pos = [name_pos[1] for name_pos in row_res]

gt_obj = self

# For each column targeted, get the data values as a new list object
for col in columns_resolved:
column_vals = data_table[col].to_list()
# This line handles both pandas and polars dataframes
column_vals = data_table[col][row_pos].to_list()

# Filter out NA values from `column_vals`
filtered_column_vals = [x for x in column_vals if not is_na(data_table, x)]
Expand Down Expand Up @@ -260,7 +268,7 @@ def data_color(

# for every color value in color_vals, apply a fill to the corresponding cell
# by using `tab_style()`
for i, color_val in enumerate(color_vals):
for i, color_val in zip(row_pos, color_vals):
if autocolor_text:
fgnd_color = _ideal_fgnd_color(bgnd_color=color_val)

Expand Down
60 changes: 60 additions & 0 deletions tests/data_color/__snapshots__/test_data_color.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,66 @@
</tbody>
'''
# ---
# name: test_data_color_pd_cols_rows_snap
'''
<tbody class="gt_table_body">
<tr>
<td style="color: #FFFFFF; background-color: #000000;" class="gt_row gt_right">1</td>
<td class="gt_row gt_right">51</td>
</tr>
<tr>
<td style="color: #000000; background-color: #80b156;" class="gt_row gt_right">2</td>
<td class="gt_row gt_right">52</td>
</tr>
<tr>
<td style="color: #000000; background-color: #25bce6;" class="gt_row gt_right">3</td>
<td class="gt_row gt_right">53</td>
</tr>
<tr>
<td style="color: #000000; background-color: #d73a91;" class="gt_row gt_right">4</td>
<td class="gt_row gt_right">54</td>
</tr>
<tr>
<td style="color: #000000; background-color: #9e9e9e;" class="gt_row gt_right">5</td>
<td class="gt_row gt_right">55</td>
</tr>
<tr>
<td class="gt_row gt_right">200</td>
<td class="gt_row gt_right">200</td>
</tr>
</tbody>
'''
# ---
# name: test_data_color_pl_cols_rows_snap
'''
<tbody class="gt_table_body">
<tr>
<td class="gt_row gt_right">1</td>
<td style="color: #FFFFFF; background-color: #000000;" class="gt_row gt_right">51</td>
</tr>
<tr>
<td class="gt_row gt_right">2</td>
<td style="color: #000000; background-color: #80b156;" class="gt_row gt_right">52</td>
</tr>
<tr>
<td class="gt_row gt_right">3</td>
<td style="color: #000000; background-color: #25bce6;" class="gt_row gt_right">53</td>
</tr>
<tr>
<td class="gt_row gt_right">4</td>
<td style="color: #000000; background-color: #d73a91;" class="gt_row gt_right">54</td>
</tr>
<tr>
<td class="gt_row gt_right">5</td>
<td style="color: #000000; background-color: #9e9e9e;" class="gt_row gt_right">55</td>
</tr>
<tr>
<td class="gt_row gt_right">200</td>
<td class="gt_row gt_right">200</td>
</tr>
</tbody>
'''
# ---
# name: test_data_color_simple_df_snap
'''
<tbody class="gt_table_body">
Expand Down
22 changes: 22 additions & 0 deletions tests/data_color/test_data_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,28 @@ def test_data_color_simple_exibble_snap(snapshot: str, df: DataFrameLike):
assert_rendered_body(snapshot, gt)


def test_data_color_pd_cols_rows_snap(snapshot: str):
df = pd.DataFrame({"a": [1, 2, 3, 4, 5, 200], "b": [51, 52, 53, 54, 55, 200]})
new_gt = GT(df).data_color(columns=["a"], rows=[0, 1, 2, 3, 4])
assert_rendered_body(snapshot, new_gt)
new_gt2 = GT(df).data_color(columns=["a"], rows=lambda df_: df_["a"].lt(60))
assert create_body_component_h(new_gt._build_data("html")) == create_body_component_h(
new_gt2._build_data("html")
)


def test_data_color_pl_cols_rows_snap(snapshot: str):
import polars.selectors as cs

df = pl.DataFrame({"a": [1, 2, 3, 4, 5, 200], "b": [51, 52, 53, 54, 55, 200]})
new_gt = GT(df).data_color(columns=["b"], rows=[0, 1, 2, 3, 4])
assert_rendered_body(snapshot, new_gt)
new_gt2 = GT(df).data_color(columns=cs.starts_with("b"), rows=pl.col("b").lt(60))
assert create_body_component_h(new_gt._build_data("html")) == create_body_component_h(
new_gt2._build_data("html")
)


@pytest.mark.parametrize("none_val", [None, np.nan, float("nan"), pd.NA])
@pytest.mark.parametrize("df_cls", [pd.DataFrame, pl.DataFrame])
def test_data_color_missing_value(df_cls, none_val):
Expand Down

0 comments on commit b6aa4f8

Please sign in to comment.