diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index ae4da7c53..6f20857ef 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -535,7 +535,11 @@ def convert_to_link_force( return F_X_L @ W_f_F - W_f_L = jax.vmap(convert_to_link_force)(W_f_F, W_H_Fi, parent_link_idxs) + W_f_L_i = jax.vmap(convert_to_link_force)(W_f_F, W_H_Fi, parent_link_idxs) + + # Sum the forces on the parent links. + mask = parent_link_idxs[:, jnp.newaxis] == jnp.arange(model.number_of_links()) + W_f_L = mask.T @ W_f_L_i return self.apply_link_forces( model=model,