Skip to content

Commit

Permalink
Add zeros_like and ones_like methods to ArrayLikeFactory subclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
GiulioRomualdi committed Jan 9, 2025
1 parent 03167f8 commit 8a99b1a
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 3 deletions.
63 changes: 62 additions & 1 deletion src/adam/casadi/casadi_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def __setitem__(self, idx, value: Union["CasadiLike", npt.ArrayLike]):
else:
self.array[idx] = value.array if isinstance(value, CasadiLike) else value


def __getitem__(self, idx) -> "CasadiLike":
"""Overrides get item operator"""
if idx is Ellipsis:
Expand Down Expand Up @@ -145,6 +144,68 @@ def array(*x) -> "CasadiLike":
"""
return CasadiLike(cs.SX(*x))

@staticmethod
def zeros_like(x) -> CasadiLike:
"""
Args:
x (npt.ArrayLike): matrix
Returns:
npt.ArrayLike: zero matrix of dimension x
"""

kind = (
cs.DM
if (isinstance(x, CasadiLike) and isinstance(x.array, cs.DM))
or isinstance(x, cs.DM)
else cs.SX
)

return (
CasadiLike(kind.zeros(x.array.shape))
if isinstance(x, CasadiLike)
else (
CasadiLike(kind.zeros(x.shape))
if isinstance(x, (cs.SX, cs.DM))
else (
TypeError(f"Unsupported type for zeros_like: {type(x)}")
if isinstance(x, CasadiLike)
else CasadiLike(kind.zeros(x.shape))
)
)
)

@staticmethod
def ones_like(x) -> CasadiLike:
"""
Args:
x (npt.ArrayLike): matrix
Returns:
npt.ArrayLike: Identity matrix of dimension x
"""

kind = (
cs.DM
if (isinstance(x, CasadiLike) and isinstance(x.array, cs.DM))
or isinstance(x, cs.DM)
else cs.SX
)

return (
CasadiLike(kind.ones(x.array.shape))
if isinstance(x, CasadiLike)
else (
CasadiLike(kind.ones(x.shape))
if isinstance(x, (cs.SX, cs.DM))
else (
TypeError(f"Unsupported type for ones_like: {type(x)}")
if isinstance(x, CasadiLike)
else CasadiLike(kind.ones(x.shape))
)
)
)


class SpatialMath(SpatialMath):

Expand Down
30 changes: 28 additions & 2 deletions src/adam/core/spatial_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def T(self):
class ArrayLikeFactory(abc.ABC):
"""Abstract class for a generic Array wrapper. Every method should be implemented for every data type."""

@staticmethod
@abc.abstractmethod
def zeros(self, x: npt.ArrayLike) -> npt.ArrayLike:
def zeros(x: npt.ArrayLike) -> npt.ArrayLike:
"""
Args:
x (npt.ArrayLike): matrix dimension
Expand All @@ -79,8 +80,9 @@ def zeros(self, x: npt.ArrayLike) -> npt.ArrayLike:
"""
pass

@staticmethod
@abc.abstractmethod
def eye(self, x: npt.ArrayLike) -> npt.ArrayLike:
def eye(x: npt.ArrayLike) -> npt.ArrayLike:
"""
Args:
x (npt.ArrayLike): matrix dimension
Expand All @@ -90,6 +92,30 @@ def eye(self, x: npt.ArrayLike) -> npt.ArrayLike:
"""
pass

@staticmethod
@abc.abstractmethod
def zeros_like(x: npt.ArrayLike) -> npt.ArrayLike:
"""
Args:
x (npt.ArrayLike): matrix
Returns:
npt.ArrayLike: zero matrix of dimension x
"""
pass

@staticmethod
@abc.abstractmethod
def ones_like(x: npt.ArrayLike) -> npt.ArrayLike:
"""
Args:
x (npt.ArrayLike): matrix
Returns:
npt.ArrayLike: ones matrix of dimension x
"""
pass


class SpatialMath:
"""Class implementing the main geometric functions used for computing rigid-body algorithm
Expand Down
30 changes: 30 additions & 0 deletions src/adam/jax/jax_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,36 @@ def array(x) -> "JaxLike":
"""
return JaxLike(jnp.array(x))

@staticmethod
def zeros_like(x) -> JaxLike:
"""
Args:
x (npt.ArrayLike): matrix
Returns:
npt.ArrayLike: zero matrix of dimension x
"""
return (
JaxLike(jnp.zeros_like(x.array))
if isinstance(x, JaxLike)
else JaxLike(jnp.zeros_like(x))
)

@staticmethod
def ones_like(x) -> JaxLike:
"""
Args:
x (npt.ArrayLike): matrix
Returns:
npt.ArrayLike: Ones matrix of dimension x
"""
return (
JaxLike(jnp.ones_like(x.array))
if isinstance(x, JaxLike)
else JaxLike(jnp.ones_like(x))
)


class SpatialMath(SpatialMath):
def __init__(self):
Expand Down
30 changes: 30 additions & 0 deletions src/adam/numpy/numpy_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,36 @@ def array(x) -> "NumpyLike":
"""
return NumpyLike(np.array(x))

@staticmethod
def zeros_like(x) -> NumpyLike:
"""
Args:
x (npt.ArrayLike): matrix
Returns:
npt.ArrayLike: zero matrix of dimension x
"""
return (
NumpyLike(np.zeros_like(x.array))
if isinstance(x, NumpyLike)
else NumpyLike(np.zeros_like(x))
)

