Skip to content

Commit

Permalink
fix: add support for loading and processing Taichi meshes
Browse files Browse the repository at this point in the history
This commit introduces support for loading and processing Taichi meshes in the mkit library. Taichi meshes can now be loaded using the new `load_taichi` function. This enhancement expands the capabilities of the library to work with Taichi meshes, providing users with more flexibility in their mesh processing workflows.
  • Loading branch information
liblaf committed May 5, 2024
1 parent d625a56 commit f332363
Show file tree
Hide file tree
Showing 17 changed files with 341 additions and 3 deletions.
4 changes: 4 additions & 0 deletions src/mkit/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from mkit.io._meshio import load_meshio as load_meshio
from mkit.io._pyvista import as_pyvista as as_pyvista
from mkit.io._pyvista import load_pyvista as load_pyvista
from mkit.io._taichi import as_taichi as as_taichi
from mkit.io._taichi import load_taichi as load_taichi
from mkit.io._trimesh import as_trimesh as as_trimesh
from mkit.io._trimesh import load_trimesh as load_trimesh
from mkit.io.types import AnyMesh
Expand All @@ -15,9 +17,11 @@
__all__ = [
"as_meshio",
"as_pyvista",
"as_taichi",
"as_trimesh",
"load_meshio",
"load_pyvista",
"load_taichi",
"load_trimesh",
"save",
]
Expand Down
39 changes: 39 additions & 0 deletions src/mkit/io/_taichi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Any

import meshio
import meshtaichi_patcher
import taichi as ti
import taichi.lang.mesh
from numpy import typing as npt

from mkit.io.types import AnyMesh
from mkit.typing import StrPath


def load_taichi(filename: StrPath) -> ti.MeshInstance:
raise NotImplementedError # TODO


def as_taichi(mesh: AnyMesh, relations: list[str] = []) -> ti.MeshInstance:
match mesh:
case ti.MeshInstance():
return mesh
case meshio.Mesh():
return meshio_to_taichi(mesh, relations)
case _:
raise NotImplementedError(f"unsupported mesh: {mesh}")


def meshio_to_taichi(mesh: meshio.Mesh, relations: list[str] = []) -> ti.MeshInstance:
total: dict[int, npt.NDArray] = {
0: mesh.points,
3: mesh.get_cells_type("tetra"),
}
patcher = meshtaichi_patcher.meshpatcher.MeshPatcher(total)
patcher.patcher.patch_size = 256
patcher.patcher.cluster_option = "greedy"
patcher.patcher.debug = False
patcher.patch(max_order=-1, patch_relation="all")
meta: dict[str, Any] = patcher.get_meta(relations)
metadata: taichi.lang.mesh.MeshMetadata = ti.Mesh.generate_meta(meta)
return ti.Mesh._create_instance(metadata)
Empty file added src/mkit/physics/__init__.py
Empty file.
25 changes: 25 additions & 0 deletions src/mkit/physics/moduli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""https://en.wikipedia.org/wiki/Template:Elastic_moduli"""


def E_nu2lambda(E: float, nu: float) -> float:
"""
Args:
E: Young's modulus
nu: Poisson's ratio
Returns:
Lamé's first parameter
"""
return E * nu / ((1 + nu) * (1 - 2 * nu))


