Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/pixelization log10 #89

Merged
merged 5 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions autoarray/config/visualize/plots.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ imaging: # Settings for plots of imaging datas
psf: false
fit: # Settings for plots of all fits (e.g. FitImagingPlotter, FitInterferometerPlotter).
subplot_fit: true # Plot subplot of all fit quantities for any dataset (e.g. the model data, residual-map, etc.)?
subplot_fit_log10: true # Plot subplot of all fit quantities for any dataset using log10 color maps (e.g. the model data, residual-map, etc.)?
all_at_end_png: true # Plot all individual plots listed below as .png (even if False)?
all_at_end_fits: true # Plot all individual plots listed below as .fits (even if False)?
all_at_end_pdf: false # Plot all individual plots listed below as publication-quality .pdf (even if False)?
Expand All @@ -29,6 +30,7 @@ inversion: # Settings for plots of inversions (e
all_at_end_png: true # Plot all individual plots listed below as .png (even if False)?
all_at_end_fits: true # Plot all individual plots listed below as .fits (even if False)?
all_at_end_pdf: false # Plot all individual plots listed below as publication-quality .pdf (even if False)?
data_subtracted: false # Plot individual plots of the data with the other inversion linear objects subtracted?
errors: false # Plot image of the errors of every mesh-pixel reconstructed value?
mesh_pixels_per_image_pixels : false # Plot the number of image-plane mesh pixels per masked data pixels?
reconstructed_image: false # Plot image of the reconstructed data (e.g. in the image-plane)?
Expand Down
28 changes: 28 additions & 0 deletions autoarray/inversion/inversion/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,34 @@ def mapped_reconstructed_image(self) -> Array2D:
"""
return sum(self.mapped_reconstructed_image_dict.values())

@cached_property
def data_subtracted_dict(self) -> Dict[LinearObj, Array2D]:
"""
Returns a dictionary of the data subtracted by the reconstructed images of combinations of all but one of the
linear objects the inversion.

This produces images of the data showing what each linear object is actually fitted to, after accounting for
the signal in the other linear objects.

