Skip to content

Commit

Permalink
Add conveter for asymmetric mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
Cadair committed Oct 16, 2023
1 parent de9b3d0 commit f4eaf1b
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 67 deletions.
54 changes: 33 additions & 21 deletions dkist/io/asdf/converters/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,8 @@ def select_tag(self, obj, tags, ctx):
from dkist.wcs.models import (InverseVaryingCelestialTransform,
InverseVaryingCelestialTransform2D,
InverseVaryingCelestialTransform3D,
InverseVaryingCelestialTransformSlit,
InverseVaryingCelestialTransformSlit2D,
InverseVaryingCelestialTransformSlit3D,
VaryingCelestialTransform, VaryingCelestialTransform2D,
VaryingCelestialTransform3D, VaryingCelestialTransformSlit,
VaryingCelestialTransformSlit2D,
VaryingCelestialTransformSlit3D)
VaryingCelestialTransform3D)

if isinstance(
obj,
Expand All @@ -49,20 +44,6 @@ def select_tag(self, obj, tags, ctx):
InverseVaryingCelestialTransform3D)
):
return "asdf://dkist.nso.edu/tags/inverse_varying_celestial_transform-1.0.0"
elif isinstance(
obj,
(VaryingCelestialTransformSlit,
VaryingCelestialTransformSlit2D,
VaryingCelestialTransformSlit3D)
):
return "asdf://dkist.nso.edu/tags/varying_celestial_transform_slit-1.0.0"
elif isinstance(
obj,
(InverseVaryingCelestialTransformSlit,
InverseVaryingCelestialTransformSlit2D,
InverseVaryingCelestialTransformSlit3D)
):
return "asdf://dkist.nso.edu/tags/inverse_varying_celestial_transform_slit-1.0.0"
else:
raise ValueError(f"Unsupported object: {obj}") # pragma: no cover

Expand All @@ -73,6 +54,11 @@ def from_yaml_tree_transform(self, node, tag, ctx):
if "inverse_varying_celestial_transform" in tag:
inverse = True

# Support reading files with the old Slit classes in them
slit = None
if "_slit" in tag:
slit = 1

return varying_celestial_transform_from_tables(
crpix=node["crpix"],
cdelt=node["cdelt"],
Expand All @@ -81,7 +67,7 @@ def from_yaml_tree_transform(self, node, tag, ctx):
pc_table=node["pc_table"],
projection=node["projection"],
inverse=inverse,
slit="_slit" in tag
slit=slit,
)

def to_yaml_tree_transform(self, model, tag, ctx):
Expand Down Expand Up @@ -161,3 +147,29 @@ def from_yaml_tree_transform(self, node, tag, ctx):
from dkist.wcs.models import Ravel

return Ravel(node["array_shape"], order=node["order"])


class AsymmetricMappingConverter(TransformConverterBase):
"""
ASDF serialization support for Ravel
"""

tags = [
"asdf://dkist.nso.edu/tags/asymmetric_mapping_model-1.0.0"
]

types = ["dkist.wcs.models.AsymmetricMapping"]

def to_yaml_tree_transform(self, model, tag, ctx):
node = {
"forward_mapping": model.forward_mapping,
"backward_mapping": model.backward_mapping,
"forward_n_inputs": model.forward_n_inputs,
"backward_n_inputs": model.backward_n_inputs,
}
return node

def from_yaml_tree_transform(self, node, tag, ctx):
from dkist.wcs.models import AsymmetricMapping

return AsymmetricMapping(**node)
22 changes: 22 additions & 0 deletions dkist/io/asdf/resources/manifests/dkist-wcs-1.2.0.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
%YAML 1.1
---
id: asdf://dkist.nso.edu/dkist/manifests/dkist-wcs-1.2.0
extension_uri: asdf://dkist.nso.edu/dkist/extensions/dkist-wcs-1.2.0
title: DKIST WCS extension
description: ASDF schemas and tags for models and WCS related classes.

