Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mesh support #156

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
eb08892
Set env var when parsing from `robot-descriptions`
flferretti May 13, 2024
39d41da
Initial version of mesh support
lorycontixd May 13, 2024
dd34048
Format and lint
flferretti May 14, 2024
1e3162f
Use already existing env var to solve mesh URIs
flferretti May 17, 2024
f965925
Skip loading empty meshes
flferretti May 20, 2024
c93ca33
Added `networkx` as testing dependency
lorycontixd May 16, 2024
dc46576
Address to reviews:
lorycontixd May 20, 2024
5152a02
Added trimesh dependecy for conda-forge
lorycontixd May 21, 2024
7ea8d8a
Moved mesh parsing logic inside mesh collision function
lorycontixd May 21, 2024
be1ac21
Implemented UniformSurfaceSampling for mesh point wrapping
lorycontixd May 21, 2024
3b77b12
Address reviews
lorycontixd May 21, 2024
1b2bcc8
Pre-commit
lorycontixd May 21, 2024
a9c34d9
Update `__eq__` magic and type hints
flferretti Jun 13, 2024
1a5fe78
Removed unused lines in conftest
lorycontixd Jun 20, 2024
65800e7
Fixed typo on logging message
lorycontixd Jun 20, 2024
ede5838
Removed unused import in conftest
lorycontixd Jun 20, 2024
7a921eb
Moved from jax.numpy to numpy in PlaneTerrain __eq__ magic to bypass …
lorycontixd Jul 5, 2024
d6c4ffc
Resolved conflict when merged main
lorycontixd Jul 5, 2024
73f34e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 5, 2024
a9e817c
Resolved conflict when merged main
lorycontixd Jul 5, 2024
e60daee
Merge branch 'add_mesh_support' of github.com:lorycontixd/jaxsim into…
lorycontixd Jul 5, 2024
fcd4cc6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 5, 2024
b0c4d12
Implemented initial reviews from https://github.com/ami-iit/jaxsim/pu…
lorycontixd Jul 5, 2024
83d51d3
Merge branch 'add_mesh_support' of github.com:lorycontixd/jaxsim into…
lorycontixd Jul 5, 2024
77a7484
Merge branch 'main' of https://github.com/ami-iit/jaxsim into add_mes…
lorycontixd Jul 15, 2024
366a6c3
First draft of new mesh wrapping algorithms
lorycontixd Jul 17, 2024
03d0ad8
Implemented structure for new mesh wrapping algorithms
lorycontixd Jul 17, 2024
5c0d62e
New mesh wrapping algorithms
lorycontixd Jul 17, 2024
bac3206
New mesh wrapping algorithms with relative tests
lorycontixd Jul 18, 2024
30a6fc6
Renamed some parameters
lorycontixd Jul 18, 2024
1166995
Restructured mesh mapping methods to follow inheritance
lorycontixd Jul 18, 2024
a8f2b3a
Merge pull request #4 from lorycontixd/new_mesh_wrapping_algos
lorycontixd Jul 18, 2024
30c1397
Run pre-commit
lorycontixd Jul 18, 2024
0c61ec3
Removed leftover parameters on create_mesh_collision
lorycontixd Jul 18, 2024
2ba53fc
Removed wrong point selection & added logs
lorycontixd Jul 18, 2024
ec7b829
Added string magics on wrapping methods
lorycontixd Jul 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies:
- jax-dataclasses >= 1.4.0
- pptree
- rod >= 0.3.0
- trimesh
- typing_extensions # python<3.12
# ====================================
# Optional dependencies from setup.cfg
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ dependencies = [
"pptree",
"rod >= 0.3.0",
"typing_extensions ; python_version < '3.12'",
"trimesh",
"manifold3d",
]

