diff --git a/herculens/Analysis/plot.py b/herculens/Analysis/plot.py index cc985e5..39a5ba2 100644 --- a/herculens/Analysis/plot.py +++ b/herculens/Analysis/plot.py @@ -79,8 +79,9 @@ def __init__(self, data_name=None, base_fontsize=14, flux_log_scale=True, def set_data(self, data): self._data = data - def set_ref_source(self, ref_source): + def set_ref_source(self, ref_source, plt_extent=None): self._ref_source = ref_source + self._ref_source_extent = plt_extent def set_ref_lens_light(self, ref_lens_light): self._ref_lens_light = ref_lens_light @@ -159,11 +160,13 @@ def model_summary(self, lens_image, kwargs_result, show_source_diff = False else: show_source_diff = True + ref_src_extent = self._ref_source_extent else: ref_source = None + ref_src_extent = None show_source_diff = False - if 'kwargs_point_source' in kwargs_result: + if len(lens_image.PointSourceModel.type_list) > 0: #TODO: support several point source models ps0_params = kwargs_result['kwargs_point_source'][0] all_ps_src_x, all_ps_src_y = lens_image.PointSourceModel.get_source_plane_points( @@ -284,7 +287,7 @@ def model_summary(self, lens_image, kwargs_result, ax.plot(curve[0], curve[1], linewidth=0.8, color='white') ax.scatter(*centers, s=20, c='gray', marker='+', linewidths=0.5) if show_shear_field: - shear_field = model_util.shear_deflection_field(lens_image, kwargs_lens, num_pixels=8) + shear_field = model_util.shear_deflection_field(lens_image, kwargs_result['kwargs_lens'], num_pixels=8) if shear_field is not None: x_field, y_field, g1_field, g2_field, ax_field, ay_field = shear_field qu = ax.quiver(x_field, y_field, @@ -337,7 +340,7 @@ def model_summary(self, lens_image, kwargs_result, ##### UNLENSED AND UNCONVOLVED SOURCE MODEL ##### ax = axes[i_row, 0] if ref_source is not None: - im = ax.imshow(ref_source, extent=src_extent, cmap=self.cmap_flux_alt, norm=norm_flux) #, vmax=vmax) + im = ax.imshow(ref_source, extent=ref_src_extent, cmap=self.cmap_flux_alt, norm=norm_flux) #, vmax=vmax) im.set_rasterized(True) ax.set_title("ref. source", fontsize=self.base_fontsize) nice_colorbar(im, position='top', pad=0.4, size=0.2,