Skip to content

Commit

Permalink
Add Metis extension (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
termi-official authored Jun 7, 2023
1 parent ab81f6f commit 6c38b6b
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 3 deletions.
11 changes: 10 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,23 @@ version = "0.0.1-DEV"
[deps]
Ferrite = "c061ca5d-56c9-439f-9c0e-210fe06d3992"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Tensors = "48a634ad-e948-5137-8d70-aa71f2a747f4"

[weakdeps]
Metis = "2679e427-3c69-5b7f-982b-ece356f1e94b"

[extensions]
FerriteDistributedMetisPartitioning = "Metis"

[compat]
Ferrite = "0.3"
Metis = "1.3"
julia = "1"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Metis", "Test"]
52 changes: 52 additions & 0 deletions ext/FerriteDistributedMetisPartitioning.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@

module FerriteDistributedMetisPartitioning

# This extension requires modules as type parameters
# https://github.com/JuliaLang/julia/pull/47749
if VERSION >= v"1.10.0-DEV.90"

using FerriteDistributed, Metis
import FerriteDistributed: CoverTopology, PartitioningAlgorithm

struct MetisPartitioning <: PartitioningAlgorithm.Ext{Metis}
alg::Symbol
end

PartitioningAlgorithm.Ext{Metis}(alg::Symbol) = MetisPartitioning(alg)

function FerriteDistributed.create_partitioning(grid::Grid{dim,C,T}, grid_topology::CoverTopology, n_partitions::Int, partition_alg::PartitioningAlgorithm.Ext{Metis}) where {dim,C,T}
n_cells_global = getncells(grid)
@assert n_cells_global > 0

if n_partitions == 1
return ones(Metis.idx_t, n_cells_global)
end

# Set up the element connectivity graph
xadj = Vector{Metis.idx_t}(undef, n_cells_global+1)
xadj[1] = 1
adjncy = Vector{Metis.idx_t}(undef, 0)
@inbounds for i in 1:n_cells_global
n_neighbors = 0
for neighbor getneighborhood(grid_topology, grid, CellIndex(i))
push!(adjncy, neighbor)
n_neighbors += 1
end
xadj[i+1] = xadj[i] + n_neighbors
end

# Generate a partitioning
return Metis.partition(
Metis.Graph(
Metis.idx_t(n_cells_global),
xadj,
adjncy
),
n_partitions;
alg=partition_alg.alg
)
end

end # VERSION check

end # module FerriteDistributedMetisPartitioning
3 changes: 2 additions & 1 deletion src/FerriteDistributed.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module FerriteDistributed

using Ferrite, MPI
using Reexport
@reexport using Ferrite, MPI

include("CoverTopology.jl")

Expand Down
2 changes: 1 addition & 1 deletion src/NODGrid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Get the number of ranks on the global communicator of the distributed grid.
Construct a non-overlapping distributed grid from a grid with the SFC induced by the element ordering on a specified MPI communicator.
It is assumed that this function is called with exactly the same grid on each MPI process in the communicator.
"""
function NODGrid(grid_comm::MPI.Comm, grid_to_distribute::Grid{dim,C,T}) where {dim,C,T}
function NODGrid(grid_comm::MPI.Comm, grid_to_distribute::Grid{dim,C,T}, alg = PartitioningAlgorithm.SFC()) where {dim,C,T}
grid_topology = CoverTopology(grid_to_distribute)
nparts = MPI.Comm_size(grid_comm)
partitioning = create_partitioning(grid_to_distribute, grid_topology, nparts, PartitioningAlgorithm.SFC())
Expand Down
21 changes: 21 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,25 @@ end
@test FerriteDistributed.global_nranks(dgrid) == 1
end

if VERSION >= v"1.10.0-DEV.90"
using Metis
@testset "Metis integration" begin
grid = generate_grid(Quadrilateral, (4,4))
dgrid = NODGrid(MPI.COMM_WORLD, grid, FerriteDistributed.PartitioningAlgorithm.Ext{Metis}(:RECURSIVE))

@test isempty(get_shared_vertices(dgrid))
@test isempty(get_shared_edges(dgrid))
@test isempty(get_shared_faces(dgrid))
@test getvertexsets(getlocalgrid(dgrid)) == getvertexsets(grid)
@test getfacesets(getlocalgrid(dgrid)) == getfacesets(grid)
@test getcellsets(getlocalgrid(dgrid)) == getcellsets(grid)
@test getnnodes(getlocalgrid(dgrid) ) == getnnodes(grid)
@test getncells(getlocalgrid(dgrid) ) == getncells(grid)

@test FerriteDistributed.global_rank(dgrid) == 1

@test FerriteDistributed.global_nranks(dgrid) == 1
end
end # VERSION >= v"1.10.0-DEV.90"

#TODO MPI based test sets with more than 1 proc

0 comments on commit 6c38b6b

Please sign in to comment.