Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new ViscoElasticContacts #248

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 67 additions & 16 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import jax.numpy as jnp

import jaxsim.api as js
import jaxsim.exceptions
import jaxsim.terrain
import jaxsim.typing as jtp
from jaxsim import logging
from jaxsim.math import Adjoint, Cross, Transform
from jaxsim.rbda import contacts

from .common import VelRepr

Expand Down Expand Up @@ -156,14 +159,11 @@ def collidable_point_dynamics(
Instead, the 6D forces are returned in the active representation.
"""

# Import privately the contacts classes.
from jaxsim.rbda.contacts import RelaxedRigidContacts, RigidContacts, SoftContacts

# Build the soft contact model.
match model.contact_model:

case SoftContacts():
assert isinstance(model.contact_model, SoftContacts)
case contacts.SoftContacts():
assert isinstance(model.contact_model, contacts.SoftContacts)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point, and the corresponding material deformation rate.
Expand All @@ -178,8 +178,8 @@ def collidable_point_dynamics(
# of the ODE system. We need to pass its dynamics to the integrator.
aux_data = dict(m_dot=CW_ṁ)

case RigidContacts():
assert isinstance(model.contact_model, RigidContacts)
case contacts.RigidContacts():
assert isinstance(model.contact_model, contacts.RigidContacts)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point.
Expand All @@ -192,8 +192,8 @@ def collidable_point_dynamics(

aux_data = dict()

case RelaxedRigidContacts():
assert isinstance(model.contact_model, RelaxedRigidContacts)
case contacts.RelaxedRigidContacts():
assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point.
Expand All @@ -206,6 +206,31 @@ def collidable_point_dynamics(

aux_data = dict()

case contacts.ViscoElasticContacts():
assert isinstance(model.contact_model, contacts.ViscoElasticContacts)

# It is not yet clear how to pass the time step to this stage.
# A possibility is to restrict the integrator to only forward Euler
# and store the Δt inside the model.
module = jaxsim.rbda.contacts.visco_elastic.step.__module__
name = jaxsim.rbda.contacts.visco_elastic.step.__name__
msg = "You need to use the custom '{}.{}' function with this contact model."
jaxsim.exceptions.raise_runtime_error_if(
condition=True, msg=msg.format(module, name)
)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point.
W_f_Ci, (W_f̿_Ci, m_tf) = model.contact_model.compute_contact_forces(
model=model,
data=data,
dt=None, # TODO
link_forces=link_forces,
joint_force_references=joint_force_references,
)

aux_data = dict(W_f_avg2_C=W_f̿_Ci, m_tf=m_tf)

case _:
raise ValueError(f"Invalid contact model {model.contact_model}")

Expand Down Expand Up @@ -278,7 +303,6 @@ def in_contact(
return links_in_contact


@jax.jit
def estimate_good_soft_contacts_parameters(
model: js.model.JaxSimModel,
*,
Expand All @@ -287,9 +311,15 @@ def estimate_good_soft_contacts_parameters(
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
damping_ratio: jtp.FloatLike = 1.0,
max_penetration: jtp.FloatLike | None = None,
) -> jaxsim.rbda.contacts.SoftContactsParams:
**kwargs,
) -> (
jaxsim.rbda.contacts.RelaxedRigidContactsParams
| jaxsim.rbda.contacts.RigidContactsParams
| jaxsim.rbda.contacts.SoftContactsParams
| jaxsim.rbda.contacts.ViscoElasticContactsParams
):
"""
Estimate good soft contacts parameters for the given model.
Estimate good parameters for soft-like contact models.
Args:
model: The model to consider.
Expand All @@ -313,7 +343,10 @@ def estimate_good_soft_contacts_parameters(
"""

def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
""""""
"""
Displacement between the CoM and the lowest collidable point using zero
joint positions.
"""

zero_data = js.data.JaxSimModelData.build(
model=model,
Expand All @@ -338,21 +371,39 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:

match model.contact_model:

case jaxsim.rbda.contacts.SoftContacts():
assert isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts)
case contacts.SoftContacts():
assert isinstance(model.contact_model, contacts.SoftContacts)

parameters = contacts.SoftContactsParams.build_default_from_jaxsim_model(
model=model,
standard_gravity=standard_gravity,
static_friction_coefficient=static_friction_coefficient,
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
p=model.contact_model.parameters.p,
q=model.contact_model.parameters.q,
)

case contacts.ViscoElasticContacts():
assert isinstance(model.contact_model, contacts.ViscoElasticContacts)

parameters = (
jaxsim.rbda.contacts.SoftContactsParams.build_default_from_jaxsim_model(
contacts.ViscoElasticContactsParams.build_default_from_jaxsim_model(
model=model,
standard_gravity=standard_gravity,
static_friction_coefficient=static_friction_coefficient,
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
p=model.contact_model.parameters.p,
q=model.contact_model.parameters.q,
**kwargs,
)
)

case _:
logging.warning("The active contact model is not soft-like, no-op.")
parameters = model.contact_model.parameters

return parameters
Expand Down
6 changes: 5 additions & 1 deletion src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,11 @@ def build(

if contacts_params is None:

if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts):
if isinstance(
model.contact_model,
jaxsim.rbda.contacts.SoftContacts
| jaxsim.rbda.contacts.ViscoElasticContacts,
):
contacts_params = js.contact.estimate_good_soft_contacts_parameters(
model=model, standard_gravity=standard_gravity
)
Expand Down
15 changes: 14 additions & 1 deletion src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import jax.lax
import jax.numpy as jnp
import jax_dataclasses
import numpy as np
import numpy.typing as npt
from jax_dataclasses import Static

import jaxsim.typing as jtp
Expand Down Expand Up @@ -753,6 +755,13 @@ class ContactParameters(JaxsimDataclass):

point: jtp.Matrix = dataclasses.field(default_factory=lambda: jnp.array([]))

enabled: Static[tuple[bool, ...]] = dataclasses.field(default_factory=tuple)

@property
def indices_of_enabled_collidable_points(self) -> npt.NDArray:

return np.where(np.array(self.enabled))[0]

@staticmethod
def build_from(model_description: ModelDescription) -> ContactParameters:
"""
Expand Down Expand Up @@ -785,7 +794,11 @@ def build_from(model_description: ModelDescription) -> ContactParameters:
)

# Build the ContactParameters object.
cp = ContactParameters(point=points, body=link_index_of_points)
cp = ContactParameters(
point=points,
body=link_index_of_points,
enabled=tuple(jnp.ones(len(link_index_of_points), dtype=bool).tolist()),
)

assert cp.point.shape[1] == 3, cp.point.shape[1]
assert cp.point.shape[0] == len(cp.body), cp.point.shape[0]
Expand Down
15 changes: 11 additions & 4 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import jaxsim.typing as jtp
from jaxsim.integrators import Time
from jaxsim.math import Quaternion
from jaxsim.rbda import contacts

from .common import VelRepr
from .ode_data import ODEState
Expand Down Expand Up @@ -371,8 +372,6 @@ def system_dynamics(
by the system dynamics evaluation.
"""

from jaxsim.rbda.contacts import RelaxedRigidContacts, RigidContacts, SoftContacts

# Compute the accelerations and the material deformation rate.
W_v̇_WB, s̈, aux_dict = system_velocity_dynamics(
model=model,
Expand All @@ -387,10 +386,18 @@ def system_dynamics(

match model.contact_model:

case SoftContacts():
case contacts.SoftContacts():
extended_ode_state["tangential_deformation"] = aux_dict["m_dot"]

case RigidContacts() | RelaxedRigidContacts():
case contacts.ViscoElasticContacts():

extended_ode_state["contacts_state"] = {
"tangential_deformation": jnp.zeros_like(
data.state.extended["tangential_deformation"]
)
}

case contacts.RigidContacts() | contacts.RelaxedRigidContacts():
pass

case _:
Expand Down
3 changes: 2 additions & 1 deletion src/jaxsim/rbda/contacts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import relaxed_rigid, rigid, soft
from . import relaxed_rigid, rigid, soft, visco_elastic
from .common import ContactModel, ContactsParams
from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams
from .rigid import RigidContacts, RigidContactsParams
from .soft import SoftContacts, SoftContactsParams
from .visco_elastic import ViscoElasticContacts, ViscoElasticContactsParams
Loading