diff --git a/src/coordinax/_src/operators/galilean/spatial_translation.py b/src/coordinax/_src/operators/galilean/spatial_translation.py index 3caab77f..d834e19a 100644 --- a/src/coordinax/_src/operators/galilean/spatial_translation.py +++ b/src/coordinax/_src/operators/galilean/spatial_translation.py @@ -23,6 +23,7 @@ from coordinax._src.vectors.d2 import CartesianPos2D from coordinax._src.vectors.d3 import CartesianPos3D from coordinax._src.vectors.d4 import FourVector +from coordinax._src.vectors.private_api import spatial_component ############################################################################## # Spatial Translations @@ -345,8 +346,15 @@ def call( # Translate the velocity (this operator will have no effect on the # velocity). # 1. convert to a Quantity in Cartesian coordinates. + qvec = spatial_component(qvec) + pvec = eqx.error_if( + pvec, + qvec._dimensionality() != pvec._dimensionality(), # noqa: SLF001 + "The position and velocity vectors must have the same dimensionality.", + ) + q = convert(qvec.vconvert(qvec._cartesian_cls), u.Quantity) # noqa: SLF001 - p = convert(pvec.vconvert(pvec._cartesian_cls, q), u.Quantity) # noqa: SLF001 + p = convert(pvec.vconvert(pvec._cartesian_cls, qvec), u.Quantity) # noqa: SLF001 # 1.5 cast to float dtype # TODO: more careful casting q, p = q.astype(float, copy=False), p.astype(float, copy=False) # 2. create the Jacobian of the operation on the position diff --git a/src/coordinax/_src/vectors/d4/spacetime.py b/src/coordinax/_src/vectors/d4/spacetime.py index badfc3cd..17c42a15 100644 --- a/src/coordinax/_src/vectors/d4/spacetime.py +++ b/src/coordinax/_src/vectors/d4/spacetime.py @@ -229,7 +229,7 @@ def __str__(self) -> str: # =============================================================== -# Constructors +# Vector API @dispatch # type: ignore[misc] @@ -280,10 +280,6 @@ def vector(cls: type[FourVector], obj: AbstractQuantity, /) -> FourVector: return cls(t=obj[..., 0] / c, q=obj[..., 1:], c=c) -# =============================================================== -# Vector Convert - - @dispatch # type: ignore[misc] def vconvert( spatial_target: type[AbstractPos3D], current: FourVector, /, **kwargs: Any @@ -304,8 +300,26 @@ def vconvert( return replace(current, q=vconvert(spatial_target, current.q, **kwargs)) +@dispatch # type: ignore[misc] +def spatial_component(x: FourVector, /) -> AbstractPos3D: + """Return the spatial component of the vector. + + Examples + -------- + >>> import unxt as u + >>> import coordinax as cx + + >>> w = cx.FourVector(t=u.Quantity(1, "s"), q=u.Quantity([1, 2, 3], "m")) + >>> print(spatial_component(w)) + + + """ + return x.q + + # =============================================================== -# Converters +# Plum API @conversion_method(type_from=FourVector, type_to=u.Quantity) # type: ignore[misc] diff --git a/src/coordinax/_src/vectors/private_api.py b/src/coordinax/_src/vectors/private_api.py new file mode 100644 index 00000000..e90260f5 --- /dev/null +++ b/src/coordinax/_src/vectors/private_api.py @@ -0,0 +1,21 @@ +"""Copyright (c) 2023 coordinax maintainers. All rights reserved.""" + +__all__ = ["spatial_component"] + +from typing import Any + +from plum import dispatch + +from .base import AbstractPos + + +@dispatch.abstract +def spatial_component(x: Any, /) -> Any: + """Return the spatial component of the vector.""" + raise NotImplementedError + + +@dispatch +def spatial_component(x: AbstractPos, /) -> AbstractPos: + """Return the spatial component of the vector.""" + return x