Skip to content

Commit

Permalink
Refactor handling of invalid exposure types.
Browse files Browse the repository at this point in the history
  • Loading branch information
mairanteodoro committed Sep 16, 2024
1 parent 0ef1c1d commit d49afef
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 122 deletions.
22 changes: 21 additions & 1 deletion romancal/tweakreg/tests/test_tweakreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from astropy import coordinates as coord
from astropy import table
from astropy import units as u
from astropy.modeling import models
from astropy.modeling import models, Model
from astropy.modeling.models import RotationSequence3D, Scale, Shift
from astropy.time import Time
from gwcs import coordinate_frames as cf
Expand Down Expand Up @@ -984,3 +984,23 @@ def test_parse_catfile_raises_error_on_invalid_content(tmp_path, catfile_line_co
trs._parse_catfile(catfile)

assert type(exec_info.value) == ValueError


@pytest.mark.parametrize(
"exposure_type",
["WFI_GRISM", "WFI_PRISM", "WFI_DARK", "WFI_FLAT", "WFI_WFSC"],
)
def test_tweakreg_skips_invalid_exposure_types(exposure_type, tmp_path, base_image):
"""Test that TweakReg updates meta.cal_step with tweakreg = COMPLETE."""
img1 = base_image(shift_1=1000, shift_2=1000)
img1.meta.exposure.type = exposure_type
img2 = base_image(shift_1=1000, shift_2=1000)
img2.meta.exposure.type = exposure_type
res = trs.TweakRegStep.call([img1, img2])

assert type(res) == ModelLibrary
with res:
for i, model in enumerate(res):
assert hasattr(model.meta.cal_step, "tweakreg")
assert model.meta.cal_step.tweakreg == "SKIPPED"
res.shelve(model, i, modify=False)
246 changes: 125 additions & 121 deletions romancal/tweakreg/tweakreg_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,142 +155,146 @@ def process(self, input):
self.expand_refcat = True

# build the catalogs for input images
imcats = []
with images:
for i, image_model in enumerate(images):
exposure_type = image_model.meta.exposure.type
if exposure_type != "WFI_IMAGE":
self.log.info("Skipping TweakReg for spectral exposure.")
image_model.meta.cal_step.tweakreg = "SKIPPED"
images.shelve(image_model)
return image_model

source_detection = getattr(image_model.meta, "source_detection", None)
if source_detection is None:
images.shelve(image_model, i, modify=False)
raise AttributeError(
"Attribute 'meta.source_detection' is missing. "
"Please either run SourceDetectionStep or provide a custom source catalog."
else:
source_detection = getattr(
image_model.meta, "source_detection", None
)
if source_detection is None:
images.shelve(image_model, i, modify=False)
raise AttributeError(
"Attribute 'meta.source_detection' is missing. "
"Please either run SourceDetectionStep or provide a custom source catalog."
)

try:
catalog = self.get_tweakreg_catalog(
source_detection, image_model, i
)
except AttributeError as e:
self.log.error(f"Failed to retrieve tweakreg_catalog: {e}")
images.shelve(image_model, i, modify=False)
raise AttributeError() from e

try:
for axis in ["x", "y"]:
# validate catalog columns
if axis not in catalog.colnames:
long_axis = f"{axis}centroid"
if long_axis in catalog.colnames:
catalog.rename_column(long_axis, axis)
else:
raise ValueError(
"'tweakreg' source catalogs must contain a header with "
"columns named either 'x' and 'y' or 'xcentroid' and 'ycentroid'."
)
except ValueError as e:
self.log.error(f"Failed to validate catalog columns: {e}")
images.shelve(image_model, i, modify=False)
raise ValueError() from e

filename = image_model.meta.filename
catalog = tweakreg.filter_catalog_by_bounding_box(
catalog, image_model.meta.wcs.bounding_box
)

try:
catalog = self.get_tweakreg_catalog(
source_detection, image_model, i
if self.save_abs_catalog:
output_name = os.path.join(
self.catalog_path, f"fit_{self.abs_refcat.lower()}_ref.ecsv"
)
catalog.write(
output_name, format=self.catalog_format, overwrite=True
)

image_model.meta["tweakreg_catalog"] = catalog.as_array()
nsources = len(catalog)
self.log.info(
f"Detected {nsources} sources in {filename}."
if nsources
else f"No sources found in {filename}."
)
except AttributeError as e:
self.log.error(f"Failed to retrieve tweakreg_catalog: {e}")
images.shelve(image_model, i, modify=False)
raise AttributeError() from e

try:
for axis in ["x", "y"]:
# validate catalog columns
if axis not in catalog.colnames:
long_axis = f"{axis}centroid"
if long_axis in catalog.colnames:
catalog.rename_column(long_axis, axis)
else:
raise ValueError(
"'tweakreg' source catalogs must contain a header with "
"columns named either 'x' and 'y' or 'xcentroid' and 'ycentroid'."
)
except ValueError as e:
self.log.error(f"Failed to validate catalog columns: {e}")
images.shelve(image_model, i, modify=False)
raise ValueError() from e

filename = image_model.meta.filename
catalog = tweakreg.filter_catalog_by_bounding_box(
catalog, image_model.meta.wcs.bounding_box
)

if self.save_abs_catalog:
output_name = os.path.join(
self.catalog_path, f"fit_{self.abs_refcat.lower()}_ref.ecsv"
# build image catalog
# catalog name
catalog_name = os.path.splitext(image_model.meta.filename)[0].strip(
"_- "
)
catalog.write(
output_name, format=self.catalog_format, overwrite=True
# catalog data
catalog_table = Table(image_model.meta.tweakreg_catalog)
catalog_table.meta["name"] = catalog_name

imcats.append(
tweakreg.construct_wcs_corrector(
wcs=image_model.meta.wcs,
refang=image_model.meta.wcsinfo,
catalog=catalog_table,
group_id=image_model.meta.group_id,
)
)

image_model.meta["tweakreg_catalog"] = catalog.as_array()
nsources = len(catalog)
self.log.info(
f"Detected {nsources} sources in {filename}."
if nsources
else f"No sources found in {filename}."
)
images.shelve(image_model, i)

# build image catalogs
imcats = []
with images:
for i, m in enumerate(images):
# catalog name
catalog_name = os.path.splitext(m.meta.filename)[0].strip("_- ")
# catalog data
catalog_table = Table(m.meta.tweakreg_catalog)
catalog_table.meta["name"] = catalog_name

imcats.append(
tweakreg.construct_wcs_corrector(
wcs=m.meta.wcs,
refang=m.meta.wcsinfo,
catalog=catalog_table,
group_id=m.meta.group_id,
)
)
images.shelve(m, i, modify=False)
# run alignment only if it was possible to build image catalogs
if len(imcats):
if getattr(images, "group_indices", None) and len(images.group_indices) > 1:
self.do_relative_alignment(imcats)

if getattr(images, "group_indices", None) and len(images.group_indices) > 1:
self.do_relative_alignment(imcats)
if self.abs_refcat in SINGLE_GROUP_REFCAT:
self.do_absolute_alignment(ref_image, imcats)

if self.abs_refcat in SINGLE_GROUP_REFCAT:
self.do_absolute_alignment(ref_image, imcats)

# finalize step
with images:
for i, imcat in enumerate(imcats):
image_model = images.borrow(i)
image_model.meta.cal_step["tweakreg"] = "COMPLETE"
# remove source catalog
del image_model.meta["tweakreg_catalog"]

# retrieve fit status and update wcs if fit is successful:
if "SUCCESS" in imcat.meta.get("fit_info")["status"]:
# Update/create the WCS .name attribute with information
# on this astrometric fit as the only record that it was
# successful:

# NOTE: This .name attrib agreed upon by the JWST Cal
# Working Group.
# Current value is merely a place-holder based
# on HST conventions. This value should also be
# translated to the FITS WCSNAME keyword
# IF that is what gets recorded in the archive
# for end-user searches.
imcat.wcs.name = f"FIT-LVL2-{self.abs_refcat}"

# serialize object from tweakwcs
# (typecasting numpy objects to python types so that it doesn't cause an
# issue when saving datamodel to ASDF)
wcs_fit_results = {
k: v.tolist() if isinstance(v, (np.ndarray, np.bool_)) else v
for k, v in imcat.meta["fit_info"].items()
}
# add fit results and new WCS to datamodel
image_model.meta["wcs_fit_results"] = wcs_fit_results
# remove unwanted keys from WCS fit results
for k in [
"eff_minobj",
"matched_ref_idx",
"matched_input_idx",
"fit_RA",
"fit_DEC",
"fitmask",
]:
del image_model.meta["wcs_fit_results"][k]

image_model.meta.wcs = imcat.wcs
images.shelve(image_model, i)
# finalize step
with images:
for i, imcat in enumerate(imcats):
image_model = images.borrow(i)
image_model.meta.cal_step["tweakreg"] = "COMPLETE"
# remove source catalog
del image_model.meta["tweakreg_catalog"]

# retrieve fit status and update wcs if fit is successful:
if "SUCCESS" in imcat.meta.get("fit_info")["status"]:
# Update/create the WCS .name attribute with information
# on this astrometric fit as the only record that it was
# successful:

# NOTE: This .name attrib agreed upon by the JWST Cal
# Working Group.
# Current value is merely a place-holder based
# on HST conventions. This value should also be
# translated to the FITS WCSNAME keyword
# IF that is what gets recorded in the archive
# for end-user searches.
imcat.wcs.name = f"FIT-LVL2-{self.abs_refcat}"

# serialize object from tweakwcs
# (typecasting numpy objects to python types so that it doesn't cause an
# issue when saving datamodel to ASDF)
wcs_fit_results = {
k: (
v.tolist()
if isinstance(v, (np.ndarray, np.bool_))
else v
)
for k, v in imcat.meta["fit_info"].items()
}
# add fit results and new WCS to datamodel
image_model.meta["wcs_fit_results"] = wcs_fit_results
# remove unwanted keys from WCS fit results
for k in [
"eff_minobj",
"matched_ref_idx",
"matched_input_idx",
"fit_RA",
"fit_DEC",
"fitmask",
]:
del image_model.meta["wcs_fit_results"][k]

image_model.meta.wcs = imcat.wcs
images.shelve(image_model, i)

return images

Expand Down

0 comments on commit d49afef

Please sign in to comment.