Skip to content

Commit

Permalink
[sq]
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed May 27, 2024
1 parent 9b4da27 commit ef74e1f
Showing 1 changed file with 33 additions and 23 deletions.
56 changes: 33 additions & 23 deletions src/jaxsim/api/com.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,24 +260,10 @@ def bias_acceleration(
and :math:`[C] = [B]` if the active velocity representation is body-fixed.
"""

# # Compute the bias acceleration of all links.
# Compute the bias acceleration of all links.
with data.switch_velocity_representation(velocity_representation=VelRepr.Body):
L_a_bias_WL = js.model.link_bias_accelerations(model=model, data=data)

# Compute the position of the CoM.
W_p_CoM = com_position(model=model, data=data)

# Compute the transform from the world frame W to the CoM frame G.
match data.velocity_representation:
case VelRepr.Inertial | VelRepr.Mixed:
# In this case G := G[W].
W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
case VelRepr.Body:
# In this case G := G[B].
W_H_G = W_H_GB = data.base_transform().at[0:3, 3].set(W_p_CoM)
case _:
raise ValueError(data.velocity_representation)

# Compute the pose of all links with forward kinematics.
W_H_L = js.model.forward_kinematics(model=model, data=data)

Expand All @@ -294,24 +280,48 @@ def bias_momentum_derivative_term(
model=model, data=data, link_index=link_index, output_vel_repr=VelRepr.Body
)

# Compute the link-to-CoM transform and the corresponding adjoint for 6D forces.
L_H_G = jaxsim.math.Transform.inverse(W_H_L[link_index]) @ W_H_G
G_Xf_L = jaxsim.math.Adjoint.from_transform(transform=L_H_G).T
# Compute the world-to-link transformations for 6D forces.
W_Xf_L = jaxsim.math.Adjoint.from_transform(
transform=W_H_L[link_index], inverse=True
).T

# Compute the contribution of the link to the bias acceleration of the CoM.
G_ḣ_bias_link_contribution = G_Xf_L @ (
W_ḣ_bias_link_contribution = W_Xf_L @ (
L_M_L @ L_a_bias_WL + jaxsim.math.Cross.vx_star(L_v_WL) @ L_M_L @ L_v_WL
)

return G_ḣ_bias_link_contribution
return W_ḣ_bias_link_contribution

# Sum the contributions of all links to the bias acceleration of the CoM.
G_ḣ_bias = jax.vmap(bias_momentum_derivative_term)(
W_ḣ_bias = jax.vmap(bias_momentum_derivative_term)(
jnp.arange(model.number_of_links()), L_a_bias_WL
).sum(axis=0)

# Compute the total mass of the model.
m = js.model.total_mass(model=model)
G_v̇l_bias_WG = G_ḣ_bias[0:3] / m

return G_v̇l_bias_WG
# Compute the position of the CoM.
W_p_CoM = com_position(model=model, data=data)

match data.velocity_representation:

# G := G[W] = (W_p_CoM, [W])
case VelRepr.Inertial | VelRepr.Mixed:

W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
GW_Xf_W = jaxsim.math.Adjoint.from_transform(W_H_GW).T
GW_ḣ_bias = GW_Xf_W @ W_ḣ_bias
GW_v̇l_com_bias = GW_ḣ_bias[0:3] / m
return GW_v̇l_com_bias

# G := G[B] = (W_p_CoM, [B])
case VelRepr.Body:
GB_Xf_W = jaxsim.math.Adjoint.from_transform(
transform=data.base_transform().at[0:3].set(W_p_CoM)
).T
GB_ḣ_bias = GB_Xf_W @ W_ḣ_bias
GB_v̇l_com_bias = GB_ḣ_bias[0:3] / m
return GB_v̇l_com_bias

case _:
raise ValueError(data.velocity_representation)

0 comments on commit ef74e1f

Please sign in to comment.