Skip to content

Commit

Permalink
Merge pull request #153 from ami-iit/feature/vmap_joint_computations
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti authored May 15, 2024
2 parents bab6755 + 16edc05 commit c14205a
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 115 deletions.
15 changes: 3 additions & 12 deletions src/jaxsim/api/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import jax
import jax.numpy as jnp
import numpy as np

import jaxsim.api as js
import jaxsim.typing as jtp
Expand All @@ -30,17 +29,9 @@ def name_to_idx(model: js.model.JaxSimModel, *, joint_name: str) -> jtp.Int:
# Note: the index of the joint for RBDAs starts from 1, but
# the index for accessing the right element starts from 0.
# Therefore, there is a -1.
return (
jnp.array(
np.argwhere(
np.array(model.kin_dyn_parameters.joint_model.joint_names)
== joint_name
)
- 1
)
.squeeze()
.astype(int)
)
return jnp.array(
model.kin_dyn_parameters.joint_model.joint_names.index(joint_name) - 1
).squeeze()
return jnp.array(-1).astype(int)


Expand Down
19 changes: 9 additions & 10 deletions src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,21 +382,20 @@ def joint_transforms_and_motion_subspaces(
)

# Compute the transforms and motion subspaces of the joints.
# TODO: understand how to use joint_indices to access joint_types, right now
# it fails when used within a JIT context.
pre_H_suc_and_S = [
supported_joint_motion(
joint_type=self.joint_model.joint_types[i + 1],
joint_position=jnp.array(s),
if self.number_of_joints() == 0:
pre_H_suc_J, S_J = jnp.empty((0, 4, 4)), jnp.empty((0, 6, 1))
else:
pre_H_suc_J, S_J = jax.vmap(supported_joint_motion)(
jnp.array(self.joint_model.joint_types[1:]).astype(int),
jnp.array(joint_positions),
jnp.array(self.joint_model.joint_axis),
)
for i, s in enumerate(jnp.array(joint_positions).astype(float))
]

# Extract the transforms and motion subspaces of the joints.
# We stack the base transform W_H_B at index 0, and a dummy motion subspace
# for either the fixed or free-floating joint connecting the world to the base.
pre_H_suc = jnp.stack([W_H_B] + [H for H, _ in pre_H_suc_and_S])
S = jnp.stack([jnp.vstack(jnp.zeros(6))] + [S for _, S in pre_H_suc_and_S])
pre_H_suc = jnp.vstack([W_H_B[jnp.newaxis, ...], pre_H_suc_J])
S = jnp.vstack([jnp.zeros((6, 1))[jnp.newaxis, ...], S_J])

# Extract the successor-to-child fixed transforms.
# Note that here we include also the index 0 since suc_H_child[0] stores the
Expand Down
82 changes: 38 additions & 44 deletions src/jaxsim/math/joint_model.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
from __future__ import annotations

import functools

import jax
import jax.numpy as jnp
import jax_dataclasses
import jaxlie
from jax_dataclasses import Static

import jaxsim.typing as jtp
from jaxsim.parsers.descriptions import (
JointDescriptor,
JointGenericAxis,
JointType,
ModelDescription,
)
from jaxsim.parsers.descriptions import JointGenericAxis, JointType, ModelDescription
from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms

from .rotation import Rotation
Expand Down Expand Up @@ -46,7 +39,8 @@ class JointModel:

joint_dofs: Static[tuple[int, ...]]
joint_names: Static[tuple[str, ...]]
joint_types: Static[tuple[JointType | JointDescriptor, ...]]
joint_types: Static[tuple[JointType, ...]]
joint_axis: Static[tuple[JointGenericAxis, ...]]

@staticmethod
def build(description: ModelDescription) -> JointModel:
Expand Down Expand Up @@ -114,7 +108,8 @@ def build(description: ModelDescription) -> JointModel:
# Static attributes
joint_dofs=tuple([base_dofs] + [int(1) for _ in ordered_joints]),
joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]),
joint_types=tuple([JointType.F] + [j.jtype for j in ordered_joints]),
joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]),
joint_axis=tuple([j.axis for j in ordered_joints]),
)

