Skip to content

Commit

Permalink
Make standalone function for interpolation over non-matching meshes (#…
Browse files Browse the repository at this point in the history
…3177)

* Fix len comparison in interpolation

* Add non-matching interpolation flag enum to interpolate to ensure that we can distinguish between a submesh with 0 cells and a nonmatching mesh with 0 cells on a process

* Ruf formatting

* Add docstring and update year

* More doc updates

* Split nonmatching interpolation into separate function.

* Apply suggestions from code review

Co-authored-by: Garth N. Wells <[email protected]>

* Apply suggestions from code review

* Apply suggestions from code review

Co-authored-by: Garth N. Wells <[email protected]>

* Various renaming

* Simplify Python interface

* Make proper Python interface and remove simplified constructors

* Simplify and ruff formatting

* Add documentation example

* Grammar fix

* Remove \p and add docstring

* Apply suggestions from code review

Co-authored-by: Garth N. Wells <[email protected]>

* Fix documentation

* Shorten docstring

* Revert local clang formatting

* Add back newline

* Syntax fixes

* Simplify syntax

* Simplifications

* Demo update

* Small update

* Simplifications

* Simplifications

* Fix typo

* Improve docs

* Doc fixes

* Refine expression interpolate docs

* Remove a default args.

User can easily mix up arg meaning

* Resolve overload

* Doc update

* Tidy

* Doc improvement

* Move cell mapping interface to Function.h

* Doc improvement

* Doc work

* Change order of interpolate functions

* Refactor

* Update test

* Improve logic

* Doc fix

* Tidy

* Improvements

* Formatting

* More fixes

* Formatting fix

* Logic improvements

* Fix order

* Tidy up

* Interface update

* Re-enable test

* Tidy up

* Improve logic in ordering

* Doc improvements

* Minor tidy

* Minor doc edit

---------

Co-authored-by: Garth N. Wells <[email protected]>
  • Loading branch information
jorgensd and garth-wells authored May 14, 2024
1 parent b9cf771 commit ab79530
Show file tree
Hide file tree
Showing 17 changed files with 727 additions and 797 deletions.
2 changes: 1 addition & 1 deletion cpp/demo/hyperelasticity/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ int main(int argc, char* argv[])

auto sigma = fem::Function<T>(S);
sigma.name = "cauchy_stress";
sigma.interpolate(sigma_expression, *mesh);
sigma.interpolate(sigma_expression);

// Save solution in VTK format
io::VTKFile file_u(mesh->comm(), "u.pvd", "w");
Expand Down
19 changes: 12 additions & 7 deletions cpp/demo/interpolation_different_meshes/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,19 @@ int main(int argc, char* argv[])
u_tet->interpolate(fun);

// Interpolate from u_tet to u_hex
constexpr T padding = 1e-8;
auto nmm_interpolation_data
= fem::create_nonmatching_meshes_interpolation_data(
*u_hex->function_space()->mesh(),
auto cell_map
= mesh_hex->topology()->index_map(mesh_hex->topology()->dim());
assert(cell_map);
std::vector<std::int32_t> cells(
cell_map->size_local() + cell_map->num_ghosts(), 0);
std::iota(cells.begin(), cells.end(), 0);
geometry::PointOwnershipData<T> interpolation_data
= fem::create_interpolation_data(
u_hex->function_space()->mesh()->geometry(),
*u_hex->function_space()->element(),
*u_tet->function_space()->mesh(), padding);
constexpr std::span<const std::int32_t> cell_map;
u_hex->interpolate(*u_tet, cell_map, nmm_interpolation_data);
*u_tet->function_space()->mesh(),
std::span<const std::int32_t>(cells), 1e-8);
u_hex->interpolate(*u_tet, cells, interpolation_data);

