Skip to content

Commit

Permalink
Reduce function argument types to basic types (#2592)
Browse files Browse the repository at this point in the history
* Fix types

* More type fixing

* Fix tests

* Interface type improvements

* Small simplification

* Simplify
  • Loading branch information
garth-wells authored Mar 20, 2023
1 parent 362f241 commit 5290785
Show file tree
Hide file tree
Showing 14 changed files with 131 additions and 125 deletions.
4 changes: 3 additions & 1 deletion cpp/demo/interpolation_different_meshes/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ int main(int argc, char* argv[])
// Interpolate from u_tet to u_hex
auto nmm_interpolation_data
= fem::create_nonmatching_meshes_interpolation_data(
*u_hex->function_space(), *u_tet->function_space());
*u_hex->function_space()->mesh(),
*u_hex->function_space()->element(),
*u_tet->function_space()->mesh());
u_hex->interpolate(*u_tet, nmm_interpolation_data);

#ifdef HAS_ADIOS2
Expand Down
5 changes: 3 additions & 2 deletions cpp/dolfinx/fem/Function.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,9 @@ class Function
assert(_function_space);
assert(_function_space->element());
assert(_function_space->mesh());
const std::vector<double> x = fem::interpolation_coords(
*_function_space->element(), *_function_space->mesh(), cells);
const std::vector<double> x
= fem::interpolation_coords(*_function_space->element(),
_function_space->mesh()->geometry(), cells);
namespace stdex = std::experimental;
stdex::mdspan<const double,
stdex::extents<std::size_t, 3, stdex::dynamic_extent>>
Expand Down
58 changes: 23 additions & 35 deletions cpp/dolfinx/fem/interpolate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ using namespace dolfinx;

