Skip to content

Commit

Permalink
Kwargs optimization of high dimensional plots (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhilbhavikatti authored Oct 21, 2024
1 parent 3b096b1 commit a789df4
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 63 deletions.
20 changes: 9 additions & 11 deletions uadapy/plotting/plots1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,9 @@ def _setup_plot(distributions, n_samples, seed, fig=None, axs=None, colors=None,
Axes object(s) to use for plotting. If None, new axes will be created.
colors : list or None, optional
List of colors to use for each distribution. If None, Glasbey colors will be used.
**kwargs : additional keyword arguments
Additional optional arguments.
- colorblind_safe : bool, optional
If True, the plot will use colors suitable for colorblind individuals.
Default is False.
colorblind_safe : bool, optional
If True, the plot will use colors suitable for colorblind individuals.
Default is False.
Returns
-------
Expand Down Expand Up @@ -122,7 +120,7 @@ def plot_1d_distribution(
distrib_colors=None,
vert=True,
colorblind_safe=False,
show_plot=True,
show_plot=False,
dot_size=0,
**kwargs):
"""
Expand Down Expand Up @@ -316,7 +314,7 @@ def generate_boxplot(distributions,
distrib_colors=None,
vert=True,
colorblind_safe=False,
show_plot=True,
show_plot=False,
**kwargs):
"""
Plot box plots for samples drawn from given distributions.
Expand Down Expand Up @@ -375,7 +373,7 @@ def generate_violinplot(distributions,
distrib_colors=None,
vert=True,
colorblind_safe=False,
show_plot=True,
show_plot=False,
**kwargs):
"""
Plot violin plots for samples drawn from given distributions.
Expand Down Expand Up @@ -433,7 +431,7 @@ def generate_dotplot(distributions,
distrib_colors=None,
vert=True,
colorblind_safe=False,
show_plot=True,
show_plot=False,
dot_size=0):
"""
Plot dot plots for samples drawn from given distributions.
Expand Down Expand Up @@ -491,7 +489,7 @@ def generate_stripplot(distributions,
distrib_colors=None,
vert=True,
colorblind_safe=False,
show_plot=True,
show_plot=False,
dot_size=0):
"""
Plot strip plots for samples drawn from given distributions.
Expand Down Expand Up @@ -551,7 +549,7 @@ def generate_swarmplot(distributions,
distrib_colors=None,
vert=True,
colorblind_safe=False,
show_plot=True,
show_plot=False,
dot_size=0):
"""
Plot swarm plots for samples drawn from given distributions.
Expand Down
55 changes: 24 additions & 31 deletions uadapy/plotting/plots2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from numpy import ma
from matplotlib import ticker

def plot_samples(distributions, n_samples, seed=55, **kwargs):
def plot_samples(distributions, n_samples, seed=55, xlabel=None, ylabel=None, title=None, show_plot=False):
"""
Plot samples from the given distribution. If several distributions should be
plotted together, an array can be passed to this function.
Expand All @@ -17,15 +17,15 @@ def plot_samples(distributions, n_samples, seed=55, **kwargs):
Number of samples per distribution.
seed : int
Seed for the random number generator for reproducibility. It defaults to 55 if not provided.
**kwargs : additional keyword arguments
Additional optional plotting arguments.
- xlabel : string, optional
label for x-axis.
- ylabel : string, optional
label for y-axis.
- show_plot : bool, optional
If True, display the plot.
Default is False.
xlabel : string, optional
label for x-axis.
ylabel : string, optional
label for y-axis.
title : string, optional
title for the plot.
show_plot : bool, optional
If True, display the plot.
Default is False.
Returns
-------
Expand All @@ -40,25 +40,24 @@ def plot_samples(distributions, n_samples, seed=55, **kwargs):
for d in distributions:
samples = d.sample(n_samples, seed)
plt.scatter(x=samples[:,0], y=samples[:,1])
if 'xlabel' in kwargs:
plt.xlabel(kwargs['xlabel'])
if 'ylabel' in kwargs:
plt.ylabel(kwargs['ylabel'])
if 'title' in kwargs:
plt.title(kwargs['title'])
if xlabel:
plt.xlabel(xlabel)
if ylabel:
plt.ylabel(ylabel)
if title:
plt.title(title)

# Get the current figure and axes
fig = plt.gcf()
axs = plt.gca()

show_plot = kwargs.get('show_plot', False)
if show_plot:
fig.tight_layout()
plt.show()

return fig, axs

def plot_contour(distributions, resolution=128, ranges=None, quantiles:list=None, seed=55, **kwargs):
def plot_contour(distributions, resolution=128, ranges=None, quantiles:list=None, seed=55, show_plot=False):
"""
Plot contour plots for samples drawn from given distributions.
Expand All @@ -74,11 +73,9 @@ def plot_contour(distributions, resolution=128, ranges=None, quantiles:list=None
List of quantiles to use for determining isovalues. If None, the 99.7%, 95%, and 68% quantiles are used.
seed : int
Seed for the random number generator for reproducibility. It defaults to 55 if not provided.
**kwargs : additional keyword arguments
Additional optional plotting arguments.
- show_plot : bool, optional
If True, display the plot.
Default is False.
show_plot : bool, optional
If True, display the plot.
Default is False.
Returns
-------
Expand Down Expand Up @@ -144,15 +141,14 @@ def plot_contour(distributions, resolution=128, ranges=None, quantiles:list=None
fig = plt.gcf()
axs = plt.gca()

show_plot = kwargs.get('show_plot', False)
if show_plot:
fig.tight_layout()
plt.show()

return fig, axs

def plot_contour_bands(distributions, n_samples, resolution=128, ranges=None, quantiles: list = None, seed=55,
**kwargs):
show_plot=False):
"""
Plot contour bands for samples drawn from given distributions.
Expand All @@ -170,11 +166,9 @@ def plot_contour_bands(distributions, n_samples, resolution=128, ranges=None, qu
List of quantiles to use for determining isovalues. If None, the 99.7%, 95%, and 68% quantiles are used.
seed : int
Seed for the random number generator for reproducibility. It defaults to 55 if not provided.
**kwargs : additional keyword arguments
Additional optional plotting arguments.
- show_plot : bool, optional
If True, display the plot.
Default is False.
show_plot : bool, optional
If True, display the plot.
Default is False.
Returns
-------
Expand Down Expand Up @@ -247,7 +241,6 @@ def plot_contour_bands(distributions, n_samples, resolution=128, ranges=None, qu
fig = plt.gcf()
axs = plt.gca()

show_plot = kwargs.get('show_plot', False)
if show_plot:
fig.tight_layout()
plt.show()
Expand Down
33 changes: 12 additions & 21 deletions uadapy/plotting/plotsND.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from uadapy import Distribution
import uadapy.plotting.utils as utils

def plot_samples(distributions, n_samples, seed=55, **kwargs):
def plot_samples(distributions, n_samples, seed=55, show_plot=False):
"""
Plot samples from the multivariate distribution as a SLOM.
Expand All @@ -15,11 +15,9 @@ def plot_samples(distributions, n_samples, seed=55, **kwargs):
Number of samples per distribution.
seed : int
Seed for the random number generator for reproducibility. It defaults to 55 if not provided.
**kwargs : additional keyword arguments
Additional optional plotting arguments.
- show_plot : bool, optional
If True, display the plot.
Default is False.
show_plot : bool, optional
If True, display the plot.
Default is False.
Returns
-------
Expand Down Expand Up @@ -64,14 +62,13 @@ def plot_samples(distributions, n_samples, seed=55, **kwargs):
fig = plt.gcf()
axs = plt.gca()

show_plot = kwargs.get('show_plot', False)
if show_plot:
fig.tight_layout()
plt.show()

return fig, axs

def plot_contour(distributions, n_samples, resolution=128, ranges=None, quantiles: list = None, seed=55, **kwargs):
def plot_contour(distributions, n_samples, resolution=128, ranges=None, quantiles: list = None, seed=55, show_plot=False):
"""
Visualizes a multidimensional distribution in a matrix of contour plots.
Expand All @@ -89,11 +86,9 @@ def plot_contour(distributions, n_samples, resolution=128, ranges=None, quantile
List of quantiles to use for determining isovalues. If None, the 99.7%, 95%, and 68% quantiles are used.
seed : int
Seed for the random number generator for reproducibility. It defaults to 55 if not provided.
**kwargs : additional keyword arguments
Additional optional plotting arguments.
- show_plot : bool, optional
If True, display the plot.
Default is False.
show_plot : bool, optional
If True, display the plot.
Default is False.
Returns
-------
Expand Down Expand Up @@ -192,15 +187,14 @@ def plot_contour(distributions, n_samples, resolution=128, ranges=None, quantile
fig = plt.gcf()
axs = plt.gca()

show_plot = kwargs.get('show_plot', False)
if show_plot:
fig.tight_layout()
plt.show()

return fig, axs

def plot_contour_samples(distributions, n_samples, resolution=128, ranges=None, quantiles: list = None, seed=55,
**kwargs):
show_plot=False):
"""
Visualizes a multidimensional distribution in a matrix visualization where the
upper diagonal contains contour plots and the lower diagonal contains scatterplots.
Expand All @@ -219,11 +213,9 @@ def plot_contour_samples(distributions, n_samples, resolution=128, ranges=None,
List of quantiles to use for determining isovalues. If None, the 99.7%, 95%, and 68% quantiles are used.
seed : int
Seed for the random number generator for reproducibility. It defaults to 55 if not provided.
**kwargs : additional keyword arguments
Additional optional plotting arguments.
- show_plot : bool, optional
If True, display the plot.
Default is False.
show_plot : bool, optional
If True, display the plot.
Default is False.
Returns
-------
Expand Down Expand Up @@ -324,7 +316,6 @@ def plot_contour_samples(distributions, n_samples, resolution=128, ranges=None,
fig = plt.gcf()
axs = plt.gca()

show_plot = kwargs.get('show_plot', False)
if show_plot:
fig.tight_layout()
plt.show()
Expand Down

0 comments on commit a789df4

Please sign in to comment.