diff --git a/docs/_quarto.yml b/docs/_quarto.yml index 62b591348..53c815bca 100644 --- a/docs/_quarto.yml +++ b/docs/_quarto.yml @@ -91,6 +91,7 @@ quartodoc: contents: - GT.tab_header - GT.tab_spanner + - GT.tab_stub - GT.tab_stubhead - GT.tab_source_note - GT.tab_style @@ -155,6 +156,8 @@ quartodoc: [`tab_stubhead()`](`great_tables.GT.tab_stubhead`), and [`tab_source_note()`](`great_tables.GT.tab_source_note`) methods. contents: + - GT.with_id + - GT.with_locale - md - html - from_column diff --git a/docs/examples/index.qmd b/docs/examples/index.qmd index b45b91ce3..3ad858e81 100644 --- a/docs/examples/index.qmd +++ b/docs/examples/index.qmd @@ -25,11 +25,12 @@ islands_mini = ( ) ( - GT(islands_mini, rowname_col="name") + GT(islands_mini) .tab_header( title="Large Landmasses of the World", subtitle="The top ten largest are presented" ) + .tab_stub(rowname_col="name") .tab_source_note(source_note="Source: The World Almanac and Book of Facts, 1975, page 406.") .tab_source_note( source_note=md("Reference: McNeil, D. R. (1977) *Interactive Data Analysis*. Wiley.") @@ -108,9 +109,10 @@ wide_pops = ( ) ( - GT(wide_pops, rowname_col="country_name", groupname_col="region") + GT(wide_pops) .tab_header(title="Populations of Oceania's Countries in 2000, 2010, and 2020") .tab_spanner(label="Total Population", columns=cs.all()) + .tab_stub(rowname_col="country_name", groupname_col="region") .fmt_integer() ) ``` diff --git a/docs/get-started/basic-stub.qmd b/docs/get-started/basic-stub.qmd index 0d8ec7cd6..ee9a341f1 100644 --- a/docs/get-started/basic-stub.qmd +++ b/docs/get-started/basic-stub.qmd @@ -16,14 +16,15 @@ from great_tables.data import islands islands_mini = islands.head(10) -GT(islands_mini, rowname_col="name") +GT(islands_mini).tab_stub(rowname_col="name") ``` Notice that the landmass names are now placed to the left? That's the **Stub**. Notably, there is a prominent border to the right of it but there's no label above the **Stub**. We can change this and apply what's known as a *stubhead label* through use of the [`tab_stubhead()`](`great_tables.GT.tab_stubhead`) method: ```{python} ( - GT(islands_mini, rowname_col="name") + GT(islands_mini) + .tab_stub(rowname_col="name") .tab_stubhead(label="landmass") ) ``` @@ -37,5 +38,19 @@ Let's incorporate row groups into the display table. This divides rows into grou ```{python} island_groups = islands.head(10).assign(group = ["subregion"] * 2 + ["country"] * 2 + ["continent"] * 6) -GT(island_groups, rowname_col="name", groupname_col="group").tab_stubhead(label="landmass") +( + GT(island_groups) + .tab_stub(rowname_col="name", groupname_col="group") + .tab_stubhead(label="landmass") +) +``` + +## GT convenience arguments + +Rather than using the `GT.tab_stub()` method, the `GT(rowname_col=..., groupname_col=...)` arguments +provide a quick way to specify row names and groups. + + +```{python} +GT(island_groups, rowname_col="name", groupname_col="group") ``` diff --git a/great_tables/_body.py b/great_tables/_body.py index 01be8c86c..5c97494ab 100644 --- a/great_tables/_body.py +++ b/great_tables/_body.py @@ -8,7 +8,7 @@ from ._gt_data import Body, Boxhead, RowGroups, Stub -def body_reassemble(body: Body, row_groups: RowGroups, stub_df: Stub, boxhead: Boxhead) -> Body: +def body_reassemble(body: Body, stub_df: Stub, boxhead: Boxhead) -> Body: # Note that this used to order the body based on groupings, but now that occurs in the # renderer itself. return body.__class__(copy_data(body.body)) diff --git a/great_tables/_gt_data.py b/great_tables/_gt_data.py index 12919a3be..3d199f17c 100644 --- a/great_tables/_gt_data.py +++ b/great_tables/_gt_data.py @@ -5,11 +5,12 @@ from collections.abc import Sequence from dataclasses import dataclass, field, replace from enum import Enum, auto -from typing import Any, Callable, TypeVar, overload +from typing import Any, Callable, Tuple, TypeVar, overload from typing_extensions import Self, TypeAlias # TODO: move this class somewhere else (even gt_data could work) +from ._options import tab_options from ._styles import CellStyle from ._tbl_data import ( DataFrameLike, @@ -29,7 +30,19 @@ # GT Data ---- -__GT = None + + +def _prep_gt(data, rowname_col, groupname_col, auto_align) -> Tuple[Stub, Boxhead, GroupRows]: + # this function is similar to Stub._set_cols, except it differs in two ways. + # * it supports auto-alignment (an expensive operation) + # * it assumes its run on data initialization, whereas _set_cols may be run after + + stub = Stub.from_data(data, rowname_col=rowname_col, groupname_col=groupname_col) + boxhead = Boxhead( + data, auto_align=auto_align, rowname_col=rowname_col, groupname_col=groupname_col + ) + + return stub, boxhead @dataclass(frozen=True) @@ -38,8 +51,6 @@ class GTData: _body: Body _boxhead: Boxhead _stub: Stub - _row_groups: RowGroups - _group_rows: GroupRows _spanners: Spanners _heading: Heading _stubhead: Stubhead @@ -74,13 +85,7 @@ def from_data( locale: str | None = None, ): data = validate_frame(data) - stub = Stub(data, rowname_col=rowname_col, groupname_col=groupname_col) - boxhead = Boxhead( - data, auto_align=auto_align, rowname_col=rowname_col, groupname_col=groupname_col - ) - - row_groups = stub._to_row_groups() - group_rows = GroupRows(data, group_key=groupname_col).reorder(row_groups) + stub, boxhead = _prep_gt(data, rowname_col, groupname_col, auto_align) if id is not None: options = Options(table_id=OptionsInfo(True, "table", "value", id)) @@ -92,8 +97,6 @@ def from_data( _body=Body.from_empty(data), _boxhead=boxhead, # uses get_tbl_data() _stub=stub, # uses get_tbl_data - _row_groups=row_groups, - _group_rows=group_rows, _spanners=Spanners([]), _heading=Heading(), _stubhead=None, @@ -141,7 +144,6 @@ def __eq__(self, other: Any) -> bool: # Body ---- -__Body = None # TODO: it seems like this could just be a DataFrameLike object? @@ -182,7 +184,6 @@ def from_empty(cls, body: DataFrameLike): # Boxhead ---- -__Boxhead = None class ColumnAlignment(Enum): @@ -234,48 +235,57 @@ def defaulted_align(self) -> str: class Boxhead(_Sequence[ColInfo]): _d: list[ColInfo] - def __init__( - self, + def __new__( + cls, data: TblData | list[ColInfo], auto_align: bool = True, rowname_col: str | None = None, groupname_col: str | None = None, ): + obj = super().__new__(cls) + if isinstance(data, list): - self._d = data + obj._d = data else: # Obtain the column names from the data and initialize the # `_boxhead` from that column_names = get_column_names(data) - self._d = [ColInfo(col) for col in column_names] + obj._d = [ColInfo(col) for col in column_names] + obj = obj.set_stub_cols(rowname_col, groupname_col) + if not isinstance(data, list) and auto_align: - self.align_from_data(data=data) + return obj.align_from_data(data=data) - if rowname_col is not None: - self.set_rowname_col(rowname_col) + return obj - if groupname_col is not None: - self.set_groupname_col(groupname_col) + def __init__(self, *args, **kwargs): + pass - def set_rowname_col(self, rowname_col: str): + def set_stub_cols(self, rowname_col: str | None, groupname_col: str | None): + # Note that None unsets a column # TODO: validate that rowname_col is in the boxhead - for ii, col in enumerate(self._d): + if rowname_col is not None and rowname_col == groupname_col: + raise ValueError( + "rowname_col and groupname_col may not be set to the same column. " + f"Received column name: `{rowname_col}`." + ) + new_cols = [] + for col in self: + # either set the col to be the new stub or row_group ---- + # note that this assumes col.var is always a string, so never equals None if col.var == rowname_col: new_col = replace(col, type=ColInfoTypeEnum.stub) - self._d[ii] = new_col - elif col.type == ColInfoTypeEnum.stub: - new_col = replace(col, type=ColInfoTypeEnum.default) - self._d[ii] = new_col - - def set_groupname_col(self, groupname_col: str): - # TODO: validate that groupname_col is in the boxhead - for ii, col in enumerate(self._d): - if col.var == groupname_col: + elif col.var == groupname_col: new_col = replace(col, type=ColInfoTypeEnum.row_group) - self._d[ii] = new_col - elif col.type == ColInfoTypeEnum.row_group: + # otherwise, unset the existing stub or row_group ---- + elif col.type == ColInfoTypeEnum.stub or col.type == ColInfoTypeEnum.row_group: new_col = replace(col, type=ColInfoTypeEnum.default) - self._d[ii] = new_col + else: + new_col = replace(col) + + new_cols.append(new_col) + + return self.__class__(new_cols) def set_cols_hidden(self, colnames: list[str]): # TODO: validate that colname is in the boxhead @@ -363,7 +373,7 @@ def align_from_data(self, data: TblData): for col, alignment in zip(self._d, align): new_cols.append(replace(col, column_align=alignment)) - self._d = new_cols + return self.__class__(new_cols) def vars_from_type(self, type: ColInfoTypeEnum) -> list[str]: return [x.var for x in self._d if x.type == type] @@ -424,6 +434,12 @@ def _get_stub_column(self) -> ColInfo | None: return None return stub_column[0] + def _get_row_group_column(self) -> ColInfo | None: + column = [x for x in self._d if x.type == ColInfoTypeEnum.row_group] + if len(column) == 0: + return None + return column[0] + # Get a list of visible column labels def _get_default_column_labels(self) -> list[str | None]: default_column_labels = [ @@ -463,12 +479,10 @@ def _get_number_of_visible_data_columns(self) -> int: # Obtain the number of visible columns in the built table; this should # account for the size of the stub in the final, built table - def _get_effective_number_of_columns( - self, stub: Stub, row_groups: RowGroups, options: Options - ) -> int: + def _get_effective_number_of_columns(self, stub: Stub, options: Options) -> int: n_data_cols = self._get_number_of_visible_data_columns() - stub_layout = stub._get_stub_layout(row_groups=row_groups, options=options) + stub_layout = stub._get_stub_layout(options=options) # Once the stub is defined in the package, we need to account # for the width of the stub at build time to fully obtain the number # of visible columns in the built table @@ -494,7 +508,6 @@ def _set_column_width(self, colname: str, width: str) -> Self: # Stub ---- -__Stub = None @dataclass(frozen=True) @@ -514,58 +527,108 @@ class RowInfo: # `built` = False -class Stub(_Sequence[RowInfo]): +class Stub: + """Container for row information and labels, along with grouping information. + + This class handles the following: + + * Creating row and grouping information from data. + * Determining row order for final presentation. + + Note that the order of entries in .group_rows determines final rendering order. + When .group_rows is empty, the original data order is used. + """ + + # TODO: the rows get reordered at various points, but are never used in rendering? + # the html rendering uses group_rows to index into the underlying DataFrame + _d: list[RowInfo] + rows: list[RowInfo] + group_rows: GroupRows - def __init__( - self, - data: TblData | list[RowInfo], - rowname_col: str | None = None, - groupname_col: str | None = None, - ): - if isinstance(data, list): - self._d = list(data) + def __init__(self, rows: list[RowInfo], group_rows: GroupRows): + self.rows = self._d = list(rows) + self.group_rows = group_rows + @classmethod + def from_data(cls, data, rowname_col: str | None = None, groupname_col: str | None = None): + # Obtain a list of row indices from the data and initialize + # the `_stub` from that + row_indices = list(range(n_rows(data))) + + if groupname_col is not None: + group_id = to_list(data[groupname_col]) else: - # Obtain a list of row indices from the data and initialize - # the `_stub` from that - row_indices = list(range(n_rows(data))) + group_id = [None] * n_rows(data) - if groupname_col is not None: - group_id = to_list(data[groupname_col]) - else: - group_id = [None] * n_rows(data) + if rowname_col is not None: + row_names = to_list(data[rowname_col]) + else: + row_names = [None] * n_rows(data) - if rowname_col is not None: - row_names = to_list(data[rowname_col]) - else: - row_names = [None] * n_rows(data) + # Obtain the column names from the data and initialize the + # `_stub` from that + row_info = [RowInfo(*i) for i in zip(row_indices, group_id, row_names)] - # Obtain the column names from the data and initialize the - # `_stub` from that - self._d = [RowInfo(*i) for i in zip(row_indices, group_id, row_names)] + # create groups, and ensure they're ordered by first observed + group_names = list({row.group_id: True for row in row_info if row.group_id is not None}) + group_rows = GroupRows(data, group_key=groupname_col).reorder(group_names) + + return cls(row_info, group_rows) + + def _set_cols( + self, data: TblData, boxhead: Boxhead, rowname_col: str | None, groupname_col: str | None + ) -> Tuple[Stub, Boxhead]: + """Return a new Stub and Boxhead, with updated rowname and groupname columns. + + Note that None unsets a column. + """ + + new_boxhead = boxhead.set_stub_cols(rowname_col, groupname_col) + new_stub = self.from_data(data, rowname_col, groupname_col) + + return new_stub, new_boxhead + + @property + def group_ids(self) -> RowGroups: + return [group.group_id for group in self.group_rows] + + def reorder_rows(self, indices) -> Self: + new_rows = [self.rows[ii] for ii in indices] - def _to_row_groups(self) -> RowGroups: - # get unique group_ids, using dict as an ordered set - group_ids = list({row.group_id: True for row in self if row.group_id is not None}) + return self.__class__(new_rows, self.group_rows) - return group_ids + def order_groups(self, group_order: RowGroups): + # TODO: validate + return self.__class__(self.rows, self.group_rows.reorder(group_order)) + + def group_indices_map(self) -> list[tuple[int, str | None]]: + return self.group_rows.indices_map(len(self.rows)) + + def __iter__(self): + return iter(self.rows) + + def __len__(self): + return len(self.rows) + + def __getitem__(self, ii: int): + return self.rows[ii] def _get_stub_components(self) -> list[str]: stub_components: list[str] = [] - if any(entry.group_id is not None for entry in self): + if any(entry.group_id is not None for entry in self.rows): stub_components.append("group_id") - if any(entry.rowname is not None for entry in self): + if any(entry.rowname is not None for entry in self.rows): stub_components.append("row_id") return stub_components # Determine whether the table should have row group labels set within a column in the stub - def _stub_group_names_has_column(self, row_groups: RowGroups, options: Options) -> bool: + def _stub_group_names_has_column(self, options: Options) -> bool: # If there aren't any row groups then the result is always False - if len(row_groups) < 1: + if len(self.group_ids) < 1: return False # Given that there are row groups, we need to look at the option `row_group_as_column` to @@ -579,12 +642,10 @@ def _stub_group_names_has_column(self, row_groups: RowGroups, options: Options) return row_group_as_column - def _get_stub_layout(self, row_groups: RowGroups, options: Options) -> list[str]: + def _get_stub_layout(self, options: Options) -> list[str]: # Determine which stub components are potentially present as columns stub_rownames_is_column = "row_id" in self._get_stub_components() - stub_groupnames_is_column = self._stub_group_names_has_column( - row_groups=row_groups, options=options - ) + stub_groupnames_is_column = self._stub_group_names_has_column(options=options) # Get the potential total number of columns in the table stub n_stub_cols = stub_rownames_is_column + stub_groupnames_is_column @@ -614,12 +675,9 @@ def _get_stub_layout(self, row_groups: RowGroups, options: Options) -> list[str] # Row groups ---- -__RowGroups = None - RowGroups: TypeAlias = list[str] # Group rows ---- -__GroupRows = None @dataclass(frozen=True) @@ -657,8 +715,8 @@ def __init__(self, data: list[GroupRowInfo] | DataFrameLike, group_key: str | No from ._tbl_data import group_splits self._d = [] - for grp_key, ind in group_splits(data, group_key=group_key).items(): - self._d.append(GroupRowInfo(grp_key, indices=ind)) + for group_id, ind in group_splits(data, group_key=group_key).items(): + self._d.append(GroupRowInfo(group_id, indices=ind)) def reorder(self, group_ids: list[str | MISSING_GROUP]) -> Self: # TODO: validate all group_ids are in data @@ -689,7 +747,6 @@ def indices_map(self, n: int) -> list[tuple[int, str | None]]: # Spanners ---- -__Spanners = None @dataclass(frozen=True) @@ -709,6 +766,14 @@ def built_label(self) -> str: raise ValueError("Spanner label must be a string and not None.") return label + def drop_var(self, name: str) -> Self: + new_vars = [entry for entry in self.vars if entry != name] + + if len(new_vars) == len(self.vars): + return self + + return replace(self, vars=new_vars) + class Spanners(_Sequence[SpannerInfo]): _d: list[SpannerInfo] @@ -739,7 +804,7 @@ def next_level(self, column_names: list[str]) -> int: return 0 overlapping_levels = [ - s.spanner_level for s in self if any(v in column_names for v in s.vars) + span.spanner_level for span in self if any(v in column_names for v in span.vars) ] return max(overlapping_levels, default=-1) + 1 @@ -747,9 +812,11 @@ def next_level(self, column_names: list[str]) -> int: def append_entry(self, span: SpannerInfo) -> Self: return self.__class__(self._d + [span]) + def remove_column(self, column: str) -> Self: + return self.__class__([span.drop_var(column) for span in self]) + # Heading --- -__Heading = None @dataclass(frozen=True) @@ -760,18 +827,13 @@ class Heading: # Stubhead ---- -__Stubhead = None - Stubhead: TypeAlias = "str | None" # Sourcenotes ---- -__Sourcenotes = None - SourceNotes = list[str] # Footnotes ---- -__Footnotes = None class FootnotePlacement(Enum): @@ -795,7 +857,6 @@ class FootnoteInfo: Footnotes: TypeAlias = list[FootnoteInfo] # Styles ---- -__Styles = None @dataclass(frozen=True) @@ -812,7 +873,6 @@ class StyleInfo: Styles: TypeAlias = list[StyleInfo] # Locale ---- -__Locale = None class Locale: @@ -823,7 +883,6 @@ def __init__(self, locale: str | None): # Formats ---- -__Formats = None class FormatterSkipElement: @@ -885,8 +944,6 @@ def __init__(self, func: FormatFns, cols: list[str], rows: list[int]): # Options ---- -__Options = None - default_fonts_list = [ "-apple-system", @@ -1128,5 +1185,7 @@ def _get_all_options_keys(self) -> list[str | None]: # return self._options[option].type def _set_option_value(self, option: str, value: Any): - self._options[option].value = value - return self + old_info = getattr(self, option) + new_info = replace(old_info, value=value) + + return replace(self, **{option: new_info}) diff --git a/great_tables/_modify_rows.py b/great_tables/_modify_rows.py new file mode 100644 index 000000000..013fe458e --- /dev/null +++ b/great_tables/_modify_rows.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, TypedDict, TypeVar, cast + +from ._gt_data import GTData, Locale, Options, RowGroups, Spanners, Stub, Boxhead, Styles + +if TYPE_CHECKING: + from ._types import GTSelf + + +def row_group_order(self: GTSelf, groups: RowGroups) -> GTSelf: + new_stub = self._stub.order_groups(groups) + + return self._replace(_stub=new_stub) + + +def _remove_from_body_styles(styles: Styles, column: str) -> Styles: + new_styles = [ + info for info in styles if not (info.locname == "data" and info.colname == column) + ] + + return new_styles + + +def _remove_from_group_styles(styles: Styles, column: str): + # TODO(#341): once group styles are supported, will need to wire this up. + return list(styles) + + +def tab_stub( + self: GTSelf, rowname_col: str | None = None, groupname_col: str | None = None +) -> GTSelf: + """Add a table stub, to emphasize row and group information. + + Parameters + ---------- + rowname_col: + The column to use for row names. By default, no row names added. + groupname_col: + The column to use for group names. By default no group names added. + + Examples + -------- + + By default, all data is together in the body of the table. + + ```{python} + from great_tables import GT, exibble + + GT(exibble) + ``` + + The table stub separates row names with a vertical line, and puts group names + on their own line. + + ```{python} + GT(exibble).tab_stub(rowname_col="row", groupname_col="group") + ``` + """ + # old columns ---- + _info = self._boxhead._get_row_group_column() + old_groupname_col = _info.var if _info is not None else None + + styles = self._styles + + # remove group styles ---- + if old_groupname_col is not None and old_groupname_col != groupname_col: + styles = _remove_from_group_styles(styles, old_groupname_col) + + # remove table body styles ---- + # they no longer apply to groupname_col + if groupname_col is not None: + styles = _remove_from_body_styles(self._styles, groupname_col) + + self = self._replace(_styles=styles) + + # remove from spanners ---- + if groupname_col is not None: + self = self._replace(_spanners=self._spanners.remove_column(groupname_col)) + + if rowname_col is not None: + self = self._replace(_spanners=self._spanners.remove_column(rowname_col)) + + # set new row and group name cols ---- + stub, boxhead = self._stub._set_cols(self._tbl_data, self._boxhead, rowname_col, groupname_col) + + return self._replace(_stub=stub, _boxhead=boxhead) + + +def with_locale(self: GTSelf, locale: str | None = None) -> GTSelf: + """Set a column to be the default locale. + + Setting a default locale affects formatters like .fmt_number, and .fmt_date, + by having them default to locale-specific features (e.g. representing one thousand + as 1.000,00) + """ + + return self._replace(_locale=Locale(locale)) + + +def with_id(self: GTSelf, id: str | None = None) -> GTSelf: + """Set the id for this table. + + Note that this is a shortcut for the `table_id=` argument in `GT.tab_options()`. + """ + return self._replace(_options=self._options._set_option_value("table_id", id)) diff --git a/great_tables/_stub.py b/great_tables/_stub.py index 2f2842461..c431caa50 100644 --- a/great_tables/_stub.py +++ b/great_tables/_stub.py @@ -4,7 +4,7 @@ from .utils_render_common import get_row_reorder_df -def reorder_stub_df(stub_df: Stub, row_groups: RowGroups) -> Stub: +def reorder_stub_df(stub_df: Stub) -> Stub: """ Reorders the components of the stub object based on the given row groups. @@ -15,8 +15,12 @@ def reorder_stub_df(stub_df: Stub, row_groups: RowGroups) -> Stub: Returns: Stub: The reordered stub object. """ - start_final = get_row_reorder_df(row_groups, stub_df) - stub_df = stub_df[[final for _, final in start_final]] + # NOTE: the original R package reordered stub rows, and returned a new GT object. + # However, since the final order is determined by the groups, we use those to + # determine the final order, just before rendering + + # start_final = get_row_reorder_df(stub_df) + # stub_df = stub_df.reorder_rows([final for _, final in start_final]) return stub_df diff --git a/great_tables/_tbl_data.py b/great_tables/_tbl_data.py index 84853ba8c..8ee25df9f 100644 --- a/great_tables/_tbl_data.py +++ b/great_tables/_tbl_data.py @@ -258,7 +258,7 @@ def group_splits(data: DataFrameLike, group_key: str) -> dict[Any, list[int]]: @group_splits.register def _(data: PdDataFrame, group_key: str) -> dict[Any, list[int]]: - g_df = data.groupby(group_key) + g_df = data.groupby(group_key, dropna=False) return {k: list(v) for k, v in g_df.indices.items()} diff --git a/great_tables/_utils_render_html.py b/great_tables/_utils_render_html.py index 856d1658a..5bacdc1e7 100644 --- a/great_tables/_utils_render_html.py +++ b/great_tables/_utils_render_html.py @@ -35,7 +35,7 @@ def create_heading_component_h(data: GTData) -> StringBuilder: # Get the effective number of columns, which is number of columns # that will finally be rendered accounting for the stub layout n_cols_total = data._boxhead._get_effective_number_of_columns( - stub=data._stub, row_groups=data._row_groups, options=data._options + stub=data._stub, options=data._options ) if has_subtitle: @@ -79,7 +79,7 @@ def create_columns_component_h(data: GTData) -> str: # body = data._body # Get vector representation of stub layout - stub_layout = data._stub._get_stub_layout(row_groups=data._row_groups, options=data._options) + stub_layout = data._stub._get_stub_layout(options=data._options) # Determine the finalized number of spanner rows spanner_row_count = _get_spanners_matrix_height(data=data, omit_columns_row=True) @@ -437,11 +437,11 @@ def create_body_component_h(data: GTData) -> str: stub_var = data._boxhead._get_stub_column() - stub_layout = data._stub._get_stub_layout(row_groups=data._row_groups, options=data._options) + stub_layout = data._stub._get_stub_layout(options=data._options) has_stub_column = "rowname" in stub_layout has_two_col_stub = "group_label" in stub_layout - has_groups = data._row_groups is not None and len(data._row_groups) > 0 + has_groups = data._stub.group_ids is not None and len(data._stub.group_ids) > 0 # If there is a stub, then prepend that to the `column_vars` list if stub_var is not None: @@ -452,14 +452,14 @@ def create_body_component_h(data: GTData) -> str: # iterate over rows (ordered by groupings) prev_group_label = None - ordered_index = data._group_rows.indices_map(n_rows(data._tbl_data)) + ordered_index = data._stub.group_indices_map() for i, group_label in ordered_index: body_cells: list[str] = [] if has_stub_column and has_groups and not has_two_col_stub: colspan_value = data._boxhead._get_effective_number_of_columns( - stub=data._stub, row_groups=data._row_groups, options=data._options + stub=data._stub, options=data._options ) # Generate a row that contains the row group label (this spans the entire row) but @@ -536,7 +536,7 @@ def create_source_notes_component_h(data: GTData) -> str: # Get the effective number of columns, which is number of columns # that will finally be rendered accounting for the stub layout n_cols_total = data._boxhead._get_effective_number_of_columns( - stub=data._stub, row_groups=data._row_groups, options=data._options + stub=data._stub, options=data._options ) # Handle the multiline source notes case (each note takes up one line) diff --git a/great_tables/gt.py b/great_tables/gt.py index 81a5fd32d..a58076e3e 100644 --- a/great_tables/gt.py +++ b/great_tables/gt.py @@ -30,6 +30,12 @@ ) from great_tables._heading import tab_header from great_tables._helpers import random_id +from great_tables._modify_rows import ( + row_group_order, + tab_stub, + with_id, + with_locale, +) from great_tables._options import ( opt_align_table_header, opt_all_caps, @@ -254,6 +260,11 @@ def __init__( tab_style = tab_style tab_options = tab_options + row_group_order = row_group_order + tab_stub = tab_stub + with_id = with_id + with_locale = with_locale + save = save as_raw_html = as_raw_html @@ -291,11 +302,11 @@ def _build_data(self, context: str) -> Self: # built._body = _migrate_unformatted_to_output(body) # built._perform_col_merge() - final_body = body_reassemble(built._body, built._row_groups, built._stub, built._boxhead) + final_body = body_reassemble(built._body, built._stub, built._boxhead) # Reordering of the metadata elements of the table - final_stub = reorder_stub_df(built._stub, built._row_groups) + final_stub = reorder_stub_df(built._stub) # self = self.reorder_footnotes() # self = self.reorder_styles() diff --git a/great_tables/utils_render_common.py b/great_tables/utils_render_common.py index 436108ea7..58e306dfb 100644 --- a/great_tables/utils_render_common.py +++ b/great_tables/utils_render_common.py @@ -11,9 +11,14 @@ TupleStartFinal: TypeAlias = tuple[int, int] -def get_row_reorder_df(groups: RowGroups, stub_df: Stub) -> list[TupleStartFinal]: +def get_row_reorder_df(stub_df: Stub, groups: RowGroups | None = None) -> list[TupleStartFinal]: + # TODO: this function should be removed, since the stub generates indices directly. + + if groups is None: + groups = stub_df.group_ids + # Get the number of non-None entries in the `groupname_col` - n_stub_entries = len([entry for entry in stub_df if entry.group_id is not None]) + n_stub_entries = len([entry for entry in stub_df.rows if entry.group_id is not None]) # Raise a ValueError if there are row group entries but no RowGroups if n_stub_entries and not len(groups): @@ -22,7 +27,7 @@ def get_row_reorder_df(groups: RowGroups, stub_df: Stub) -> list[TupleStartFinal # If there aren't any row groups then return a list of tuples that don't lead # to any resorting later on (e.g., `[(0, 0), (1, 1), (2, 2) ... (n, n)]`) if not len(groups): - indices = range(len(stub_df)) + indices = range(len(stub_df.rows)) # TODO: is this used in indexing? If so, we may need to use # ii + 1 for the final part? @@ -30,7 +35,9 @@ def get_row_reorder_df(groups: RowGroups, stub_df: Stub) -> list[TupleStartFinal # where in the group each element is # TODO: this doesn't yield consistent values - groups_pos = [groups.index(row.group_id) if row.group_id is not None else -1 for row in stub_df] + groups_pos = [ + groups.index(row.group_id) if row.group_id is not None else -1 for row in stub_df.rows + ] # From running test_body_reassemble(): # print(groups_pos) diff --git a/tests/__snapshots__/test_modify_rows.ambr b/tests/__snapshots__/test_modify_rows.ambr new file mode 100644 index 000000000..a3eeea475 --- /dev/null +++ b/tests/__snapshots__/test_modify_rows.ambr @@ -0,0 +1,15 @@ +# serializer version: 1 +# name: test_row_group_order + ''' + + + 2 + 4 + + + 1 + 3 + + + ''' +# --- diff --git a/tests/test_gt.py b/tests/test_gt.py index df6640fe7..3601abea8 100644 --- a/tests/test_gt.py +++ b/tests/test_gt.py @@ -14,10 +14,16 @@ def gt_tbl(): def test_gt_replace(gt_tbl: GT): - row_groups = ["x"] - new_gt_tbl = gt_tbl._replace(_row_groups=row_groups) + new_gt_tbl = gt_tbl._replace(_has_built=True) - assert new_gt_tbl._row_groups is row_groups + assert new_gt_tbl._has_built + + +def test_gt_groupname_and_rowname_col_equal_raises(): + with pytest.raises(ValueError) as exc_info: + GT(pd.DataFrame({"g": [1], "row": [1]}), rowname_col="g", groupname_col="g") + + assert "may not be set to the same column." in exc_info.value.args[0] def test_gt_object_prerender(gt_tbl: GT): @@ -32,7 +38,6 @@ def test_gt_object_prerender(gt_tbl: GT): assert isinstance(gt_tbl._source_notes, list) assert isinstance(gt_tbl._footnotes, list) assert isinstance(gt_tbl._styles, list) - assert isinstance(gt_tbl._row_groups, list) assert type(gt_tbl._locale).__name__ == "Locale" diff --git a/tests/test_gt_data.py b/tests/test_gt_data.py index e89e66897..ee89b3d70 100644 --- a/tests/test_gt_data.py +++ b/tests/test_gt_data.py @@ -2,22 +2,36 @@ from great_tables._gt_data import Boxhead, ColInfo, RowInfo, Stub -def test_stub_construct_manual(): - stub = Stub([RowInfo(0), RowInfo(1)]) - assert stub[0] == RowInfo(0) - - def test_stub_construct_df(): - stub = Stub(pd.DataFrame({"x": [8, 9]})) + stub = Stub.from_data(pd.DataFrame({"x": [8, 9]})) assert len(stub) == 2 assert stub[0] == RowInfo(0) assert stub[1] == RowInfo(1) +def test_stub_construct_manual(): + stub = Stub.from_data(pd.DataFrame({"x": [8, 9]})) + + stub2 = Stub(stub.rows, stub.group_rows) + assert stub2[0] == RowInfo(0) + + def test_stub_construct_df_rowname(): # TODO: remove groupname_col from here - stub = Stub(pd.DataFrame({"x": [8, 9], "y": [1, 2]}), rowname_col="x", groupname_col=None) + stub = Stub.from_data( + pd.DataFrame({"x": [8, 9], "y": [1, 2]}), rowname_col="x", groupname_col=None + ) + + +def test_stub_order_groups(): + stub = Stub.from_data(pd.DataFrame({"g": ["b", "a", "b", "c"]}), groupname_col="g") + assert stub.group_ids == ["b", "a", "c"] + + stub2 = stub.order_groups(["c", "a", "b"]) + assert stub2.group_ids == ["c", "a", "b"] + + assert stub2.group_indices_map() == [(3, "c"), (1, "a"), (0, "b"), (2, "b")] def test_boxhead_reorder(): diff --git a/tests/test_modify_rows.py b/tests/test_modify_rows.py new file mode 100644 index 000000000..cad42526b --- /dev/null +++ b/tests/test_modify_rows.py @@ -0,0 +1,171 @@ +import pandas as pd + +from great_tables import GT, loc, style +from great_tables._utils_render_html import create_body_component_h + + +def assert_rendered_body(snapshot, gt): + built = gt._build_data("html") + body = create_body_component_h(built) + + assert snapshot == body + + +def test_row_group_order(snapshot): + gt = GT(pd.DataFrame({"g": ["a", "b"], "x": [1, 2], "y": [3, 4]}), groupname_col="g") + + assert_rendered_body(snapshot, gt.row_group_order(["b", "a"])) + + +def test_with_groupname_col(): + gt = GT(pd.DataFrame({"g": ["b", "a"], "x": [1, 2], "y": [3, 4]})) + + new_gt = gt.tab_stub(groupname_col="g") + group_rows = new_gt._stub.group_rows + + assert list(grp.group_id for grp in group_rows) == ["b", "a"] + assert [grp.indices for grp in group_rows] == [[0], [1]] + + +def test_with_groupname_col_undo_spanner_style(): + SPAN_COLS = ["g", "x"] + STYLE_COLS = ["g", "y"] + + gt = ( + GT(pd.DataFrame({"g": ["b"], "x": [1], "y": [3]})) + .tab_spanner("A", SPAN_COLS) + .tab_style(style.fill("red"), loc.body(columns=STYLE_COLS)) + ) + + assert gt._spanners[0].vars == SPAN_COLS + assert len(gt._styles) == 2 + assert {style.colname for style in gt._styles} == set(STYLE_COLS) + + new_gt = gt.tab_stub(groupname_col="g") + + # grouping col dropped from spanner vars + assert len(new_gt._spanners) == 1 + assert new_gt._spanners[0].vars == ["x"] + + # grouping col dropped from body styles + assert len(new_gt._styles) == 1 + assert new_gt._styles[0].colname == "y" + + +def test_with_groupname_col_unset(): + + gt = GT( + pd.DataFrame({"g": ["b"], "row": ["one"], "x": [1], "y": [3]}), + rowname_col="row", + groupname_col="g", + ) + + assert gt._boxhead._get_row_group_column().var == "g" + assert len(gt._stub.group_rows) == 1 + + new_gt = gt.tab_stub(rowname_col="row") + + # check row unchanged ---- + assert new_gt._boxhead._get_stub_column().var == "row" + assert new_gt._stub.rows[0].rowname == "one" + + # check group col removed ---- + assert new_gt._boxhead._get_row_group_column() is None + assert len(new_gt._stub.group_rows) == 0 + + +def test_with_rowname_col(): + gt = GT(pd.DataFrame({"g": ["b", "a"], "x": [1, 2], "y": [3, 4]})) + + new_gt = gt.tab_stub(rowname_col="g") + rows = new_gt._stub.rows + + assert [row.rowname for row in rows] == ["b", "a"] + + +def test_with_rowname_col_undo_spanner_style(): + SPAN_COLS = ["g", "x"] + STYLE_COLS = ["g", "y"] + + gt = ( + GT(pd.DataFrame({"g": ["b"], "x": [1], "y": [3]})) + .tab_spanner("A", SPAN_COLS) + .tab_style(style.fill("red"), loc.body(columns=STYLE_COLS)) + ) + + assert gt._spanners[0].vars == SPAN_COLS + assert len(gt._styles) == 2 + assert {style.colname for style in gt._styles} == set(STYLE_COLS) + + new_gt = gt.tab_stub(rowname_col="g") + + # rowname col dropped from spanner vars + assert len(new_gt._spanners) == 1 + assert new_gt._spanners[0].vars == ["x"] + + # rowname col *kept* in body styles + assert len(new_gt._styles) == 2 + + +def test_with_rowname_col_unset(): + + gt = GT( + pd.DataFrame({"g": ["b"], "row": ["one"], "x": [1], "y": [3]}), + rowname_col="row", + groupname_col="g", + ) + + assert gt._boxhead._get_stub_column().var == "row" + assert gt._stub.rows[0].rowname == "one" + + new_gt = gt.tab_stub(groupname_col="g") + + # check row removed ---- + assert new_gt._boxhead._get_stub_column() is None + assert new_gt._stub.rows[0].rowname is None + + # check group col unchanged ---- + assert new_gt._boxhead._get_row_group_column().var == "g" + assert len(new_gt._stub.group_rows) == 1 + + +def test_with_locale(): + gt = GT(pd.DataFrame({"x": [1]}), locale="es") + + assert gt._locale._locale == "es" + + assert gt.with_locale("de")._locale._locale == "de" + + +def test_with_locale_unset(): + gt = GT(pd.DataFrame({"x": [1]}), locale="es") + + assert gt._locale._locale == "es" + + assert gt.with_locale()._locale._locale is None + + +def test_with_id(): + gt = GT(pd.DataFrame({"x": [1]}), id="abc") + + assert gt._options.table_id.value == "abc" + + assert gt.with_id("zzz")._options.table_id.value == "zzz" + + +def test_with_id_unset(): + gt = GT(pd.DataFrame({"x": [1]}), id="abc") + + assert gt._options.table_id.value == "abc" + + assert gt.with_id()._options.table_id.value is None + + +def test_with_id_preserves_other_options(): + gt = GT(pd.DataFrame({"x": [1]})).tab_options(container_width="20px") + + assert gt._options.container_width.value == "20px" + + new_gt = gt.with_id("zzz") + assert new_gt._options.table_id.value == "zzz" + assert new_gt._options.container_width.value == "20px" diff --git a/tests/test_tbl_data.py b/tests/test_tbl_data.py index 69c9d0862..dfc038fdd 100644 --- a/tests/test_tbl_data.py +++ b/tests/test_tbl_data.py @@ -1,3 +1,4 @@ +import math import pandas as pd import polars as pl import polars.testing @@ -13,6 +14,7 @@ create_empty_frame, eval_select, get_column_names, + group_splits, is_series, reorder, to_frame, @@ -130,6 +132,35 @@ def test_eval_selector_polars_list_raises(): assert "entry 1 is type: " in str(exc_info.value.args[0]) +@pytest.mark.parametrize("Frame", [pd.DataFrame, pl.DataFrame]) +def test_group_splits_pd(Frame): + df = Frame({"g": ["b", "a", "b", "c"]}) + + splits = group_splits(df, "g") + assert set(splits.keys()) == {"a", "b", "c"} + assert splits["b"] == [0, 2] + assert splits["a"] == [1] + assert splits["c"] == [3] + + +def test_group_splits_pd_na(): + df = pd.DataFrame({"g": ["b", "a", None]}) + + splits = group_splits(df, "g") + assert len(splits.keys()) == 3 + nan_key = [k for k in splits if isinstance(k, float) and math.isnan(k)][0] + + assert splits[nan_key] == [2] + + +def test_group_splits_pl_na(): + df = pl.DataFrame({"g": ["b", "a", None]}) + + splits = group_splits(df, "g") + assert set(splits.keys()) == {"b", "a", None} + assert splits[None] == [2] + + def test_validate_selector_list_strict_raises(): with pytest.raises(TypeError) as exc_info: _validate_selector_list([pl.col("a")]) diff --git a/tests/test_utils_render_common.py b/tests/test_utils_render_common.py index 4cf4047f0..0c13ce22a 100644 --- a/tests/test_utils_render_common.py +++ b/tests/test_utils_render_common.py @@ -1,20 +1,26 @@ import pytest -from great_tables._gt_data import RowInfo, Stub +from great_tables._gt_data import RowInfo, Stub, GroupRowInfo from great_tables.utils_render_common import get_row_reorder_df def test_get_row_reorder_df_simple(): groups = ["b", "a"] - stub = Stub([RowInfo(0, "a"), RowInfo(1, "b"), RowInfo(2, "a")]) + stub = Stub( + [RowInfo(0, "a"), RowInfo(1, "b"), RowInfo(2, "a")], + [GroupRowInfo("a", indices=[0, 2]), GroupRowInfo("b", indices=[1])], + ) - start_end = get_row_reorder_df(groups, stub) + start_end = get_row_reorder_df(stub, groups) assert start_end == [(0, 1), (1, 0), (2, 2)] def test_get_row_reorder_df_no_groups(): groups = [] - stub = Stub([RowInfo(0, "a"), RowInfo(1, "b")]) + stub = Stub( + [RowInfo(0, "a"), RowInfo(1, "b")], + [GroupRowInfo("a", indices=[0]), GroupRowInfo("b", indices=[1])], + ) with pytest.raises(ValueError): - get_row_reorder_df(groups, stub) + get_row_reorder_df(stub, groups)