diff --git a/spatial_casadi/spatial.py b/spatial_casadi/spatial.py index cd02223..69496b0 100644 --- a/spatial_casadi/spatial.py +++ b/spatial_casadi/spatial.py @@ -99,7 +99,7 @@ def __mul__(self, other): p = self.as_quat() q = other.as_quat() r = _compose_quat(p, q) - return Rotation(r, normalize=not isinstance(r, casadi.SX)) + return Rotation(r, normalize=not isinstance(r, (casadi.SX, casadi.MX))) elif isinstance(other, Translation): return Translation(self.as_matrix() @ other.as_vector()) @@ -135,7 +135,7 @@ def inv(self): """! Invert this rotation.""" return Rotation( casadi.vertcat(self._quat[:-1], -self._quat[-1]), - normalize=not isinstance(self._quat, casadi.SX), + normalize=not isinstance(self._quat, (casadi.SX, casadi.MX)), ) def magnitude(self): @@ -175,7 +175,7 @@ def from_quat(quat: ArrayType, seq: str = "xyzw"): quat_use = casadi.vertcat(x, y, z, w) - return Rotation(quat_use, normalize=not isinstance(quat, casadi.SX)) + return Rotation(quat_use, normalize=not isinstance(quat, (casadi.SX, casadi.MX))) @staticmethod def from_matrix(matrix: ArrayType): @@ -228,7 +228,7 @@ def alt_true_case(i, j, k): ), ) - return Rotation(quat, normalize=not isinstance(quat, casadi.SX)) + return Rotation(quat, normalize=not isinstance(quat, (casadi.SX, casadi.MX))) @staticmethod def from_rotvec(rotvec: ArrayType, degrees: bool = False): @@ -260,7 +260,7 @@ def from_rotvec(rotvec: ArrayType, degrees: bool = False): casadi.cos(angle * 0.5), ) - return Rotation(quat, normalize=not isinstance(quat, casadi.SX)) + return Rotation(quat, normalize=not isinstance(quat, (casadi.SX, casadi.MX))) @staticmethod def from_mrp(mrp: ArrayType): @@ -280,7 +280,7 @@ def from_mrp(mrp: ArrayType): (2.0 - mrp_squared_plus_1) / mrp_squared_plus_1, ) - return Rotation(quat, normalize=not isinstance(mrp, casadi.SX)) + return Rotation(quat, normalize=not isinstance(mrp, (casadi.SX, casadi.MX))) @staticmethod def from_euler(seq, angles, degrees=False): @@ -323,7 +323,7 @@ def from_euler(seq, angles, degrees=False): quat = _elementary_quat_compose(seq, angles, intrinsic) - return Rotation(quat, normalize=not isinstance(quat, casadi.SX)) + return Rotation(quat, normalize=not isinstance(quat, (casadi.SX, casadi.MX))) # # As methods