tags:
- schema_uri: "asdf://dkist.nso.edu/schemas/varying_celestial_transform-1.0.0"
tag_uri: "asdf://dkist.nso.edu/tags/varying_celestial_transform-1.0.0"
- schema_uri: "asdf://dkist.nso.edu/schemas/varying_celestial_transform-1.0.0"
tag_uri: "asdf://dkist.nso.edu/tags/inverse_varying_celestial_transform-1.0.0"
- schema_uri: "asdf://dkist.nso.edu/schemas/coupled_compound_model-1.0.0"
tag_uri: "asdf://dkist.nso.edu/tags/coupled_compound_model-1.0.0"
- schema_uri: "asdf://dkist.nso.edu/schemas/varying_celestial_transform-1.0.0"
tag_uri: "asdf://dkist.nso.edu/tags/varying_celestial_transform_slit-1.0.0"
- schema_uri: "asdf://dkist.nso.edu/schemas/varying_celestial_transform-1.0.0"
tag_uri: "asdf://dkist.nso.edu/tags/inverse_varying_celestial_transform_slit-1.0.0"
- schema_uri: "asdf://dkist.nso.edu/schemas/ravel_model-1.0.0"
tag_uri: "asdf://dkist.nso.edu/tags/ravel_model-1.0.0"
- schema_uri: "asdf://dkist.nso.edu/schemas/asymmetric_mapping_model-1.0.0"
tag_uri: "asdf://dkist.nso.edu/tags/asymmetric_mapping_model-1.0.0"
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
%YAML 1.1
---
$schema: "http://stsci.edu/schemas/yaml-schema/draft-01"
id: "asdf://dkist.nso.edu/schemas/asymmetric_mapping_model-1.0.0"
title: >
Reorder, add and drop axes with different mappings in forward and reverse transforms.
definitions:
mapping:
type: array
items:
type: integer

allOf:
- $ref: "http://stsci.edu/schemas/asdf/transform/transform-1.2.0"
- properties:
forward_n_inputs:
description: |
Explicitly set the number of input axes in the forward direction.
type: integer
backward_n_inputs:
description: |
Explicitly set the number of input axes in the backward direction.
type: integer
forward_mapping:
$ref: "#/definitions/mapping"
backward_mapping:
$ref: "#/definitions/mapping"
required: [forward_mapping, backward_mapping, forward_n_inputs, backward_n_inputs]
...
8 changes: 4 additions & 4 deletions dkist/io/asdf/resources/schemas/ravel_model-1.0.0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ description:
allOf:
- $ref: "http://stsci.edu/schemas/asdf/transform/transform-1.2.0"
- properties:
array_shape:
type: array
order:
type: string
array_shape:
type: array
order:
type: string
required: [array_shape, order]
38 changes: 18 additions & 20 deletions dkist/io/asdf/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,9 @@
from astropy.coordinates.matrix_utilities import rotation_matrix
from astropy.modeling import CompoundModel

from dkist.wcs.models import (CoupledCompoundModel,
InverseVaryingCelestialTransform,
InverseVaryingCelestialTransform2D,
Ravel,
Unravel,
VaryingCelestialTransform,
VaryingCelestialTransform2D,
from dkist.wcs.models import (CoupledCompoundModel, InverseVaryingCelestialTransform,
InverseVaryingCelestialTransform2D, Ravel, Unravel,
VaryingCelestialTransform, VaryingCelestialTransform2D,
varying_celestial_transform_from_tables)


Expand Down Expand Up @@ -75,29 +71,31 @@ def test_roundtrip_vct_slit():
for a in np.linspace(0, 90, 10)] * u.pix

vct = varying_celestial_transform_from_tables(crpix=(5, 5) * u.pix,
cdelt=(1, 1) * u.arcsec/u.pix,
crval_table=(0, 0) * u.arcsec,
pc_table=varying_matrix_lt,
lon_pole=180 * u.deg)
cdelt=(1, 1) * u.arcsec/u.pix,
crval_table=(0, 0) * u.arcsec,
pc_table=varying_matrix_lt,
lon_pole=180 * u.deg,
slit=0)
new_vct = roundtrip_object(vct)
assert isinstance(new_vct, VaryingCelestialTransformSlit)
assert isinstance(new_vct, CompoundModel)
new_ivct = roundtrip_object(vct.inverse)
assert isinstance(new_ivct, InverseVaryingCelestialTransformSlit)
assert isinstance(new_ivct, CompoundModel)


def test_roundtrip_vct_slit2d():
varying_matrix_lt = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 15)] * u.pix
varying_matrix_lt = varying_matrix_lt.reshape((5, 3, 2, 2))

vct = VaryingCelestialTransformSlit2D(crpix=(5, 5) * u.pix,
cdelt=(1, 1) * u.arcsec/u.pix,
crval_table=(0, 0) * u.arcsec,
pc_table=varying_matrix_lt,
lon_pole=180 * u.deg)
vct = varying_celestial_transform_from_tables(crpix=(5, 5) * u.pix,
cdelt=(1, 1) * u.arcsec/u.pix,
crval_table=(0, 0) * u.arcsec,
pc_table=varying_matrix_lt,
lon_pole=180 * u.deg,
slit=0)
new_vct = roundtrip_object(vct)
assert isinstance(new_vct, VaryingCelestialTransformSlit2D)
assert isinstance(new_vct, CompoundModel)
new_ivct = roundtrip_object(vct.inverse)
assert isinstance(new_ivct, InverseVaryingCelestialTransformSlit2D)
assert isinstance(new_ivct, CompoundModel)


