Skip to content

Commit

Permalink
Handle the conversion of input representation
Browse files Browse the repository at this point in the history
  • Loading branch information
xela-95 committed Jul 19, 2024
1 parent 3c5e544 commit 78b690d
Showing 1 changed file with 106 additions and 30 deletions.
136 changes: 106 additions & 30 deletions src/jaxsim/api/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,50 +301,126 @@ def jacobian_derivative(
# Get the index of the parent link.
L = idx_of_parent_link(model=model, frame_index=frame_index)

# Compute the Jacobian of the parent link in inertial representation.
W_J_WL_I = js.link.jacobian(
model=model,
data=data,
link_index=L,
output_vel_repr=VelRepr.Inertial,
)
with data.switch_velocity_representation(VelRepr.Inertial):
# Compute the Jacobian of the parent link in inertial representation.
W_J_WL_W = js.link.jacobian(
model=model,
data=data,
link_index=L,
output_vel_repr=VelRepr.Inertial,
)

# Compute the Jacobian derivative of the parent link in inertial representation.
W_J̇_WL_I = js.link.jacobian_derivative(
model=model,
data=data,
link_index=L,
output_vel_repr=VelRepr.Inertial,
)
# Compute the Jacobian derivative of the parent link in inertial representation.
W_J̇_WL_W = js.link.jacobian_derivative(
model=model,
data=data,
link_index=L,
output_vel_repr=VelRepr.Inertial,
)

# =====================================================
# Compute quantities to adjust the input representation
# =====================================================

def compute_T(model: js.model.JaxSimModel, X: jtp.Matrix) -> jtp.Matrix:
In = jnp.eye(model.dofs())
T = jax.scipy.linalg.block_diag(X, In)
return T

def compute_Ṫ(model: js.model.JaxSimModel, : jtp.Matrix) -> jtp.Matrix:
On = jnp.zeros(shape=(model.dofs(), model.dofs()))
= jax.scipy.linalg.block_diag(, On)
return

# Compute the operator to change the representation of ν, and its
# time derivative.
match data.velocity_representation:
case VelRepr.Inertial:
W_H_W = jnp.eye(4)
W_X_W = Adjoint.from_transform(transform=W_H_W)
W_Ẋ_W = jnp.zeros((6, 6))

T = compute_T(model=model, X=W_X_W)
= compute_Ṫ(model=model, =W_Ẋ_W)

case VelRepr.Body:
W_H_B = data.base_transform()
W_X_B = Adjoint.from_transform(transform=W_H_B)
B_v_WB = data.base_velocity()
B_vx_WB = Cross.vx(B_v_WB)
W_Ẋ_B = W_X_B @ B_vx_WB

T = compute_T(model=model, X=W_X_B)
= compute_Ṫ(model=model, =W_Ẋ_B)

case VelRepr.Mixed:
W_H_B = data.base_transform()
W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
W_X_BW = Adjoint.from_transform(transform=W_H_BW)
BW_v_WB = data.base_velocity()
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
BW_vx_W_BW = Cross.vx(BW_v_W_BW)
W_Ẋ_BW = W_X_BW @ BW_vx_W_BW

T = compute_T(model=model, X=W_X_BW)
= compute_Ṫ(model=model, =W_Ẋ_BW)

case _:
raise ValueError(data.velocity_representation)

# =====================================================
# Compute quantities to adjust the output representation
# =====================================================

# Compute the adjoint and its derivative from inertial to desired output representation.
match output_vel_repr:
case VelRepr.Inertial:
O_J̇_WF_I = W_J̇_WL_I
O_X_W = W_X_W = Adjoint.from_transform(transform=jnp.eye(4))
O_Ẋ_W = W_Ẋ_W = jnp.zeros((6, 6))
# O_J̇_WF_I = W_J̇_WL_I
O_J̇_WF_I = jnp.zeros(shape=(6, 6 + model.dofs()))
O_J̇_WF_I += O_Ẋ_W @ W_J_WL_W @ T
O_J̇_WF_I += O_X_W @ W_J̇_WL_W @ T
O_J̇_WF_I += O_X_W @ W_J_WL_W @

case VelRepr.Body:
W_H_F = transform(model=model, data=data, frame_index=frame_index)
F_H_W = Transform.inverse(W_H_F)
F_X_W = Adjoint.from_transform(transform=F_H_W)
W_v_WF = W_J_WL_I @ data.generalized_velocity()
O_X_W = F_X_W = Adjoint.from_transform(transform=F_H_W)
with data.switch_velocity_representation(VelRepr.Inertial):
W_nu = data.generalized_velocity()
W_v_WF = W_J_WL_W @ W_nu
W_vx_WF = Cross.vx(W_v_WF)
O_J̇_WF_I = F_X_W @ (W_J̇_WL_I - W_vx_WF @ W_J_WL_I)
O_Ẋ_W = F_Ẋ_W = -F_X_W @ W_vx_WF
# O_J̇_WF_I = F_X_W @ (W_J̇_WL_I - W_vx_WF @ W_J_WL_I)
O_J̇_WF_I = jnp.zeros(shape=(6, 6 + model.dofs()))
O_J̇_WF_I += O_Ẋ_W @ W_J_WL_W @ T
O_J̇_WF_I += O_X_W @ W_J̇_WL_W @ T
O_J̇_WF_I += O_X_W @ W_J_WL_W @

case VelRepr.Mixed:
W_H_F = transform(model=model, data=data, frame_index=frame_index)
W_H_FW = W_H_F.at[0:3, 0:3].set(jnp.zeros((3, 3)))
FW_H_W = Transform.inverse(W_H_FW)
FW_X_W = Adjoint.from_transform(transform=FW_H_W)
FW_J_WF_I = jacobian(
model=model,
data=data,
frame_index=frame_index,
output_vel_repr=VelRepr.Mixed,
)
FW_v_WF = FW_J_WF_I @ data.generalized_velocity()
W_v_WFW = jnp.zeros(6).at[0:3].set(FW_v_WF[0:3])
W_vx_WFW = Cross.vx(W_v_WFW)
O_J̇_WF_I = FW_X_W @ (W_J̇_WL_I - W_vx_WFW @ W_J_WL_I)
O_X_W = FW_X_W = Adjoint.from_transform(transform=FW_H_W)
with data.switch_velocity_representation(VelRepr.Mixed):
FW_J_WF_FW = jacobian(
model=model,
data=data,
frame_index=frame_index,
output_vel_repr=VelRepr.Mixed,
)
FW_v_WF = FW_J_WF_FW @ data.generalized_velocity()
W_v_W_FW = jnp.zeros(6).at[0:3].set(FW_v_WF[0:3])

W_vx_W_FW = Cross.vx(W_v_W_FW)

# O_J̇_WF_I = FW_X_W @ (W_J̇_WL_I - W_vx_WFW @ W_J_WL_I)
O_Ẋ_W = FW_Ẋ_W = -FW_X_W @ W_vx_W_FW

O_J̇_WF_I = jnp.zeros(shape=(6, 6 + model.dofs()))
O_J̇_WF_I += O_Ẋ_W @ W_J_WL_W @ T
O_J̇_WF_I += O_X_W @ W_J̇_WL_W @ T
O_J̇_WF_I += O_X_W @ W_J_WL_W @

case _:
raise ValueError(output_vel_repr)
Expand Down

0 comments on commit 78b690d

Please sign in to comment.