Skip to content

Commit

Permalink
Merge pull request #85 from Jammy2211/feature/contour
Browse files Browse the repository at this point in the history
Feature/contour
  • Loading branch information
Jammy2211 authored Dec 30, 2023
2 parents 7bc78c4 + 1595ad7 commit 8954aba
Show file tree
Hide file tree
Showing 22 changed files with 197 additions and 310 deletions.
11 changes: 11 additions & 0 deletions autoarray/config/visualize/mat_wrap_2d.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ ArrayOverlay: # wrapper for `plt.imshow()`: customize arrays overlaid.
alpha: 0.5
subplot:
alpha: 0.5
Contour: # wrapper for `plt.contour()`: customize contours plotted on the figure.
figure:
colors: "k"
total_contours: 10 # Number of contours to plot
use_log10: true # If true, contours are plotted with log10 spacing, if False, linear spacing.
include_values: true # If true, the values of the contours are plotted on the figure.
subplot:
colors: "k"
total_contours: 10 # Number of contours to plot
use_log10: true # If true, contours are plotted with log10 spacing, if False, linear spacing.
include_values: true # If true, the values of the contours are plotted on the figure.
BorderScatter: # wrapper for `plt.scatter()`: customize the apperance of 2D borders.
figure:
c: r
Expand Down
1 change: 1 addition & 0 deletions autoarray/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from autoarray.plot.wrap.one_d.fill_between import FillBetween

from autoarray.plot.wrap.two_d.array_overlay import ArrayOverlay
from autoarray.plot.wrap.two_d.contour import Contour
from autoarray.plot.wrap.two_d.grid_scatter import GridScatter
from autoarray.plot.wrap.two_d.grid_plot import GridPlot
from autoarray.plot.wrap.two_d.grid_errorbar import GridErrorbar
Expand Down
25 changes: 14 additions & 11 deletions autoarray/plot/mat_plot/two_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
legend: Optional[wb.Legend] = None,
output: Optional[wb.Output] = None,
array_overlay: Optional[w2d.ArrayOverlay] = None,
contour: Optional[w2d.Contour] = None,
grid_scatter: Optional[w2d.GridScatter] = None,
grid_plot: Optional[w2d.GridPlot] = None,
grid_errorbar: Optional[w2d.GridErrorbar] = None,
Expand All @@ -57,6 +58,7 @@ def __init__(
parallel_overscan_plot: Optional[w2d.ParallelOverscanPlot] = None,
serial_prescan_plot: Optional[w2d.SerialPrescanPlot] = None,
serial_overscan_plot: Optional[w2d.SerialOverscanPlot] = None,
use_log10 : bool = False
):
"""
Visualizes 2D data structures (e.g an `Array2D`, `Grid2D`, `VectorField`, etc.) using Matplotlib.
Expand Down Expand Up @@ -115,6 +117,8 @@ def __init__(
Sets if the figure is displayed on the user's screen or output to `.png` using `plt.show` and `plt.savefig`
array_overlay
Overlays an input `Array2D` over the figure using `plt.imshow`.
contour
Overlays contours of an input `Array2D` over the figure using `plt.contour`.
grid_scatter
Scatters a `Grid2D` of (y,x) coordinates over the figure using `plt.scatter`.
grid_plot
Expand Down Expand Up @@ -145,6 +149,8 @@ def __init__(
Plots the serial prescan on an `Array2D` data structure representing a CCD imaging via `plt.plot`.
serial_overscan_plot
Plots the serial overscan on an `Array2D` data structure representing a CCD imaging via `plt.plot`.
use_log10
If True, the plot has a log10 colormap, colorbar and contours showing the values.
"""