def E_nu2mu(E: float, nu: float) -> float:
"""
Args:
E: Young's modulus
nu: Poisson's ratio
Returns:
Lamé's second parameter / Shear modulus
"""
return E / (2 * (1 + nu))
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
7 changes: 4 additions & 3 deletions tasks/tetgen/src/create.py → tasks/01-tetgen/src/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ def main(
output_file: Annotated[
pathlib.Path, typer.Option("-o", "--output", dir_okay=False, writable=True)
],
displacement: Annotated[float, typer.Option()] = 0.05,
) -> None:
face: trimesh.Trimesh = trimesh.creation.icosphere(radius=1.0)
pre_skull: trimesh.Trimesh = trimesh.creation.icosphere(radius=0.5)
face: trimesh.Trimesh = trimesh.creation.icosphere(radius=0.2)
pre_skull: trimesh.Trimesh = trimesh.creation.icosphere(radius=0.1)
post_skull: trimesh.Trimesh = pre_skull.copy()
post_skull.vertices[post_skull.vertices[:, 1] < 0.0] += [0.0, -0.1, 0.0]
post_skull.vertices[post_skull.vertices[:, 1] < 0.0] += [0.0, -displacement, 0.0]
tri: trimesh.Trimesh = trimesh.util.concatenate([face, pre_skull])
tet: meshio.Mesh = mkit.ops.tetgen.tetgen(
mkit.io.as_meshio(
Expand Down
File renamed without changes.
25 changes: 25 additions & 0 deletions tasks/02-simulation/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
default:

data/%.vtu: ../01-tetgen/src/create.py
@ mkdir --parents --verbose "$(@D)"
python "$<" --output="$@" --displacement="$*"

# displacement, Poisson's ratio, method
define run
default: data/$(1)-$(2)-$(3).vtu
data/$(1)-$(2)-$(3).vtu: data/$(1).vtu src/mtm-$(3).py
python "src/mtm-$(3).py" --output="$$@" --poisson-ratio="$(2)" "$$<"
endef

$(eval $(call run,0.01,0.00,matrix-free))
$(eval $(call run,0.01,0.10,matrix-free))
$(eval $(call run,0.01,0.20,matrix-free))
$(eval $(call run,0.01,0.30,matrix-free))
$(eval $(call run,0.01,0.40,matrix-free))
$(eval $(call run,0.01,0.46,matrix-free))
$(eval $(call run,0.05,0.00,matrix-free))
$(eval $(call run,0.05,0.10,matrix-free))
$(eval $(call run,0.05,0.20,matrix-free))
$(eval $(call run,0.05,0.30,matrix-free))
$(eval $(call run,0.05,0.40,matrix-free))
$(eval $(call run,0.05,0.46,matrix-free))
244 changes: 244 additions & 0 deletions tasks/02-simulation/src/mtm-matrix-free.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import functools
import pathlib
from typing import Annotated, no_type_check

import meshio
import mkit.cli
import mkit.io
import mkit.physics.moduli
import numpy as np
import scipy.sparse.linalg
import taichi as ti
import typer
from loguru import logger
from numpy import typing as npt


@no_type_check
@ti.func
def _stiffness_tet(
volume: float, Mj: ti.math.vec3, Mk: ti.math.vec3, lambda_: float, mu: float
) -> ti.math.mat3:
return (
1
/ (36 * volume)
* (
lambda_ * Mk.outer_product(Mj)
+ mu * Mj.outer_product(Mk)
+ mu * Mj.dot(Mk) * ti.Matrix.identity(float, 3)
)
)


@no_type_check
@ti.func
def is_fixed(vert: ti.template()) -> bool:
return not is_free(vert)


@no_type_check
@ti.func
def is_free(vert: ti.template()) -> bool:
return ti.math.isnan(vert.disp).any()


@ti.data_oriented
class MTM:
mesh: ti.MeshInstance

def __init__(self, mesh: meshio.Mesh) -> None:
self.mesh = mkit.io.as_taichi(mesh, ["CE", "CV", "EV", "VE"])
self.mesh.cells.place({"lambda_": float, "mu": float, "volume": float})
self.mesh.edges.place({"K": ti.math.mat3})
self.mesh.verts.place(
{
"Ax": ti.math.vec3,
"b": ti.math.vec3,
"disp": ti.math.vec3,
"f": ti.math.vec3,
"K": ti.math.mat3,
"pos": ti.math.vec3,
"x": ti.math.vec3,
}
)
disp: ti.MatrixField = self.mesh.verts.get_member_field("disp")
disp.from_numpy(mesh.point_data["disp"])
pos: ti.MatrixField = self.mesh.verts.get_member_field("pos")
pos.from_numpy(self.mesh.get_position_as_numpy())

@functools.cached_property
def fixed_mask(self) -> npt.NDArray[np.bool_]:
return ~self.free_mask

@functools.cached_property
def free_mask(self) -> npt.NDArray[np.bool_]:
disp_ti: ti.MatrixField = self.mesh.verts.get_member_field("disp")
disp_np: npt.NDArray[np.floating] = disp_ti.to_numpy()
fixed_mask: npt.NDArray[np.bool_] = np.isnan(disp_np).any(axis=1)
return fixed_mask

def init_material(self, E: float, nu: float) -> None:
lambda_: float = mkit.physics.moduli.E_nu2lambda(E, nu)
lambda_ti: ti.ScalarField = self.mesh.cells.get_member_field("lambda_")
lambda_ti.fill(lambda_)
mu: float = mkit.physics.moduli.E_nu2mu(E, nu)
mu_ti: ti.ScalarField = self.mesh.cells.get_member_field("mu")
mu_ti.fill(mu)

def volume(self) -> None:
self._volume()

def stiffness(self) -> None:
K_edges: ti.MatrixField = self.mesh.edges.get_member_field("K")
K_edges.fill(0)
K_verts: ti.MatrixField = self.mesh.verts.get_member_field("K")
K_verts.fill(0)
self._stiffness()

def force(self, x_free: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]:
disp_ti: ti.MatrixField = self.mesh.verts.get_member_field("disp")
x_np: npt.NDArray[np.floating] = disp_ti.to_numpy()
x_np[self.free_mask] = x_free
x_ti: ti.MatrixField = self.mesh.verts.get_member_field("x")
x_ti.from_numpy(x_np)
self._force()
f_ti: ti.MatrixField = self.mesh.verts.get_member_field("f")
f_np: npt.NDArray[np.floating] = f_ti.to_numpy()
return f_np

def Ax(self, x_free: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]:
disp_ti: ti.MatrixField = self.mesh.verts.get_member_field("disp")
x_np: npt.NDArray[np.floating] = disp_ti.to_numpy()
x_np[self.free_mask] = x_free
x_ti: ti.MatrixField = self.mesh.verts.get_member_field("x")
x_ti.from_numpy(x_np)
self._Ax()
Ax_ti: ti.MatrixField = self.mesh.verts.get_member_field("Ax")
Ax_np: npt.NDArray[np.floating] = Ax_ti.to_numpy()
return Ax_np[self.free_mask]

def b(self) -> npt.NDArray[np.floating]:
disp_ti: ti.MatrixField = self.mesh.verts.get_member_field("disp")
x_ti: ti.MatrixField = self.mesh.verts.get_member_field("x")
x_ti.copy_from(disp_ti)
self._b()
b_ti: ti.MatrixField = self.mesh.verts.get_member_field("b")
b_np: npt.NDArray[np.floating] = b_ti.to_numpy()
return b_np[self.free_mask]

@no_type_check
@ti.kernel
def _volume(self):
for c in self.mesh.cells:
dX = ti.Matrix.cols(
[c.verts[i].pos - c.verts[3].pos for i in ti.static(range(3))]
)
c.volume = 1 / 6 * ti.abs(dX.determinant())

@no_type_check
@ti.kernel
def _stiffness(self):
for c in self.mesh.cells:
M = ti.Matrix.zero(ti.f64, 4, 3)
for i in ti.static(range(4)):
v = c.verts[i]
p = [c.verts[(i + j) % 4].pos for j in range(1, 4)]
Mi = p[0].cross(p[1]) + p[1].cross(p[2]) + p[2].cross(p[0])
if Mi.dot(v.pos - p[0]) > 0:
Mi = -Mi
M[i, :] = Mi
v.K += _stiffness_tet(c.volume, Mi, Mi, c.lambda_, c.mu)
for e_id in ti.static(range(6)):
e = c.edges[e_id]
v_idx = ti.Vector.zero(int, 2) # vertex index in cell (in range(4))
for i in ti.static(range(2)):
for j in range(4):
if e.verts[i].id == c.verts[j].id:
v_idx[i] = j
break
j, k = v_idx
Mj, Mk = M[j, :], M[k, :]
e.K += _stiffness_tet(c.volume, Mj, Mk, c.lambda_, c.mu)

@no_type_check
@ti.kernel
def _force(self):
for v in self.mesh.verts:
v.f = v.K @ v.x
for e_idx in range(v.edges.size):
e = v.edges[e_idx]
if e.verts[0].id == v.id:
v.Ax += e.K @ e.verts[1].x
elif e.verts[1].id == v.id:
v.Ax += e.K.transpose() @ e.verts[0].x

@no_type_check
@ti.kernel
def _Ax(self):
for v in self.mesh.verts:
if is_free(v):
v.Ax = v.K @ v.x
for e_idx in range(v.edges.size):
e = v.edges[e_idx]
if e.verts[0].id == v.id and is_free(e.verts[1]):
v.Ax += e.K @ e.verts[1].x
elif e.verts[1].id == v.id and is_free(e.verts[0]):
v.Ax += e.K.transpose() @ e.verts[0].x

@no_type_check
@ti.kernel
def _b(self):
for v in self.mesh.verts:
if is_free(v):
v.b = 0
for e_idx in range(v.edges.size):
e = v.edges[e_idx]
if e.verts[0].id == v.id and is_fixed(e.verts[1]):
v.b -= e.K @ e.verts[1].disp
elif e.verts[1].id == v.id and is_fixed(e.verts[0]):
v.b -= e.K.transpose() @ e.verts[0].disp


class Operator(scipy.sparse.linalg.LinearOperator):
mtm: MTM

def __init__(self, mtm: MTM) -> None:
self.mtm = mtm
super().__init__(float, (self.n_free * 3, self.n_free * 3))

@functools.cached_property
def n_free(self) -> int:
return np.count_nonzero(self.mtm.free_mask)

def _matvec(self, x: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]:
return self.mtm.Ax(x.reshape((self.n_free, 3))).flatten()


def main(
input_file: Annotated[pathlib.Path, typer.Argument(exists=True, dir_okay=False)],
*,
output_file: Annotated[
pathlib.Path, typer.Option("-o", "--output", dir_okay=False, writable=True)
],
poisson_ratio: Annotated[float, typer.Option()] = 0.0,
) -> None:
ti.init(ti.cpu, default_fp=ti.float64)
mesh: meshio.Mesh = mkit.io.load_meshio(input_file)
mtm = MTM(mesh)
mtm.init_material(E=3000, nu=poisson_ratio)
mtm.volume()
mtm.stiffness()
A = Operator(mtm)
x_free: npt.NDArray[np.floating]
info: int
x_free, info = scipy.sparse.linalg.gmres(A, mtm.b().flatten())
x_free = x_free.reshape((A.n_free, 3))
logger.info("GMRES info: {}", info)
disp: npt.NDArray[np.floating] = mesh.point_data["disp"]
mesh.points[mtm.fixed_mask] += disp[mtm.fixed_mask]
mesh.points[mtm.free_mask] += x_free
mkit.io.save(output_file, mesh, point_data={"force": mtm.force(x_free)})


if __name__ == "__main__":
mkit.cli.run(main)

0 comments on commit f332363

Please sign in to comment.