diff --git a/pyproject.toml b/pyproject.toml index ada6eaf5..8c40a92e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ scipy = "*" seaborn = "*" matplotlib = "*" umap-learn = { version = "*", extras = ["plot"] } +pacmap = "*" pandas = "*" h5py = "*" PyYAML = "*" diff --git a/vital/utils/plot.py b/vital/utils/plot.py index 40e642c6..e6f9adae 100644 --- a/vital/utils/plot.py +++ b/vital/utils/plot.py @@ -1,11 +1,10 @@ import logging -from typing import Any, Dict, Iterable, Iterator, Union +from typing import Any, Dict, Iterable, Iterator, Literal, Union import matplotlib import numpy as np import pandas as pd import seaborn as sns -import umap from matplotlib import pyplot as plt from matplotlib.axes import Axes @@ -13,18 +12,23 @@ def embedding_scatterplot( - data: pd.DataFrame, plots_kwargs: Iterable[Dict[str, Any]], umap_kwargs: Dict[str, Any] = None, data_tag: str = None + data: pd.DataFrame, + plots_kwargs: Iterable[Dict[str, Any]], + data_tag: str = None, + method: Literal["tsne", "umap", "pacmap"] = "pacmap", + **embedding_kwargs, ) -> Iterator[Axes]: - """Generates 2D scatter plots of some data, reducing its dimensionality to 2 using UMAP if it's not already 2D. + """Generates 2D scatter plots of some data, reducing its dimensionality to 2 if it's not already 2D. Args: data: Dataframe with each column representing a dimension of the data, and relevant metadata being stored in a multiindex. plots_kwargs: Sets of kwargs to use to generate different versions of the scatter plot, e.g. modifying the variables used for hue and/or style. - umap_kwargs: If the data has more than 2 dimensions, UMAP is used to reduce the dimensionality of the data for - plotting purposes. This parameter is passed along to the UMAP estimator's `init`. data_tag: String describing the data used in the titles/logs, etc. If not specified, it defaults to 'data'. + method: If the data has more than 2 dimensions, this parameter specifies the method to use to reduce the + dimensionality of the data for plotting purposes. + **embedding_kwargs: Parameters passed along to the embedding's constructor. Returns: An iterator over the generated scatter plots. @@ -40,15 +44,31 @@ def embedding_scatterplot( elif len(data.columns) == 2: plot_title = f"2D {data_tag}" else: # len(encoding_dims) > 2 - if umap_kwargs is None: - umap_kwargs = {} - plot_title = f"2D UMAP embedding of the {len(data.columns)}D {data_tag}" - logger.info(f"Generating 2D UMAP embedding of {len(data.columns)}D {data_tag}...") - umap_embedding = umap.UMAP(**umap_kwargs).fit_transform(data) + if embedding_kwargs is None: + embedding_kwargs = {} + match method: + case "tsne": + from sklearn.manifold import TSNE - # Update the encodings dataframe with the new UMAP embedding + embedding_cls = TSNE + case "umap": + import umap + + embedding_cls = umap.UMAP + case "pacmap": + from pacmap import PaCMAP + + embedding_cls = PaCMAP + case _: + raise ValueError(f"Unknown embedding method '{method}'. Must be one of: ['tsne', 'umap', 'pacmap'].") + + plot_title = f"2D {embedding_cls.__name__} embedding of the {len(data.columns)}D {data_tag}" + logger.info(f"Generating 2D {method} embedding of {len(data.columns)}D {data_tag}...") + data_2d = embedding_cls(**embedding_kwargs).fit_transform(data.to_numpy()) + + # Update the encodings dataframe with the new 2D embedding data = data.drop(labels=data.columns, axis="columns") - data[[0, 1]] = umap_embedding + data[[0, 1]] = data_2d # Generate a plot of the embedding for each set of plot kwargs provided for plot_kwargs in plots_kwargs: