Skip to content

Commit

Permalink
Move ExplicitRungeKuttaSO3Mixin to integrators.common
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Mar 21, 2024
1 parent 1907455 commit fbf9986
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 67 deletions.
65 changes: 65 additions & 0 deletions src/jaxsim/integrators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
import jax
import jax.numpy as jnp
import jax_dataclasses
import jaxlie
from jax_dataclasses import Static

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import Quaternion
from jaxsim.utils.jaxsim_dataclass import JaxsimDataclass, Mutability

try:
Expand Down Expand Up @@ -527,3 +530,65 @@ def butcher_tableau_supports_fsal(
# possibly intermediate kᵢ derivative).
# Note that if multiple rows match (it should not), we return the first match.
return True, int(jnp.where(rows_of_A_with_fsal == True)[0].tolist()[0])


class ExplicitRungeKuttaSO3Mixin:
"""
Mixin class to apply over explicit RK integrators defined on
`PyTreeType = ODEState` to integrate the quaternion on SO(3).
"""

@classmethod
def integrate_rk_stage(
cls, x0: js.ode_data.ODEState, t0: Time, dt: TimeStep, k: js.ode_data.ODEState
) -> js.ode_data.ODEState:

op = lambda x0_leaf, k_leaf: x0_leaf + dt * k_leaf
xf: js.ode_data.ODEState = jax.tree_util.tree_map(op, x0, k)

W_Q_B_t0 = x0.physics_model.base_quaternion
W_ω_WB_t0 = x0.physics_model.base_angular_velocity

return xf.replace(
physics_model=xf.physics_model.replace(
base_quaternion=Quaternion.integration(
quaternion=W_Q_B_t0,
dt=dt,
omega=W_ω_WB_t0,
omega_in_body_fixed=False,
),
)
)

@classmethod
def post_process_state(
cls, x0: js.ode_data.ODEState, t0: Time, xf: js.ode_data.ODEState, dt: TimeStep
) -> js.ode_data.ODEState:

# Indices to convert quaternions between serializations.
to_xyzw = jnp.array([1, 2, 3, 0])
to_wxyz = jnp.array([3, 0, 1, 2])

# Get the initial quaternion.
W_Q_B_t0 = jaxlie.SO3.from_quaternion_xyzw(
xyzw=x0.physics_model.base_quaternion[to_xyzw]
)

# Get the final angular velocity.
# This is already computed by averaging the kᵢ in RK-based schemes.
# Therefore, by using the ω at tf, we obtain a RK scheme operating
# on the SO(3) manifold.
W_ω_WB_tf = xf.physics_model.base_angular_velocity

# Integrate the quaternion on SO(3).
# Note that we left-multiply with the exponential map since the angular
# velocity is expressed in the inertial frame.
W_Q_B_tf = jaxlie.SO3.exp(tangent=dt * W_ω_WB_tf) @ W_Q_B_t0

# Replace the quaternion in the final state.
return xf.replace(
physics_model=xf.physics_model.replace(
base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
),
validate=True,
)
67 changes: 1 addition & 66 deletions src/jaxsim/integrators/fixed_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@
import jax
import jax.numpy as jnp
import jax_dataclasses
import jaxlie

import jaxsim.api as js
from jaxsim.math import Quaternion

from .common import ExplicitRungeKutta, PyTreeType, Time, TimeStep
from .common import ExplicitRungeKutta, ExplicitRungeKuttaSO3Mixin, PyTreeType

ODEStateDerivative = js.ode_data.ODEState


# =====================================================
# Explicit Runge-Kutta integrators operating on PyTrees
# =====================================================
Expand Down Expand Up @@ -90,68 +87,6 @@ class RungeKutta4(ExplicitRungeKutta[PyTreeType], Generic[PyTreeType]):
# ===============================================================================


class ExplicitRungeKuttaSO3Mixin:
"""
Mixin class to apply over explicit RK integrators defined on
`PyTreeType = ODEState` to integrate the quaternion on SO(3).
"""

@classmethod
def integrate_rk_stage(
cls, x0: js.ode_data.ODEState, t0: Time, dt: TimeStep, k: js.ode_data.ODEState
) -> js.ode_data.ODEState:

op = lambda x0_leaf, k_leaf: x0_leaf + dt * k_leaf
xf: js.ode_data.ODEState = jax.tree_util.tree_map(op, x0, k)

W_Q_B_t0 = x0.physics_model.base_quaternion
W_ω_WB_t0 = x0.physics_model.base_angular_velocity

return xf.replace(
physics_model=xf.physics_model.replace(
base_quaternion=Quaternion.integration(
quaternion=W_Q_B_t0,
dt=dt,
omega=W_ω_WB_t0,
omega_in_body_fixed=False,
),
)
)

@classmethod
def post_process_state(
cls, x0: js.ode_data.ODEState, t0: Time, xf: js.ode_data.ODEState, dt: TimeStep
) -> js.ode_data.ODEState:

# Indices to convert quaternions between serializations.
to_xyzw = jnp.array([1, 2, 3, 0])
to_wxyz = jnp.array([3, 0, 1, 2])

# Get the initial quaternion.
W_Q_B_t0 = jaxlie.SO3.from_quaternion_xyzw(
xyzw=x0.physics_model.base_quaternion[to_xyzw]
)

# Get the final angular velocity.
# This is already computed by averaging the kᵢ in RK-based schemes.
# Therefore, by using the ω at tf, we obtain a RK scheme operating
# on the SO(3) manifold.
W_ω_WB_tf = xf.physics_model.base_angular_velocity

# Integrate the quaternion on SO(3).
# Note that we left-multiply with the exponential map since the angular
# velocity is expressed in the inertial frame.
W_Q_B_tf = jaxlie.SO3.exp(tangent=dt * W_ω_WB_tf) @ W_Q_B_t0

# Replace the quaternion in the final state.
return xf.replace(
physics_model=xf.physics_model.replace(
base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz]
),
validate=True,
)


@jax_dataclasses.pytree_dataclass
class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, ForwardEuler[js.ode_data.ODEState]):
pass
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/integrators/variable_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from jax_dataclasses import Static

from jaxsim import typing as jtp
from jaxsim.integrators.fixed_step import ExplicitRungeKuttaSO3Mixin
from jaxsim.utils import Mutability

from .common import (
ExplicitRungeKutta,
ExplicitRungeKuttaSO3Mixin,
NextState,
PyTreeType,
State,
Expand Down

0 comments on commit fbf9986

Please sign in to comment.