@staticmethod
def ones_like(x) -> NumpyLike:
"""
Args:
x (npt.ArrayLike): matrix
Returns:
npt.ArrayLike: Ones matrix of dimension x
"""
return (
NumpyLike(np.ones_like(x.array))
if isinstance(x, NumpyLike)
else NumpyLike(np.ones_like(x))
)


class SpatialMath(SpatialMath):
def __init__(self):
Expand Down
30 changes: 30 additions & 0 deletions src/adam/pytorch/torch_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,36 @@ def array(x: ntp.ArrayLike) -> "TorchLike":
"""
return TorchLike(torch.tensor(x))

@staticmethod
def zeros_like(x) -> TorchLike:
"""
Args:
x (npt.ArrayLike): matrix
Returns:
npt.ArrayLike: zero matrix of dimension x
"""
return (
TorchLike(torch.zeros_like(x.array))
if isinstance(x, TorchLike)
else TorchLike(torch.zeros_like(x))
)

@staticmethod
def ones_like(x) -> TorchLike:
"""
Args:
x (npt.ArrayLike): matrix
Returns:
npt.ArrayLike: Identity matrix of dimension x
"""
return (
TorchLike(torch.ones_like(x.array))
if isinstance(x, TorchLike)
else TorchLike(torch.ones_like(x))
)


class SpatialMath(SpatialMath):
def __init__(self):
Expand Down
13 changes: 13 additions & 0 deletions tests/test_casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from conftest import RobotCfg, State

from adam.casadi import KinDynComputations
from adam.casadi.casadi_like import CasadiLike, CasadiLikeFactory


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -173,3 +174,15 @@ def test_gravity_term(setup_test):
assert idyn_gravity - adam_gravity == pytest.approx(0.0, abs=1e-4)
adam_gravity = cs.DM(adam_kin_dyn.gravity_term_fun()(state.H, state.joints_pos))
assert idyn_gravity - adam_gravity == pytest.approx(0.0, abs=1e-4)


def test_casadi_like():
B = cs.DM([[1.0, 2.0], [3.0, 4.0]])
B_like = CasadiLike(B)
assert B_like[...].array - B == pytest.approx(0.0, abs=1e-5)

ones = CasadiLikeFactory.ones_like(B)
assert ones[...].array - cs.DM.ones(2, 2) == pytest.approx(0.0, abs=1e-5)

zeros = CasadiLikeFactory.zeros_like(B)
assert zeros[...].array - cs.DM.zeros(2, 2) == pytest.approx(0.0, abs=1e-5)
14 changes: 14 additions & 0 deletions tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import pytest
from conftest import RobotCfg, State
from jax import config
import jax.numpy as jnp


from adam.jax import KinDynComputations
from adam.jax.jax_like import JaxLike, JaxLikeFactory

config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -119,3 +122,14 @@ def test_gravity_term(setup_test):
idyn_gravity = robot_cfg.idyn_function_values.gravity_term
adam_gravity = adam_kin_dyn.gravity_term(state.H, state.joints_pos)
assert idyn_gravity - adam_gravity == pytest.approx(0.0, abs=1e-4)

def test_jax_like():
B = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
B_like = JaxLike(B)
assert B_like[...].array - B == pytest.approx(0.0, abs=1e-5)

ones = JaxLikeFactory.ones_like(B_like)
assert ones.array - jnp.ones_like(B) == pytest.approx(0.0, abs=1e-5)

zeros = JaxLikeFactory.zeros_like(B_like)
assert zeros.array - jnp.zeros_like(B) == pytest.approx(0.0, abs=1e-5)
13 changes: 13 additions & 0 deletions tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from conftest import RobotCfg, State

from adam.numpy import KinDynComputations
from adam.numpy.numpy_like import NumpyLike, NumpyLikeFactory


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -116,3 +117,15 @@ def test_gravity_term(setup_test):
idyn_gravity = robot_cfg.idyn_function_values.gravity_term
adam_gravity = adam_kin_dyn.gravity_term(state.H, state.joints_pos)
assert idyn_gravity - adam_gravity == pytest.approx(0.0, abs=1e-4)


def test_numpy_like():
B = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
B_like = NumpyLike(B)
assert B_like[...].array - B == pytest.approx(0.0, abs=1e-5)

ones = NumpyLikeFactory.ones_like(B_like)
assert ones.array - np.ones_like(B) == pytest.approx(0.0, abs=1e-5)

zeros = NumpyLikeFactory.zeros_like(B_like)
assert zeros.array - np.zeros_like(B) == pytest.approx(0.0, abs=1e-5)
13 changes: 13 additions & 0 deletions tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from conftest import RobotCfg, State

from adam.pytorch import KinDynComputations
from adam.pytorch.torch_like import TorchLike, TorchLikeFactory

torch.set_default_dtype(torch.float64)

Expand Down Expand Up @@ -128,3 +129,15 @@ def test_gravity_term(setup_test):
idyn_gravity = robot_cfg.idyn_function_values.gravity_term
adam_gravity = adam_kin_dyn.gravity_term(state.H, state.joints_pos)
assert idyn_gravity - adam_gravity.numpy() == pytest.approx(0.0, abs=1e-4)


def test_torch_like():
B = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
B_like = TorchLike(B)
assert B_like[...].array - B == pytest.approx(0.0, abs=1e-5)

ones = TorchLikeFactory.ones_like(B_like)
assert ones.array - np.ones_like(B) == pytest.approx(0.0, abs=1e-5)

zeros = TorchLikeFactory.zeros_like(B_like)
assert zeros.array - np.zeros_like(B) == pytest.approx(0.0, abs=1e-5)

0 comments on commit 8a99b1a

Please sign in to comment.