Skip to content

Commit

Permalink
Implement Jacobian with jax.lax.scan
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Sep 21, 2022
1 parent a218f5e commit 0a273b2
Showing 1 changed file with 63 additions and 13 deletions.
76 changes: 63 additions & 13 deletions src/jaxsim/physics/algos/jacobian.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Tuple

import jax
import jax.numpy as jnp
import numpy as np

Expand All @@ -7,11 +10,7 @@
from . import utils


def jacobian(
model: PhysicsModel,
body_index: int,
q: jtp.Vector,
) -> jtp.Matrix:
def jacobian(model: PhysicsModel, body_index: jtp.Int, q: jtp.Vector) -> jtp.Matrix:

_, q, _, _, _, _ = utils.process_inputs(physics_model=model, q=q)

Expand All @@ -23,27 +22,78 @@ def jacobian(
i_X_0 = jnp.zeros_like(i_X_pre)
i_X_0 = i_X_0.at[0].set(jnp.eye(6))

for i in np.arange(start=1, stop=model.NB):
# Parent array mapping: i -> λ(i).
# Exception: λ(0) must not be used, it's initialized to -1.
λ = model.parent

# ====================
# Propagate kinematics
# ====================

PropagateKinematicsCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
propagate_kinematics_carry = (i_X_λi, i_X_0)

def propagate_kinematics(
carry: PropagateKinematicsCarry, i: jtp.Int
) -> Tuple[PropagateKinematicsCarry, None]:

i_X_λi, i_X_0 = carry

i_X_λi_i = i_X_pre[i] @ pre_X_λi[i]
i_X_λi = i_X_λi.at[i].set(i_X_λi_i)

i_X_0_i = i_X_λi[i] @ i_X_0[model.parent[i]]
i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]]
i_X_0 = i_X_0.at[i].set(i_X_0_i)

return (i_X_λi, i_X_0), None

(i_X_λi, i_X_0), _ = jax.lax.scan(
f=propagate_kinematics,
init=propagate_kinematics_carry,
xs=np.arange(start=1, stop=model.NB),
)

# ============================
# Compute doubly-left Jacobian
# ============================

J = jnp.zeros(shape=(6, 6 + model.dofs()))

Jb = i_X_0[body_index]
J = J.at[0:6, 0:6].set(Jb)

for i in reversed(model.support_body_array(body_index=body_index)):
ComputeJacobianCarry = jtp.MatrixJax
compute_jacobian_carry = J

def compute_jacobian(
carry: ComputeJacobianCarry, i: jtp.Int
) -> Tuple[ComputeJacobianCarry, None]:
def update_jacobian(
carry: Tuple[ComputeJacobianCarry, jtp.Int]
) -> ComputeJacobianCarry:

J, i = carry

ii = i - 1

Js_i = i_X_0[body_index] @ jnp.linalg.inv(i_X_0[i]) @ S[i]
J = J.at[0:6, 6 + ii].set(Js_i.squeeze())

return J

ii = i - 1
carry = jax.lax.cond(
pred=(jnp.any(i == model.support_body_array(body_index=body_index))),
true_fun=update_jacobian,
false_fun=lambda carry_i: carry_i[0],
operand=(carry, i),
)

if i == 0:
break
return carry, None

Js_i = i_X_0[body_index] @ jnp.linalg.inv(i_X_0[i]) @ S[i]
J = J.at[0:6, 6 + ii].set(Js_i.squeeze())
J, _ = jax.lax.scan(
f=compute_jacobian,
init=compute_jacobian_carry,
xs=np.arange(start=1, stop=model.NB),
)

return J

0 comments on commit 0a273b2

Please sign in to comment.