Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added rename_std option to lineplot_and_heatmaps #181

Merged
merged 1 commit into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ All notable changes to this project will be documented in this file.

The format is based on `Keep a Changelog <https://keepachangelog.com>`_.

6.9
---
- Added ``rename_std`` option to ``lineplot_and_heatmaps``, which fixes a quasi-bug introduced in the ``rename_stat_col`` option by the changes in version 6.8.

6.8
---
- Add ``addtl_slider_stats_as_max`` to ``lineplot_and_heatmap``.
Expand Down
2 changes: 1 addition & 1 deletion polyclonal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

__author__ = "`the Bloom lab <https://research.fhcrc.org/bloom/en.html>`_"
__email__ = "[email protected]"
__version__ = "6.8"
__version__ = "6.9"
__url__ = "https://github.com/jbloomlab/polyclonal"

from polyclonal.alphabets import AAS
Expand Down
45 changes: 38 additions & 7 deletions polyclonal/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ def lineplot_and_heatmap(
show_heatmap=True,
scale_stat_col=1,
rename_stat_col=None,
rename_std=True,
sites_to_show=None,
):
"""Lineplots and heatmaps of per-site and per-mutation values.
Expand Down Expand Up @@ -580,6 +581,9 @@ def lineplot_and_heatmap(
Multiply numbers in `stat_col` by this number before plotting.
rename_stat_col : None or str
If a str, rename `stat_col` to this. Also changes y-axis labels.
rename_std : bool
If ``rename_stat_col`` is ``True``, rename any column named ``<stat_col>_std``
to ``<rename_stat_col>_std``.
sites_to_show : None or dict
If `None`, all sites are shown. If a dict, can be keyed by "include_range"
(value a 2-tuple giving first and last site to include, inclusive),
Expand All @@ -590,13 +594,6 @@ def lineplot_and_heatmap(
altair.Chart
Interactive plot.
"""
if rename_stat_col:
if rename_stat_col in data_df.columns:
raise ValueError(f"{rename_stat_col=} already in {data_df.columns=}")
data_df = data_df.rename(columns={stat_col: rename_stat_col})
stat_col = rename_stat_col

basic_req_cols = ["site", "wildtype", "mutant", stat_col, category_col]
if addtl_tooltip_stats is None:
addtl_tooltip_stats = []
if addtl_slider_stats is None:
Expand All @@ -612,6 +609,40 @@ def lineplot_and_heatmap(
addtl_slider_stats_hide_not_filter = set(addtl_slider_stats).intersection(
addtl_slider_stats_hide_not_filter
)

if rename_stat_col:
if rename_stat_col in data_df.columns:
raise ValueError(f"{rename_stat_col=} already in {data_df.columns=}")
data_df = data_df.rename(columns={stat_col: rename_stat_col})
std_col = f"{stat_col}_std"
stat_col = rename_stat_col
if rename_std:
rename_std_col = f"{rename_stat_col}_std"
if std_col in set(data_df.columns):
if rename_std_col in set(data_df.columns):
raise ValueError(
f"{data_df.columns=} has both {std_col=} and {rename_std_col=}"
)
data_df = data_df.rename(columns={std_col: rename_std_col})
if std_col in addtl_slider_stats:
addtl_slider_stats[rename_std_col] = addtl_slider_stats[std_col]
del addtl_slider_stats[std_col]

def replace_std(col):
if col == std_col:
return rename_std_col
else:
return col

addtl_tooltip_stats = [replace_std(c) for c in addtl_tooltip_stats]
addtl_slider_stats_as_max = [
replace_std(c) for c in addtl_slider_stats_as_max
]
addtl_slider_stats_hide_not_filter = [
replace_std(c) for c in addtl_slider_stats_hide_not_filter
]

basic_req_cols = ["site", "wildtype", "mutant", stat_col, category_col]
req_cols = basic_req_cols + addtl_tooltip_stats + list(addtl_slider_stats)
if site_zoom_bar_color_col:
req_cols.append(site_zoom_bar_color_col)
Expand Down