def parent_H_child(
Expand Down Expand Up @@ -204,8 +199,9 @@ def predecessor_H_successor(
"""

pre_H_suc, S = supported_joint_motion(
joint_type=self.joint_types[joint_index],
joint_position=joint_position,
self.joint_types[joint_index],
joint_position,
self.joint_axis[joint_index],
)

return pre_H_suc, S
Expand All @@ -226,59 +222,57 @@ def successor_H_child(self, joint_index: jtp.IntLike) -> jtp.Matrix:
return self.suc_H_i[joint_index]


@functools.partial(jax.jit, static_argnames=["joint_type"])
@jax.jit
def supported_joint_motion(
joint_type: JointType | JointDescriptor, joint_position: jtp.VectorLike
joint_type: JointType,
joint_position: jtp.VectorLike,
joint_axis: JointGenericAxis,
/,
) -> tuple[jtp.Matrix, jtp.Array]:
"""
Compute the homogeneous transformation and motion subspace of a joint.
Args:
joint_type: The type of the joint.
joint_axis: The axis of rotation or translation of the joint.
joint_position: The position of the joint.
Returns:
A tuple containing the homogeneous transformation and the motion subspace.
"""

if isinstance(joint_type, JointType):
type_enum = joint_type
elif isinstance(joint_type, JointDescriptor):
type_enum = joint_type.joint_type
else:
raise ValueError(joint_type)

# Prepare the joint position
s = jnp.array(joint_position).astype(float)

match type_enum:

case JointType.R:
joint_type: JointGenericAxis
def compute_F():
return jaxlie.SE3.identity(), jnp.zeros(shape=(6, 1))

pre_H_suc = jaxlie.SE3.from_rotation(
rotation=jaxlie.SO3.from_matrix(
Rotation.from_axis_angle(vector=s * joint_type.axis)
)
def compute_R():
pre_H_suc = jaxlie.SE3.from_rotation(
rotation=jaxlie.SO3.from_matrix(
Rotation.from_axis_angle(vector=s * joint_axis)
)
)

S = jnp.vstack(jnp.hstack([jnp.zeros(3), joint_type.axis.squeeze()]))

case JointType.P:
joint_type: JointGenericAxis

pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.identity(),
translation=jnp.array(s * joint_type.axis),
)
S = jnp.vstack(jnp.hstack([jnp.zeros(3), joint_axis.squeeze()]))
return pre_H_suc, S

S = jnp.vstack(jnp.hstack([joint_type.axis.squeeze(), jnp.zeros(3)]))
def compute_P():
pre_H_suc = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.identity(),
translation=jnp.array(s * joint_axis),
)

case JointType.F:
pre_H_suc = jaxlie.SE3.identity()
S = jnp.zeros(shape=(6, 1))
S = jnp.vstack(jnp.hstack([joint_axis.squeeze(), jnp.zeros(3)]))
return pre_H_suc, S

case _:
raise ValueError(joint_type)
pre_H_suc, S = jax.lax.switch(
index=joint_type,
branches=(
compute_F, # JointType.Fixed
compute_R, # JointType.Revolute
compute_P, # JointType.Prismatic
),
)

return pre_H_suc.as_matrix(), S
2 changes: 1 addition & 1 deletion src/jaxsim/parsers/descriptions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .collision import BoxCollision, CollidablePoint, CollisionShape, SphereCollision
from .joint import JointDescription, JointDescriptor, JointGenericAxis, JointType
from .joint import JointDescription, JointGenericAxis, JointType
from .link import LinkDescription
from .model import ModelDescription
45 changes: 10 additions & 35 deletions src/jaxsim/parsers/descriptions/joint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

import dataclasses
import enum
from typing import Tuple, Union
from typing import ClassVar, Tuple, Union

import jax_dataclasses
import numpy as np
Expand All @@ -14,39 +13,15 @@
from .link import LinkDescription


@enum.unique
class JointType(enum.IntEnum):
"""
Type of supported joints.
"""

@staticmethod
def _generate_next_value_(name, start, count, last_values):
# Start auto Enum value from 0 instead of 1
return count

#: Fixed joint.
F = enum.auto()

#: Revolute joint (1 DoF around axis).
R = enum.auto()

#: Prismatic joint (1 DoF along axis).
P = enum.auto()


@jax_dataclasses.pytree_dataclass
class JointDescriptor:
"""
Base class for joint types requiring to store additional metadata.
"""

#: The joint type.
joint_type: JointType
@dataclasses.dataclass(frozen=True)
class JointType:
Fixed: ClassVar[int] = 0
Revolute: ClassVar[int] = 1
Prismatic: ClassVar[int] = 2


@jax_dataclasses.pytree_dataclass
class JointGenericAxis(JointDescriptor):
class JointGenericAxis:
"""
A joint requiring the specification of a 3D axis.
"""
Expand All @@ -55,7 +30,7 @@ class JointGenericAxis(JointDescriptor):
axis: jtp.Vector

def __hash__(self) -> int:
return hash((self.joint_type, tuple(np.array(self.axis).tolist())))
return hash((tuple(np.array(self.axis).tolist())))

def __eq__(self, other: JointGenericAxis) -> bool:
if not isinstance(other, JointGenericAxis):
Expand All @@ -73,7 +48,7 @@ class JointDescription(JaxsimDataclass):
name (str): The name of the joint.
axis (npt.NDArray): The axis of rotation or translation for the joint.
pose (npt.NDArray): The pose transformation matrix of the joint.
jtype (Union[JointType, JointDescriptor]): The type of the joint.
jtype (JointType): The type of the joint.
child (LinkDescription): The child link attached to the joint.
parent (LinkDescription): The parent link attached to the joint.
index (Optional[int]): An optional index for the joint.
Expand All @@ -89,7 +64,7 @@ class JointDescription(JaxsimDataclass):
name: jax_dataclasses.Static[str]
axis: npt.NDArray
pose: npt.NDArray
jtype: jax_dataclasses.Static[Union[JointType, JointDescriptor]]
jtype: jax_dataclasses.Static[JointType]
child: LinkDescription = dataclasses.dataclass(repr=False)
parent: LinkDescription = dataclasses.dataclass(repr=False)

Expand Down
10 changes: 6 additions & 4 deletions src/jaxsim/parsers/kinematic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,7 @@ def transform(self, name: str) -> npt.NDArray:
# Compute the joint transform from the predecessor to the successor frame.
pre_H_J = self.pre_H_suc(
joint_type=joint.jtype,
joint_axis=joint.axis,
joint_position=self._initial_joint_positions[joint.name],
)

Expand Down Expand Up @@ -762,14 +763,15 @@ def relative_transform(self, relative_to: str, name: str) -> npt.NDArray:

@staticmethod
def pre_H_suc(
joint_type: descriptions.JointType | descriptions.JointDescriptor,
joint_type: descriptions.JointType,
joint_axis: descriptions.JointGenericAxis,
joint_position: float | None = None,
) -> npt.NDArray:

import jaxsim.math

return np.array(
jaxsim.math.supported_joint_motion(
joint_type=joint_type, joint_position=joint_position
)[0]
jaxsim.math.supported_joint_motion(joint_type, joint_position, joint_axis)[
0
]
)
2 changes: 1 addition & 1 deletion src/jaxsim/parsers/rod/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def build_model_description(
considered_joints=[
j.name
for j in sdf_data.joint_descriptions
if j.jtype is not descriptions.JointType.F
if j.jtype is not descriptions.JointType.Fixed
],
)

Expand Down
12 changes: 4 additions & 8 deletions src/jaxsim/parsers/rod/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:

def joint_to_joint_type(
joint: rod.Joint,
) -> descriptions.JointType | descriptions.JointDescriptor:
) -> descriptions.JointType:
"""
Extract the joint type from an SDF joint.
Expand All @@ -76,7 +76,7 @@ def joint_to_joint_type(
joint_type = joint.type

if joint_type == "fixed":
return descriptions.JointType.F
return descriptions.JointType.Fixed

if not (axis.xyz is not None and axis.xyz.xyz is not None):
raise ValueError("Failed to read axis xyz data")
Expand All @@ -86,14 +86,10 @@ def joint_to_joint_type(
axis_xyz = axis_xyz / np.linalg.norm(axis_xyz)

if joint_type in {"revolute", "continuous"}:
return descriptions.JointGenericAxis(
joint_type=descriptions.JointType.R, axis=axis_xyz
)
return descriptions.JointType.Revolute

if joint_type == "prismatic":
return descriptions.JointGenericAxis(
joint_type=descriptions.JointType.P, axis=axis_xyz
)
return descriptions.JointType.Prismatic

raise ValueError("Joint not supported", axis_xyz, joint_type)

Expand Down

0 comments on commit c14205a

Please sign in to comment.