Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to compute the Jacobian derivative of additional frames rigidly attached to links #210

Merged
merged 11 commits into from
Jul 19, 2024

Conversation

xela-95
Copy link
Member

@xela-95 xela-95 commented Jul 17, 2024

Closes #208


📚 Documentation preview 📚: https://jaxsim--210.org.readthedocs.build//210/

This test checks that the bias acceleration of a frame obtained by computing J dot @ nu JaxSim corresponds to the same quantity obtained through KinDynComputations
@xela-95
Copy link
Member Author

xela-95 commented Jul 17, 2024

I created this draft PR to share my code in an easy way with you.

However the test is failing for two reasons:

  • some of the frames are not found by KinDynComputations
  • for the other frames the resulting bias acceleration is quite different

@xela-95
Copy link
Member Author

xela-95 commented Jul 17, 2024

In b747a26 I commented out the test against iDynTree and tried to test directly $\dot J$ against the same quantity computing through AD. It seems that the inertial and body cases pass for all models, while the mixed case no.

@xela-95
Copy link
Member Author

xela-95 commented Jul 19, 2024

In 78b690d I also explicitly handled the conversion of the input representation of the jacobian via matrix ${} ^W T _I = diag({} ^W X _I, I _n)$ (with $I$ the active data representation). Unfortunately the test still fails for the [UR10-mixed] and [erogcub-mixed] cases, while it works for [box-mixed]. This is making me think of something related to joints that the box does not have, but I could not spot the bug here.

I suspect something here:

case VelRepr.Mixed:
W_H_F = transform(model=model, data=data, frame_index=frame_index)
W_H_FW = W_H_F.at[0:3, 0:3].set(jnp.zeros((3, 3)))
FW_H_W = Transform.inverse(W_H_FW)
O_X_W = FW_X_W = Adjoint.from_transform(transform=FW_H_W)
with data.switch_velocity_representation(VelRepr.Mixed):
FW_J_WF_FW = jacobian(
model=model,
data=data,
frame_index=frame_index,
output_vel_repr=VelRepr.Mixed,
)
FW_v_WF = FW_J_WF_FW @ data.generalized_velocity()
W_v_W_FW = jnp.zeros(6).at[0:3].set(FW_v_WF[0:3])
W_vx_W_FW = Cross.vx(W_v_W_FW)
# O_J̇_WF_I = FW_X_W @ (W_J̇_WL_I - W_vx_WFW @ W_J_WL_I)
O_Ẋ_W = FW_Ẋ_W = -FW_X_W @ W_vx_W_FW
O_J̇_WF_I = jnp.zeros(shape=(6, 6 + model.dofs()))
O_J̇_WF_I += O_Ẋ_W @ W_J_WL_W @ T
O_J̇_WF_I += O_X_W @ W_J̇_WL_W @ T
O_J̇_WF_I += O_X_W @ W_J_WL_W @

@diegoferigo @flferretti any idea?

Copy link
Member

@diegoferigo diegoferigo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need some time to go through you code, but the following is definitely a bug. Let's see what happens as soon as you fix it.

src/jaxsim/api/frame.py Outdated Show resolved Hide resolved
@xela-95 xela-95 marked this pull request as ready for review July 19, 2024 09:35
@xela-95 xela-95 requested a review from flferretti as a code owner July 19, 2024 09:35
Copy link
Member

@diegoferigo diegoferigo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just minor suggestions.

src/jaxsim/api/frame.py Outdated Show resolved Hide resolved
src/jaxsim/api/frame.py Outdated Show resolved Hide resolved
tests/test_api_frame.py Outdated Show resolved Hide resolved
tests/test_api_frame.py Outdated Show resolved Hide resolved
@diegoferigo diegoferigo changed the title Add frame Jacobian derivative Add function to compute the Jacobian derivative of additional frames rigidly attached to links Jul 19, 2024
Copy link
Collaborator

@flferretti flferretti left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's great @xela-95, thanks a lot! LGTM

@diegoferigo diegoferigo merged commit b0c0db5 into ami-iit:main Jul 19, 2024
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add frame Jacobian derivative
3 participants