diff --git a/requirements.github_actions.txt b/requirements.github_actions.txt index 35c2791d..4094251f 100644 --- a/requirements.github_actions.txt +++ b/requirements.github_actions.txt @@ -5,6 +5,8 @@ scipy astropy matplotlib jupyter +numba +joblib docutils requests diff --git a/requirements.readthedocs.txt b/requirements.readthedocs.txt index 9236de4d..114aa7e6 100644 --- a/requirements.readthedocs.txt +++ b/requirements.readthedocs.txt @@ -2,6 +2,7 @@ numpy>=1.16 scipy matplotlib astropy +numba docutils requests diff --git a/scopesim/effects/psf_utils.py b/scopesim/effects/psf_utils.py index 4f61fb9e..1d2001c6 100644 --- a/scopesim/effects/psf_utils.py +++ b/scopesim/effects/psf_utils.py @@ -1,14 +1,17 @@ +from typing import Tuple, List + +import matplotlib.pyplot as plt import numpy as np -from scipy import ndimage as spi -from scipy.interpolate import RectBivariateSpline, griddata -from scipy.ndimage import zoom +import numpy.typing as npt from astropy import units as u from astropy.convolution import Gaussian2DKernel from astropy.io import fits -import matplotlib.pyplot as plt -from matplotlib.colors import LogNorm +from numba import njit, prange +from scipy import ndimage as spi +from scipy.interpolate import RectBivariateSpline, griddata +from scipy.ndimage import zoom -from .. import rc, utils +from .. import utils from ..optics import image_plane_utils as imp_utils @@ -81,9 +84,7 @@ def nmrms_from_strehl_and_wavelength(strehl, wavelength, strehl_hdu, return nm -def make_strehl_map_from_table(tbl, pixel_scale=1*u.arcsec): - - +def make_strehl_map_from_table(tbl, pixel_scale=1 * u.arcsec): # pixel_scale = utils.quantify(pixel_scale, u.um).to(u.deg) # coords = np.array([tbl["x"], tbl["y"]]).T # @@ -102,7 +103,7 @@ def make_strehl_map_from_table(tbl, pixel_scale=1*u.arcsec): hdr = imp_utils.header_from_list_of_xy(np.array([-25, 25]) / 3600., np.array([-25, 25]) / 3600., - pixel_scale=1/3600) + pixel_scale=1 / 3600) map_hdu = fits.ImageHDU(header=hdr, data=map) @@ -114,19 +115,19 @@ def rescale_kernel(image, scale_factor, spline_order=None): spline_order = utils.from_currsys("!SIM.computing.spline_order") sum_image = np.sum(image) image = zoom(image, scale_factor, order=spline_order) - image = np.nan_to_num(image, copy=False) # numpy version >=1.13 + image = np.nan_to_num(image, copy=False) # numpy version >=1.13 # Re-centre kernel im_shape = image.shape dy, dx = np.divmod(np.argmax(image), im_shape[1]) - np.array(im_shape) // 2 if dy > 0: - image = image[2*dy:, :] + image = image[2 * dy:, :] elif dy < 0: - image = image[:2*dy, :] + image = image[:2 * dy, :] if dx > 0: - image = image[:, 2*dx:] + image = image[:, 2 * dx:] elif dx < 0: - image = image[:, :2*dx] + image = image[:, :2 * dx] sum_new_image = np.sum(image) image *= sum_image / sum_new_image @@ -139,15 +140,14 @@ def cutout_kernel(image, fov_header): xcen, ycen = 0.5 * w, 0.5 * h dx = 0.5 * fov_header["NAXIS1"] dy = 0.5 * fov_header["NAXIS2"] - x0, x1 = max(0, int(xcen-dx)), min(w, int(xcen+dx)) - y0, y1 = max(0, int(ycen-dy)), min(w, int(ycen+dy)) + x0, x1 = max(0, int(xcen - dx)), min(w, int(xcen + dx)) + y0, y1 = max(0, int(ycen - dy)), min(w, int(ycen + dy)) image_cutout = image[y0:y1, x0:x1] return image_cutout def get_strehl_cutout(fov_header, strehl_imagehdu): - image = np.zeros((fov_header["NAXIS2"], fov_header["NAXIS1"])) canvas_hdu = fits.ImageHDU(header=fov_header, data=image) canvas_hdu = imp_utils.add_imagehdu_to_imagehdu(strehl_imagehdu, @@ -197,7 +197,7 @@ def get_psf_wave_exts(hdu_list, wave_key="WAVE0"): def get_total_wfe_from_table(tbl): wfes = utils.quantity_from_table("wfe_rms", tbl, "um") n_surfs = tbl["n_surfaces"] - total_wfe = np.sum(n_surfs * wfes**2)**0.5 + total_wfe = np.sum(n_surfs * wfes ** 2) ** 0.5 return total_wfe @@ -217,7 +217,7 @@ def wfe2strehl(wfe, wave): wave = utils.quantify(wave, u.um) wfe = utils.quantify(wfe, u.um) x = 2 * 3.1415926526 * wfe / wave - strehl = np.exp(-x**2) + strehl = np.exp(-x ** 2) return strehl @@ -269,6 +269,7 @@ def rotational_blur(image, angle): return image_rot / n_angles + def get_bkg_level(obj, bg_w): """ Determine the background level of image or cube slices @@ -289,7 +290,7 @@ def get_bkg_level(obj, bg_w): else: mask = np.zeros_like(obj, dtype=np.bool8) if bg_w > 0: - mask[bg_w:-bg_w,bg_w:-bg_w] = True + mask[bg_w:-bg_w, bg_w:-bg_w] = True bkg_level = np.ma.median(np.ma.masked_array(obj, mask=mask)) elif obj.ndim == 3: @@ -305,3 +306,196 @@ def get_bkg_level(obj, bg_w): else: raise ValueError("Unsupported dimension:", obj.ndim) return bkg_level + + +@njit() +def kernel_grid_linear_interpolation(kernel_grid: npt.NDArray, position: npt.NDArray) -> npt.NDArray: + """Bi-linear interpolation of a grid of 2D arrays at a given position. + + This function interpolates a grid of 2D arrays at a given position using a weighted mean (i.e. bi-linear + interpolation). The grid object should be of shape (M, N, I, J), with MxN the shape of the grid of arrays and IxJ + the shape of the array at each point. + + Parameters + ---------- + kernel_grid : npt.NDArray + An array with shape `(M, N, I, J)` defining a `MxN` grid of 2D arrays to be interpolated. + position: npt.NDArray + An array containing the position in the `MxN` at which the resulting 2D array is computed. + + Returns + ------- + npt.NDArray + An IxJ array at the given position obtained by interpolation. + """ + # Grid and kernel dimensions + grid_i, grid_j, kernel_i, kernel_j = kernel_grid.shape + + # Find the closest grid points to the given position + x, y = position + x0 = int(x) + y0 = int(y) + x1 = x0 + 1 + y1 = y0 + 1 + + # Get the four closest arrays to the given position + psf00 = kernel_grid[x0, y0, :, :] + psf01 = kernel_grid[x0, y1, :, :] + psf10 = kernel_grid[x1, y0, :, :] + psf11 = kernel_grid[x1, y1, :, :] + + # Define the weights for each grid point + dx = x - x0 + dy = y - y0 + inv_dx = 1 - dx + inv_dy = 1 - dy + + a = inv_dx * inv_dy + b = dx * inv_dy + c = inv_dx * dy + d = dx * dy + + # Construct support array and retrieve pixel values by interpolating + output = np.empty((kernel_i, kernel_j), dtype=kernel_grid.dtype) + for i in range(kernel_i): + for j in range(kernel_j): + output[i, j] = ( + a * psf00[i, j] + + b * psf01[i, j] + + c * psf10[i, j] + + d * psf11[i, j] + ) + return output + + +@njit(parallel=True) +def _convolve2d_varying_kernel(image: npt.NDArray, + kernel_grid: npt.NDArray, + coordinates: Tuple[npt.NDArray, npt.NDArray], + interpolator) -> npt.NDArray: + """(Helper) Convolve an image with a spatially-varying kernel by interpolating a discrete kernel grid. + + Numba JIT function for performing the convolution of an image with a spatially-varying kernel by interpolation of a + kernel grid at each pixel position. Check `convolve2d_varying_kernel` for more information. + + Parameters + ---------- + image: npt.NDArray + The image to be convolved. + kernel_grid : npt.NDArray + An array with shape `(M, N, I, J)` defining an `MxN` grid of 2D kernels. + coordinates : Tuple[npt.ArrayLike, npt.ArrayLike] + A tuple of arrays defining the axis coordinates of each pixel of the image in the kernel grid coordinated in + which the kernel is to be computed. + interpolator + A Numba njit'ted function that performs the interpolation. It's signature should be + `(kernel_grid: npt.NDArray, position: npt.NDArray, check_bounds: bool) -> npt.NDArray`. + + Returns + ------- + npt.NDArray + The image convolved with the kernel grid interpolated at each pixel. + """ + # [JA] TODO: Allow for kernel center != kernel.shape // 2 + # Get image, grid and kernel dimensions + img_i, img_j = image.shape + grid_i, grid_j, kernel_i, kernel_j = kernel_grid.shape + + # Add padding to the image (note: Numba doesn't support np.pad) + kernel_ci, kernel_cj = kernel_i // 2, kernel_j // 2 + padded_img = np.zeros((img_i + kernel_i - 1, img_j + kernel_j - 1), dtype=image.dtype) + padded_img[kernel_ci:kernel_ci + img_i, kernel_cj:kernel_cj + img_j] = image + + # Create output array + output = np.zeros_like(padded_img) + # Compute kernel and convolve for each pixel + for i in prange(img_i): + x = coordinates[0][i] + for j in range(img_j): + pixel_value = image[i, j] + if pixel_value != 0: + y = coordinates[1][j] + # Get kernel for current pixel + position = np.array((x, y)) + kernel = interpolator(kernel_grid=kernel_grid, + position=position) + + # Apply to image + tmp = np.zeros_like(padded_img) + + start_i, start_j = i, j + stop_i, stop_j = start_i + kernel_i, start_j + kernel_j + tmp[start_i:stop_i, start_j:stop_j] += pixel_value * kernel + tmp[start_i:stop_i, start_j:stop_j] = pixel_value * kernel + + output += tmp + return output[kernel_ci:kernel_ci + img_i, kernel_cj:kernel_cj + img_j] + + +def convolve2d_varying_kernel(image: npt.ArrayLike, + kernel_grid: npt.ArrayLike, + coordinates: List[npt.ArrayLike], + *, + mode: str = "linear") -> npt.NDArray: + """Convolve an image with a spatially-varying kernel by interpolating a discrete kernel grid. + + An image is convolved with a spatially-varying kernel, as defined by a discrete kernel grid, by computing, for each + of the image pixels, an effective kernel. The effective kernel is obtained by interpolating the origin kernel grid + at the position of each image pixel. + + + Parameters + ---------- + image: npt.Arraylike + The image to be convolved. + kernel_grid : npt.ArrayLike + An array with shape `(M, N, I, J)` defining an `MxN` grid of 2D kernels. + coordinates : List[npt.ArrayLike] + A tuple of arrays defining the axis coordinates of each pixel of the image in the kernel grid coordinated in + which the kernel is to be computed. + mode : str + The interpolation mode to be used to interpolate the convolution kernel (currently only `\"linear\"` - for + bi-linear interpolation - is implemented). + + Returns + ------- + npt.NDArray + The image convolved with the kernel grid interpolated at each pixel. + + Raises + ------ + ValueError + If the provided axis coordinates are out of bounds with respect to the provided kernel grid. + ValueError + If the provided axis coordinates do not match the image shape. + NotImplementedError + If the interpolation mode (`mode`) is `nearest` (nearest neighbor interpolation). + ValueError + If the interpolation mode (`mode`) is not `nearest` or `linear`. + """ + + image = np.array(image) + kernel_grid = np.array(kernel_grid) + x, y = (np.array(axis) for axis in tuple(coordinates)) + + # Validate coordinates + if np.any((x.max(), y.max()) >= image.shape) or np.any((x.min(), y.min()) < (0, 0)): + raise ValueError("Coordinates out of kernel grid bounds.") + + if (x.size, y.size) != image.shape: + raise ValueError("Coordinates provided do not match image shape.") + + # Select interpolation mode + mode = str(mode).lower() + if mode == "linear": + interpolation_fn = kernel_grid_linear_interpolation + elif mode == "nearest": + interpolation_fn = None + raise NotImplementedError(f"Mode \'{mode}\' not implemented.") + else: + raise ValueError(f"Invalid interpolation mode \'{mode}\'") + + return _convolve2d_varying_kernel(image=image, + kernel_grid=kernel_grid, + coordinates=(x, y), + interpolator=interpolation_fn) diff --git a/scopesim/effects/psfs.py b/scopesim/effects/psfs.py index 90417674..b19f6561 100644 --- a/scopesim/effects/psfs.py +++ b/scopesim/effects/psfs.py @@ -1,19 +1,19 @@ -from copy import deepcopy +import anisocado as aniso import numpy as np -from scipy.signal import convolve -from scipy.interpolate import RectBivariateSpline - +import numpy.typing as npt from astropy import units as u -from astropy.io import fits from astropy.convolution import Gaussian2DKernel +from astropy.io import fits from astropy.wcs import WCS -import anisocado as aniso +from scipy.interpolate import RectBivariateSpline +from scipy.signal import convolve -from .effects import Effect -from . import ter_curves_utils as tu from . import psf_utils as pu -from ..base_classes import ImagePlaneBase, FieldOfViewBase, FOVSetupBase +from . import ter_curves_utils as tu +from .effects import Effect +from .psf_utils import convolve2d_varying_kernel from .. import utils +from ..base_classes import ImagePlaneBase, FieldOfViewBase, FOVSetupBase class PoorMansFOV: @@ -22,7 +22,7 @@ def __init__(self, recursion_call=False): self.header = {"CDELT1": pixel_scale / 3600., "CDELT2": pixel_scale / 3600., "NAXIS1": 128, - "NAXIS2": 128,} + "NAXIS2": 128, } self.meta = utils.from_currsys("!SIM.spectral") self.wavelength = self.meta["wave_mid"] * u.um if not recursion_call: @@ -39,7 +39,7 @@ def __init__(self, **kwargs): params = {"flux_accuracy": "!SIM.computing.flux_accuracy", "sub_pixel_flag": "!SIM.sub_pixel.flag", "z_order": [40, 640], - "convolve_mode": "same", # "full", "same" + "convolve_mode": "same", # "full", "same" "bkg_width": -1, "wave_key": "WAVE0", "normalise_kernel": True, @@ -65,7 +65,7 @@ def apply_to(self, obj, **kwargs): # 2. During observe: convolution elif isinstance(obj, self.convolution_classes): if ((hasattr(obj, "fields") and len(obj.fields) > 0) or - (obj.hdu is not None)): + (obj.hdu is not None)): kernel = self.get_kernel(obj).astype(float) # apply rotational blur for field-tracking observations @@ -75,7 +75,7 @@ def apply_to(self, obj, **kwargs): kernel = pu.rotational_blur(kernel, rot_blur_angle) # normalise psf kernel KERNEL SHOULD BE normalised within get_kernel() - #if utils.from_currsys(self.meta["normalise_kernel"]) is True: + # if utils.from_currsys(self.meta["normalise_kernel"]) is True: # kernel /= np.sum(kernel) # kernel[kernel < 0.] = 0. @@ -112,7 +112,6 @@ def apply_to(self, obj, **kwargs): return obj - def fov_grid(self, which="waveset", **kwargs): waveset = [] if which == "waveset": @@ -144,7 +143,6 @@ def plot(self, obj=None, **kwargs): return plt.gcf() - ################################################################################ # Analytical PSFs - Vibration, Seeing, NCPAs @@ -159,6 +157,7 @@ class Vibration(AnalyticalPSF): """ Creates a wavelength independent kernel image """ + def __init__(self, **kwargs): super().__init__(**kwargs) self.meta["z_order"] = [244, 744] @@ -187,6 +186,7 @@ class NonCommonPathAberration(AnalyticalPSF): Needed: pixel_scale Accepted: kernel_width, strehl_drift """ + def __init__(self, **kwargs): super().__init__(**kwargs) self.meta["z_order"] = [241, 641] @@ -213,7 +213,7 @@ def fov_grid(self, which="waveset", **kwargs): max_sr = pu.wfe2strehl(self.total_wfe, self.meta["wave_max"]) srs = np.arange(min_sr, max_sr, self.meta["strehl_drift"]) - waves = 6.2831853 * self.total_wfe * (-np.log(srs))**-0.5 + waves = 6.2831853 * self.total_wfe * (-np.log(srs)) ** -0.5 waves = utils.quantify(waves, u.um).value waves = (list(waves) + [self.meta["wave_max"]]) * u.um else: @@ -274,6 +274,7 @@ class SeeingPSF(AnalyticalPSF): [arcsec] """ + def __init__(self, fwhm=1.5, **kwargs): super().__init__(**kwargs) @@ -427,6 +428,7 @@ class AnisocadoConstPSF(SemiAnalyticalPSF): psf_side_length: 512 """ + def __init__(self, **kwargs): super(AnisocadoConstPSF, self).__init__(**kwargs) params = {"z_order": [42, 652], @@ -438,7 +440,7 @@ def __init__(self, **kwargs): self.required_keys = ["filename", "strehl", "wavelength"] utils.check_keys(self.meta, self.required_keys, action="error") - self.nmRms # check to see if it throws an error + self.nmRms # check to see if it throws an error self._psf_object = None self._kernel = None @@ -520,17 +522,17 @@ def plot(self, obj=None, **kwargs): im = kernel r_sky = pixel_scale * im.shape[0] plt.imshow(im, norm=LogNorm(), origin='lower', - extent= [-r_sky, r_sky, -r_sky, r_sky], **kwargs) + extent=[-r_sky, r_sky, -r_sky, r_sky], **kwargs) plt.ylabel("[arcsec]") plt.subplot2grid((2, 2), (0, 1)) x = kernel.shape[1] // 2 y = kernel.shape[0] // 2 r = 16 - im = kernel[y-r:y+r, x-r:x+r] + im = kernel[y - r:y + r, x - r:x + r] r_sky = pixel_scale * im.shape[0] plt.imshow(im, norm=LogNorm(), origin='lower', - extent= [-r_sky, r_sky, -r_sky, r_sky], **kwargs) + extent=[-r_sky, r_sky, -r_sky, r_sky], **kwargs) plt.ylabel("[arcsec]") plt.gca().yaxis.set_label_position('right') @@ -554,7 +556,6 @@ def plot(self, obj=None, **kwargs): return plt.gcf() - ################################################################################ # Discrete PSFs - MAORY and co PSFs @@ -573,6 +574,7 @@ class FieldConstantPSF(DiscretePSF): For spectroscopy, the a wavelength-dependent PSF cube is built, where for each wavelength the reference PSF is scaled proportional to wavelength. """ + def __init__(self, **kwargs): # sub_pixel_flag and flux_accuracy are taken care of in PSF base class super().__init__(**kwargs) @@ -582,7 +584,7 @@ def __init__(self, **kwargs): self.meta["z_order"] = [262, 662] self._waveset, self.kernel_indexes = pu.get_psf_wave_exts( - self._file, self.meta["wave_key"]) + self._file, self.meta["wave_key"]) self.current_layer_id = None self.current_ext = None self.current_data = None @@ -624,7 +626,7 @@ def get_kernel(self, fov): self.kernel = pu.rescale_kernel(self.kernel, pix_ratio) if ((fov.header["NAXIS1"] < hdr["NAXIS1"]) or - (fov.header["NAXIS2"] < hdr["NAXIS2"])): + (fov.header["NAXIS2"] < hdr["NAXIS2"])): self.kernel = pu.cutout_kernel(self.kernel, fov.header) return self.kernel @@ -639,7 +641,7 @@ def make_psf_cube(self, fov): lam = fov.hdu.header["CDELT3"] * (1 + np.arange(fov.hdu.header["NAXIS3"]) - fov.hdu.header["CRPIX3"]) \ - + fov.hdu.header["CRVAL3"] + + fov.hdu.header["CRVAL3"] # adapt the size of the output cube to the FOV's spatial shape nxpsf = min(512, 2 * nxfov + 1) @@ -658,7 +660,7 @@ def make_psf_cube(self, fov): psfwcs = WCS(hdr) psf = self._file[ext].data - psf = psf/psf.sum() # normalisation of the input psf + psf = psf / psf.sum() # normalisation of the input psf nxin, nyin = psf.shape # We need linear interpolation to preserve positivity. Might think of @@ -682,7 +684,7 @@ def make_psf_cube(self, fov): psf_wave_pixscale] xpsf, ypsf = psfwcs.all_world2pix(xworld, yworld, 0) outcube[i,] = (ipsf(ypsf, xpsf, grid=False) - * fov_pixel_scale**2 / psf_wave_pixscale**2) + * fov_pixel_scale ** 2 / psf_wave_pixscale ** 2) self.kernel = outcube.reshape((lam.shape[0], nypsf, nxpsf)) # fits.writeto("test_psfcube.fits", data=self.kernel, overwrite=True) @@ -700,9 +702,9 @@ class FieldVaryingPSF(DiscretePSF): Default 1e-3. Level of flux conservation during rescaling of kernel """ + def __init__(self, **kwargs): - # sub_pixel_flag and flux_accuracy are taken care of in PSF base class - super(FieldVaryingPSF, self).__init__(**kwargs) + super().__init__(**kwargs) self.required_keys = ["filename"] utils.check_keys(self.meta, self.required_keys, action="error") @@ -735,7 +737,7 @@ def apply_to(self, fov, **kwargs): for kernel, mask in kernels_masks: # renormalise the kernel if needs be - kernel[kernel<0.] = 0. + kernel[kernel < 0.] = 0. sum_kernel = np.sum(kernel) if abs(sum_kernel - 1) > self.meta["flux_accuracy"]: kernel /= sum_kernel @@ -836,3 +838,261 @@ def strehl_imagehdu(self): def plot(self): return super().plot(PoorMansFOV()) + + +# [JA] TODO: The following docstrings are incomplete in some places (e.g. apply_to), mostly due the method and +# attribute definition being quite obfuscated throughout the codebase (e.g. meta) and little to no type hints being +# provided. Little can be done unless this aspects are improved upon. +class GridFieldVaryingPSF(DiscretePSF): + """A PSF that varies across the field, defined by a regular grid of PSFs at different positions. + + This effect specifies a field-varying PSF which is defined at different positions in a regular grid in the field. + Applying this effect to a Field of View object constitutes the convolution of the respective image with an effective + PSF at each pixel, which is computed by interpolating the PSFs defined at discrete grid points. + + Parameters + ---------- + ... + + Attributes + ---------- + ... + + Methods + ------- + normalise_kernel_grid + Normalises each kernel in a kernel grid by its sum. + get_kernel_grid + Find the nearest wavelength and build PSF kernel grid from file. + apply_to + TODO + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.meta["z_order"] = [262, 662] # [JA] TODO: Is this ok? + + self._waveset, self._kernel_indexes = pu.get_psf_wave_exts(self._file, self.meta["wave_key"]) + self._current_layer_id = None # The index of the PSF layer used previously + self._kernel = None # Cached PSF grid + + def normalise_kernel_grid(self, grid: npt.NDArray, *, force_normalise: bool = False) -> npt.NDArray: + """Normalises each kernel in a kernel grid by its sum. + + Normalises the integral of each kernel in a kernel grid to unity. This normalisation is only performed if the + flux accuracy is not respected or the `force_normalise` flag is set to `True`. + + Parameters + ---------- + grid : npt.NDArray + An array with shape `(M, N, I, J)` defining a `MxN` grid of 2D kernels to be normalised. + force_normalise : bool + Whether to normalise the kernel grid unheeded to whether the flux accuracy is respected. + + Returns + ------- + npt.NDArray + The input kernel grid with each grid normalised to unity. + """ + kernel_sum = np.sum(grid, axis=(2, 3), keepdims=True) + if not force_normalise and np.all(abs(kernel_sum - 1) < self.meta["flux_accuracy"]): + return grid + else: + return grid / kernel_sum + + def _setup_fov(self, obj: FOVSetupBase): + """Setups effect with Field of View object for subsequent observation. + + Parameters + ---------- + obj + + Returns + ------- + obj + + """ + waveset = self._waveset + if len(waveset) != 0: + waveset_edges = 0.5 * (waveset[:-1] + waveset[1:]) + obj.split("wave", utils.quantify(waveset_edges, u.um).value) + return obj + + def _apply_to_fov(self, obj, *, normalise_kernel: bool = True, interpolation_mode="linear"): + """Applies PSF to provided Field of View object. + + Parameters + ---------- + obj + ... + normalise_kernel : bool + Whether to force normalise the kernel grid unheeded to whether the flux accuracy is respected + interpolation_mode : str + The interpolation mode to be used to interpolate the convolution kernel. + + Returns + ------- + obj + + """ + # Check if there are any fov.fields to apply a psf to + # [JA] TODO: Check this alternative condition. Which is correct? + if len(obj.fields) > 0: + if obj.image is None: + obj.image = obj.make_image_hdu().data + + # old_shape = obj.image.shape + + # normalise kernel + kernel_grid = self.get_kernel_grid(obj) + kernel_grid = self.normalise_kernel_grid(kernel_grid, force_normalise=normalise_kernel) + + # image = obj.image.astype(float) + image = obj.image.astype(float) + + # Get image and kernel dimensions and positions + img_m, img_n = image.shape + unit_factor_1 = u.Unit(obj.header.get("CUNIT1").lower()).to(u.deg) if obj.header.get("CUNIT1") else 1 + unit_factor_2 = u.Unit(obj.header.get("CUNIT2").lower()).to(u.deg) if obj.header.get("CUNIT2") else 1 + img_coors = utils.GridCoordinates(m=img_m, + n=img_n, + x0=obj.header["CRVAL1"] * unit_factor_1, + y0=obj.header["CRVAL2"] * unit_factor_2, + dx=obj.header["CDELT1"] * unit_factor_1, + dy=obj.header["CDELT2"] * unit_factor_2, + i0=obj.header["CRPIX1"], + j0=obj.header["CRPIX2"]) + + hdr = self._file[self._current_layer_id].header + unit_factor_1 = u.Unit(hdr.get("CUNIT1").lower()).to(u.deg) if hdr.get("CUNIT1") else 1 + unit_factor_2 = u.Unit(hdr.get("CUNIT2").lower()).to(u.deg) if hdr.get("CUNIT2") else 1 + grid_m, grid_n, psf_m, psf_n = kernel_grid.shape + grid_coors = utils.GridCoordinates(m=grid_m, + n=grid_n, + x0=hdr["CRVAL1"] * unit_factor_1, + y0=hdr["CRVAL2"] * unit_factor_2, + dx=hdr["CDELT1"] * unit_factor_1, + dy=hdr["CDELT2"] * unit_factor_2, + i0=hdr["CRPIX1"], + j0=hdr["CRPIX2"]) + + coordinates = img_coors.pix_in_reference_frame(grid_coors) + + # Convolve + bkg_level = pu.get_bkg_level(image, self.meta["bkg_width"]) + canvas = convolve2d_varying_kernel(image=image-bkg_level, + kernel_grid=kernel_grid, + coordinates=coordinates, + mode=interpolation_mode) + canvas += bkg_level + + # [JA] TODO: In some classes, obj.hdu.data is updated; in other, it's the obj.image. Which is correct? + obj.hdu.data = canvas + obj.image = canvas + + # reset WCS header info + new_shape = canvas.shape + + # [JA] TODO: Implement this + # # ..todo: careful with which dimensions mean what + # if "CRPIX1" in fov.header: + # fov.header["CRPIX1"] += (new_shape[0] - old_shape[0]) / 2 + # fov.header["CRPIX2"] += (new_shape[1] - old_shape[1]) / 2 + # + # if "CRPIX1D" in fov.header: + # fov.header["CRPIX1D"] += (new_shape[0] - old_shape[0]) / 2 + # fov.header["CRPIX2D"] += (new_shape[1] - old_shape[1]) / 2 + + return obj + + def get_kernel_grid(self, fov) -> npt.NDArray: + """Find the nearest wavelength and build PSF kernel grid from file. + + Given a Field of View object, the closest wavelength PSF grid is retrieved from the input file and, if needed, + the PSFs are rescaled to match the FOV pixel scale. + + Parameters + ---------- + fov + The Field of View object to be convolved with the PSF grid. + + Returns + ------- + npt.NDArray + The PSF grid to be used for the provided Field of View object, rescaled to its pixel scale. + + Raises + ------ + NotImplementedError + If the FOV object does not have 2 axis. + """ + ii = pu.nearest_index(fov.wavelength, self._waveset) + ext = self._kernel_indexes[ii] + + # Update kernel + if ext != self._current_layer_id: + if fov.hdu.header['NAXIS'] != 2: + raise NotImplementedError("Only FOV with 2 axis supported for now.") + + self._kernel = self._file[ext].data + self._current_layer_id = ext + hdr = self._file[ext].header + + # [JA] TODO: Check if CDELTs are different for the same grid + # compare kernel and fov pixel scales, rescale if needed + kernel_unit_factor = u.Unit(hdr.get("CUNIT3").lower()).to(u.deg) if hdr.get("CUNIT3") else 1 + kernel_pixel_scale = hdr["CDELT3"] * kernel_unit_factor + + fov_unit_factor = u.Unit(fov.header.get("CUNIT1").lower()).to(u.deg) if fov.header.get("CUNIT1") else 1 + fov_pixel_scale = fov.header["CDELT1"] * fov_unit_factor + + # rescaling kept inside loop to avoid rescaling for every fov + pix_ratio = kernel_pixel_scale / fov_pixel_scale + if abs(pix_ratio - 1) > self.meta["flux_accuracy"]: # [JA] TODO: Does this make sense? + kernel_center_pixel = np.array([hdr["CRPIX3"], hdr["CRPIX4"]]) + self._kernel = utils.rescale_array_grid(self._kernel, + center_pixel=kernel_center_pixel, + origin_pixel_scale=kernel_pixel_scale, + target_pixel_scale=fov_pixel_scale) + + # [JA] TODO: fix this + # if ((fov.header["NAXIS1"] < hdr["NAXIS3"]) or + # (fov.header["NAXIS2"] < hdr["NAXIS4"])): + # self._kernel = pu.cutout_kernel(self._kernel, fov.header) + + return self._kernel + + def apply_to(self, obj, **kwargs): + """TODO docstring + + Parameters + ---------- + obj + ... + + Returns + ------- + obj + ... + + Raises + ------ + ValueError + If `obj` is not of type `FOVSetupBase` or doesn't match one of `self.convolution_classes`. + """ + # [JA] TODO: Write unified docstring for PSF classes + + # 1. During setup of the FieldOfViews + if isinstance(obj, FOVSetupBase) and self._waveset is not None: + obj = self._setup_fov(obj) + # 2. During observe: convolution + elif isinstance(obj, self.convolution_classes): + if (hasattr(obj, "fields") and len(obj.fields) > 0) or (obj.hdu is not None): + obj = self._apply_to_fov(obj, + normalise_kernel=self.meta.get("normalise_kernel", True), + interpolation_mode=self.meta.get("interpolation_mode", "linear")) + else: + raise ValueError("\'obj\' should be of type \'FOVSetupBase\' or match one of \'self.convolution_classes\'.") + + return obj diff --git a/scopesim/utils.py b/scopesim/utils.py index d0599250..a1eb6e0d 100644 --- a/scopesim/utils.py +++ b/scopesim/utils.py @@ -1,23 +1,28 @@ """ Helper functions for ScopeSim """ +from __future__ import annotations + +import logging import math import os -from pathlib import Path import sys -import logging -import logging from collections import OrderedDict -from docutils.core import publish_string from copy import deepcopy +from pathlib import Path +from typing import Tuple, Union, Optional, List +import numpy as np +import numpy.typing as npt import requests +import scipy as sp import yaml -import numpy as np from astropy import units as u -from astropy.io import fits from astropy.io import ascii as ioascii +from astropy.io import fits from astropy.table import Column, Table +from docutils.core import publish_string +from joblib import Parallel, delayed from . import rc @@ -131,7 +136,7 @@ def moffat(r, alpha, beta): ------- eta """ - return (beta - 1)/(np.pi * alpha**2) * (1 + (r/alpha)**2)**(-beta) + return (beta - 1) / (np.pi * alpha ** 2) * (1 + (r / alpha) ** 2) ** (-beta) def poissonify(arr): @@ -174,12 +179,14 @@ def nearest(arr, val): return np.argmin(abs(arr - val)) + def power_vector(val, degree): """Return the vector of powers of val up to a degree""" if degree < 0 or not isinstance(degree, int): raise ValueError("degree must be a positive integer") - return np.array([val**exp for exp in range(degree + 1)]) + return np.array([val ** exp for exp in range(degree + 1)]) + def deriv_polynomial2d(poly): """Derivatives (gradient) of a Polynomial2D model @@ -204,8 +211,8 @@ def deriv_polynomial2d(poly): i = int(match.group(1)) j = int(match.group(2)) cij = getattr(poly, pname) - pname_x = "c%d_%d" % (i-1, j) - pname_y = "c%d_%d" % (i, j-1) + pname_x = "c%d_%d" % (i - 1, j) + pname_y = "c%d_%d" % (i, j - 1) setattr(dpoly_dx, pname_x, i * cij) setattr(dpoly_dy, pname_y, j * cij) @@ -256,7 +263,8 @@ def add_SED_to_scopesim(file_in, file_out=None, wave_units="um"): if file_out is None: if "SED_" not in file_name: file_out = rc.__data_dir__ + "SED_" + file_name + ".dat" - else: file_out = rc.__data_dir__ + file_name + ".dat" + else: + file_out = rc.__data_dir__ + file_name + ".dat" if file_ext.lower() in "fits": data = fits.getdata(file_in) @@ -265,7 +273,7 @@ def add_SED_to_scopesim(file_in, file_out=None, wave_units="um"): lam, val = ioascii.read(file_in)[:2] lam = (lam * u.Unit(wave_units)).to(u.um) - mask = (lam > 0.3*u.um) * (lam < 5.0*u.um) + mask = (lam > 0.3 * u.um) * (lam < 5.0 * u.um) np.savetxt(file_out, np.array((lam[mask], val[mask]), dtype=np.float32).T, header="wavelength value \n [um] [flux]") @@ -308,7 +316,7 @@ def seq(start, stop, step=1): increment of the sequence, defaults to 1 """ - feps = 1e-10 # value used in R seq.default + feps = 1e-10 # value used in R seq.default delta = stop - start if delta == 0 and stop == 0: @@ -349,7 +357,7 @@ def add_mags(mags): """ Returns a combined magnitude for a group of py_objects with ``mags`` """ - return -2.5*np.log10((10**(-0.4*np.array(mags))).sum()) + return -2.5 * np.log10((10 ** (-0.4 * np.array(mags))).sum()) def dist_mod_from_distance(d): @@ -366,7 +374,7 @@ def distance_from_dist_mod(mu): d = 10**(1 + mu / 5) """ - d = 10**(1 + mu / 5) + d = 10 ** (1 + mu / 5) return d @@ -395,7 +403,7 @@ def telescope_diffraction_limit(aperture_size, wavelength, distance=None): """ - diff_limit = (((wavelength*u.um)/(aperture_size*u.m))*u.rad).to(u.arcsec).value + diff_limit = (((wavelength * u.um) / (aperture_size * u.m)) * u.rad).to(u.arcsec).value if distance is not None: diff_limit *= distance / u.pc.to(u.AU) @@ -477,7 +485,6 @@ def set_logger_level(which="console", level="ERROR"): """ - hdlr_name = f"scopesim_{which}_logger" level = {"ON": "INFO", "OFF": "CRITICAL"}.get(level.upper(), level) logger = logging.getLogger() @@ -557,7 +564,7 @@ def find_file(filename, path=None, silent=False): for trydir in path if trydir is not None] for fname in trynames: - if os.path.exists(fname): # success + if os.path.exists(fname): # success # strip leading ./ while fname[:2] == './': fname = fname[2:] @@ -604,11 +611,10 @@ def airmass2zendist(airmass): zenith distance in degrees """ - return np.rad2deg(np.arccos(1/airmass)) + return np.rad2deg(np.arccos(1 / airmass)) def convert_table_comments_to_dict(tbl): - comments_dict = {} if "comments" in tbl.meta: try: @@ -630,7 +636,6 @@ def convert_table_comments_to_dict(tbl): def change_table_entry(tbl, col_name, new_val, old_val=None, position=None): - offending_col = list(tbl[col_name].data) if old_val is not None: @@ -747,7 +752,6 @@ def quantify(item, unit): return quant - def extract_type_from_unit(unit, unit_type): """ Extract ``astropy`` physical type from a compound unit @@ -767,11 +771,11 @@ def extract_type_from_unit(unit, unit_type): """ - unit = unit**1 + unit = unit ** 1 extracted_units = u.Unit("") for base, power in zip(unit._bases, unit._powers): - if unit_type == (base**abs(power)).physical_type: - extracted_units *= base**power + if unit_type == (base ** abs(power)).physical_type: + extracted_units *= base ** power new_unit = unit / extracted_units @@ -796,13 +800,13 @@ def extract_base_from_unit(unit, base_unit): """ - unit = unit**1 + unit = unit ** 1 extracted_units = u.Unit("") for base, power in zip(unit._bases, unit._powers): if base == base_unit: - extracted_units *= base**power + extracted_units *= base ** power - new_unit = unit * extracted_units**-1 + new_unit = unit * extracted_units ** -1 return new_unit, extracted_units @@ -892,7 +896,7 @@ def has_needed_keywords(header, suffix=""): """ keys = ["CDELT1", "CRVAL1", "CRPIX1"] return sum([key + suffix in header.keys() for key in keys]) == 3 and \ - "NAXIS1" in header.keys() + "NAXIS1" in header.keys() def stringify_dict(dic, ignore_types=(str, int, float)): @@ -986,8 +990,8 @@ def check_keys(input_dict, required_keys, action="error", all_any="all"): "".format(required_keys, input_dict.keys())) elif "warn" in action: logging.warning("One or more of the following keys missing " - "from input_dict: \n{} \n{}" - "".format(required_keys, input_dict.keys())) + "from input_dict: \n{} \n{}" + "".format(required_keys, input_dict.keys())) return keys_present @@ -1064,3 +1068,427 @@ def return_latest_github_actions_jobs_status(owner_name="AstarVienna", repo_name params_list += [params] return params_list + + +class GridAxis: + """Class defining a grid axis coordinates and transformations between pixel coordinates and real coordinates. + + Parameters + ---------- + n : int + The dimension of the axis. + x0 : float + The central coordinate - default `0`` + i0 : Optional[int] + The central pixel coordinate - default `n//2`. + dx : float + The pixel spacing - default `1.0`. + + Attributes + ---------- + n : int + The dimension of the axis. + x0 : float + The central corodinate. + i0 : Optional[int] + The central pixel coordinate. + dx : float + The pixel spacing. + + Methods + ------- + pos_to_pix + Transforms real reference frame position to pixel positions. + pix_to_pos + Transforms pixel positions to real reference frame positions. + pix_in_reference_frame + Transforms pixel positions from one reference frame to another. + """ + + def __init__(self, n: int, *, x0: float = 0, i0: Optional[int] = None, dx: float = 1.0): + """ + Parameters + ---------- + n : int + The dimension of the axis. + x0 : float + The central corodinate - default `0`` + i0 : Optional[int] + The central pixel coordinate - default `n//2`. + dx : float + The pixel spacing - default `1.0`. + """ + self.n = n + self.x0 = x0 + self.i0 = i0 if i0 else n // 2 + self.dx = dx + + def pos_to_pix(self, pos: Union[float, npt.NDArray]) -> npt.NDArray: + """Transforms real reference frame position to pixel positions. + + Parameters + ---------- + pos : + A single or an array of positions in the reference frame. + + Returns + ------- + npt.NDArray + The pixel position corresponding to the provided reference frame position. + """ + pos = np.array(pos, ndmin=1) + return (pos - self.x0) / self.dx + self.i0 + + def pix_to_pos(self, pix: Union[float, npt.NDArray]) -> npt.NDArray: + """Transforms pixel positions to real reference frame positions. + + Parameters + ---------- + pix : Union[float, npt.NDArray] + A single or an array of pixel positions. + + Returns + ------- + npt.NDArray + The position in the reference frame corresponding to the provided pixel positions. + """ + pix = np.array(pix, ndmin=1) + return self.x0 + (pix - self.i0) * self.dx + + def pix_in_reference_frame(self, + other: GridAxis, + pix: Optional[Union[float, npt.NDArray]] = None) -> npt.NDArray: + """Transforms pixel positions from one reference frame to another. + + Parameters + ---------- + other : GridAxis + The coordinate axis onto which the pixel coordinates are projected. + pix : Optional[Union[float, npt.NDArray]] + The pixel positions to be evaluated. If not provided, all pixel positions are evaluated. + + Returns + ------- + npt.NDArray + The pixel coordinates projected in the other reference frame. + """ + pix = np.array(pix, ndmin=1) if pix else np.arange(self.n) + return (self.x0 + (pix - self.i0) * self.dx - other.x0) / other.dx + other.i0 + + @property + def coors(self) -> npt.NDArray: + """Axis coordinates""" + return self.pix_to_pos(np.arange(self.n)) + + @property + def pix(self) -> npt.NDArray: + """Axis pixel coordinates""" + return np.arange(self.n) - self.i0 + + @property + def shape(self) -> Tuple[float]: + """Axis shape""" + return self.n, + + +class GridCoordinates: + """Class defining the coordinates of a grid and transformations between pixel coordinates and real coordinates. + + Parameters + ---------- + m : int + The first dimension of the grid. + n : int + The second dimension of the grid. + x0 : float + The central coordinate of the first dimension - default `0`. + y0 : float + The central coordinate of the second dimension - default `0`. + i0 : Optional[int] + The central pixel coordinate of the first dimension - default `m//2`. + j0 : Optional[int] + The central pixel coordinate of the second dimension - default `n//2`. + dx : float + The pixel spacing of the first dimension - default `1.0`. + dy : float + The pixel spacing of the second dimension - default `1.0`. + + Attributes + ---------- + None + + Methods + ------- + pos_to_pix + Transforms real reference frame positions to pixel positions. + pix_to_pos + Transforms pixel positions to real reference frame positions. + pix_in_reference_frame + Transforms pixel positions from one reference frame to another. + """ + + def __init__(self, + m: int, n: int, + *, + x0: float = 0, y0: float = 0, + i0: Optional[int] = None, j0: Optional[int] = None, + dx: float = 1.0, dy: float = 1.0): + """ + Parameters + ---------- + m : int + The first dimension of the grid. + n : int + The second dimension of the grid. + x0 : float + The central coordinate of the first dimension - default `0`. + y0 : float + The central coordinate of the second dimension - default `0`. + i0 : Optional[int] + The central pixel coordinate of the first dimension - default `m//2`. + j0 : Optional[int] + The central pixel coordinate of the second dimension - default `n//2`. + dx : float + The pixel spacing of the first dimension - default `1.0`. + dy : float + The pixel spacing of the second dimension - default `1.0`. + """ + i0 = i0 if i0 else m // 2 + j0 = j0 if j0 else n // 2 + self._x_axis = GridAxis(n=m, x0=x0, i0=i0, dx=dx) + self._y_axis = GridAxis(n=n, x0=y0, i0=j0, dx=dy) + + @property + def x(self) -> GridAxis: + """X GridAxis object""" + return self._x_axis + + @property + def y(self) -> GridAxis: + """Y GridAxis object""" + return self._y_axis + + @property + def shape(self) -> Tuple: + """Grid shape""" + return self.x.shape + self.y.shape + + @property + def i(self) -> npt.NDArray: + """X pixel coordinates""" + return self.x.pix + + @property + def j(self) -> npt.NDArray: + """Y pixel coordinates""" + return self.y.pix + + def pos_to_pix(self, x: Union[float, npt.NDArray], y: Union[float, npt.NDArray]) -> npt.NDArray: + """Transforms real reference frame positions to pixel positions. + + Parameters + ---------- + x : float + A single or an array of X positions in the reference frame. + y : float + A single or an array of Y position in the reference frame. + + Returns + ------- + npt.NDArray + The pixel position corresponding to the provided reference frame position. + """ + return np.array([self.x.pos_to_pix(x), self.y.pos_to_pix(y)]) + + def pix_to_pos(self, i: Union[float, npt.NDArray], j: Union[float, npt.NDArray]) -> npt.NDArray: + """Transforms pixel positions to real reference frame positions. + + Parameters + ---------- + i : Union[float, npt.NDArray] + A single or an array of X pixel positions. + j : Union[float, npt.NDArray] + A single or an array of Y pixel positions. + + Returns + ------- + npt.NDArray + The position in the reference frame corresponding to the provided pixel positions. + """ + return np.array([self.x.pix_to_pos(i), self.y.pix_to_pos(j)]) + + def pix_in_reference_frame(self, + other: GridCoordinates, + i: Optional[Union[float, npt.NDArray]] = None, + j: Optional[Union[float, npt.NDArray]] = None, + grid: bool = False) -> List[npt.NDArray]: + """Transforms pixel positions from one reference frame to another. + + Parameters + ---------- + other : GridCoordinates + The coordinate reference frame onto which the pixel coordinates are projected. + i : Optional[Union[float, npt.NDArray]] + The X axis pixel positions to be evaluated. If not provided, all pixel positions are evaluated. + j : Optional[Union[float, npt.NDArray]] + The Y axis pixel positions to be evaluated. If not provided, all pixel positions are evaluated. + grid : bool + If `True` the return is a meshgrid instead of only the axis coordinates. + + Returns + ------- + List[npt.NDArray] + The pixel coordinates or a meshgrid of pixel coordinates projected in the other reference frame. + """ + i_proj = self.x.pix_in_reference_frame(other.x, i) + j_proj = self.y.pix_in_reference_frame(other.y, j) + + if grid: + return np.meshgrid(i_proj, j_proj) + else: + return [i_proj, j_proj] + + +def rescale_array(array: npt.ArrayLike, + array_x: npt.ArrayLike, + array_y: npt.ArrayLike, + target_x: npt.ArrayLike, + target_y: npt.ArrayLike, + *, + normalize: bool = True, + method: str = "linear") -> npt.NDArray: + """Rescales an array to the desired axis coordinates. + + Parameters + ---------- + array : npt.ArrayLike + The input array to be rescaled. + array_x : npt.ArrayLike + The coordinates of the X axis of the input array. + array_y : npt.ArrayLike + The coordinates of the X axis of the input array. + target_x : npt.ArrayLike + The desired coordinates of the X axis after rescaling. + target_y : npt.ArrayLike + The desired coordinates of the X axis after rescaling. + normalize : bool + Whether to normalize the array by its sum. + method : str + Interpolation method (Default = "linear" : bi-linear interpolation) + + Returns + ------- + npt.NDArray + The rescaled array. + """ + rescaled_array = sp.interpolate.interpn(points=(array_x, array_y), + values=array, + xi=(target_x, target_y), + bounds_error=False, + fill_value=0, + method=method) + + if normalize: + rescaled_array *= np.sum(array) / np.sum(rescaled_array) + return rescaled_array + + +def rescale_array_grid(grid: npt.NDArray, + center_pixel: npt.NDArray, + origin_pixel_scale: float, + target_pixel_scale: float, + *, + method: str = "cubic") -> npt.NDArray: + """Rescales each array in a grid of arrays to the desired pixel scale. + + Parameters + ---------- + grid : npt.NDArray + An array with shape `(M, N, I, J)` defining a `MxN` grid of 2D arrays to be rescaled. + center_pixel : npt.NDArray + The pixel coordinates defining the center pixel position of each array + origin_pixel_scale : float + The pixel width of the input grid. + target_pixel_scale : float + The desired pixel width for the output grid. + method : int + Interpolation method (Default = "linear" : bi-linear interpolation) + + + Returns + ------- + npt.NDArray + A `MxN` grid of 2D arrays rescaled by interpolation to the desired pixel_scale. + """ + + # Get grid shape + grid_i, grid_j, array_i, array_j = grid.shape + + # Define origin coordinates + array_x = (np.arange(start=0, stop=array_i, step=1) - center_pixel[0]) * origin_pixel_scale + array_y = (np.arange(start=0, stop=array_j, step=1) - center_pixel[1]) * origin_pixel_scale + + # Define target coordinates + target_x_lim = np.max((abs(np.floor(array_x[0] / target_pixel_scale)), + abs(np.ceil(array_x[-1] / target_pixel_scale)))) + target_x_start = -target_x_lim * target_pixel_scale + target_x_stop = (target_x_lim + 0.001) * target_pixel_scale + target_x = np.arange(start=target_x_start, stop=target_x_stop, step=target_pixel_scale) + + target_y_lim = np.max((abs(np.floor(array_y[0] / target_pixel_scale)), + abs(np.ceil(array_y[-1] / target_pixel_scale)))) + target_y_start = -target_y_lim * target_pixel_scale + target_y_stop = (target_y_lim + 0.001) * target_pixel_scale + target_y = np.arange(start=target_y_start, stop=target_y_stop, step=target_pixel_scale) + + target_shape = (target_x.size, target_y.size) + + target_x, target_y = np.meshgrid(target_x, target_y, indexing='ij') + + # Flatten the input grid along the first two dimensions (i, j) + flat_input_grid = grid.reshape((-1, array_i, array_j)) + + # Use joblib to rescale the arrays in parallel + parallel = False + if parallel: + rescaled_arrays = Parallel(n_jobs=-1)(delayed(rescale_array)(array=array, + array_x=array_x, + array_y=array_y, + target_x=target_x, + target_y=target_y, + method=method) for array in flat_input_grid) + else: + rescaled_arrays = [rescale_array(array=array, + array_x=array_x, + array_y=array_y, + target_x=target_x, + target_y=target_y, + method=method) + for array in flat_input_grid] + + # Combine the zoomed arrays back into a single grid + output_grid = np.array(rescaled_arrays).reshape((grid_i, grid_j, *target_shape)) + return output_grid + + +def normalise_array_grid(grid: npt.ArrayLike) -> npt.NDArray: + """Normalizes each array in a grid of arrays by its sum. + + Parameters + ---------- + grid : npt.ArrayLike + An array with shape `(M, N, I, J)` defining a `MxN` grid of 2D arrays to be normalised. + + Returns + ------- + npt.NDArray + The input grid with each array normalized by its sum to unity. + """ + grid = np.array(grid) + + # Calculate the sum along the arrays + grid_sum = np.sum(grid, axis=(2, 3), keepdims=True) + + # Apply normalization factor + normalized_grid = grid / grid_sum + return normalized_grid diff --git a/setup.py b/setup.py index 7b06bc0d..399b0bfd 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,8 @@ def setup_package(): "scipy>=1.0.0", "astropy>=2.0", "matplotlib>=1.5", + "numba", + "joblib", "docutils", "requests>=2.20", diff --git a/test-GridFieldVaryingPSF.py b/test-GridFieldVaryingPSF.py new file mode 100644 index 00000000..03686ff4 --- /dev/null +++ b/test-GridFieldVaryingPSF.py @@ -0,0 +1,103 @@ +import numpy as np +from matplotlib import pyplot as plt + +import scopesim as sim +from scopesim.effects.psfs import GridFieldVaryingPSF +import synphot +from astropy.table import Table +from astropy import units as u + +# sim.download_packages(["LFOA"]) +# sim.download_packages(["Armazones", "ELT", "MICADO", "MAORY"]) + +# Create sims and load mds +cmd = sim.UserCommands(use_instrument="MICADO", set_modes=["SCAO", "IMG_1.5mas"]) +micado = sim.OpticalTrain(cmd) +micado.cmds["!SIM.sub_pixel.flag"] = True + +for effect_name in ["full_detector_array", "micado_adc_3D_shift", + "micado_ncpas_psf", "relay_psf"]: + micado[effect_name].include = False + print(micado[effect_name]) + +micado["detector_window"].data["x_cen"] = 0 # [mm] distance from optical axis on the focal plane +micado["detector_window"].data["y_cen"] = 0 +micado["detector_window"].data["x_size"] = 256 # [pixel] width of detector +micado["detector_window"].data["y_size"] = 256 + +params = { + "filename": "psf_grid.fits", + "wave_key": "WAVELENG"} + +psf = GridFieldVaryingPSF(name="psf_grid", **params) +micado.optics_manager.add_effect(psf) + +# Object +pixel_scale = 0.0015 * u.arcsec / u.pixel +vega = synphot.spectrum.SourceSpectrum.from_vega() +sep = 64 +i = [*([-sep, 0, sep] * 3)] +j = [*([sep] * 3), *([0] * 3), *([-sep] * 3)] +x = [ii * u.pixel * pixel_scale for ii in i] +y = [jj * u.pixel * pixel_scale for jj in j] +ref = [0] * len(x) +weight = [*([10 ** -1] * 3), *([10 ** -2] * 3), *([10 ** -3] * 3)] +obj = Table(names=["x", "y", "ref", "weight"], + data=[x, y, ref, weight], + units=[u.arcsec, u.arcsec, None, None]) +src = sim.Source(table=obj, spectra=[vega]) + +# # Baseline +micado["psf_grid"].include = False +micado.observe(src, update=True) + +plt.figure(figsize=(8, 8)) +plt.imshow(micado.image_planes[0].data, origin="lower") +plt.title("Baseline") +plt.show() + + +# GridFieldVaryingPSF +micado["psf_grid"].include = True +print(micado.effects) + +micado.observe(src, update=True) +plt.figure(figsize=(8, 8)) +img = micado.image_planes[0].data +plt.imshow(img, origin="lower") +plt.title("Custom PSF") +plt.show() + + +s_img = 256 +n = 1000 +i = np.random.rand(n) * (s_img-1) - s_img // 2 +j = np.random.rand(n) * (s_img-1) - s_img // 2 +x = [ii * u.pixel * pixel_scale for ii in i] +y = [jj * u.pixel * pixel_scale for jj in j] +ref = [0] * len(x) +weight = [1e-2]*n +obj_rand = Table(names=["x", "y", "ref", "weight"], + data=[x, y, ref, weight], + units=[u.arcsec, u.arcsec, None, None]) +src_rand = sim.Source(table=obj_rand, spectra=[vega]) + +# Baseline (random) +micado["psf_grid"].include = False +micado.observe(src_rand, update=True) +micado[effect_name].include = False + +plt.figure(figsize=(8, 8)) +plt.imshow(micado.image_planes[0].data, origin="lower") +plt.title("Baseline (random)") +plt.show() + + +# GridFieldVaryingPSF (random) +micado["psf_grid"].include = True +micado.observe(src_rand, update=True) +plt.figure(figsize=(8, 8)) +img = micado.image_planes[0].data +plt.imshow(img, origin="lower") +plt.title("Custom PSF (random)") +plt.show() diff --git a/test-make_psf_grid.py b/test-make_psf_grid.py new file mode 100644 index 00000000..439153fa --- /dev/null +++ b/test-make_psf_grid.py @@ -0,0 +1,71 @@ +import numpy as np +from astropy.convolution import Gaussian2DKernel +from astropy.io import fits + +# Create PSF grid +n_psf = 11 +s_psf = 128 +epsf = [ + [Gaussian2DKernel(1 + 0.1*x + 0.1*y, x_size=s_psf, y_size=s_psf).array + for y in range(n_psf)] for x in range(n_psf)] +epsf = np.array(epsf) + +pix_size = 1.5 # mas +fov = pix_size * 256 + +grid_spacing = fov/(n_psf-1) +waveleng = 3.7 # um + +primary_hdr = fits.Header() +primary_hdr["SIMPLE"] = (True, "conforms to FITS standard") +primary_hdr["BITPIX"] = (8, "array data type") +primary_hdr["NAXIS"] = (0, "number of array dimensions") +primary_hdr["EXTEND"] = True +primary_hdr["FILETYPE"] = 'Point Spread Function (Grid)' +primary_hdr["AUTHOR"] = 'J. Aveiro' +primary_hdr["DATE"] = '2023' +primary_hdr["SOURCE"] = 'TEST' +primary_hdr["ORIGDATE"] = '2023' +primary_hdr["WAVELENG"] = (waveleng, "microns") +primary_hdr["PIXSIZE"] = (pix_size, "milliarcsec") +primary_hdr["XPOSITIO"] = (0.00000, "arcsec") +primary_hdr["YPOSITIO"] = (0.00000, "arcsec") + +image_hdr = fits.Header() +image_hdr["WAVELENG"] = (waveleng, "microns") +image_hdr["PIXSIZE"] = (pix_size, "milliarcsec") +image_hdr["NAXIS"] = 4 +image_hdr["NAXIS1"] = n_psf +image_hdr["NAXIS2"] = n_psf +image_hdr["NAXIS3"] = s_psf +image_hdr["NAXIS4"] = s_psf +image_hdr["PIXSCALE"] = (pix_size, "milliarcsec") +image_hdr["CDELT1"] = (grid_spacing, "[mas] Coordinate increment at reference point") +image_hdr["CDELT2"] = (grid_spacing, "[mas] Coordinate increment at reference point") +image_hdr["CDELT3"] = (pix_size, "[mas] Coordinate increment at reference point") +image_hdr["CDELT4"] = (pix_size, "[mas] Coordinate increment at reference point") +image_hdr["CTYPE1"] = ("LINEAR", "Coordinate type code") +image_hdr["CTYPE2"] = ("LINEAR", "Coordinate type code") +image_hdr["CTYPE3"] = ("LINEAR", "Coordinate type code") +image_hdr["CTYPE4"] = ("LINEAR", "Coordinate type code") +image_hdr["CUNIT1"] = ("mas", "Units of coordinate increment and value") +image_hdr["CUNIT2"] = ("mas", "Units of coordinate increment and value") +image_hdr["CUNIT3"] = ("mas", "Units of coordinate increment and value") +image_hdr["CUNIT4"] = ("mas", "Units of coordinate increment and value") +image_hdr["CRVAL1"] = (0.0, "[mas] Coordinate value at reference point") +image_hdr["CRVAL2"] = (0.0, "[mas] Coordinate value at reference point") +image_hdr["CRVAL3"] = (0.0, "[mas] Coordinate value at reference point") +image_hdr["CRVAL4"] = (0.0, "[mas] Coordinate value at reference point") +image_hdr["CRPIX1"] = (n_psf//2, "Grid coordinate of reference point") +image_hdr["CRPIX2"] = (n_psf//2, "Grid coordinate of reference point") +image_hdr["CRPIX3"] = (s_psf//2, "Pixel coordinate of reference point") +image_hdr["CRPIX4"] = (s_psf//2, "Pixel coordinate of reference point") + +# Construct FITS +primary_hdu = fits.PrimaryHDU(header=primary_hdr) +image_hdu = fits.ImageHDU(epsf, header=image_hdr) +hdul = fits.HDUList([primary_hdu, image_hdu]) + +# Save +filename = "psf_grid.fits" +hdul.writeto(filename, overwrite=True)