Skip to content

Commit

Permalink
🎨 style(vecs): improve base
Browse files Browse the repository at this point in the history
  • Loading branch information
nstarman committed Jan 4, 2025
1 parent 91d509c commit daba406
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions src/coordinax/_src/vectors/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

__all__ = ["AbstractVector"]

import math
from abc import abstractmethod
from collections.abc import Callable, Mapping
from types import MappingProxyType
Expand Down Expand Up @@ -222,7 +223,7 @@ def aval(self) -> jax.core.ShapedArray:
# ===============================================================
# Array API

def __array_namespace__(self) -> "ArrayAPINamespace":
def __array_namespace__(self) -> Any:
"""Return the array API namespace.
Here we return the `quaxed.numpy` module, which is a drop-in replacement
Expand Down Expand Up @@ -333,7 +334,7 @@ def size(self) -> int:
4
"""
return int(jnp.prod(jnp.asarray(self.shape)))
return int(math.prod(self.shape))

@property
def T(self) -> "Self": # noqa: N802
Expand Down Expand Up @@ -476,8 +477,16 @@ def __eq__(self: "AbstractVector", other: object) -> Any:
>>> vec == 2
False
Positions are covered by a separate dispatch. So here we show velocities
and accelerations:
And positions.
>>> q = cx.vecs.CylindricalPos(rho=u.Quantity([1.0, 2.0], "kpc"),
... phi=u.Quantity([0.0, 0.2], "rad"),
... z=u.Quantity(0.0, "kpc"))
>>> q == q
Array([ True, True], dtype=bool)
Most positions are covered by a separate dispatch. So here we show
velocities and accelerations:
>>> vel1 = cx.vecs.CartesianVel1D(u.Quantity([1, 2, 3], "km/s"))
>>> vel2 = cx.vecs.CartesianVel1D(u.Quantity([1, 0, 3], "km/s"))
Expand Down Expand Up @@ -534,12 +543,6 @@ def __eq__(self: "AbstractVector", other: object) -> Any:
>>> vel1 == vel2
Array([ True, False], dtype=bool)
>>> q = cx.vecs.CylindricalPos(rho=u.Quantity([1.0, 2.0], "kpc"),
... phi=u.Quantity([0.0, 0.2], "rad"),
... z=u.Quantity(0.0, "kpc"))
>>> q == q
Array([ True, True], dtype=bool)
"""
if type(other) is not type(self):
return NotImplemented
Expand Down

0 comments on commit daba406

Please sign in to comment.