diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 97ed652..5e6e9a6 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,6 +6,10 @@ All notable changes to this project will be documented in this file. The format is based on `Keep a Changelog `_. +6.9 +--- +- Added ``rename_std`` option to ``lineplot_and_heatmaps``, which fixes a quasi-bug introduced in the ``rename_stat_col`` option by the changes in version 6.8. + 6.8 --- - Add ``addtl_slider_stats_as_max`` to ``lineplot_and_heatmap``. diff --git a/polyclonal/__init__.py b/polyclonal/__init__.py index 4602f15..b61f263 100644 --- a/polyclonal/__init__.py +++ b/polyclonal/__init__.py @@ -31,7 +31,7 @@ __author__ = "`the Bloom lab `_" __email__ = "jbloom@fredhutch.org" -__version__ = "6.8" +__version__ = "6.9" __url__ = "https://github.com/jbloomlab/polyclonal" from polyclonal.alphabets import AAS diff --git a/polyclonal/plot.py b/polyclonal/plot.py index 84d27a0..568bc70 100644 --- a/polyclonal/plot.py +++ b/polyclonal/plot.py @@ -483,6 +483,7 @@ def lineplot_and_heatmap( show_heatmap=True, scale_stat_col=1, rename_stat_col=None, + rename_std=True, sites_to_show=None, ): """Lineplots and heatmaps of per-site and per-mutation values. @@ -580,6 +581,9 @@ def lineplot_and_heatmap( Multiply numbers in `stat_col` by this number before plotting. rename_stat_col : None or str If a str, rename `stat_col` to this. Also changes y-axis labels. + rename_std : bool + If ``rename_stat_col`` is ``True``, rename any column named ``_std`` + to ``_std``. sites_to_show : None or dict If `None`, all sites are shown. If a dict, can be keyed by "include_range" (value a 2-tuple giving first and last site to include, inclusive), @@ -590,13 +594,6 @@ def lineplot_and_heatmap( altair.Chart Interactive plot. """ - if rename_stat_col: - if rename_stat_col in data_df.columns: - raise ValueError(f"{rename_stat_col=} already in {data_df.columns=}") - data_df = data_df.rename(columns={stat_col: rename_stat_col}) - stat_col = rename_stat_col - - basic_req_cols = ["site", "wildtype", "mutant", stat_col, category_col] if addtl_tooltip_stats is None: addtl_tooltip_stats = [] if addtl_slider_stats is None: @@ -612,6 +609,40 @@ def lineplot_and_heatmap( addtl_slider_stats_hide_not_filter = set(addtl_slider_stats).intersection( addtl_slider_stats_hide_not_filter ) + + if rename_stat_col: + if rename_stat_col in data_df.columns: + raise ValueError(f"{rename_stat_col=} already in {data_df.columns=}") + data_df = data_df.rename(columns={stat_col: rename_stat_col}) + std_col = f"{stat_col}_std" + stat_col = rename_stat_col + if rename_std: + rename_std_col = f"{rename_stat_col}_std" + if std_col in set(data_df.columns): + if rename_std_col in set(data_df.columns): + raise ValueError( + f"{data_df.columns=} has both {std_col=} and {rename_std_col=}" + ) + data_df = data_df.rename(columns={std_col: rename_std_col}) + if std_col in addtl_slider_stats: + addtl_slider_stats[rename_std_col] = addtl_slider_stats[std_col] + del addtl_slider_stats[std_col] + + def replace_std(col): + if col == std_col: + return rename_std_col + else: + return col + + addtl_tooltip_stats = [replace_std(c) for c in addtl_tooltip_stats] + addtl_slider_stats_as_max = [ + replace_std(c) for c in addtl_slider_stats_as_max + ] + addtl_slider_stats_hide_not_filter = [ + replace_std(c) for c in addtl_slider_stats_hide_not_filter + ] + + basic_req_cols = ["site", "wildtype", "mutant", stat_col, category_col] req_cols = basic_req_cols + addtl_tooltip_stats + list(addtl_slider_stats) if site_zoom_bar_color_col: req_cols.append(site_zoom_bar_color_col)