[project.optional-dependencies]
Expand All @@ -64,6 +66,8 @@ testing = [
"pytest >=6.0",
"pytest-icdiff",
"robot-descriptions",
"networkx",
"lark"
]
viz = [
"lxml",
Expand Down
89 changes: 89 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
[metadata]
name = jaxsim
description = A differentiable physics engine and multibody dynamics library for control and robot learning.
long_description = file: README.md
long_description_content_type = text/markdown
author = Diego Ferigo
author_email = [email protected]
license = BSD
license_files = LICENSE
platforms = any
url = https://github.com/ami-iit/jaxsim

project_urls =
Changelog = https://github.com/ami-iit/jaxsim/releases
Documentation = https://jaxsim.readthedocs.io
Source = https://github.com/ami-iit/jaxsim
Tracker = https://github.com/ami-iit/jaxsim/issues

keywords =
physics
physics engine
jax
rigid body dynamics
featherstone
reinforcement learning
robot
robotics
sdf
urdf

classifiers =
Development Status :: 4 - Beta
Framework :: Robot Framework
Intended Audience :: Developers
Intended Audience :: Science/Research
License :: OSI Approved :: BSD License
Operating System :: POSIX :: Linux
Operating System :: MacOS
Operating System :: Microsoft
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.12
Programming Language :: Python :: 3 :: Only
Programming Language :: Python :: Implementation :: CPython
Topic :: Games/Entertainment :: Simulation
Topic :: Scientific/Engineering :: Artificial Intelligence
Topic :: Scientific/Engineering :: Physics
Topic :: Software Development

[options]
zip_safe = False
packages = find:
package_dir =
=src
python_requires = >=3.10
install_requires =
coloredlogs
jax >= 0.4.13
jaxlib >= 0.4.13
jaxlie >= 1.3.0
jax_dataclasses >= 1.4.0
pptree
rod >= 0.3.0
trimesh
lorycontixd marked this conversation as resolved.
Show resolved Hide resolved
typing_extensions ; python_version < '3.12'

[options.packages.find]
where = src

[options.extras_require]
style =
# Note: keep the black version in sync with style.yml and environment.yml.
black[jupyter] ~= 24.0
isort
pre-commit
testing =
idyntree >= 12.2.1
networkx
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this new dependency used for?

pytest >=6.0
pytest-icdiff
robot-descriptions
viz =
lxml
mediapy
mujoco >= 3.0.0
all =
%(style)s
%(testing)s
%(viz)s
8 changes: 7 additions & 1 deletion src/jaxsim/parsers/descriptions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from .collision import BoxCollision, CollidablePoint, CollisionShape, SphereCollision
from .collision import (
BoxCollision,
CollidablePoint,
CollisionShape,
MeshCollision,
SphereCollision,
)
from .joint import JointDescription, JointGenericAxis, JointType
from .link import LinkDescription
from .model import ModelDescription
19 changes: 19 additions & 0 deletions src/jaxsim/parsers/descriptions/collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,22 @@ def __eq__(self, other: BoxCollision) -> bool:
return False

return hash(self) == hash(other)


@dataclasses.dataclass
class MeshCollision(CollisionShape):
center: jtp.VectorLike

def __hash__(self) -> int:
return hash(
(
hash(tuple(self.center.tolist())),
hash(self.collidable_points),
)
)

def __eq__(self, other: MeshCollision) -> bool:
if not isinstance(other, MeshCollision):
return False

return hash(self) == hash(other)
208 changes: 208 additions & 0 deletions src/jaxsim/parsers/rod/meshes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
from abc import ABC, abstractmethod
from typing import List, Sequence

import numpy as np
import rod
import trimesh

from jaxsim import logging


def parse_object_mapping_object(obj) -> trimesh.Trimesh:
if isinstance(obj, trimesh.Trimesh):
return obj
elif isinstance(obj, dict):
if "type" not in obj:
raise ValueError("Object type not specified")
if obj["type"] == "box":
if "extents" not in obj:
raise ValueError("Box extents not specified")
return trimesh.creation.box(extents=obj["extents"])
elif obj["type"] == "sphere":
if "radius" not in obj:
raise ValueError("Sphere radius not specified")
return trimesh.creation.icosphere(subdivisions=4, radius=obj["radius"])
else:
raise ValueError(f"Invalid object type {obj['type']}")
elif isinstance(obj, rod.builder.primitive_builder.PrimitiveBuilder):
raise NotImplementedError("PrimitiveBuilder not implemented")
else:
raise ValueError("Invalid object type")


class MeshMappingMethod(ABC):
@abstractmethod
def __init__(self, **kwargs):
pass

@abstractmethod
def extract_points(self, mesh: trimesh.Trimesh) -> np.ndarray:
self.mesh = mesh
pass

def __call__(self, mesh: trimesh.Trimesh):
return self.extract_points(mesh=mesh)

@abstractmethod
def __str__(self):
return self.__class__.__name__


class VertexExtraction(MeshMappingMethod):
def __init__(self):
pass

def extract_points(self, mesh: trimesh.Trimesh) -> np.ndarray:
super().extract_points(mesh)
return self.mesh.vertices

def __str__(self):
return "VertexExtraction"


class RandomSurfaceSampling(MeshMappingMethod):
def __init__(self, n: int = -1):
if n <= 0 or n > len(self.mesh.vertices):
logging.warning(
"Invalid number of points for random surface sampling. Defaulting to all vertices"
)
n = len(self.mesh.vertices)
self.n = n

def extract_points(self, mesh: trimesh.Trimesh) -> np.ndarray:
super().extract_points(mesh)
return self.mesh.sample(self.n)

def __str__(self):
return f"RandomSurfaceSampling(n={self.n})"


class UniformSurfaceSampling(MeshMappingMethod):
def __init__(self, n: int = -1):
if n <= 0 or n > len(self.mesh.vertices):
logging.warning(
"Invalid number of points for uniform surface sampling. Defaulting to all vertices"
)
n = len(self.mesh.vertices)
self.n = n

def extract_points(self, mesh: trimesh.Trimesh) -> np.ndarray:
super().extract_points(mesh)
return trimesh.sample.sample_surface_even(mesh=self.mesh, count=self.n)

def __str__(self):
return f"UniformSurfaceSampling(n={self.n})"


class AAP(MeshMappingMethod):
def __init__(self, axis: str, operator: str, value: float):
valid_methods = [">", "<", ">=", "<="]
if operator not in valid_methods:
raise ValueError(
f"Invalid method {operator} for AAP. Valid methods are {valid_methods}"
)

match (operator):
case ">":
self.aap_operator = np.greater
case "<":
self.aap_operator = np.less
case ">=":
self.aap_operator = np.greater_equal
case "<=":
self.aap_operator = np.less_equal
case _:
raise ValueError(f"Invalid method {operator} for AAP")

self.axis = axis
self.value = value

def extract_points(self, mesh: trimesh.Trimesh) -> np.ndarray:
super().extract_points(mesh)
if self.axis == "x":
points = self.mesh.vertices[
self.aap_operator(self.mesh.vertices[:, 0], self.value)
]
elif self.axis == "y":
points = self.mesh.vertices[
self.aap_operator(self.mesh.vertices[:, 1], self.value)
]
elif self.axis == "z":
points = self.mesh.vertices[
self.aap_operator(self.mesh.vertices[:, 2], self.value)
]
else:
raise ValueError("Invalid axis for axis-aligned plane")

return points

def __str__(self):
return (
f"AAP(axis={self.axis}, operator={self.aap_operator}, value={self.value})"
)


class SelectPointsOverAxis(MeshMappingMethod):
def __init__(self, axis: str, direction: str, n: int):
valid_dirs = ["higher", "lower"]
if direction not in valid_dirs:
raise ValueError(f"Invalid direction. Valid directions are {valid_dirs}")
self.axis = axis
self.direction = direction
self.n = n

def extract_points(self, mesh: trimesh.Trimesh) -> np.ndarray:
super().extract_points(mesh)
arr = self.mesh.vertices

index = 0 if self.axis == "x" else 1 if self.axis == "y" else 2
# Sort the array in ascending order
sorted_arr = arr[arr[:, index].argsort()]

if self.direction == "lower":
# Select first N points
points = sorted_arr[: self.n]
elif self.direction == "higher":
# Select last N points
points = sorted_arr[-self.n :]
else:
raise ValueError(
f"Invalid direction {self.direction} for SelectPointsOverAxis method"
)

return points

def __str__(self):
return f"SelectPointsOverAxis(axis={self.axis}, direction={self.direction}, n={self.n})"


class ObjectMapping(MeshMappingMethod):
def __init__(
self, objs: Sequence[trimesh.Trimesh | dict], method: str = "subtract"
):
valid_methods = ["subtract", "intersect"]
if method not in valid_methods:
raise ValueError(f"Invalid method {method} for object mapping")
self.method = method

self.objs: List[trimesh.Trimesh] = []
for obj in objs:
self.objs.append(parse_object_mapping_object(obj))

def extract_points(self, mesh: trimesh.Trimesh) -> np.ndarray:
super().extract_points(mesh)
if len(self.objs) == 0:
return mesh.vertices
if self.method == "subtract":
for obj in self.objs:
mesh = mesh.difference(obj)
elif self.method == "intersect":
for obj in self.objs:
mesh = mesh.intersection(obj)
else:
raise ValueError(f"Invalid method {self.method} for object mapping")

return mesh.vertices

def __str__(self):
return f"ObjectMapping(method={self.method}, objs={len(self.objs)})"
11 changes: 11 additions & 0 deletions src/jaxsim/parsers/rod/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,17 @@ def extract_model_data(

collisions.append(sphere_collision)

if collision.geometry.mesh is not None:
mesh_collision = utils.create_mesh_collision(
collision=collision,
link_description=links_dict[link.name],
method=utils.meshes.SelectPointsOverAxis(
axis="z", direction="lower", n=50
),
)
if mesh_collision is not None:
collisions.append(mesh_collision)
Comment on lines +336 to +337
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this ever happen? I guess that in the case utils.create_mesh_collision fails, an exception is raised.


return SDFData(
model_name=sdf_model.name,
link_descriptions=links,
Expand Down
Loading
Loading