super().__init__(
Expand All @@ -168,6 +174,8 @@ def __init__(

self.array_overlay = array_overlay or w2d.ArrayOverlay(is_default=True)

self.contour = contour or w2d.Contour(is_default=True)

self.grid_scatter = grid_scatter or w2d.GridScatter(is_default=True)
self.grid_plot = grid_plot or w2d.GridPlot(is_default=True)
self.grid_errorbar = grid_errorbar or w2d.GridErrorbar(is_default=True)
Expand Down Expand Up @@ -202,6 +210,8 @@ def __init__(
is_default=True
)

self.use_log10 = use_log10

self.is_for_subplot = False

def plot_array(
Expand Down Expand Up @@ -277,7 +287,7 @@ def plot_array(
ax = self.setup_subplot()

aspect = self.figure.aspect_from(shape_native=array.shape_native)
norm = self.cmap.norm_from(array=array)
norm = self.cmap.norm_from(array=array, use_log10=self.use_log10)

origin = conf.instance["visualize"]["general"]["general"]["imshow_origin"]

Expand Down Expand Up @@ -318,7 +328,7 @@ def plot_array(
pixels=array.shape_native[1],
)

self.title.set(auto_title=auto_labels.title)
self.title.set(auto_title=auto_labels.title, use_log10=self.use_log10)
self.ylabel.set()
self.xlabel.set()

Expand All @@ -334,18 +344,11 @@ def plot_array(

if self.colorbar is not False:
cb = self.colorbar.set(
units=self.units, ax=ax, norm=norm, cb_unit=auto_labels.cb_unit
units=self.units, ax=ax, norm=norm, cb_unit=auto_labels.cb_unit, use_log10=self.use_log10
)
self.colorbar_tickparams.set(cb=cb)

# levels = np.logspace(np.log10(0.3), np.log10(20.0), 10)
# plt.contour(
# # array.mask.derive_grid.unmasked_sub_1,
# array.native[::-1],
# levels=levels,
# colors="black",
# extent=extent,
# )
self.contour.set(array=array, extent=extent, use_log10=self.use_log10)

grid_indexes = None

Expand Down
8 changes: 4 additions & 4 deletions autoarray/plot/wrap/base/cmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def vmax_from(self, array: np.ndarray):
return np.max(array)
return self.config_dict["vmax"]

def norm_from(self, array: np.ndarray) -> object:
def norm_from(self, array: np.ndarray, use_log10 : bool = False) -> object:
"""
Returns the `Normalization` object which scales of the colormap.
Expand Down Expand Up @@ -86,12 +86,12 @@ def norm_from(self, array: np.ndarray) -> object:
if isinstance(self.config_dict["norm"], colors.Normalize):
return self.config_dict["norm"]

if self.config_dict["norm"] in "linear":
return colors.Normalize(vmin=vmin, vmax=vmax)
elif self.config_dict["norm"] in "log":
if self.config_dict["norm"] in "log" or use_log10:
if vmin == 0.0:
vmin = 1.0e-4
return colors.LogNorm(vmin=vmin, vmax=vmax)
elif self.config_dict["norm"] in "linear":
return colors.Normalize(vmin=vmin, vmax=vmax)
elif self.config_dict["norm"] in "symmetric_log":
return colors.SymLogNorm(
vmin=vmin,
Expand Down
44 changes: 26 additions & 18 deletions autoarray/plot/wrap/base/colorbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def cb_unit(self):
return conf.instance["visualize"]["general"]["units"]["cb_unit"]
return self.manual_unit

def manual_tick_values_from(self, norm=None):
def tick_values_from(self, norm=None, use_log10 : bool = False):
if (
sum(
x is not None
Expand All @@ -75,19 +75,24 @@ def manual_tick_values_from(self, norm=None):
if self.manual_tick_values is not None:
return self.manual_tick_values

if (
self.manual_tick_values is None
and self.manual_tick_labels is None
and norm is not None
):
if norm is not None:

min_value = norm.vmin
max_value = norm.vmax
mid_value = (max_value + min_value) / 2.0

if use_log10:

log_mid_value = (np.log10(max_value) + np.log10(min_value)) / 2.0
mid_value = 10 ** log_mid_value

else:

mid_value = (max_value + min_value) / 2.0

return [min_value, mid_value, max_value]

def manual_tick_labels_from(
self, units: Units, manual_tick_values: List[float], cb_unit=None
def tick_labels_from(
self, units: Units, manual_tick_values: List[float], cb_unit=None,
):
if manual_tick_values is None:
return None
Expand All @@ -102,12 +107,15 @@ def manual_tick_labels_from(
]

if self.manual_log10:

manual_tick_labels = [
"{:.0e}".format(label) for label in manual_tick_labels
]

manual_tick_labels = [
label.replace("1e", "$10^{") + "}$" for label in manual_tick_labels
]

manual_tick_labels = [
label.replace("{-0", "{-").replace("{+0", "{+").replace("+", "")
for label in manual_tick_labels
Expand All @@ -126,22 +134,22 @@ def manual_tick_labels_from(

return manual_tick_labels

def set(self, units: Units, ax=None, norm=None, cb_unit=None):
def set(self, units: Units, ax=None, norm=None, cb_unit=None, use_log10 : bool = False):
"""
Set the figure's colorbar, optionally overriding the tick labels and values with manual inputs.
"""

manual_tick_values = self.manual_tick_values_from(norm=norm)
manual_tick_labels = self.manual_tick_labels_from(
manual_tick_values=manual_tick_values, units=units, cb_unit=cb_unit
tick_values = self.tick_values_from(norm=norm, use_log10=use_log10)
tick_labels = self.tick_labels_from(
manual_tick_values=tick_values, units=units, cb_unit=cb_unit,
)

if manual_tick_values is None and manual_tick_labels is None:
if tick_values is None and tick_labels is None:
cb = plt.colorbar(ax=ax, **self.config_dict)
else:
cb = plt.colorbar(ticks=manual_tick_values, ax=ax, **self.config_dict)
cb = plt.colorbar(ticks=tick_values, ax=ax, **self.config_dict)
cb.ax.set_yticklabels(
labels=manual_tick_labels, va=self.manual_alignment or "center"
labels=tick_labels, va=self.manual_alignment or "center"
)

return cb
Expand All @@ -167,8 +175,8 @@ def set_with_color_values(
mappable = cm.ScalarMappable(cmap=cmap)
mappable.set_array(color_values)

manual_tick_values = self.manual_tick_values_from(norm=norm)
manual_tick_labels = self.manual_tick_labels_from(
manual_tick_values = self.tick_values_from(norm=norm)
manual_tick_labels = self.tick_labels_from(
manual_tick_values=manual_tick_values, units=units
)

Expand Down
6 changes: 5 additions & 1 deletion autoarray/plot/wrap/base/title.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ def __init__(self, **kwargs):

self.manual_label = self.kwargs.get("label")

def set(self, auto_title=None):
def set(self, auto_title=None, use_log10 : bool = False):

config_dict = self.config_dict

label = auto_title if self.manual_label is None else self.manual_label

if use_log10:
label = f"{label} (log10)"

if "label" in config_dict:
config_dict.pop("label")

Expand Down
1 change: 1 addition & 0 deletions autoarray/plot/wrap/two_d/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .array_overlay import ArrayOverlay
from .contour import Contour
from .grid_scatter import GridScatter
from .grid_plot import GridPlot
from .grid_errorbar import GridErrorbar
Expand Down
96 changes: 96 additions & 0 deletions autoarray/plot/wrap/two_d/contour.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import matplotlib.pyplot as plt
import numpy as np
from typing import List, Optional, Union

from autoarray.plot.wrap.two_d.abstract import AbstractMatWrap2D
from autoarray.structures.arrays.uniform_2d import Array2D


class Contour(AbstractMatWrap2D):
def __init__(
self,
manual_levels: Optional[List[float]] = None,
total_contours: Optional[int] = None,
use_log10: Optional[bool] = None,
include_values: Optional[bool] = None,
**kwargs,
):
"""
Customizes the contours of the plotted figure.
This object wraps the following Matplotlib method:
- plt.contours: https://matplotlib.org/3.3.2/api/_as_gen/matplotlib.pyplot.contours.html
Parameters
----------
manual_levels
Manually override the levels at which the contours are plotted.
total_contours
The total number of contours plotted, which also determines the spacing between each contour.
use_log10
Whether the contours are plotted with a log10 spacing between each contour (alternative is linear).
include_values
Whether the values of the contours are included on the figure.
"""

super().__init__(**kwargs)

self.manual_levels = manual_levels
self.total_contours = total_contours or self.config_dict.get("total_contours")
self.use_log10 = use_log10 or self.config_dict.get("use_log10")
self.include_values = include_values or self.config_dict.get("include_values")

def levels_from(
self, array: Union[np.ndarray, Array2D]
) -> Union[np.ndarray, List[float]]:
"""
The levels at which the contours are plotted, which may be determined in the following ways:
- Automatically computed from the minimum and maximum values of the array, using a log10 or linear spacing.
- Overriden by the input `manual_levels` (e.g. if it is not None).
Returns
-------
The levels at which the contours are plotted.
"""
if self.manual_levels is None:
if self.use_log10:
return np.logspace(
np.log10(np.min(array)),
np.log10(np.max(array)),
self.total_contours,
)
return np.linspace(np.min(array), np.max(array), self.total_contours)

return self.manual_levels

def set(self, array: Union[np.ndarray, Array2D], extent: List[float] = None, use_log10 : bool = False):
"""
Plot an input grid of (y,x) coordinates using the matplotlib method `plt.scatter`.
Parameters
----------
array
The array of values the contours are plotted over.
"""

if not use_log10:
if self.kwargs.get("is_default") is True:
return

config_dict = self.config_dict
config_dict.pop("total_contours")
config_dict.pop("use_log10")
config_dict.pop("include_values")

levels = self.levels_from(array)

ax = plt.contour(
array.native[::-1], levels=levels, extent=extent, **config_dict
)
if self.include_values:
try:
ax.clabel(levels=levels, inline=True, fontsize=10)
except ValueError:
pass
27 changes: 27 additions & 0 deletions test_autoarray/plot/wrap/base/test_abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import autoarray.plot as aplt


def test__from_config_or_via_manual_input():
# Testing for config loading, could be any matplot object but use GridScatter as example

grid_scatter = aplt.GridScatter()

assert grid_scatter.config_dict["marker"] == "x"
assert grid_scatter.config_dict["c"] == "y"

grid_scatter = aplt.GridScatter(marker="x")

assert grid_scatter.config_dict["marker"] == "x"
assert grid_scatter.config_dict["c"] == "y"

grid_scatter = aplt.GridScatter()
grid_scatter.is_for_subplot = True

assert grid_scatter.config_dict["marker"] == "."
assert grid_scatter.config_dict["c"] == "r"

grid_scatter = aplt.GridScatter(c=["r", "b"])
grid_scatter.is_for_subplot = True

assert grid_scatter.config_dict["marker"] == "."
assert grid_scatter.config_dict["c"] == ["r", "b"]
Loading

0 comments on commit 8954aba

Please sign in to comment.