From 1685ad5026af34791909ecb50ecd6fc88669e4cf Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 15 May 2024 13:21:45 +0200 Subject: [PATCH 1/6] Avoid the use of `enum` in `Joint*` classes Deprecate `JointDescriptor`, introduce new `JointType`, separate `JointType` from `JointGenericAxis` --- src/jaxsim/math/joint_model.py | 72 +++++++++------------ src/jaxsim/parsers/descriptions/__init__.py | 2 +- src/jaxsim/parsers/descriptions/joint.py | 49 +++++++------- src/jaxsim/parsers/kinematic_graph.py | 8 ++- src/jaxsim/parsers/rod/utils.py | 10 +-- 5 files changed, 63 insertions(+), 78 deletions(-) diff --git a/src/jaxsim/math/joint_model.py b/src/jaxsim/math/joint_model.py index 6e52952a2..a6f3776ec 100644 --- a/src/jaxsim/math/joint_model.py +++ b/src/jaxsim/math/joint_model.py @@ -1,7 +1,5 @@ from __future__ import annotations -import functools - import jax import jax.numpy as jnp import jax_dataclasses @@ -9,12 +7,7 @@ from jax_dataclasses import Static import jaxsim.typing as jtp -from jaxsim.parsers.descriptions import ( - JointDescriptor, - JointGenericAxis, - JointType, - ModelDescription, -) +from jaxsim.parsers.descriptions import JointGenericAxis, JointType, ModelDescription from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms from .rotation import Rotation @@ -46,7 +39,8 @@ class JointModel: joint_dofs: Static[tuple[int, ...]] joint_names: Static[tuple[str, ...]] - joint_types: Static[tuple[JointType | JointDescriptor, ...]] + joint_types: Static[tuple[JointType, ...]] + joint_axis: Static[tuple[JointGenericAxis, ...]] @staticmethod def build(description: ModelDescription) -> JointModel: @@ -115,6 +109,7 @@ def build(description: ModelDescription) -> JointModel: joint_dofs=tuple([base_dofs] + [int(1) for _ in ordered_joints]), joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]), joint_types=tuple([JointType.F] + [j.jtype for j in ordered_joints]), + joint_axis=tuple([j.axis for j in ordered_joints]), ) def parent_H_child( @@ -226,59 +221,54 @@ def successor_H_child(self, joint_index: jtp.IntLike) -> jtp.Matrix: return self.suc_H_i[joint_index] -@functools.partial(jax.jit, static_argnames=["joint_type"]) +@jax.jit def supported_joint_motion( - joint_type: JointType | JointDescriptor, joint_position: jtp.VectorLike + joint_type: JointType, joint_axis: JointGenericAxis, joint_position: jtp.VectorLike ) -> tuple[jtp.Matrix, jtp.Array]: """ Compute the homogeneous transformation and motion subspace of a joint. Args: joint_type: The type of the joint. + joint_axis: The axis of rotation or translation of the joint. joint_position: The position of the joint. Returns: A tuple containing the homogeneous transformation and the motion subspace. """ - if isinstance(joint_type, JointType): - type_enum = joint_type - elif isinstance(joint_type, JointDescriptor): - type_enum = joint_type.joint_type - else: - raise ValueError(joint_type) - # Prepare the joint position s = jnp.array(joint_position).astype(float) - match type_enum: - - case JointType.R: - joint_type: JointGenericAxis + def compute_F(): + return jaxlie.SE3.identity(), jnp.zeros(shape=(6, 1)) - pre_H_suc = jaxlie.SE3.from_rotation( - rotation=jaxlie.SO3.from_matrix( - Rotation.from_axis_angle(vector=s * joint_type.axis) - ) + def compute_R(): + pre_H_suc = jaxlie.SE3.from_rotation( + rotation=jaxlie.SO3.from_matrix( + Rotation.from_axis_angle(vector=s * joint_axis) ) + ) - S = jnp.vstack(jnp.hstack([jnp.zeros(3), joint_type.axis.squeeze()])) - - case JointType.P: - joint_type: JointGenericAxis - - pre_H_suc = jaxlie.SE3.from_rotation_and_translation( - rotation=jaxlie.SO3.identity(), - translation=jnp.array(s * joint_type.axis), - ) + S = jnp.vstack(jnp.hstack([jnp.zeros(3), joint_axis.squeeze()])) + return pre_H_suc, S - S = jnp.vstack(jnp.hstack([joint_type.axis.squeeze(), jnp.zeros(3)])) + def compute_P(): + pre_H_suc = jaxlie.SE3.from_rotation_and_translation( + rotation=jaxlie.SO3.identity(), + translation=jnp.array(s * joint_axis), + ) - case JointType.F: - pre_H_suc = jaxlie.SE3.identity() - S = jnp.zeros(shape=(6, 1)) + S = jnp.vstack(jnp.hstack([joint_axis.squeeze(), jnp.zeros(3)])) + return pre_H_suc, S - case _: - raise ValueError(joint_type) + pre_H_suc, S = jax.lax.switch( + index=joint_type, + branches=( + compute_F, # JointType.F + compute_R, # JointType.R + compute_P, # JointType.P + ), + ) return pre_H_suc.as_matrix(), S diff --git a/src/jaxsim/parsers/descriptions/__init__.py b/src/jaxsim/parsers/descriptions/__init__.py index 08e97c15b..9e180c155 100644 --- a/src/jaxsim/parsers/descriptions/__init__.py +++ b/src/jaxsim/parsers/descriptions/__init__.py @@ -1,4 +1,4 @@ from .collision import BoxCollision, CollidablePoint, CollisionShape, SphereCollision -from .joint import JointDescription, JointDescriptor, JointGenericAxis, JointType +from .joint import JointDescription, JointGenericAxis, JointType from .link import LinkDescription from .model import ModelDescription diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index 26af59000..0a4d21d84 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -1,7 +1,6 @@ from __future__ import annotations import dataclasses -import enum from typing import Tuple, Union import jax_dataclasses @@ -14,39 +13,35 @@ from .link import LinkDescription -@enum.unique -class JointType(enum.IntEnum): - """ - Type of supported joints. - """ - - @staticmethod - def _generate_next_value_(name, start, count, last_values): - # Start auto Enum value from 0 instead of 1 - return count - - #: Fixed joint. - F = enum.auto() +class _JointTypeMeta(type): + def __new__(cls, name, bases, dct): + cls_instance = super().__new__(cls, name, bases, dct) - #: Revolute joint (1 DoF around axis). - R = enum.auto() + # Assign integer values to the descriptors + cls_instance.F = 0 + cls_instance.R = 1 + cls_instance.P = 2 - #: Prismatic joint (1 DoF along axis). - P = enum.auto() + return cls_instance -@jax_dataclasses.pytree_dataclass -class JointDescriptor: +class JointType(metaclass=_JointTypeMeta): """ - Base class for joint types requiring to store additional metadata. + Type of supported joints. """ - #: The joint type. - joint_type: JointType + class F: + pass + + class R: + pass + + class P: + pass @jax_dataclasses.pytree_dataclass -class JointGenericAxis(JointDescriptor): +class JointGenericAxis: """ A joint requiring the specification of a 3D axis. """ @@ -55,7 +50,7 @@ class JointGenericAxis(JointDescriptor): axis: jtp.Vector def __hash__(self) -> int: - return hash((self.joint_type, tuple(np.array(self.axis).tolist()))) + return hash((tuple(np.array(self.axis).tolist()))) def __eq__(self, other: JointGenericAxis) -> bool: if not isinstance(other, JointGenericAxis): @@ -73,7 +68,7 @@ class JointDescription(JaxsimDataclass): name (str): The name of the joint. axis (npt.NDArray): The axis of rotation or translation for the joint. pose (npt.NDArray): The pose transformation matrix of the joint. - jtype (Union[JointType, JointDescriptor]): The type of the joint. + jtype (JointType): The type of the joint. child (LinkDescription): The child link attached to the joint. parent (LinkDescription): The parent link attached to the joint. index (Optional[int]): An optional index for the joint. @@ -89,7 +84,7 @@ class JointDescription(JaxsimDataclass): name: jax_dataclasses.Static[str] axis: npt.NDArray pose: npt.NDArray - jtype: jax_dataclasses.Static[Union[JointType, JointDescriptor]] + jtype: jax_dataclasses.Static[JointType] child: LinkDescription = dataclasses.dataclass(repr=False) parent: LinkDescription = dataclasses.dataclass(repr=False) diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index 040c6d332..7c847b480 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -689,6 +689,7 @@ def transform(self, name: str) -> npt.NDArray: # Compute the joint transform from the predecessor to the successor frame. pre_H_J = self.pre_H_suc( joint_type=joint.jtype, + joint_axis=joint.axis, joint_position=self._initial_joint_positions[joint.name], ) @@ -762,7 +763,8 @@ def relative_transform(self, relative_to: str, name: str) -> npt.NDArray: @staticmethod def pre_H_suc( - joint_type: descriptions.JointType | descriptions.JointDescriptor, + joint_type: descriptions.JointType, + joint_axis: descriptions.JointGenericAxis, joint_position: float | None = None, ) -> npt.NDArray: @@ -770,6 +772,8 @@ def pre_H_suc( return np.array( jaxsim.math.supported_joint_motion( - joint_type=joint_type, joint_position=joint_position + joint_type=joint_type, + joint_axis=joint_axis, + joint_position=joint_position, )[0] ) diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index afdc880d5..cfa0d508e 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -61,7 +61,7 @@ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix: def joint_to_joint_type( joint: rod.Joint, -) -> descriptions.JointType | descriptions.JointDescriptor: +) -> descriptions.JointType: """ Extract the joint type from an SDF joint. @@ -86,14 +86,10 @@ def joint_to_joint_type( axis_xyz = axis_xyz / np.linalg.norm(axis_xyz) if joint_type in {"revolute", "continuous"}: - return descriptions.JointGenericAxis( - joint_type=descriptions.JointType.R, axis=axis_xyz - ) + return descriptions.JointType.R if joint_type == "prismatic": - return descriptions.JointGenericAxis( - joint_type=descriptions.JointType.P, axis=axis_xyz - ) + return descriptions.JointType.P raise ValueError("Joint not supported", axis_xyz, joint_type) From 747940326ee7df1013ccb989496b49db257c5e79 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 15 May 2024 13:23:46 +0200 Subject: [PATCH 2/6] Simplify finding joint names from idx --- src/jaxsim/api/joint.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/src/jaxsim/api/joint.py b/src/jaxsim/api/joint.py index e70020496..ae4049ae3 100644 --- a/src/jaxsim/api/joint.py +++ b/src/jaxsim/api/joint.py @@ -3,7 +3,6 @@ import jax import jax.numpy as jnp -import numpy as np import jaxsim.api as js import jaxsim.typing as jtp @@ -30,17 +29,9 @@ def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int: # Note: the index of the joint for RBDAs starts from 1, but # the index for accessing the right element starts from 0. # Therefore, there is a -1. - return ( - jnp.array( - np.argwhere( - np.array(model.kin_dyn_parameters.joint_model.joint_names) - == joint_name - ) - - 1 - ) - .squeeze() - .astype(int) - ) + return jnp.array( + model.kin_dyn_parameters.joint_model.joint_names.index(joint_name) - 1 + ).squeeze() return jnp.array(-1).astype(int) From 3c53127c3b8377d0373f6948b53bbbe5bf0cf7f0 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 15 May 2024 13:24:53 +0200 Subject: [PATCH 3/6] Compute `pre_H_suc` and `S` using `vmap` --- src/jaxsim/api/kin_dyn_parameters.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 8861cd0aa..4efb00bab 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -382,21 +382,17 @@ def joint_transforms_and_motion_subspaces( ) # Compute the transforms and motion subspaces of the joints. - # TODO: understand how to use joint_indices to access joint_types, right now - # it fails when used within a JIT context. - pre_H_suc_and_S = [ - supported_joint_motion( - joint_type=self.joint_model.joint_types[i + 1], - joint_position=jnp.array(s), - ) - for i, s in enumerate(jnp.array(joint_positions).astype(float)) - ] + pre_H_suc_J, S_J = jax.vmap(supported_joint_motion)( + jnp.array(self.joint_model.joint_types[1:]), + jnp.array(self.joint_model.joint_axis), + jnp.array(joint_positions).astype(float), + ) # Extract the transforms and motion subspaces of the joints. # We stack the base transform W_H_B at index 0, and a dummy motion subspace # for either the fixed or free-floating joint connecting the world to the base. - pre_H_suc = jnp.stack([W_H_B] + [H for H, _ in pre_H_suc_and_S]) - S = jnp.stack([jnp.vstack(jnp.zeros(6))] + [S for _, S in pre_H_suc_and_S]) + pre_H_suc = jnp.vstack([W_H_B[None, ...], pre_H_suc_J]) + S = jnp.vstack([jnp.zeros((6, 1))[None, ...], S_J]) # Extract the successor-to-child fixed transforms. # Note that here we include also the index 0 since suc_H_child[0] stores the From 486def159e8e79d1f06138115014050139520b25 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 15 May 2024 15:30:16 +0200 Subject: [PATCH 4/6] Handle joint computation for jointless models Co-authored-by: Diego Ferigo --- src/jaxsim/api/kin_dyn_parameters.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 4efb00bab..c5df87d8b 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -382,11 +382,14 @@ def joint_transforms_and_motion_subspaces( ) # Compute the transforms and motion subspaces of the joints. - pre_H_suc_J, S_J = jax.vmap(supported_joint_motion)( - jnp.array(self.joint_model.joint_types[1:]), - jnp.array(self.joint_model.joint_axis), - jnp.array(joint_positions).astype(float), - ) + if self.number_of_joints() == 0: + pre_H_suc_J, S_J = jnp.empty((0, 4, 4)), jnp.empty((0, 6, 1)) + else: + pre_H_suc_J, S_J = jax.vmap(supported_joint_motion)( + jnp.array(self.joint_model.joint_types[1:]).astype(int), + jnp.array(self.joint_model.joint_axis), + jnp.array(joint_positions), + ) # Extract the transforms and motion subspaces of the joints. # We stack the base transform W_H_B at index 0, and a dummy motion subspace From 4f142bc27096e9d8ca570807d2ebcb70716c7522 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 15 May 2024 16:52:48 +0200 Subject: [PATCH 5/6] Address review - Use explicit joint type names - Use `jnp.newaxis` instead of `None` - Use dataclass for `JointType` instead of metaclass Co-authored-by: Diego Ferigo --- src/jaxsim/api/kin_dyn_parameters.py | 4 +-- src/jaxsim/math/joint_model.py | 8 +++--- src/jaxsim/parsers/descriptions/joint.py | 32 +++++------------------- src/jaxsim/parsers/rod/parser.py | 2 +- src/jaxsim/parsers/rod/utils.py | 6 ++--- 5 files changed, 16 insertions(+), 36 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index c5df87d8b..b9c536320 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -394,8 +394,8 @@ def joint_transforms_and_motion_subspaces( # Extract the transforms and motion subspaces of the joints. # We stack the base transform W_H_B at index 0, and a dummy motion subspace # for either the fixed or free-floating joint connecting the world to the base. - pre_H_suc = jnp.vstack([W_H_B[None, ...], pre_H_suc_J]) - S = jnp.vstack([jnp.zeros((6, 1))[None, ...], S_J]) + pre_H_suc = jnp.vstack([W_H_B[jnp.newaxis, ...], pre_H_suc_J]) + S = jnp.vstack([jnp.zeros((6, 1))[jnp.newaxis, ...], S_J]) # Extract the successor-to-child fixed transforms. # Note that here we include also the index 0 since suc_H_child[0] stores the diff --git a/src/jaxsim/math/joint_model.py b/src/jaxsim/math/joint_model.py index a6f3776ec..635fe2003 100644 --- a/src/jaxsim/math/joint_model.py +++ b/src/jaxsim/math/joint_model.py @@ -108,7 +108,7 @@ def build(description: ModelDescription) -> JointModel: # Static attributes joint_dofs=tuple([base_dofs] + [int(1) for _ in ordered_joints]), joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]), - joint_types=tuple([JointType.F] + [j.jtype for j in ordered_joints]), + joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]), joint_axis=tuple([j.axis for j in ordered_joints]), ) @@ -265,9 +265,9 @@ def compute_P(): pre_H_suc, S = jax.lax.switch( index=joint_type, branches=( - compute_F, # JointType.F - compute_R, # JointType.R - compute_P, # JointType.P + compute_F, # JointType.Fixed + compute_R, # JointType.Revolute + compute_P, # JointType.Prismatic ), ) diff --git a/src/jaxsim/parsers/descriptions/joint.py b/src/jaxsim/parsers/descriptions/joint.py index 0a4d21d84..9abf7d257 100644 --- a/src/jaxsim/parsers/descriptions/joint.py +++ b/src/jaxsim/parsers/descriptions/joint.py @@ -1,7 +1,7 @@ from __future__ import annotations import dataclasses -from typing import Tuple, Union +from typing import ClassVar, Tuple, Union import jax_dataclasses import numpy as np @@ -13,31 +13,11 @@ from .link import LinkDescription -class _JointTypeMeta(type): - def __new__(cls, name, bases, dct): - cls_instance = super().__new__(cls, name, bases, dct) - - # Assign integer values to the descriptors - cls_instance.F = 0 - cls_instance.R = 1 - cls_instance.P = 2 - - return cls_instance - - -class JointType(metaclass=_JointTypeMeta): - """ - Type of supported joints. - """ - - class F: - pass - - class R: - pass - - class P: - pass +@dataclasses.dataclass(frozen=True) +class JointType: + Fixed: ClassVar[int] = 0 + Revolute: ClassVar[int] = 1 + Prismatic: ClassVar[int] = 2 @jax_dataclasses.pytree_dataclass diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index ed49e7702..c2e396d57 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -352,7 +352,7 @@ def build_model_description( considered_joints=[ j.name for j in sdf_data.joint_descriptions - if j.jtype is not descriptions.JointType.F + if j.jtype is not descriptions.JointType.Fixed ], ) diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index cfa0d508e..a001a1da7 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -76,7 +76,7 @@ def joint_to_joint_type( joint_type = joint.type if joint_type == "fixed": - return descriptions.JointType.F + return descriptions.JointType.Fixed if not (axis.xyz is not None and axis.xyz.xyz is not None): raise ValueError("Failed to read axis xyz data") @@ -86,10 +86,10 @@ def joint_to_joint_type( axis_xyz = axis_xyz / np.linalg.norm(axis_xyz) if joint_type in {"revolute", "continuous"}: - return descriptions.JointType.R + return descriptions.JointType.Revolute if joint_type == "prismatic": - return descriptions.JointType.P + return descriptions.JointType.Prismatic raise ValueError("Joint not supported", axis_xyz, joint_type) From 16edc053fb8d6aeed9da1f56c90f74faaa968443 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 15 May 2024 18:33:59 +0200 Subject: [PATCH 6/6] Allow positional-only arg `supported_joint_motion` Co-authored-by: Diego Ferigo --- src/jaxsim/api/kin_dyn_parameters.py | 2 +- src/jaxsim/math/joint_model.py | 10 +++++++--- src/jaxsim/parsers/kinematic_graph.py | 8 +++----- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index b9c536320..622af2498 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -387,8 +387,8 @@ def joint_transforms_and_motion_subspaces( else: pre_H_suc_J, S_J = jax.vmap(supported_joint_motion)( jnp.array(self.joint_model.joint_types[1:]).astype(int), - jnp.array(self.joint_model.joint_axis), jnp.array(joint_positions), + jnp.array(self.joint_model.joint_axis), ) # Extract the transforms and motion subspaces of the joints. diff --git a/src/jaxsim/math/joint_model.py b/src/jaxsim/math/joint_model.py index 635fe2003..2e29553e9 100644 --- a/src/jaxsim/math/joint_model.py +++ b/src/jaxsim/math/joint_model.py @@ -199,8 +199,9 @@ def predecessor_H_successor( """ pre_H_suc, S = supported_joint_motion( - joint_type=self.joint_types[joint_index], - joint_position=joint_position, + self.joint_types[joint_index], + joint_position, + self.joint_axis[joint_index], ) return pre_H_suc, S @@ -223,7 +224,10 @@ def successor_H_child(self, joint_index: jtp.IntLike) -> jtp.Matrix: @jax.jit def supported_joint_motion( - joint_type: JointType, joint_axis: JointGenericAxis, joint_position: jtp.VectorLike + joint_type: JointType, + joint_position: jtp.VectorLike, + joint_axis: JointGenericAxis, + /, ) -> tuple[jtp.Matrix, jtp.Array]: """ Compute the homogeneous transformation and motion subspace of a joint. diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index 7c847b480..3e7c3e2ac 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -771,9 +771,7 @@ def pre_H_suc( import jaxsim.math return np.array( - jaxsim.math.supported_joint_motion( - joint_type=joint_type, - joint_axis=joint_axis, - joint_position=joint_position, - )[0] + jaxsim.math.supported_joint_motion(joint_type, joint_position, joint_axis)[ + 0 + ] )