//-----------------------------------------------------------------------------
std::vector<double>
fem::interpolation_coords(const FiniteElement& element, const mesh::Mesh& mesh,
fem::interpolation_coords(const FiniteElement& element,
const mesh::Geometry<double>& geometry,
std::span<const std::int32_t> cells)
{
// Get mesh geometry data and the element coordinate map
const std::size_t gdim = mesh.geometry().dim();
const graph::AdjacencyList<std::int32_t>& x_dofmap = mesh.geometry().dofmap();

std::span<const double> x_g = mesh.geometry().x();
const CoordinateElement& cmap = mesh.geometry().cmap();
// Get geometry data and the element coordinate map
const std::size_t gdim = geometry.dim();
const graph::AdjacencyList<std::int32_t>& x_dofmap = geometry.dofmap();
std::span<const double> x_g = geometry.x();
const CoordinateElement& cmap = geometry.cmap();
const std::size_t num_dofs_g = cmap.dim();

// Get the interpolation points on the reference cells
Expand Down Expand Up @@ -71,49 +71,37 @@ fem::interpolation_coords(const FiniteElement& element, const mesh::Mesh& mesh,
}
//-----------------------------------------------------------------------------
fem::nmm_interpolation_data_t fem::create_nonmatching_meshes_interpolation_data(
const fem::FunctionSpace& Vu, const fem::FunctionSpace& Vv,
std::span<const std::int32_t> cells)
const mesh::Geometry<double>& geometry0, const FiniteElement& element0,
const mesh::Mesh& mesh1, std::span<const std::int32_t> cells)
{

// Collect all the points at which values are needed to define the
// interpolating function
auto element_u = Vu.element();
assert(element_u);
auto mesh = Vu.mesh();
assert(mesh);
const std::vector<double> coords_b
= interpolation_coords(*element_u, *mesh, cells);

namespace stdex = std::experimental;
using cmdspan2_t
= stdex::mdspan<const double, stdex::dextents<std::size_t, 2>>;
using mdspan2_t = stdex::mdspan<double, stdex::dextents<std::size_t, 2>>;
cmdspan2_t coords(coords_b.data(), 3, coords_b.size() / 3);
const std::vector<double> coords
= interpolation_coords(element0, geometry0, cells);

// Transpose interpolation coords
std::vector<double> x(coords.size());
mdspan2_t _x(x.data(), coords_b.size() / 3, 3);
for (std::size_t j = 0; j < coords.extent(1); ++j)
for (std::size_t i = 0; i < 3; ++i)
_x(j, i) = coords(i, j);
std::size_t num_points = coords.size() / 3;
for (std::size_t i = 0; i < num_points; ++i)
for (std::size_t j = 0; j < 3; ++j)
x[3 * i + j] = coords[i + j * num_points];

// Determine ownership of each point
auto mesh_v = Vv.mesh();
assert(mesh_v);
return geometry::determine_point_ownership(*mesh_v, x);
return geometry::determine_point_ownership(mesh1, x);
}
//-----------------------------------------------------------------------------
fem::nmm_interpolation_data_t
fem::create_nonmatching_meshes_interpolation_data(const FunctionSpace& Vu,
const FunctionSpace& Vv)
fem::create_nonmatching_meshes_interpolation_data(const mesh::Mesh& mesh0,
const FiniteElement& element0,
const mesh::Mesh& mesh1)
{
assert(Vu.mesh());
int tdim = Vu.mesh()->topology().dim();
auto cell_map = Vu.mesh()->topology().index_map(tdim);
int tdim = mesh0.topology().dim();
auto cell_map = mesh0.topology().index_map(tdim);
assert(cell_map);
std::int32_t num_cells = cell_map->size_local() + cell_map->num_ghosts();
std::vector<std::int32_t> cells(num_cells, 0);
std::iota(cells.begin(), cells.end(), 0);
return create_nonmatching_meshes_interpolation_data(Vu, Vv, cells);
return create_nonmatching_meshes_interpolation_data(mesh0.geometry(),
element0, mesh1, cells);
}
//-----------------------------------------------------------------------------
69 changes: 34 additions & 35 deletions cpp/dolfinx/fem/interpolate.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,24 @@ namespace dolfinx::fem
template <typename T>
class Function;

/// Compute the evaluation points in the physical space at which an
/// expression should be computed to interpolate it in a finite element
/// space.
/// @brief Compute the evaluation points in the physical space at which
/// an expression should be computed to interpolate it in a finite
/// element space.
///
/// @param[in] element The element to be interpolated into
/// @param[in] mesh The domain
/// @param[in] geometry Mesh geometry
/// @param[in] cells Indices of the cells in the mesh to compute
/// interpolation coordinates for
/// @return The coordinates in the physical space at which to evaluate
/// an expression. The shape is (3, num_points) and storage is row-major.
std::vector<double> interpolation_coords(const fem::FiniteElement& element,
const mesh::Mesh& mesh,
const mesh::Geometry<double>& geometry,
std::span<const std::int32_t> cells);

/// Helper type for the data that can be cached to speed up repeated
/// interpolation of discrete functions on nonmatching meshes
using nmm_interpolation_data_t = decltype(std::function{
dolfinx::geometry::determine_point_ownership})::result_type;
using nmm_interpolation_data_t
= decltype(std::function{geometry::determine_point_ownership})::result_type;

/// Forward declaration
template <typename T>
Expand All @@ -52,18 +52,18 @@ void interpolate(Function<T>& u, std::span<const T> f,

namespace impl
{
/// @brief Scatter data into non-contiguous memory
/// @brief Scatter data into non-contiguous memory.
///
/// Scatter blocked data `send_values` to its
/// corresponding src_rank and insert the data into `recv_values`.
/// The insert location in `recv_values` is determined by `dest_ranks`.
/// If the j-th dest rank is -1, then
/// `recv_values[j*block_size:(j+1)*block_size]) = 0.
/// Scatter blocked data `send_values` to its corresponding src_rank and
/// insert the data into `recv_values`. The insert location in
/// `recv_values` is determined by `dest_ranks`. If the j-th dest rank
/// is -1, then `recv_values[j*block_size:(j+1)*block_size]) = 0.
///
/// @param[in] comm The mpi communicator
/// @param[in] src_ranks The rank owning the values of each row in send_values
/// @param[in] dest_ranks List of ranks receiving data. Size of array is how
/// many values we are receiving (not unrolled for blcok_size).
/// @param[in] src_ranks The rank owning the values of each row in
/// send_values
/// @param[in] dest_ranks List of ranks receiving data. Size of array is
/// how many values we are receiving (not unrolled for blcok_size).
/// @param[in] send_values The values to send back to owner. Shape
/// (src_ranks.size(), block_size). Storage is row-major.
/// @param[in] s_shape Shape of send_values
Expand All @@ -73,11 +73,10 @@ namespace impl
/// @note dest_ranks can contain repeated entries
/// @note dest_ranks might contain -1 (no process owns the point)
template <typename T>
void scatter_values(const MPI_Comm& comm,
std::span<const std::int32_t> src_ranks,
void scatter_values(MPI_Comm comm, std::span<const std::int32_t> src_ranks,
std::span<const std::int32_t> dest_ranks,
std::span<const T> send_values,
const std::array<std::size_t, 2>& s_shape,
std::array<std::size_t, 2> s_shape,
std::span<T> recv_values)
{
const std::size_t block_size = s_shape[1];
Expand Down Expand Up @@ -916,29 +915,29 @@ void interpolate(Function<T>& u, std::span<const T> f,
}
}

/// Generate data needed to interpolate discrete functions across
/// @brief Generate data needed to interpolate discrete functions across
/// different meshes.
///
/// @param[out] Vu The function space of the function to interpolate
/// into
/// @param[in] Vv The function space of the function to interpolate from
/// @param[in] geometry0 Mesh geometry of the space to interpolate into
/// @param[in] element0 Element of the space to interpolate into
/// @param[in] mesh1 Mesh of the function to interpolate from
/// @param[in] cells Indices of the cells in the destination mesh on
/// which to interpolate. Should be the same as the list used when
/// calling fem::interpolation_coords.
nmm_interpolation_data_t create_nonmatching_meshes_interpolation_data(
const FunctionSpace& Vu, const FunctionSpace& Vv,
std::span<const std::int32_t> cells);

/// Generate data needed to interpolate discrete functions defined on
/// different meshes. Interpolate on all cells in the mesh.
///
/// @param[out] Vu The function space of the function to interpolate into
/// @param[in] Vv The function space of the function to interpolate from
const mesh::Geometry<double>& geometry0, const FiniteElement& element0,
const mesh::Mesh& mesh1, std::span<const std::int32_t> cells);

/// @brief Generate data needed to interpolate discrete functions
/// defined on different meshes. Interpolate on all cells in the mesh.
/// @param[in] mesh0 Mesh of the space to interpolate into
/// @param[in] element0 Element of the space to interpolate into
/// @param[in] mesh1 Mesh of the function to interpolate from
nmm_interpolation_data_t
create_nonmatching_meshes_interpolation_data(const FunctionSpace& Vu,
const FunctionSpace& Vv);
create_nonmatching_meshes_interpolation_data(const mesh::Mesh& mesh0,
const FiniteElement& element0,
const mesh::Mesh& mesh1);

/// Interpolate from one finite element Function to another one
/// @brief Interpolate from one finite element Function to another one.
/// @param[out] u The function to interpolate into
/// @param[in] v The function to be interpolated
/// @param[in] cells List of cell indices to interpolate on
Expand Down
11 changes: 5 additions & 6 deletions cpp/dolfinx/geometry/BoundingBoxTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@ using namespace dolfinx::geometry;
namespace
{
//-----------------------------------------------------------------------------
std::vector<std::int32_t> range(const mesh::Mesh& mesh, int tdim)
std::vector<std::int32_t> range(mesh::Topology& topology, int tdim)
{
// Initialize entities of given dimension if they don't exist
mesh.topology_mutable().create_entities(tdim);

auto map = mesh.topology().index_map(tdim);
topology.create_entities(tdim);
auto map = topology.index_map(tdim);
assert(map);
const std::int32_t num_entities = map->size_local() + map->num_ghosts();
std::vector<std::int32_t> r(num_entities);
Expand Down Expand Up @@ -207,7 +205,8 @@ std::int32_t _build_from_point(
//-----------------------------------------------------------------------------
BoundingBoxTree::BoundingBoxTree(const mesh::Mesh& mesh, int tdim,
double padding)
: BoundingBoxTree::BoundingBoxTree(mesh, tdim, range(mesh, tdim), padding)
: BoundingBoxTree::BoundingBoxTree(
mesh, tdim, range(mesh.topology_mutable(), tdim), padding)
{
// Do nothing
}
Expand Down
13 changes: 8 additions & 5 deletions cpp/dolfinx/geometry/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,10 +532,10 @@ int geometry::compute_first_colliding_cell(
if (norm < eps2)
return cell;
}

return -1;
}
return -1;
}

//-------------------------------------------------------------------------------
std::tuple<std::vector<std::int32_t>, std::vector<std::int32_t>,
std::vector<double>, std::vector<std::int32_t>>
Expand Down Expand Up @@ -649,8 +649,9 @@ geometry::determine_point_ownership(const mesh::Mesh& mesh,
cell_indicator[p / 3] = (colliding_cell >= 0) ? rank : -1;
colliding_cells[p / 3] = colliding_cell;
}
// Create neighborhood communicator in the reverse direction: send back col to
// requesting processes

// Create neighborhood communicator in the reverse direction: send
// back col to requesting processes
MPI_Comm reverse_comm;
MPI_Dist_graph_create_adjacent(
comm, out_ranks.size(), out_ranks.data(), MPI_UNWEIGHTED, in_ranks.size(),
Expand Down Expand Up @@ -689,6 +690,7 @@ geometry::determine_point_ownership(const mesh::Mesh& mesh,
if ((recv_ranks[i] >= 0) && (point_owners[pos] == -1))
point_owners[pos] = recv_ranks[i];
}

// Communication is reversed again to send dest ranks to all processes
std::swap(send_sizes, recv_sizes);
std::swap(send_offsets, recv_offsets);
Expand Down Expand Up @@ -739,4 +741,5 @@ geometry::determine_point_ownership(const mesh::Mesh& mesh,

return std::make_tuple(point_owners, owned_recv_ranks, owned_recv_points,
owned_recv_cells);
};
};
//-------------------------------------------------------------------------------
6 changes: 4 additions & 2 deletions cpp/dolfinx/io/ADIOS2Writers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ void vtx_write_mesh(adios2::IO& io, adios2::Engine& engine,
io, "NumberOfNodes", {adios2::LocalValueDim});
engine.Put<std::uint32_t>(vertices, num_vertices);

const auto [vtkcells, shape] = io::extract_vtk_connectivity(mesh);
const auto [vtkcells, shape] = io::extract_vtk_connectivity(
mesh.geometry(), mesh.topology().cell_type());

// Add cell metadata
const int tdim = topology.dim();
Expand Down Expand Up @@ -552,7 +553,8 @@ void fides_write_mesh(adios2::IO& io, adios2::Engine& engine,
const int tdim = topology.dim();
const std::int32_t num_cells = topology.index_map(tdim)->size_local();
const int num_nodes = geometry.cmap().dim();
const auto [cells, shape] = io::extract_vtk_connectivity(mesh);
const auto [cells, shape] = io::extract_vtk_connectivity(
mesh.geometry(), mesh.topology().cell_type());

// "Put" topology data in the result in the ADIOS2 file
adios2::Variable<std::int64_t> local_topology = define_variable<std::int64_t>(
Expand Down
6 changes: 4 additions & 2 deletions cpp/dolfinx/io/VTKFile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,8 @@ void write_function(
if (is_cellwise(*V0))
{
std::vector<std::int64_t> tmp;
std::tie(tmp, cshape) = io::extract_vtk_connectivity(*mesh0);
std::tie(tmp, cshape) = io::extract_vtk_connectivity(
mesh0->geometry(), mesh0->topology().cell_type());
cells.assign(tmp.begin(), tmp.end());
const mesh::Geometry<double>& geometry = mesh0->geometry();
x.assign(geometry.x().begin(), geometry.x().end());
Expand Down Expand Up @@ -770,7 +771,8 @@ void io::VTKFile::write(const mesh::Mesh& mesh, double time)
piece_node.append_attribute("NumberOfCells") = num_cells;

// Add mesh data to "Piece" node
const auto [cells, cshape] = extract_vtk_connectivity(mesh);
const auto [cells, cshape]
= extract_vtk_connectivity(mesh.geometry(), mesh.topology().cell_type());
std::array<std::size_t, 2> xshape = {geometry.x().size() / 3, 3};
std::vector<std::uint8_t> x_ghost(xshape[0], 0);
std::fill(std::next(x_ghost.begin(), xmap->size_local()), x_ghost.end(), 1);
Expand Down
12 changes: 5 additions & 7 deletions cpp/dolfinx/io/vtk_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,20 +188,18 @@ io::vtk_mesh_from_space(const fem::FunctionSpace& V)
}
//-----------------------------------------------------------------------------
std::pair<std::vector<std::int64_t>, std::array<std::size_t, 2>>
io::extract_vtk_connectivity(const mesh::Mesh& mesh)
io::extract_vtk_connectivity(const mesh::Geometry<double>& geometry,
mesh::CellType cell_type)
{
// Get DOLFINx to VTK permutation
// FIXME: Use better way to get number of nodes
const graph::AdjacencyList<std::int32_t>& dofmap_x = mesh.geometry().dofmap();
const std::size_t num_nodes = mesh.geometry().cmap().dim();
mesh::CellType cell_type = mesh.topology().cell_type();
const graph::AdjacencyList<std::int32_t>& dofmap_x = geometry.dofmap();
const std::size_t num_nodes = geometry.cmap().dim();
std::vector vtkmap
= io::cells::transpose(io::cells::perm_vtk(cell_type, num_nodes));

// Extract mesh 'nodes'
const int tdim = mesh.topology().dim();
const std::size_t num_cells = mesh.topology().index_map(tdim)->size_local()
+ mesh.topology().index_map(tdim)->num_ghosts();
const std::size_t num_cells = dofmap_x.num_nodes();

// Build mesh connectivity

Expand Down
Loading

0 comments on commit 5290785

Please sign in to comment.