Skip to content

Commit

Permalink
Fix and update array handle
Browse files Browse the repository at this point in the history
  • Loading branch information
Giulero committed Oct 15, 2023
1 parent 810168d commit 158dec6
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/adam/casadi/casadi_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def array(*x) -> "CasadiLike":
Returns:
CasadiLike: Vector wrapping *x
"""
return CasadiLike(cs.DM(*x))
return CasadiLike(cs.SX(*x))


class SpatialMath(SpatialMath):
Expand Down
8 changes: 6 additions & 2 deletions src/adam/core/rbd_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def jacobian(
return J_tot
# let's move to mixed representation
elif self.frame_velocity_representation == Representations.MIXED_REPRESENTATION:
w_H_L = base_transform @ B_H_L.array
if type(base_transform) != type(B_H_L):
base_transform = self.math.factory.array(base_transform)
w_H_L = base_transform @ B_H_L
LI_X_L = self.math.adjoint_mixed(w_H_L)
X = self.math.factory.eye(6 + self.NDoF)
X[:6, :6] = self.math.adjoint_mixed_inverse(base_transform)
Expand Down Expand Up @@ -302,7 +304,9 @@ def jacobian_dot(
return J_dot
# let's move to mixed representation
elif self.frame_velocity_representation == Representations.MIXED_REPRESENTATION:
I_H_L = base_transform @ B_H_j.array
if type(base_transform) != type(B_H_L):
base_transform = self.math.factory.array(base_transform)
I_H_L = base_transform @ B_H_j
LI_X_L = self.math.adjoint_mixed(I_H_L)
X = self.math.factory.eye(6 + self.NDoF)
X[:6, :6] = self.math.adjoint_mixed_inverse(base_transform)
Expand Down
2 changes: 1 addition & 1 deletion src/adam/jax/jax_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def eye(x) -> "JaxLike":
return JaxLike(jnp.eye(x))

@staticmethod
def array(*x) -> "JaxLike":
def array(x) -> "JaxLike":
"""
Returns:
JaxLike: Vector wrapping *x
Expand Down
2 changes: 1 addition & 1 deletion src/adam/numpy/numpy_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def eye(x: int) -> "NumpyLike":
return NumpyLike(np.eye(x))

@staticmethod
def array(*x) -> "NumpyLike":
def array(x) -> "NumpyLike":
"""
Returns:
NumpyLike: Vector wrapping *x
Expand Down
2 changes: 1 addition & 1 deletion src/adam/pytorch/torch_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def eye(x: int) -> "TorchLike":
return TorchLike(torch.eye(x).float())

@staticmethod
def array(*x: ntp.ArrayLike) -> "TorchLike":
def array(x: ntp.ArrayLike) -> "TorchLike":
"""
Returns:
TorchLike: vector wrapping x
Expand Down

0 comments on commit 158dec6

Please sign in to comment.