diff --git a/README.md b/README.md index a92a5c39..83b6486d 100644 --- a/README.md +++ b/README.md @@ -17,18 +17,20 @@ **adam** is based on Roy Featherstone's Rigid Body Dynamics Algorithms. ### Table of contents - - [🐍 Dependencies](#-dependencies) - - [πŸ’Ύ Installation](#-installation) - - [🐍 Installation with pip](#-installation-with-pip) - - [πŸ“¦ Installation with conda](#-installation-with-conda) - - [Installation from conda-forge package](#installation-from-conda-forge-package) - - [πŸ”¨ Installation from repo](#-installation-from-repo) - - [πŸš€ Usage](#-usage) - - [Jax interface](#jax-interface) - - [CasADi interface](#casadi-interface) - - [PyTorch interface](#pytorch-interface) - - [πŸ¦Έβ€β™‚οΈ Contributing](#️-contributing) - - [Todo](#todo) + +- [🐍 Dependencies](#-dependencies) +- [πŸ’Ύ Installation](#-installation) + - [🐍 Installation with pip](#-installation-with-pip) + - [πŸ“¦ Installation with conda](#-installation-with-conda) + - [Installation from conda-forge package](#installation-from-conda-forge-package) + - [πŸ”¨ Installation from repo](#-installation-from-repo) +- [πŸš€ Usage](#-usage) + - [Jax interface](#jax-interface) + - [CasADi interface](#casadi-interface) + - [PyTorch interface](#pytorch-interface) + - [PyTorch Batched interface](#pytorch-batched-interface) +- [πŸ¦Έβ€β™‚οΈ Contributing](#️-contributing) +- [Todo](#todo) ## 🐍 Dependencies @@ -41,6 +43,7 @@ Other requisites are: - `casadi` - `pytorch` - `numpy` +- `jax2torch` They will be installed in the installation step! @@ -114,6 +117,9 @@ mamba create -n adamenv -c conda-forge adam-robotics If you want to use `jax` or `pytorch`, just install the corresponding package as well. +> [!NOTE] +> Check also the conda JAX installation guide [here](https://jax.readthedocs.io/en/latest/installation.html#conda-community-supported) + ### πŸ”¨ Installation from repo Install in a conda environment the required dependencies: @@ -133,13 +139,13 @@ Install in a conda environment the required dependencies: - **PyTorch** interface dependencies: ```bash - mamba create -n adamenv -c conda-forge pytorch numpy lxml prettytable matplotlib urdfdom-py + mamba create -n adamenv -c conda-forge pytorch numpy lxml prettytable matplotlib urdfdom-py jax2torch ``` - **ALL** interfaces dependencies: ```bash - mamba create -n adamenv -c conda-forge jax casadi pytorch numpy lxml prettytable matplotlib urdfdom-py + mamba create -n adamenv -c conda-forge jax casadi pytorch numpy lxml prettytable matplotlib urdfdom-py jax2torch ``` Activate the environment, clone the repo and install the library: @@ -154,10 +160,13 @@ pip install --no-deps . ## πŸš€ Usage The following are small snippets of the use of **adam**. More examples are arriving! -Have also a look at te `tests` folder. +Have also a look at the `tests` folder. ### Jax interface +> [!NOTE] +> Check also the Jax installation guide [here](https://jax.readthedocs.io/en/latest/installation.html#) + ```python import adam from adam.jax import KinDynComputations @@ -205,11 +214,14 @@ jitted_vmapped_frame_fk = jit(vmapped_frame_fk) # and called on a batch of data joints_batch = jnp.tile(joints, (1024, 1)) w_H_b_batch = jnp.tile(w_H_b, (1024, 1, 1)) - w_H_f_batch = jitted_vmapped_frame_fk(w_H_b_batch, joints_batch) + ``` +> [!NOTE] +> The first call of the jitted function can be slow, since JAX needs to compile the function. Then it will be faster! + ### CasADi interface ```python @@ -251,7 +263,6 @@ joints = cs.MX.sym('joints', len(joints_name_list)) M = kinDyn.mass_matrix_fun() print(M(w_H_b, joints)) - ``` ### PyTorch interface @@ -284,6 +295,43 @@ M = kinDyn.mass_matrix(w_H_b, joints) print(M) ``` +### PyTorch Batched interface + +> [!NOTE] +> When using this interface, note that the first call of the jitted function can be slow, since JAX needs to compile the function. Then it will be faster! + +```python +import adam +from adam.pytorch import KinDynComputationsBatch +import icub_models + +# if you want to icub-models +model_path = icub_models.get_model_file("iCubGazeboV2_5") +# The joint list +joints_name_list = [ + 'torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch', + 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch', + 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch', 'l_hip_roll', + 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll', 'r_hip_pitch', + 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', 'r_ankle_roll' +] + +kinDyn = KinDynComputationsBatch(model_path, joints_name_list) +# choose the representation you want to use the body fixed representation +kinDyn.set_frame_velocity_representation(adam.Representations.BODY_FIXED_REPRESENTATION) +# or, if you want to use the mixed representation (that is the default) +kinDyn.set_frame_velocity_representation(adam.Representations.MIXED_REPRESENTATION) +w_H_b = np.eye(4) +joints = np.ones(len(joints_name_list)) + +num_samples = 1024 +w_H_b_batch = torch.tensor(np.tile(w_H_b, (num_samples, 1, 1)), dtype=torch.float32) +joints_batch = torch.tensor(np.tile(joints, (num_samples, 1)), dtype=torch.float32) + +M = kinDyn.mass_matrix(w_H_b_batch, joints_batch) +w_H_f = kinDyn.forward_kinematics('frame_name', w_H_b_batch, joints_batch) +``` + ## πŸ¦Έβ€β™‚οΈ Contributing **adam** is an open-source project. Contributions are very welcome! diff --git a/ci_env.yml b/ci_env.yml index 1a858b48..e97eff79 100644 --- a/ci_env.yml +++ b/ci_env.yml @@ -16,6 +16,7 @@ dependencies: - pytest-repeat - icub-models - idyntree >=11.0.0 - - gitpython + - gitpython - jax - pytorch + - jax2torch diff --git a/setup.cfg b/setup.cfg index e93a0967..b77e5411 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,6 +44,9 @@ casadi = casadi pytorch = torch + jax + jaxlib + jax2torch test = jax jaxlib @@ -54,6 +57,8 @@ test = icub-models black gitpython + jax2torch + conversions = idyntree all = @@ -61,6 +66,7 @@ all = jaxlib casadi torch + jax2torch [tool:pytest] addopts = --capture=no --verbose diff --git a/src/adam/pytorch/__init__.py b/src/adam/pytorch/__init__.py index 8a3a50b6..7a33bea1 100644 --- a/src/adam/pytorch/__init__.py +++ b/src/adam/pytorch/__init__.py @@ -3,4 +3,5 @@ # GNU Lesser General Public License v2.1 or any later version. from .computations import KinDynComputations +from .computation_batch import KinDynComputationsBatch from .torch_like import TorchLike diff --git a/src/adam/pytorch/computation_batch.py b/src/adam/pytorch/computation_batch.py new file mode 100644 index 00000000..b5e96a10 --- /dev/null +++ b/src/adam/pytorch/computation_batch.py @@ -0,0 +1,444 @@ +# Copyright (C) 2024 Istituto Italiano di Tecnologia (IIT). All rights reserved. +# This software may be modified and distributed under the terms of the +# GNU Lesser General Public License v2.1 or any later version. + +import jax +import jax.numpy as jnp +import numpy as np +import torch +from jax2torch import jax2torch + +from adam.core.constants import Representations +from adam.core.rbd_algorithms import RBDAlgorithms +from adam.jax.jax_like import SpatialMath +from adam.model import Model, URDFModelFactory + + +class KinDynComputationsBatch: + """This is a small class that retrieves robot quantities using Jax for Floating Base systems. + These functions are vmapped and jit compiled and passed to jax2torch to convert them to PyTorch functions. + """ + + def __init__( + self, + urdfstring: str, + joints_name_list: list = None, + root_link: str = "root_link", + gravity: np.array = jnp.array([0, 0, -9.80665, 0, 0, 0]), + ) -> None: + """ + Args: + urdfstring (str): path of the urdf + joints_name_list (list): list of the actuated joints + root_link (str, optional): the first link. Defaults to 'root_link'. + """ + math = SpatialMath() + factory = URDFModelFactory(path=urdfstring, math=math) + model = Model.build(factory=factory, joints_name_list=joints_name_list) + self.rbdalgos = RBDAlgorithms(model=model, math=math) + self.NDoF = self.rbdalgos.NDoF + self.g = gravity + self.funcs = {} + + def set_frame_velocity_representation( + self, representation: Representations + ) -> None: + """Sets the representation of the velocity of the frames + + Args: + representation (Representations): The representation of the velocity + """ + self.rbdalgos.set_frame_velocity_representation(representation) + + def mass_matrix( + self, base_transform: torch.Tensor, joint_positions: torch.Tensor + ) -> torch.Tensor: + """Returns the Mass Matrix functions computed the CRBA + + Args: + base_transform (torch.Tensor): The batch of homogenous transforms from base to world frame + joint_positions (torch.Tensor): The batch of joints position + + Returns: + M (torch.Tensor): The batch Mass Matrix + """ + + return self.mass_matrix_fun()(base_transform, joint_positions) + + def mass_matrix_fun(self): + """Returns the Mass Matrix functions computed the CRBA as a pytorch function + + Returns: + M (pytorch function): Mass Matrix + """ + + if self.funcs.get("mass_matrix") is not None: + return self.funcs["mass_matrix"] + print("[INFO] Compiling mass matrix function") + + def fun(base_transform, joint_positions): + [M, _] = self.rbdalgos.crba(base_transform, joint_positions) + return M.array + + vmapped_fun = jax.vmap(fun, in_axes=(0, 0)) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs["mass_matrix"] = jax2torch(jit_vmapped_fun) + return self.funcs["mass_matrix"] + + def centroidal_momentum_matrix( + self, base_transform: torch.Tensor, joint_positions: torch.Tensor + ) -> torch.Tensor: + """Returns the Centroidal Momentum Matrix functions computed the CRBA + + Args: + base_transform (torch.Tensor): The homogenous transform from base to world frame + joint_positions (torch.Tensor): The joints position + + Returns: + Jcc (torch.Tensor): Centroidal Momentum matrix + """ + + return self.centroidal_momentum_matrix_fun()(base_transform, joint_positions) + + def centroidal_momentum_matrix_fun(self): + """Returns the Centroidal Momentum Matrix functions computed the CRBA as a pytorch function + + Returns: + Jcc (pytorch function): Centroidal Momentum matrix + """ + + if self.funcs.get("centroidal_momentum_matrix") is not None: + return self.funcs["centroidal_momentum_matrix"] + print("[INFO] Compiling centroidal momentum matrix function") + + def fun(base_transform, joint_positions): + [_, Jcm] = self.rbdalgos.crba(base_transform, joint_positions) + return Jcm.array + + vmapped_fun = jax.vmap(fun, in_axes=(0, 0)) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs["centroidal_momentum_matrix"] = jax2torch(jit_vmapped_fun) + return self.funcs["centroidal_momentum_matrix"] + + def relative_jacobian( + self, frame: str, joint_positions: torch.Tensor + ) -> torch.Tensor: + + return self.relative_jacobian_fun(frame)(joint_positions) + + def relative_jacobian_fun(self, frame: str): + """Returns the Jacobian between the root link and a specified frame frames as a pytorch function + + Args: + frame (str): The tip of the chain + + Returns: + J (pytorch function): The Jacobian between the root and the frame + """ + + if self.funcs.get(f"relative_jacobian_{frame}") is not None: + return self.funcs[f"relative_jacobian_{frame}"] + print(f"[INFO] Compiling relative jacobian function for {frame} frame") + + def fun(joint_positions): + return self.rbdalgos.relative_jacobian(frame, joint_positions).array + + vmapped_fun = jax.vmap(fun) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs[f"relative_jacobian_{frame}"] = jax2torch(jit_vmapped_fun) + return self.funcs[f"relative_jacobian_{frame}"] + + def jacobian_dot( + self, + frame: str, + base_transform: torch.Tensor, + joint_positions: torch.Tensor, + base_velocity: torch.Tensor, + joint_velocities: torch.Tensor, + ) -> torch.Tensor: + """Returns the Jacobian derivative relative to the specified frame + + Args: + frame (str): The frame to which the jacobian will be computed + base_transform (torch.Tensor): The homogenous transform from base to world frame + joint_positions (torch.Tensor): The joints position + base_velocity (torch.Tensor): The base velocity + joint_velocities (torch.Tensor): The joint velocities + + Returns: + Jdot (torch.Tensor): The Jacobian derivative relative to the frame + """ + + return self.jacobian_dot_fun(frame)( + base_transform, joint_positions, base_velocity, joint_velocities + ) + + def jacobian_dot_fun( + self, + frame: str, + ): + """Returns the Jacobian derivative between the root and the specified frame as a pytorch function + + Args: + frame (str): The frame to which the jacobian will be computed + + Returns: + Jdot (pytorch function): The Jacobian derivative between the root and the frame + """ + + if self.funcs.get(f"jacobian_dot_{frame}") is not None: + return self.funcs[f"jacobian_dot_{frame}"] + print(f"[INFO] Compiling jacobian dot function for {frame} frame") + + def fun(base_transform, joint_positions, base_velocity, joint_velocities): + return self.rbdalgos.jacobian_dot( + frame, base_transform, joint_positions, base_velocity, joint_velocities + ).array + + vmapped_fun = jax.vmap(fun, in_axes=(0, 0, 0, 0)) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs[f"jacobian_dot_{frame}"] = jax2torch(jit_vmapped_fun) + return self.funcs[f"jacobian_dot_{frame}"] + + def forward_kinematics( + self, frame: str, base_transform: torch.Tensor, joint_positions: torch.Tensor + ) -> torch.Tensor: + """Computes the forward kinematics between the root and the specified frame + + Args: + frame (str): The frame to which the fk will be computed + + Returns: + H (torch.Tensor): The fk represented as Homogenous transformation matrix + """ + + return self.forward_kinematics_fun(frame)(base_transform, joint_positions) + + def forward_kinematics_fun(self, frame: str): + """Computes the forward kinematics between the root and the specified frame as a pytorch function + + Args: + frame (str): The frame to which the fk will be computed + + Returns: + H (pytorch function): The fk represented as Homogenous transformation matrix + """ + + if self.funcs.get(f"forward_kinematics_{frame}") is not None: + return self.funcs[f"forward_kinematics_{frame}"] + print(f"[INFO] Compiling forward kinematics function for {frame} frame") + + def fun(base_transform, joint_positions): + return self.rbdalgos.forward_kinematics( + frame, base_transform, joint_positions + ).array + + vmapped_fun = jax.vmap(fun, in_axes=(0, 0)) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs[f"forward_kinematics_{frame}"] = jax2torch(jit_vmapped_fun) + return self.funcs[f"forward_kinematics_{frame}"] + + def jacobian( + self, frame: str, base_transform: torch.Tensor, joint_positions: torch.Tensor + ) -> torch.Tensor: + """Returns the Jacobian relative to the specified frame + + Args: + frame (str): The frame to which the jacobian will be computed + base_transform (torch.Tensor): The homogenous transform from base to world frame + joint_positions (torch.Tensor): The joints position + + Returns: + J (torch.Tensor): The Jacobian between the root and the frame + """ + return self.jacobian_fun(frame)(base_transform, joint_positions) + + def jacobian_fun(self, frame: str): + """Returns the Jacobian relative to the specified frame as a pytorch function + + Args: + frame (str): The frame to which the jacobian will be computed + + Returns: + J (pytorch function): The Jacobian relative to the frame + """ + if self.funcs.get(f"jacobian_{frame}") is not None: + return self.funcs[f"jacobian_{frame}"] + print(f"[INFO] Compiling jacobian function for {frame} frame") + + def fun(base_transform, joint_positions): + return self.rbdalgos.jacobian(frame, base_transform, joint_positions).array + + vmapped_fun = jax.vmap(fun, in_axes=(0, 0)) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs[f"jacobian_{frame}"] = jax2torch(jit_vmapped_fun) + return self.funcs[f"jacobian_{frame}"] + + def bias_force( + self, + base_transform: torch.Tensor, + joint_positions: torch.Tensor, + base_velocity: torch.Tensor, + joint_velocities: torch.Tensor, + ) -> jnp.array: + """Returns the bias force of the floating-base dynamics equation, + using a reduced RNEA (no acceleration and external forces) + + Args: + base_transform (torch.Tensor): The homogenous transform from base to world frame + joint_positions (torch.Tensor): The joints position + base_velocity (torch.Tensor): The base velocity + joint_velocities (torch.Tensor): The joints velocity + + Returns: + h (torch.Tensor): the bias force + """ + return self.bias_force_fun()( + base_transform, joint_positions, base_velocity, joint_velocities + ) + + def bias_force_fun(self): + """Returns the bias force of the floating-base dynamics equation as a pytorch function + + Returns: + h (pytorch function): the bias force + """ + if self.funcs.get("bias_force") is not None: + return self.funcs["bias_force"] + print("[INFO] Compiling bias force function") + + def fun(base_transform, joint_positions, base_velocity, joint_velocities): + return self.rbdalgos.rnea( + base_transform, joint_positions, base_velocity, joint_velocities, self.g + ).array.squeeze() + + vmapped_fun = jax.vmap(fun, in_axes=(0, 0, 0, 0)) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs["bias_force"] = jax2torch(jit_vmapped_fun) + return self.funcs["bias_force"] + + def coriolis_term( + self, + base_transform: torch.Tensor, + joint_positions: torch.Tensor, + base_velocity: torch.Tensor, + joint_velocities: torch.Tensor, + ) -> torch.Tensor: + """Returns the coriolis term of the floating-base dynamics equation, + using a reduced RNEA (no acceleration and external forces) + + Args: + base_transform (torch.Tensor): The homogenous transform from base to world frame + joint_positions (torch.Tensor): The joints position + base_velocity (torch.Tensor): The base velocity + joint_velocities (torch.Tensor): The joints velocity + + Returns: + C (torch.Tensor): the Coriolis term + """ + return self.coriolis_term_fun()( + base_transform, joint_positions, base_velocity, joint_velocities + ) + + def coriolis_term_fun(self): + """Returns the coriolis term of the floating-base dynamics equation as a pytorch function + + Returns: + C (pytorch function): the Coriolis term + """ + if self.funcs.get("coriolis_term") is not None: + return self.funcs["coriolis_term"] + print("[INFO] Compiling coriolis term function") + + def fun(base_transform, joint_positions, base_velocity, joint_velocities): + return self.rbdalgos.rnea( + base_transform, + joint_positions, + base_velocity.reshape(6, 1), + joint_velocities, + np.zeros(6), + ).array.squeeze() + + vmapped_fun = jax.vmap(fun, in_axes=(0, 0, 0, 0)) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs["coriolis_term"] = jax2torch(jit_vmapped_fun) + return self.funcs["coriolis_term"] + + def gravity_term( + self, base_transform: torch.Tensor, joint_positions: torch.Tensor + ) -> torch.Tensor: + """Returns the gravity term of the floating-base dynamics equation, + using a reduced RNEA (no acceleration and external forces) + + Args: + base_transform (torch.Tensor): The homogenous transform from base to world frame + joint_positions (torch.Tensor): The joints position + + Returns: + G (jnp.array): the gravity term + """ + return self.gravity_term_fun()(base_transform, joint_positions) + + def gravity_term_fun(self): + """Returns the gravity term of the floating-base dynamics equation as a pytorch function + + Returns: + G (pytorch function): the gravity term + """ + if self.funcs.get("gravity_term") is not None: + return self.funcs["gravity_term"] + print("[INFO] Compiling gravity term function") + + def fun(base_transform, joint_positions): + return self.rbdalgos.rnea( + base_transform, + joint_positions, + np.zeros(6).reshape(6, 1), + np.zeros(self.NDoF), + self.g, + ).array.squeeze() + + vmapped_fun = jax.vmap(fun, in_axes=(0, 0)) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs["gravity_term"] = jax2torch(jit_vmapped_fun) + return self.funcs["gravity_term"] + + def CoM_position( + self, base_transform: torch.Tensor, joint_positions: torch.Tensor + ) -> torch.Tensor: + """Returns the CoM positon + + Args: + base_transform (torch.Tensor): The homogenous transform from base to world frame + joint_positions (torch.Tensor): The joints position + + Returns: + CoM (torch.Tensor): The CoM position + """ + return self.CoM_position_fun()(base_transform, joint_positions) + + def CoM_position_fun(self): + """Returns the CoM positon as a pytorch function + + Returns: + CoM (pytorch function): The CoM position + """ + if self.funcs.get("CoM_position") is not None: + return self.funcs["CoM_position"] + print("[INFO] Compiling CoM position function") + + def fun(base_transform, joint_positions): + return self.rbdalgos.CoM_position(base_transform, joint_positions).array + + vmapped_fun = jax.vmap(fun, in_axes=(0, 0)) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs["CoM_position"] = jax2torch(jit_vmapped_fun) + return self.funcs["CoM_position"] + + def get_total_mass(self) -> float: + """Returns the total mass of the robot + + Returns: + mass: The total mass + """ + return self.rbdalgos.get_total_mass() diff --git a/tests/pytorch_batch/test_pytorch_batch.py b/tests/pytorch_batch/test_pytorch_batch.py new file mode 100644 index 00000000..44b3a33d --- /dev/null +++ b/tests/pytorch_batch/test_pytorch_batch.py @@ -0,0 +1,184 @@ +import logging + +import icub_models +import idyntree.swig as idyntree +import jax.numpy as jnp +import numpy as np +import pytest +from jax import config + +import adam +from adam.geometry import utils +from adam.pytorch import KinDynComputationsBatch +from adam.numpy import KinDynComputations +import torch + +np.random.seed(42) +config.update("jax_enable_x64", True) + +model_path = str(icub_models.get_model_file("iCubGazeboV2_5")) + +joints_name_list = [ + "torso_pitch", + "torso_roll", + "torso_yaw", + "l_shoulder_pitch", + "l_shoulder_roll", + "l_shoulder_yaw", + "l_elbow", + "r_shoulder_pitch", + "r_shoulder_roll", + "r_shoulder_yaw", + "r_elbow", + "l_hip_pitch", + "l_hip_roll", + "l_hip_yaw", + "l_knee", + "l_ankle_pitch", + "l_ankle_roll", + "r_hip_pitch", + "r_hip_roll", + "r_hip_yaw", + "r_knee", + "r_ankle_pitch", + "r_ankle_roll", +] + + +comp = KinDynComputationsBatch(model_path, joints_name_list) +comp.set_frame_velocity_representation(adam.Representations.MIXED_REPRESENTATION) + +comp_np = KinDynComputations(model_path, joints_name_list) +comp_np.set_frame_velocity_representation(adam.Representations.MIXED_REPRESENTATION) + +n_dofs = len(joints_name_list) +# base pose quantities +xyz = (np.random.rand(3) - 0.5) * 5 +rpy = (np.random.rand(3) - 0.5) * 5 +base_vel = (np.random.rand(6) - 0.5) * 5 +# joints quantitites +joints_val = (np.random.rand(n_dofs) - 0.5) * 5 +joints_dot_val = (np.random.rand(n_dofs) - 0.5) * 5 + +g = np.array([0, 0, -9.80665]) +H_b = utils.H_from_Pos_RPY(xyz, rpy) +n_samples = 10 + +H_b_batch = torch.tile(torch.tensor(H_b), (n_samples, 1, 1)).requires_grad_() +joints_val_batch = torch.tile(torch.tensor(joints_val), (n_samples, 1)).requires_grad_() +base_vel_batch = torch.tile(torch.tensor(base_vel), (n_samples, 1)).requires_grad_() +joints_dot_val_batch = torch.tile( + torch.tensor(joints_dot_val), (n_samples, 1) +).requires_grad_() + + +# Check if the quantities are the correct testing against the numpy implementation +# Check if the dimensions are correct (batch dimension) +# Check if the gradient is computable + + +def test_mass_matrix(): + mass_matrix = comp.mass_matrix(H_b_batch, joints_val_batch) + mass_matrix_np = comp_np.mass_matrix(H_b, joints_val) + assert np.allclose(mass_matrix[0].detach().numpy(), mass_matrix_np) + assert mass_matrix.shape == (n_samples, n_dofs + 6, n_dofs + 6) + mass_matrix.sum().backward() + + +def test_centroidal_momentum_matrix(): + centroidal_momentum_matrix = comp.centroidal_momentum_matrix( + H_b_batch, joints_val_batch + ) + centroidal_momentum_matrix_np = comp_np.centroidal_momentum_matrix(H_b, joints_val) + assert np.allclose( + centroidal_momentum_matrix[0].detach().numpy(), centroidal_momentum_matrix_np + ) + assert centroidal_momentum_matrix.shape == (n_samples, 6, n_dofs + 6) + centroidal_momentum_matrix.sum().backward() + + +def test_relative_jacobian(): + frame = "l_sole" + relative_jacobian = comp.relative_jacobian(frame, joints_val_batch) + assert np.allclose( + relative_jacobian[0].detach().numpy(), + comp_np.relative_jacobian(frame, joints_val), + ) + assert relative_jacobian.shape == (n_samples, 6, n_dofs) + relative_jacobian.sum().backward() + + +def test_jacobian_dot(): + frame = "l_sole" + jacobian_dot = comp.jacobian_dot( + frame, H_b_batch, joints_val_batch, base_vel_batch, joints_dot_val_batch + ) + assert np.allclose( + jacobian_dot[0].detach().numpy(), + comp_np.jacobian_dot(frame, H_b, joints_val, base_vel, joints_dot_val), + ) + assert jacobian_dot.shape == (n_samples, 6, n_dofs + 6) + jacobian_dot.sum().backward() + + +def test_forward_kineamtics(): + frame = "l_sole" + forward_kinematics = comp.forward_kinematics(frame, H_b_batch, joints_val_batch) + assert np.allclose( + forward_kinematics[0].detach().numpy(), + comp_np.forward_kinematics(frame, H_b, joints_val), + ) + assert forward_kinematics.shape == (n_samples, 4, 4) + forward_kinematics.sum().backward() + + +def test_jacobian(): + frame = "l_sole" + jacobian = comp.jacobian(frame, H_b_batch, joints_val_batch) + assert np.allclose( + jacobian[0].detach().numpy(), comp_np.jacobian(frame, H_b, joints_val) + ) + assert jacobian.shape == (n_samples, 6, n_dofs + 6) + jacobian.sum().backward() + + +def test_bias_force(): + bias_force = comp.bias_force( + H_b_batch, joints_val_batch, base_vel_batch, joints_dot_val_batch + ) + assert np.allclose( + bias_force[0].detach().numpy(), + comp_np.bias_force(H_b, joints_val, base_vel, joints_dot_val), + ) + assert bias_force.shape == (n_samples, n_dofs + 6) + bias_force.sum().backward() + + +def test_coriolis_term(): + coriolis_term = comp.coriolis_term( + H_b_batch, joints_val_batch, base_vel_batch, joints_dot_val_batch + ) + assert np.allclose( + coriolis_term[0].detach().numpy(), + comp_np.coriolis_term(H_b, joints_val, base_vel, joints_dot_val), + ) + assert coriolis_term.shape == (n_samples, n_dofs + 6) + coriolis_term.sum().backward() + + +def test_gravity_term(): + gravity_term = comp.gravity_term(H_b_batch, joints_val_batch) + assert np.allclose( + gravity_term[0].detach().numpy(), comp_np.gravity_term(H_b, joints_val) + ) + assert gravity_term.shape == (n_samples, n_dofs + 6) + gravity_term.sum().backward() + + +def test_CoM_position(): + CoM_position = comp.CoM_position(H_b_batch, joints_val_batch) + assert np.allclose( + CoM_position[0].detach().numpy(), comp_np.CoM_position(H_b, joints_val) + ) + assert CoM_position.shape == (n_samples, 3) + CoM_position.sum().backward()