#ifdef HAS_ADIOS2
io::VTXWriter<double> write_tet(mesh_tet->comm(), "u_tet.bp", {u_tet});
Expand Down
48 changes: 23 additions & 25 deletions cpp/dolfinx/fem/Expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ template <dolfinx::scalar T,
class Expression
{
public:
/// @brief Scalar type
/// @brief Scalar type.
///
/// Field type for the Expression, e.g. `double`,
/// `std::complex<float>`, etc.
Expand All @@ -51,16 +51,18 @@ class Expression

/// @brief Create an Expression.
///
/// @note Users should prefer the @ref create_expression factory functions.
/// @note Users should prefer the @ref create_expression factory
/// functions.
///
/// @param[in] coefficients Coefficients in the Expression
/// @param[in] coefficients Coefficients in the Expression.
/// @param[in] constants Constants in the Expression
/// @param[in] X points on reference cell, `shape=(number of points,
/// tdim)` and storage is row-major.
/// @param[in] X Points on the reference cell, `shape=(number of
/// points, tdim)` and storage is row-major.
/// @param[in] Xshape Shape of `X`.
/// @param[in] fn function for tabulating expression
/// @param[in] value_shape shape of expression evaluated at single point
/// @param[in] argument_function_space Function space for Argument
/// @param[in] fn Function for tabulating the Expression.
/// @param[in] value_shape Shape of Expression evaluated at single
/// point.
/// @param[in] argument_function_space Function space for Argument.
Expression(
const std::vector<std::shared_ptr<
const Function<scalar_type, geometry_type>>>& coefficients,
Expand Down Expand Up @@ -142,35 +144,31 @@ class Expression
}

/// @brief Evaluate Expression on cells or facets.
///
/// @param[in] mesh Cells on which to evaluate the Expression.
/// @param[in] entities List of entities to evaluate the expression on. This
/// could be either a list of cells or a list of (cell, local facet index)
/// tuples. Array is flattened per entity.
/// @param[out] values A 2D array to store the result. Caller
/// is responsible for correct sizing which should be `(num_cells,
/// @param[in] entities List of entities to evaluate the expression
/// on. This could be either a list of cells or a list of (cell, local
/// @param[out] values A 2D array to store the result. Caller is
/// responsible for correct sizing which should be `(num_cells,
/// num_points * value_size * num_all_argument_dofs columns)`.
/// @param[in] vshape The shape of @p values (row-major storage).
/// facet index) tuples. Array is flattened per entity.
/// @param[in] vshape The shape of `values` (row-major storage).
void eval(const mesh::Mesh<geometry_type>& mesh,
std::span<const std::int32_t> entities,
std::span<scalar_type> values,
std::array<std::size_t, 2> vshape) const
{
std::size_t estride;
if (mesh.topology()->dim() == _x_ref.second[1])
{
estride = 1;
}
else if (mesh.topology()->dim() == _x_ref.second[1] + 1)
{
estride = 2;
}
else
{
throw std::runtime_error("Invalid dimension of evaluation points.");
}

// Prepare coefficients and constants
const auto [coeffs, cstride] = pack_coefficients(*this, entities, estride);
const std::vector<scalar_type> constant_data = pack_constants(*this);
auto [coeffs, cstride] = pack_coefficients(*this, entities, estride);
std::vector<scalar_type> constant_data = pack_constants(*this);
auto fn = this->get_tabulate_expression();

// Prepare cell geometry
Expand All @@ -179,7 +177,7 @@ class Expression
// Get geometry data
auto& cmap = mesh.geometry().cmap();

const std::size_t num_dofs_g = cmap.dim();
std::size_t num_dofs_g = cmap.dim();
auto x_g = mesh.geometry().x();

// Create data structures used in evaluation
Expand Down Expand Up @@ -227,11 +225,11 @@ class Expression
}

// Iterate over cells and 'assemble' into values
const int size0 = _x_ref.second[0] * value_size();
int size0 = _x_ref.second[0] * value_size();
std::vector<scalar_type> values_local(size0 * num_argument_dofs, 0);
for (std::size_t e = 0; e < entities.size() / estride; ++e)
{
const std::int32_t entity = entities[e * estride];
std::int32_t entity = entities[e * estride];
auto x_dofs = MDSPAN_IMPL_STANDARD_NAMESPACE::submdspan(
x_dofmap, entity, MDSPAN_IMPL_STANDARD_NAMESPACE::full_extent);
for (std::size_t i = 0; i < x_dofs.size(); ++i)
Expand Down
Loading

0 comments on commit ab79530

Please sign in to comment.