Skip to content

Commit

Permalink
Add critical lines option to convergence_diff plotting routine
Browse files Browse the repository at this point in the history
  • Loading branch information
aymgal committed Oct 10, 2023
1 parent dfaebd4 commit 206994d
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 19 deletions.
10 changes: 7 additions & 3 deletions coolest/api/composable_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,12 @@ def evaluate_magnification(self, x, y):
mu = 1. / det_A
return mu

def ray_shooting(self, x, y):
"""evaluates the lens equation beta = theta - alpha(theta)"""
alpha_x, alpha_y = self.evaluate_deflection(x, y)
x_rs, y_rs = x - alpha_x, y - alpha_y
return x_rs, y_rs


class ComposableLensModel(object):
"""Given a COOLEST object, evaluates a selection of entity and
Expand Down Expand Up @@ -382,6 +388,4 @@ def evaluate_lensed_surface_brightness(self, x, y):

def ray_shooting(self, x, y):
"""evaluates the lens equation beta = theta - alpha(theta)"""
alpha_x, alpha_y = self.lens_mass.evaluate_deflection(x, y)
x_rs, y_rs = x - alpha_x, y - alpha_y
return x_rs, y_rs
return self.lens_mass.ray_shooting(x, y)
4 changes: 2 additions & 2 deletions coolest/api/plot_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def cax_colorbar(fig, cax, norm=None, cmap=None, mappable=None, label=None, font
if label is not None:
cb.set_label(label, fontsize=fontsize, **label_kwargs)

def plot_regular_grid(ax, image_, neg_values_as_bad=True, xylim=None, **imshow_kwargs):
def plot_regular_grid(ax, image_, neg_values_as_bad=False, xylim=None, **imshow_kwargs):
if neg_values_as_bad:
image = np.copy(image_)
image[image < 0] = np.nan
Expand All @@ -212,7 +212,7 @@ def plot_regular_grid(ax, image_, neg_values_as_bad=True, xylim=None, **imshow_k
ax.yaxis.set_major_locator(plt.MaxNLocator(3))
return ax, im

def plot_irregular_grid(ax, points, xylim, neg_values_as_bad=True,
def plot_irregular_grid(ax, points, xylim, neg_values_as_bad=False,
norm=None, cmap=None, plot_points=False):
x, y, z = points
im = plot_voronoi(ax, x, y, z, neg_values_as_bad=neg_values_as_bad,
Expand Down
60 changes: 49 additions & 11 deletions coolest/api/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,8 @@ def plot_surface_brightness(self, ax, coordinates=None,
if coordinates_lens is None:
coordinates_lens = util.get_coordinates(self.coolest).create_new_coordinates(pixel_scale_factor=0.1)
# NOTE: here we assume that `kwargs_light` is for the source!
lens_model = ComposableLensModel(self.coolest, self._directory,
kwargs_selection_source=kwargs_light,
kwargs_selection_lens_mass=kwargs_lens_mass)
_, caustics = util.find_all_lens_lines(coordinates_lens, lens_model)
mass_model = ComposableMassModel(self.coolest, self._directory, **kwargs_lens_mass)
_, caustics = util.find_all_lens_lines(coordinates_lens, mass_model)
if cmap is None:
cmap = self.cmap_flux
if coordinates is not None:
Expand Down Expand Up @@ -181,7 +179,7 @@ def plot_model_residuals(self, ax, mask=None,
bbox={'color': 'white', 'alpha': 0.6})
return image

def plot_convergence(self, ax,
def plot_convergence(self, ax, coordinates=None,
norm=None, cmap=None, xylim=None, neg_values_as_bad=False,
add_colorbar=True, kwargs_lens_mass=None):
"""plt.imshow panel showing the 2D convergence map associated to the
Expand All @@ -193,7 +191,8 @@ def plot_convergence(self, ax,
**kwargs_lens_mass)
if cmap is None:
cmap = self.cmap_conv
coordinates = util.get_coordinates(self.coolest)
if coordinates is None:
coordinates = util.get_coordinates(self.coolest)
extent = coordinates.plt_extent
x, y = coordinates.pixel_coordinates
image = mass_model.evaluate_convergence(x, y)
Expand All @@ -205,9 +204,47 @@ def plot_convergence(self, ax,
cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
cb.set_label(r"$\kappa$")
return image

def plot_convergence_diff(
self, ax, reference_map, relative_error=True,
norm=None, cmap=None, xylim=None, coordinates=None,
add_colorbar=True, kwargs_lens_mass=None,
plot_crit_lines=False, crit_lines_color='black', crit_lines_alpha=0.5):
"""plt.imshow panel showing the 2D convergence map associated to the
selected lensing entities (see ComposableMassModel docstring)
"""
if kwargs_lens_mass is None:
kwargs_lens_mass = {}
mass_model = ComposableMassModel(self.coolest, self._directory,
**kwargs_lens_mass)
if cmap is None:
cmap = self.cmap_res
if norm is None:
norm = Normalize(-1, 1)
if coordinates is None:
coordinates = util.get_coordinates(self.coolest)
if plot_crit_lines:
critical_lines, _ = util.find_all_lens_lines(coordinates, mass_model)
extent = coordinates.plt_extent
x, y = coordinates.pixel_coordinates
image = mass_model.evaluate_convergence(x, y)
if relative_error is True:
diff = (reference_map - image) / reference_map
else:
diff = reference_map - image
ax, im = plut.plot_regular_grid(ax, diff, extent=extent,
cmap=cmap,
norm=norm, xylim=xylim)
if plot_crit_lines:
for cline in critical_lines:
ax.plot(cline[0], cline[1], lw=1, color=crit_lines_color, alpha=crit_lines_alpha)
if add_colorbar:
cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
cb.set_label(r"$\kappa$")
return image

def plot_magnification(self, ax,
norm=None, cmap=None, xylim=None, neg_values_as_bad=False,
norm=None, cmap=None, xylim=None,
add_colorbar=True, coordinates=None, kwargs_lens_mass=None):
"""plt.imshow panel showing the 2D magnification map associated to the
selected lensing entities (see ComposableMassModel docstring)
Expand All @@ -226,8 +263,7 @@ def plot_magnification(self, ax,
extent = coordinates.plt_extent
image = mass_model.evaluate_magnification(x, y)
ax, im = plut.plot_regular_grid(ax, image, extent=extent,
cmap=cmap,
neg_values_as_bad=neg_values_as_bad,
cmap=cmap,
norm=norm, xylim=xylim)
if add_colorbar:
cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
Expand All @@ -236,7 +272,7 @@ def plot_magnification(self, ax,

def plot_magnification_diff(
self, ax, reference_map, relative_error=True,
norm=None, cmap=None, xylim=None, neg_values_as_bad=False,
norm=None, cmap=None, xylim=None,
add_colorbar=True, coordinates=None, kwargs_lens_mass=None):
"""plt.imshow panel showing the (absolute or relative)
difference between 2D magnification maps
Expand All @@ -260,7 +296,6 @@ def plot_magnification_diff(
diff = reference_map - image
ax, im = plut.plot_regular_grid(ax, diff, extent=extent,
cmap=cmap,
neg_values_as_bad=neg_values_as_bad,
norm=norm, xylim=xylim)
if add_colorbar:
cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
Expand Down Expand Up @@ -310,6 +345,9 @@ def plot_convergence(self, axes, **kwargs):
def plot_magnification(self, axes, **kwargs):
return self._plot_lens_model_multi('plot_magnification', axes, **kwargs)

def plot_convergence_diff(self, axes, *args, **kwargs):
return self._plot_lens_model_multi('plot_convergence_diff', axes, *args, **kwargs)

def plot_magnification_diff(self, axes, *args, **kwargs):
return self._plot_lens_model_multi('plot_magnification_diff', axes, *args, **kwargs)

Expand Down
13 changes: 10 additions & 3 deletions coolest/api/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,7 @@ def read_sersic(light, param={}, prefix='Sersic_0_'):

return param


def find_critical_lines(coordinates, mag_map):
from skimage import measure
# invert and find contours corresponding to infinite magnification (i.e., changing sign)
inv_mag = 1. / np.array(mag_map)
contours = measure.find_contours(inv_mag, 0.)
Expand All @@ -421,14 +419,23 @@ def find_critical_lines(coordinates, mag_map):
return lines

def find_caustics(crit_lines, composable_lens):
"""`composable_lens` can be an instance of `ComposableLens` or `ComposableMass`"""
lines = []
for cline in crit_lines:
cl_src_x, cl_src_y = composable_lens.ray_shooting(cline[0], cline[1])
lines.append((np.array(cl_src_x), np.array(cl_src_y)))
return lines

def find_all_lens_lines(coordinates, composable_lens):
mag_map = composable_lens.lens_mass.evaluate_magnification(*coordinates.pixel_coordinates)
"""`composable_lens` can be an instance of `ComposableLens` or `ComposableMass`"""
from coolest.api.composable_models import ComposableLensModel, ComposableMassModel # avoiding circular imports
if isinstance(composable_lens, ComposableLensModel):
mag_fn = composable_lens.lens_mass.evaluate_magnification
elif isinstance(composable_lens, ComposableMassModel):
mag_fn = composable_lens.evaluate_magnification
else:
raise ValueError("`composable_lens` must be a ComposableLensModel or a ComposableMassModel.")
mag_map = mag_fn(*coordinates.pixel_coordinates)
crit_lines = find_critical_lines(coordinates, mag_map)
caustics = find_caustics(crit_lines, composable_lens)
return crit_lines, caustics
Expand Down

0 comments on commit 206994d

Please sign in to comment.