Returns
-------
A dictionary of the data subtracted by the reconstructed images of combinations of all but one of the
linear objects the inversion.
"""

data_subtracted_dict = {}

for linear_obj in self.linear_obj_list:
data_subtracted_dict[linear_obj] = copy.copy(self.data)

for linear_obj_other in self.linear_obj_list:
if linear_obj != linear_obj_other:
data_subtracted_dict[
linear_obj
] -= self.mapped_reconstructed_image_dict[linear_obj_other]

return data_subtracted_dict

@cached_property
@profile_func
def regularization_term(self) -> float:
Expand Down
58 changes: 39 additions & 19 deletions autoarray/inversion/plot/inversion_plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def figures_2d(self, reconstructed_image: bool = False):
def figures_2d_of_pixelization(
self,
pixelization_index: int = 0,
data_subtracted: bool = False,
reconstructed_image: bool = False,
reconstruction: bool = False,
errors: bool = False,
Expand Down Expand Up @@ -143,6 +144,17 @@ def figures_2d_of_pixelization(

mapper_plotter = self.mapper_plotter_from(mapper_index=pixelization_index)

if data_subtracted:
array = self.inversion.data_subtracted_dict[mapper_plotter.mapper]

self.mat_plot_2d.plot_array(
array=array,
visuals_2d=self.get_visuals_2d_for_data(),
auto_labels=AutoLabels(
title="Data Subtracted", filename="data_subtracted"
),
)

if reconstructed_image:
array = self.inversion.mapped_reconstructed_image_dict[
mapper_plotter.mapper
Expand Down Expand Up @@ -248,10 +260,8 @@ def subplot_of_mapper(

self.include_2d._mapper_image_plane_mesh_grid = False

self.mat_plot_2d.plot_array(
array=self.inversion.data,
visuals_2d=self.get_visuals_2d_for_data(),
auto_labels=AutoLabels(title=f" Data"),
self.figures_2d_of_pixelization(
pixelization_index=mapper_index, data_subtracted=True
)

self.figures_2d_of_pixelization(
Expand All @@ -274,45 +284,55 @@ def subplot_of_mapper(
)

self.mat_plot_2d.use_log10 = False
self.mat_plot_2d.contour = contour_original

self.include_2d._mapper_image_plane_mesh_grid = True
self.figures_2d_of_pixelization(
pixelization_index=mapper_index, reconstructed_image=True
pixelization_index=mapper_index, reconstruction=True
)

self.include_2d._mapper_image_plane_mesh_grid = False
self.set_title(label="Source Reconstruction (Unzoomed)")
self.figures_2d_of_pixelization(
pixelization_index=mapper_index, mesh_pixels_per_image_pixels=True
pixelization_index=mapper_index,
reconstruction=True,
zoom_to_brightest=False,
)
self.set_title(label=None)

self.include_2d._mapper_image_plane_mesh_grid = mapper_image_plane_mesh_grid
self.mat_plot_2d.use_log10 = True

self.set_title(label="Source Reconstruction (log10)")

self.figures_2d_of_pixelization(
pixelization_index=mapper_index, reconstruction=True
)

self.set_title(label="Source Reconstruction (Unzoomed)")
self.set_title(label="Source Reconstruction (Unzoomed log10)")
self.figures_2d_of_pixelization(
pixelization_index=mapper_index,
reconstruction=True,
zoom_to_brightest=False,
)
self.set_title(label=None)

self.figures_2d_of_pixelization(pixelization_index=mapper_index, errors=True)
self.mat_plot_2d.use_log10 = False
self.mat_plot_2d.contour = contour_original

self.include_2d._mapper_image_plane_mesh_grid = True
self.figures_2d_of_pixelization(
pixelization_index=mapper_index, reconstructed_image=True
)

self.include_2d._mapper_image_plane_mesh_grid = False
self.figures_2d_of_pixelization(
pixelization_index=mapper_index, mesh_pixels_per_image_pixels=True
)

self.include_2d._mapper_image_plane_mesh_grid = mapper_image_plane_mesh_grid

self.set_title(label="Errors (Unzoomed)")
self.figures_2d_of_pixelization(
pixelization_index=mapper_index, errors=True, zoom_to_brightest=False
)

try:
self.figures_2d_of_pixelization(
pixelization_index=mapper_index, regularization_weights=True
)
except IndexError:
pass

self.set_title(label="Regularization Weights (Unzoomed)")
try:
self.figures_2d_of_pixelization(
Expand Down
6 changes: 6 additions & 0 deletions autoarray/plot/mat_plot/two_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ def plot_array(
if array is None or np.all(array == 0):
return

if self.use_log10 and (np.all(array == array[0]) or np.all(array < 0)):
return

if array.pixel_scales is None and self.units.use_scaled:
raise exc.ArrayException(
"You cannot plot an array using its scaled unit_label if the input array does not have "
Expand Down Expand Up @@ -624,6 +627,7 @@ def _plot_delaunay_mapper(
colorbar_tickparams=self.colorbar_tickparams,
aspect=aspect_inv,
ax=ax,
use_log10=self.use_log10,
)

self.title.set(auto_title=auto_labels.title)
Expand Down Expand Up @@ -690,6 +694,7 @@ def _plot_voronoi_mapper(
colorbar=self.colorbar,
colorbar_tickparams=self.colorbar_tickparams,
ax=ax,
use_log10=self.use_log10,
)

else:
Expand All @@ -702,6 +707,7 @@ def _plot_voronoi_mapper(
colorbar_tickparams=self.colorbar_tickparams,
aspect=aspect_inv,
ax=ax,
use_log10=self.use_log10,
)

self.title.set(auto_title=auto_labels.title)
Expand Down
23 changes: 15 additions & 8 deletions autoarray/plot/wrap/base/colorbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,13 @@ def set(
return cb

def set_with_color_values(
self, units: Units, cmap: str, color_values: np.ndarray, ax=None, norm=None
self,
units: Units,
cmap: str,
color_values: np.ndarray,
ax=None,
norm=None,
use_log10: bool = False,
):
"""
Set the figure's colorbar using an array of already known color values.
Expand All @@ -178,15 +184,16 @@ def set_with_color_values(
The values of the pixels on the Voronoi mesh which are used to create the colorbar.
"""

mappable = cm.ScalarMappable(cmap=cmap)
mappable = cm.ScalarMappable(norm=norm, cmap=cmap)
mappable.set_array(color_values)

manual_tick_values = self.tick_values_from(norm=norm)
manual_tick_labels = self.tick_labels_from(
manual_tick_values=manual_tick_values, units=units
tick_values = self.tick_values_from(norm=norm, use_log10=use_log10)
tick_labels = self.tick_labels_from(
manual_tick_values=tick_values,
units=units,
)

if manual_tick_values is None and manual_tick_labels is None:
if tick_values is None and tick_labels is None:
cb = plt.colorbar(
mappable=mappable,
ax=ax,
Expand All @@ -196,11 +203,11 @@ def set_with_color_values(
cb = plt.colorbar(
mappable=mappable,
ax=ax,
ticks=manual_tick_values,
ticks=tick_values,
**self.config_dict,
)
cb.ax.set_yticklabels(
labels=manual_tick_labels, va=self.manual_alignment or "center"
labels=tick_labels, va=self.manual_alignment or "center"
)

return cb
15 changes: 12 additions & 3 deletions autoarray/plot/wrap/two_d/interpolated_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def imshow_reconstruction(
colorbar_tickparams: wb.ColorbarTickParams = None,
aspect=None,
ax=None,
use_log10: bool = False,
):
"""
Given a `Mapper` and a corresponding array of `pixel_values` (e.g. the reconstruction values of a Delaunay
Expand Down Expand Up @@ -76,6 +77,10 @@ def imshow_reconstruction(
if pixel_values is None:
return

interpolation_array = mapper.interpolated_array_from(values=pixel_values)

norm = cmap.norm_from(array=interpolation_array, use_log10=use_log10)

vmin = cmap.vmin_from(array=pixel_values)
vmax = cmap.vmax_from(array=pixel_values)

Expand All @@ -86,16 +91,20 @@ def imshow_reconstruction(

if colorbar is not None:
colorbar = colorbar.set_with_color_values(
units=units, cmap=cmap, color_values=color_values, ax=ax
units=units,
cmap=cmap,
norm=norm,
color_values=color_values,
ax=ax,
use_log10=use_log10,
)
if colorbar is not None and colorbar_tickparams is not None:
colorbar_tickparams.set(cb=colorbar)

interpolation_array = mapper.interpolated_array_from(values=pixel_values)

plt.imshow(
X=interpolation_array.native,
cmap=cmap,
norm=norm,
extent=mapper.source_plane_mesh_grid.geometry.extent_square,
aspect=aspect,
)
Expand Down
20 changes: 19 additions & 1 deletion autoarray/plot/wrap/two_d/voronoi_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def draw_voronoi_pixels(
colorbar: Optional[wb.Colorbar],
colorbar_tickparams: Optional[wb.ColorbarTickParams] = None,
ax=None,
use_log10: bool = False,
):
"""
Draws the Voronoi pixels of the input `mapper` using its `mesh_grid` which contains the (y,x)
Expand All @@ -46,6 +47,12 @@ def draw_voronoi_pixels(
The colormap used to plot each Voronoi cell.
colorbar
The `Colorbar` object in `mat_base` used to set the colorbar of the figure the Voronoi mesh is plotted on.
colorbar_tickparams
The `ColorbarTickParams` object in `mat_base` used to set the tick labels of the colorbar.
ax
The matplotlib axis the Voronoi mesh is plotted on.
use_log10
If `True`, the colorbar is plotted using a log10 scale.
"""

if ax is None:
Expand All @@ -54,6 +61,12 @@ def draw_voronoi_pixels(
regions, vertices = mesh_util.voronoi_revised_from(voronoi=mapper.voronoi)

if pixel_values is not None:
norm = cmap.norm_from(array=pixel_values, use_log10=use_log10)

if use_log10:
pixel_values[pixel_values < 1e-4] = 1e-4
pixel_values = np.log10(pixel_values)

vmin = cmap.vmin_from(array=pixel_values)
vmax = cmap.vmax_from(array=pixel_values)

Expand All @@ -69,7 +82,12 @@ def draw_voronoi_pixels(

if colorbar is not None:
cb = colorbar.set_with_color_values(
units=units, cmap=cmap, color_values=color_values, ax=ax
units=units,
norm=norm,
cmap=cmap,
color_values=color_values,
ax=ax,
use_log10=use_log10,
)

if cb is not None and colorbar_tickparams is not None:
Expand Down
32 changes: 32 additions & 0 deletions test_autoarray/inversion/inversion/test_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,38 @@ def test__mapped_reconstructed_image():
assert (inversion.mapped_reconstructed_image == 3.0 * np.ones(2)).all()


def test__data_subtracted_dict():
linear_obj_0 = aa.m.MockLinearObj()

mapped_reconstructed_data_dict = {linear_obj_0: np.ones(3)}

# noinspection PyTypeChecker
inversion = aa.m.MockInversion(
data=3.0 * np.ones(3),
linear_obj_list=[linear_obj_0],
mapped_reconstructed_data_dict=mapped_reconstructed_data_dict,
)

assert (inversion.data_subtracted_dict[linear_obj_0] == 3.0 * np.ones(3)).all()

linear_obj_1 = aa.m.MockLinearObj()

mapped_reconstructed_data_dict = {
linear_obj_0: np.ones(3),
linear_obj_1: 2.0 * np.ones(3),
}

# noinspection PyTypeChecker
inversion = aa.m.MockInversion(
data=3.0 * np.ones(3),
linear_obj_list=[linear_obj_0, linear_obj_1],
mapped_reconstructed_data_dict=mapped_reconstructed_data_dict,
)

assert (inversion.data_subtracted_dict[linear_obj_0] == np.ones(3)).all()
assert (inversion.data_subtracted_dict[linear_obj_1] == 2.0 * np.ones(3)).all()


def test__reconstruction_raises_exception_for_linalg_error():
# noinspection PyTypeChecker
inversion = aa.m.MockInversion(
Expand Down
4 changes: 2 additions & 2 deletions test_autoarray/plot/wrap/two_d/test_voronoi_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def test__draws_voronoi_pixels_for_sensible_input(voronoi_mapper_9_3x3):
voronoi_drawer.draw_voronoi_pixels(
mapper=voronoi_mapper_9_3x3,
pixel_values=None,
units=None,
units=aplt.Units(),
cmap=aplt.Cmap(),
colorbar=None,
)
Expand All @@ -20,7 +20,7 @@ def test__draws_voronoi_pixels_for_sensible_input(voronoi_mapper_9_3x3):
voronoi_drawer.draw_voronoi_pixels(
mapper=voronoi_mapper_9_3x3,
pixel_values=values,
units=None,
units=aplt.Units(),
cmap=aplt.Cmap(),
colorbar=aplt.Colorbar(fraction=0.1, pad=0.05),
)
Loading