Skip to content

Commit

Permalink
Set Default colors for all the plots (#27)
Browse files Browse the repository at this point in the history
* Set Default colors for all the plots

* Add distribution colors and colorblind safe option for high dimensional plots
  • Loading branch information
nikhilbhavikatti authored Oct 25, 2024
1 parent b7d6ca8 commit 42a9bf2
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 50 deletions.
27 changes: 17 additions & 10 deletions uadapy/plotting/plots1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import glasbey as gb
import seaborn as sns
from matplotlib.patches import Ellipse
import uadapy.plotting.utils as utils


def _calculate_freedman_diaconis_bins(data):
Expand Down Expand Up @@ -42,7 +43,7 @@ def _setup_plot(distributions, n_samples, seed, fig=None, axs=None, colors=None,
axs : matplotlib.axes.Axes or array of Axes or None, optional
Axes object(s) to use for plotting. If None, new axes will be created.
colors : list or None, optional
List of colors to use for each distribution. If None, Glasbey colors will be used.
List of colors to use for each distribution. If None, Matplotlib Set2 and glasbey colors will be used.
colorblind_safe : bool, optional
If True, the plot will use colors suitable for colorblind individuals.
Default is False.
Expand Down Expand Up @@ -96,13 +97,19 @@ def _setup_plot(distributions, n_samples, seed, fig=None, axs=None, colors=None,
for d in distributions:
samples.append(d.sample(n_samples, seed))

# Generate Glasbey colors
# Generate colors
if colors is None:
palette = gb.create_palette(palette_size=len(samples), colorblind_safe=colorblind_safe)
if colorblind_safe:
palette = gb.create_palette(palette_size=len(samples), colorblind_safe=colorblind_safe)
else:
palette = utils.get_colors(len(samples))
else:
# If colors are provided but fewer than the number of samples, add more colors from Glasbey palette
if len(colors) < len(samples):
additional_colors = gb.create_palette(palette_size=len(samples) - len(colors), colorblind_safe=colorblind_safe)
if colorblind_safe:
additional_colors = gb.create_palette(palette_size=len(samples) - len(colors), colorblind_safe=colorblind_safe)
else:
additional_colors = utils.get_colors(len(samples) - len(colors))
colors.extend(additional_colors)
palette = colors

Expand Down Expand Up @@ -145,7 +152,7 @@ def plot_1d_distribution(
dim_labels : list or None, optional
Titles for each subplot.
distrib_colors : list or None, optional
List of colors to use for each distribution. If None, Glasbey colors will be used.
List of colors to use for each distribution. If None, Matplotlib Set2 and glasbey colors will be used.
vert : bool, optional
If True, plots will be drawn vertically. If False, plots will be drawn horizontally.
Default is True.
Expand Down Expand Up @@ -336,7 +343,7 @@ def generate_boxplot(distributions,
dim_labels : list or None, optional
Titles for each subplot.
distrib_colors : list or None, optional
List of colors to use for each distribution. If None, Glasbey colors will be used.
List of colors to use for each distribution. If None, Matplotlib Set2 and glasbey colors will be used.
vert : bool, optional
If True, plots will be drawn vertically. If False, plots will be drawn horizontally.
Default is True.
Expand Down Expand Up @@ -395,7 +402,7 @@ def generate_violinplot(distributions,
dim_labels : list or None, optional
Titles for each subplot.
distrib_colors : list or None, optional
List of colors to use for each distribution. If None, Glasbey colors will be used.
List of colors to use for each distribution. If None, Matplotlib Set2 and glasbey colors will be used.
vert : bool, optional
If True, plots will be drawn vertically. If False, plots will be drawn horizontally.
Default is True.
Expand Down Expand Up @@ -451,7 +458,7 @@ def generate_dotplot(distributions,
dim_labels : list or None, optional
Titles for each subplot.
distrib_colors : list or None, optional
List of colors to use for each distribution. If None, Glasbey colors will be used.
List of colors to use for each distribution. If None, Matplotlib Set2 and glasbey colors will be used.
vert : bool, optional
If True, plots will be drawn vertically. If False, plots will be drawn horizontally.
Default is True.
Expand Down Expand Up @@ -511,7 +518,7 @@ def generate_stripplot(distributions,
dim_labels : list or None, optional
Titles for each subplot.
distrib_colors : list or None, optional
List of colors to use for each distribution. If None, Glasbey colors will be used.
List of colors to use for each distribution. If None, Matplotlib Set2 and glasbey colors will be used.
vert : bool, optional
If True, plots will be drawn vertically. If False, plots will be drawn horizontally.
Default is True.
Expand Down Expand Up @@ -571,7 +578,7 @@ def generate_swarmplot(distributions,
dim_labels : list or None, optional
Titles for each subplot.
distrib_colors : list or None, optional
List of colors to use for each distribution. If None, Glasbey colors will be used.
List of colors to use for each distribution. If None, Matplotlib Set2 and glasbey colors will be used.
vert : bool, optional
If True, plots will be drawn vertically. If False, plots will be drawn horizontally.
Default is True.
Expand Down
112 changes: 85 additions & 27 deletions uadapy/plotting/plots2D.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import matplotlib.pyplot as plt
import numpy as np
from uadapy import Distribution
from numpy import ma
from matplotlib import ticker

def plot_samples(distributions, n_samples, seed=55, xlabel=None, ylabel=None, title=None, show_plot=False):
from matplotlib.colors import ListedColormap
import uadapy.plotting.utils as utils
import glasbey as gb

def plot_samples(distributions,
n_samples,
seed=55,
xlabel=None,
ylabel=None,
title=None,
distrib_colors=None,
colorblind_safe=False,
show_plot=False):
"""
Plot samples from the given distribution. If several distributions should be
plotted together, an array can be passed to this function.
Expand All @@ -23,6 +32,11 @@ def plot_samples(distributions, n_samples, seed=55, xlabel=None, ylabel=None, ti
label for y-axis.
title : string, optional
title for the plot.
distrib_colors : list or None, optional
List of colors to use for each distribution. If None, Matplotlib Set2 and glasbey colors will be used.
colorblind_safe : bool, optional
If True, the plot will use colors suitable for colorblind individuals.
Default is False.
show_plot : bool, optional
If True, display the plot.
Default is False.
Expand All @@ -37,9 +51,25 @@ def plot_samples(distributions, n_samples, seed=55, xlabel=None, ylabel=None, ti

if isinstance(distributions, Distribution):
distributions = [distributions]
for d in distributions:

# Generate colors
if distrib_colors is None:
if colorblind_safe:
palette = gb.create_palette(palette_size=len(distributions), colorblind_safe=colorblind_safe)
else:
palette = utils.get_colors(len(distributions))
else:
if len(distrib_colors) < len(distributions):
if colorblind_safe:
additional_colors = gb.create_palette(palette_size=len(distributions) - len(distrib_colors), colorblind_safe=colorblind_safe)
else:
additional_colors = utils.get_colors(len(distributions) - len(distrib_colors))
distrib_colors.extend(additional_colors)
palette = distrib_colors

for i, d in enumerate(distributions):
samples = d.sample(n_samples, seed)
plt.scatter(x=samples[:,0], y=samples[:,1])
plt.scatter(x=samples[:,0], y=samples[:,1], color=palette[i])
if xlabel:
plt.xlabel(xlabel)
if ylabel:
Expand All @@ -57,7 +87,14 @@ def plot_samples(distributions, n_samples, seed=55, xlabel=None, ylabel=None, ti

return fig, axs

def plot_contour(distributions, resolution=128, ranges=None, quantiles:list=None, seed=55, show_plot=False):
def plot_contour(distributions,
resolution=128,
ranges=None,
quantiles:list=None,
seed=55,
distrib_colors=None,
colorblind_safe=False,
show_plot=False):
"""
Plot contour plots for samples drawn from given distributions.
Expand All @@ -73,6 +110,11 @@ def plot_contour(distributions, resolution=128, ranges=None, quantiles:list=None
List of quantiles to use for determining isovalues. If None, the 99.7%, 95%, and 68% quantiles are used.
seed : int
Seed for the random number generator for reproducibility. It defaults to 55 if not provided.
distrib_colors : list or None, optional
List of colors to use for each distribution. If None, Matplotlib Set2 and glasbey colors will be used.
colorblind_safe : bool, optional
If True, the plot will use colors suitable for colorblind individuals.
Default is False.
show_plot : bool, optional
If True, display the plot.
Default is False.
Expand All @@ -92,7 +134,21 @@ def plot_contour(distributions, resolution=128, ranges=None, quantiles:list=None

if isinstance(distributions, Distribution):
distributions = [distributions]
contour_colors = generate_spectrum_colors(len(distributions))

# Generate colors
if distrib_colors is None:
if colorblind_safe:
palette = gb.create_palette(palette_size=len(distributions), colorblind_safe=colorblind_safe)
else:
palette = utils.get_colors(len(distributions))
else:
if len(distrib_colors) < len(distributions):
if colorblind_safe:
additional_colors = gb.create_palette(palette_size=len(distributions) - len(distrib_colors), colorblind_safe=colorblind_safe)
else:
additional_colors = utils.get_colors(len(distributions) - len(distrib_colors))
distrib_colors.extend(additional_colors)
palette = distrib_colors

if ranges is None:
min_val = np.zeros(distributions[0].mean().shape)+1000
Expand All @@ -114,7 +170,7 @@ def plot_contour(distributions, resolution=128, ranges=None, quantiles:list=None
coordinates = coordinates.reshape((-1, 2))
pdf = d.pdf(coordinates)
pdf = pdf.reshape(xv.shape)
color = contour_colors[i]
color = palette[i]

# Monte Carlo approach for determining isovalues
isovalues = []
Expand Down Expand Up @@ -147,7 +203,12 @@ def plot_contour(distributions, resolution=128, ranges=None, quantiles:list=None

return fig, axs

def plot_contour_bands(distributions, n_samples, resolution=128, ranges=None, quantiles: list = None, seed=55,
def plot_contour_bands(distributions,
n_samples,
resolution=128,
ranges=None,
quantiles: list = None,
seed=55,
show_plot=False):
"""
Plot contour bands for samples drawn from given distributions.
Expand Down Expand Up @@ -186,13 +247,9 @@ def plot_contour_bands(distributions, n_samples, resolution=128, ranges=None, qu
if isinstance(distributions, Distribution):
distributions = [distributions]

# Sequential and perceptually uniform colormaps
colormaps = [
'Reds', 'Blues', 'Greens', 'Greys', 'Oranges', 'Purples',
'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu', 'BuPu',
'GnBu', 'PuBu', 'YlGnBu', 'PuBuGn', 'BuGn', 'YlGn',
'viridis', 'plasma', 'inferno', 'magma', 'cividis'
]
n_quantiles = len(quantiles)
alpha_values = np.linspace(1/n_quantiles, 1.0, n_quantiles) # Creates alpha values from 1/n to 1.0
custom_cmap = utils.create_shaded_set2_colormap(alpha_values)

if ranges is None:
min_val = np.zeros(distributions[0].mean().shape)+1000
Expand Down Expand Up @@ -233,9 +290,18 @@ def plot_contour_bands(distributions, n_samples, resolution=128, ranges=None, qu
elif int((1 - quantile/100) * n_samples) >= n_samples:
raise ValueError(f"Quantile {quantile} results in an index that is out of bounds.")
isovalues.append(densities[int((1 - quantile/100) * n_samples)])
isovalues.append(densities[-1]) # Minimum density value

# Extract the subset of colors corresponding to the current Set2 color and its 3 alpha variations
start_idx = i * n_quantiles
end_idx = start_idx + n_quantiles
color_subset = custom_cmap.colors[start_idx:end_idx]

# Generate logarithmic levels and create the contour plot with different colormap for each distribution
plt.contourf(xv, yv, pdf, levels=isovalues, locator=ticker.LogLocator(), cmap=colormaps[i % len(colormaps)])
# Create a ListedColormap for the current color and its alpha variations
cmap_subset = ListedColormap(color_subset)

# Generate the filled contour plot with transparency and better visibility
plt.contourf(xv, yv, pdf, levels=isovalues, cmap=cmap_subset)

# Get the current figure and axes
fig = plt.gcf()
Expand All @@ -246,11 +312,3 @@ def plot_contour_bands(distributions, n_samples, resolution=128, ranges=None, qu
plt.show()

return fig, axs

# HELPER FUNCTIONS
def generate_random_colors(length):
return ["#"+''.join([np.random.choice('0123456789ABCDEF') for j in range(6)]) for _ in range(length)]

def generate_spectrum_colors(length):
cmap = plt.cm.get_cmap('viridis', length) # You can choose different colormaps like 'jet', 'hsv', 'rainbow', etc.
return np.array([cmap(i) for i in range(length)])
Loading

0 comments on commit 42a9bf2

Please sign in to comment.