def test_coupled_compound_model():
Expand Down
32 changes: 16 additions & 16 deletions dkist/wcs/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import Union, Iterable
from typing import Union, Literal, Iterable

import numpy as np

Expand All @@ -23,7 +23,7 @@
"BaseVaryingCelestialTransform",
"BaseVaryingCelestialTransform2D",
"generate_celestial_transform",
"BiDirectionalMapping",
"AsymmetricMapping",
"varying_celestial_transform_from_tables",
"Ravel",
"Unravel",
Expand Down Expand Up @@ -633,18 +633,18 @@ def _calculate_separability_matrix(self):
return matrix


class BiDirectionalMapping(m.Mapping):
class AsymmetricMapping(m.Mapping):
"""
A Mapping which uses a different mapping for the forward and backward directions.
"""
def __init__(
self,
forward_mapping,
backward_mapping,
forward_n_inputs=None,
backward_n_inputs=None,
name=None,
meta=None,
self,
forward_mapping,
backward_mapping,
forward_n_inputs=None,
backward_n_inputs=None,
name=None,
meta=None,
):
super().__init__(forward_mapping, n_inputs=forward_n_inputs, name=name, meta=meta)
self.backward_mapping = backward_mapping
Expand Down Expand Up @@ -673,8 +673,8 @@ def varying_celestial_transform_from_tables(
crval_table: Union[Iterable[float], u.Quantity],
lon_pole: Union[float, u.Quantity] = None,
projection: Model = m.Pix2Sky_TAN(),
inverse=False,
slit=None,
inverse: bool = False,
slit: Union[None, Literal[0, 1]] = None,
) -> BaseVaryingCelestialTransform:
"""
Generate a `.BaseVaryingCelestialTransform` based on the dimensionality of the tables.
Expand Down Expand Up @@ -706,10 +706,10 @@ def varying_celestial_transform_from_tables(
mapping = list(range(table_d + 2 - 1))
mapping.insert(2, slit)
backward_mapping = [[1, 0][slit]]
transform = BiDirectionalMapping(forward_mapping=mapping,
backward_mapping=backward_mapping,
backward_n_inputs=transform.inverse.n_outputs,
name="SlitMapping") | transform
transform = AsymmetricMapping(forward_mapping=mapping,
backward_mapping=backward_mapping,
backward_n_inputs=transform.inverse.n_outputs,
name="SlitMapping") | transform
return transform


Expand Down
13 changes: 7 additions & 6 deletions dkist/wcs/tests/test_coupled_compound_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from astropy.modeling.separable import separability_matrix

from dkist.wcs.models import (CoupledCompoundModel, VaryingCelestialTransform,
VaryingCelestialTransform2D, VaryingCelestialTransformSlit,
VaryingCelestialTransformSlit2D)
VaryingCelestialTransform2D, varying_celestial_transform_from_tables)


@pytest.fixture
Expand Down Expand Up @@ -137,9 +136,10 @@ def test_coupled_slit_no_repeat(linear_time):
kwargs = dict(crpix=(5, 5) * u.pix,
cdelt=(1, 1) * u.arcsec/u.pix,
crval_table=(0, 0) * u.arcsec,
lon_pole=180 * u.deg)
lon_pole=180 * u.deg,
slit=1)

vct_slit = VaryingCelestialTransformSlit(pc_table=pc_table, **kwargs)
vct_slit = varying_celestial_transform_from_tables(pc_table=pc_table, **kwargs)

tfrm = CoupledCompoundModel("&", vct_slit, linear_time, shared_inputs=1)
pixel = (0*u.pix, 4*u.pix)
Expand All @@ -155,9 +155,10 @@ def test_coupled_slit_with_repeat(linear_time):
kwargs = dict(crpix=(5, 5) * u.pix,
cdelt=(1, 1) * u.arcsec/u.pix,
crval_table=(0, 0) * u.arcsec,
lon_pole=180 * u.deg)
lon_pole=180 * u.deg,
slit=1)

vct_slit = VaryingCelestialTransformSlit2D(pc_table=pc_table, **kwargs)
vct_slit = varying_celestial_transform_from_tables(pc_table=pc_table, **kwargs)

tfrm = CoupledCompoundModel("&", vct_slit, linear_time & linear_time, shared_inputs=2)
pixel = (0*u.pix, 0*u.pix, 0*u.pix)
Expand Down

0 comments on commit f4eaf1b

Please sign in to comment.