Skip to content

Commit

Permalink
♻️ refactor(frames): don't separate Coordinate constructor from base …
Browse files Browse the repository at this point in the history
…class
  • Loading branch information
nstarman committed Dec 29, 2024
1 parent e400275 commit 41481a3
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 103 deletions.
163 changes: 60 additions & 103 deletions src/coordinax/_src/frames/coordinate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


from textwrap import indent
from typing import Any, NoReturn
from typing import Any, ClassVar, NoReturn
from typing_extensions import override

import equinox as eqx
Expand Down Expand Up @@ -39,14 +39,6 @@ class AbstractCoordinate(AbstractVector):
#: `coordinax.frames.AbstractReferenceFrame` object.
frame: eqx.AbstractVar[AbstractReferenceFrame]

@classmethod
@dispatch # type: ignore[misc]
def from_(
cls: "type[AbstractCoordinate]", *args: Any, **kwargs: Any
) -> "AbstractCoordinate":
"""Construct a coordinate from other data."""
return super().from_(*args, **kwargs)

# ===============================================================
# Coordinate API

Expand All @@ -68,33 +60,10 @@ def to_frame(self, to_frame: AbstractReferenceFrame, /) -> "AbstractCoordinate":
)
"""
op = self.transform_op(to_frame)
op = self.frame.transform_op(to_frame)
new_data = op(self.data)
return replace(self, data=new_data, frame=to_frame)

# ===============================================================
# Frame API

def transform_op(self, to_frame: AbstractReferenceFrame, /) -> AbstractOperator:
"""Make a frame transform operator.
Examples
--------
>>> import coordinax as cx
>>> cicrs = cx.Coordinate(cx.CartesianPos3D.from_([1, 2, 3], "kpc"),
... cx.frames.ICRS())
>>> cicrs.transform_op(cx.frames.Galactocentric())
Pipe((
GalileanRotation(rotation=f32[3,3]),
GalileanSpatialTranslation(CartesianPos3D( ... )),
GalileanRotation(rotation=f32[3,3]),
VelocityBoost(CartesianVel3D( ... ))
))
"""
return self.frame.transform_op(to_frame)

# ===============================================================
# Vector API

Expand Down Expand Up @@ -125,7 +94,7 @@ def aval(self) -> NoReturn:
# ===============================================================
# Plum API

__faithful__ = True
__faithful__: ClassVar = True

# ===============================================================
# Python API
Expand Down Expand Up @@ -270,58 +239,6 @@ def _dimensionality(self) -> int:
# TODO: Space is currently not implemented.
return self.data._dimensionality() # noqa: SLF001

# ---------------------------------------------------------------
# Constructors

@classmethod
@AbstractCoordinate.from_.dispatch
def from_(
cls: "type[Coordinate]",
data: Space | AbstractPos,
frame: AbstractReferenceFrame,
/,
) -> "Coordinate":
"""Construct a coordinate from data and a frame.
Examples
--------
>>> import coordinax as cx
>>> data = cx.CartesianPos3D.from_([1, 2, 3], "kpc")
>>> cx.Coordinate.from_(data, cx.frames.ICRS())
Coordinate(
data=Space({ 'length': CartesianPos3D( ... ) }),
frame=ICRS()
)
"""
return cls(data=data, frame=frame)

@classmethod
@AbstractCoordinate.from_.dispatch
def from_(
cls: "type[Coordinate]",
data: Space | AbstractPos,
base_frame: AbstractReferenceFrame,
ops: AbstractOperator,
/,
) -> "Coordinate":
"""Construct a coordinate from data and a frame.
Examples
--------
>>> import coordinax as cx
>>> data = cx.CartesianPos3D.from_([1, 2, 3], "kpc")
>>> cx.Coordinate.from_(data, cx.frames.ICRS(), cx.ops.Identity())
Coordinate(
data=Space({ 'length': CartesianPos3D( ... ) }),
frame=TransformedReferenceFrame(base_frame=ICRS(), xop=Identity())
)
"""
return cls(data=data, frame=TransformedReferenceFrame(base_frame, ops))

# ===============================================================
# Vector API

Expand Down Expand Up @@ -363,7 +280,62 @@ def __getitem__(self: "Coordinate", index: str) -> AbstractVector:
return self.data[index]


##############################################################################
# ===============================================================
# Constructors


@dispatch
def vector(
cls: type[Coordinate],
data: Space | AbstractPos,
frame: AbstractReferenceFrame,
/,
) -> Coordinate:
"""Construct a coordinate from data and a frame.
Examples
--------
>>> import coordinax as cx
>>> data = cx.CartesianPos3D.from_([1, 2, 3], "kpc")
>>> cx.Coordinate.from_(data, cx.frames.ICRS())
Coordinate(
data=Space({ 'length': CartesianPos3D( ... ) }),
frame=ICRS()
)
"""
return cls(data=data, frame=frame)


@dispatch
def vector(
cls: type[Coordinate],
data: Space | AbstractPos,
base_frame: AbstractReferenceFrame,
ops: AbstractOperator,
/,
) -> Coordinate:
"""Construct a coordinate from data and a frame.
Examples
--------
>>> import coordinax as cx
>>> data = cx.CartesianPos3D.from_([1, 2, 3], "kpc")
>>> cx.Coordinate.from_(data, cx.frames.ICRS(), cx.ops.Identity())
Coordinate(
data=Space({ 'length': CartesianPos3D( ... ) }),
frame=TransformedReferenceFrame(base_frame=ICRS(), xop=Identity())
)
"""
frame = TransformedReferenceFrame(base_frame, ops)
return cls(data=data, frame=frame)


# ===============================================================
# Vector conversion


@dispatch # type: ignore[misc]
Expand All @@ -388,7 +360,7 @@ def vconvert(target: type[AbstractPos], w: Coordinate, /) -> Coordinate:
return replace(w, data=w.data.vconvert(target))


##############################################################################
# ===============================================================
# Transform operations


Expand Down Expand Up @@ -417,18 +389,3 @@ def call(self: AbstractOperator, x: Coordinate, /) -> Coordinate:
"""
return replace(x, data=self(x.data))


##############################################################################
# Math operations


# @register(jax.lax.add_p) # type: ignore[misc]
# def _add_crd_crd(w1: Coordinate, w2: Coordinate, /) -> Coordinate:
# """Add two coordinates."""
# # Transform w2 to w1's frame if necessary
# new2 = w2 if w1.frame == w2.frame else w2.to_frame(w1.frame)
# # Add the data
# data = w1.data + new2.data

# return replace(w1, data=data)
5 changes: 5 additions & 0 deletions src/coordinax/_src/vectors/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ def _dimensionality(self) -> int:
def from_(
cls: "type[AbstractVector]", *args: Any, **kwargs: Any
) -> "AbstractVector":
"""Create a vector from arguments.
See `coordinax.vector` for more information.
"""
return vector(cls, *args, **kwargs)

# ===============================================================
Expand Down

0 comments on commit 41481a3

Please sign in to comment.