Skip to content

Commit

Permalink
JP-3588: Use Pastasoss datamodel for NIRISS SOSS transform solution (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tapastro authored Sep 20, 2024
1 parent 4d60e7a commit 40a019e
Show file tree
Hide file tree
Showing 8 changed files with 693 additions and 839 deletions.
6 changes: 6 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ emicorr

- Removed unnecessary copies, and created a single copy at step.py level. [#8676]

extract_1d
----------

- Updated NIRISS SOSS extraction to utilize ``pastasoss``
rotation solution. [#8763]

first_frame
-----------

Expand Down
5 changes: 0 additions & 5 deletions docs/jwst/extract_1d/arguments.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,6 @@ The ``extract_1d`` step has the following step-specific arguments.
used when soss_wave_grid is not provided to make sure the computation time or the memory
used stays reasonable. Default value is 20000.

``--soss_transform``
This is a NIRISS-SOSS algorithm-specific parameter; this defines a rotation to
apply to the reference files to match the observation. It should be specified as
a list of three floats, with default values of None.

``--soss_tikfac``
This is a NIRISS-SOSS algorithm-specific parameter; this is the regularization
factor used in the SOSS extraction. If not specified, ATOCA will calculate a
Expand Down
10 changes: 3 additions & 7 deletions jwst/extract_1d/extract_1d_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,13 @@ class Extract1dStep(Step):
soss_estimate = input_file(default = None) # Estimate used to generate the wavelength grid
soss_rtol = float(default=1.0e-4) # Relative tolerance needed on a pixel model
soss_max_grid_size = integer(default=20000) # Maximum grid size, if wave_grid not specified
soss_transform = list(default=None, min=3, max=3) # rotation applied to the ref files to match observation.
soss_tikfac = float(default=None) # regularization factor for NIRISS SOSS extraction
soss_width = float(default=40.) # aperture width used to extract the 1D spectrum from the de-contaminated trace.
soss_bad_pix = option("model", "masking", default="masking") # method used to handle bad pixels
soss_modelname = output_file(default = None) # Filename for optional model output of traces and pixel weights
"""

reference_file_types = ['extract1d', 'apcorr', 'wavemap', 'spectrace', 'specprofile', 'speckernel']
reference_file_types = ['extract1d', 'apcorr', 'pastasoss', 'specprofile', 'speckernel']

def process(self, input):
"""Execute the step.
Expand Down Expand Up @@ -432,8 +431,7 @@ def process(self, input):
return input_model

# Load reference files.
spectrace_ref_name = self.get_reference_file(input_model, 'spectrace')
wavemap_ref_name = self.get_reference_file(input_model, 'wavemap')
pastasoss_ref_name = self.get_reference_file(input_model, 'pastasoss')
specprofile_ref_name = self.get_reference_file(input_model, 'specprofile')
speckernel_ref_name = self.get_reference_file(input_model, 'speckernel')

Expand All @@ -444,7 +442,6 @@ def process(self, input):
soss_kwargs['tikfac'] = self.soss_tikfac
soss_kwargs['width'] = self.soss_width
soss_kwargs['bad_pix'] = self.soss_bad_pix
soss_kwargs['transform'] = self.soss_transform
soss_kwargs['subtract_background'] = self.subtract_background
soss_kwargs['rtol'] = self.soss_rtol
soss_kwargs['max_grid_size'] = self.soss_max_grid_size
Expand All @@ -458,8 +455,7 @@ def process(self, input):
# Run the extraction.
result, ref_outputs, atoca_outputs = soss_extract.run_extract1d(
input_model,
spectrace_ref_name,
wavemap_ref_name,
pastasoss_ref_name,
specprofile_ref_name,
speckernel_ref_name,
subarray,
Expand Down
74 changes: 0 additions & 74 deletions jwst/extract_1d/soss_extract/atoca_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,80 +252,6 @@ def get_wv_map_bounds(wave_map, dispersion_axis=1):
return wave_top, wave_bottom


def check_dispersion_direction(wave_map, dispersion_axis=1, dwv_sign=-1):
"""Check that the dispersion axis is increasing in the good direction
given by `dwv_sign``
Parameters
----------
wave_map : array[float]
2d-map of the pixel central wavelength
dispersion_axis : int, optional
Which axis is the dispersion axis (0 or 1)
dwv_sign : int, optional
Direction of increasing wavelengths (-1 or 1)
Returns
-------
bool_map : array[bool]
Boolean 2d map of the valid dispersion direction, same shape as `wave_map`
"""

# Estimate the direction of increasing wavelength
wave_left, wave_right = get_wv_map_bounds(wave_map, dispersion_axis=dispersion_axis)
dwv = wave_right - wave_left

# Return bool map of pixels following the good direction
bool_map = (dwv_sign * dwv >= 0)
# The bad value could be from left or right so mask both
bool_map &= np.roll(bool_map, 1, axis=dispersion_axis)

return bool_map


def mask_bad_dispersion_direction(wave_map, n_max=10, fill_value=0, dispersion_axis=1, dwv_sign=-1):
"""Change value of the pixels in `wave_map` that do not follow the
general dispersion direction.
Parameters
----------
wave_map : array[float]
2d-map of the pixel central wavelength
n_max : int
Maximum number of iterations
fill_value : float
Value use to replace pixels that do not follow the dispersion direction
dispersion_axis : int, optional
Which axis is the dispersion axis (0 or 1)
dwv_sign : int, optional
Direction of increasing wavelengths (-1 or 1)
Returns
-------
wave_map : array[float]
The corrected wave_map.
convergence flag : bool
Boolean set to True if all the pixels are now valid, False otherwise.
"""
# Do not modify the input
wave_map = wave_map.copy()

# Make the correction iteratively
for i_try in range(n_max):
# Check which pixels are good
is_good_direction = check_dispersion_direction(wave_map, dispersion_axis, dwv_sign)
# Stop iteration if all good, or apply correction where needed.
if is_good_direction.all():
convergence_flag = True
break
else:
wave_map[~is_good_direction] = fill_value
else:
# Did not succeed! :(
convergence_flag = False

return wave_map, convergence_flag


def oversample_grid(wave_grid, n_os=1):
"""Create an oversampled version of the input 1D wavelength grid.
Expand Down
Loading

0 comments on commit 40a019e

Please sign in to comment.