-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cherry-pick HYPRE extension from Ferrite-FEM/Ferrite.jl#486 - which i…
…n turn is based on work done by @fredrikekre
- Loading branch information
1 parent
6c38b6b
commit cbe1d30
Showing
4 changed files
with
189 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
""" | ||
Module containing the code for distributed assembly via HYPRE.jl | ||
""" | ||
module FerriteDistributedHYPREAssembly | ||
|
||
using FerriteDistributed | ||
# TODO remove me. These are merely hotfixes to split the extensions trasiently via an internal API. | ||
import FerriteDistributed: getglobalgrid, num_local_true_dofs, num_local_dofs, global_comm, interface_comm, global_rank, compute_owner, remote_entities | ||
using MPI | ||
using HYPRE | ||
using Base: @propagate_inbounds | ||
|
||
include("FerriteDistributedHYPREAssembly/assembler.jl") | ||
include("FerriteDistributedHYPREAssembly/conversion.jl") | ||
|
||
function __init__() | ||
@info "FerriteHYPRE extension loaded." | ||
end | ||
|
||
end # module FerriteHYPRE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
function Ferrite.create_sparsity_pattern(::Type{<:HYPREMatrix}, dh::Ferrite.AbstractDofHandler, ch::Union{ConstraintHandler,Nothing}=nothing; kwargs...) | ||
K = create_sparsity_pattern(dh, ch; kwargs...) | ||
fill!(K.nzval, 1) | ||
return HYPREMatrix(K) | ||
end | ||
|
||
########################################### | ||
## HYPREAssembler and associated methods ## | ||
########################################### | ||
|
||
struct HYPREAssembler <: Ferrite.AbstractSparseAssembler | ||
A::HYPRE.HYPREAssembler | ||
end | ||
|
||
Ferrite.matrix_handle(a::HYPREAssembler) = a.A.A.A # :) | ||
Ferrite.vector_handle(a::HYPREAssembler) = a.A.b.b # :) | ||
|
||
function Ferrite.start_assemble(K::HYPREMatrix, f::HYPREVector) | ||
return HYPREAssembler(HYPRE.start_assemble!(K, f)) | ||
end | ||
|
||
function Ferrite.assemble!(a::HYPREAssembler, dofs::AbstractVector{<:Integer}, ke::AbstractMatrix, fe::AbstractVector) | ||
HYPRE.assemble!(a.A, dofs, ke, fe) | ||
end | ||
|
||
function Ferrite.end_assemble(a::HYPREAssembler) | ||
HYPRE.finish_assemble!(a.A) | ||
end | ||
|
||
## Methods for arrayutils.jl ## | ||
|
||
function Ferrite.addindex!(A::HYPREMatrix, v, i::Int, j::Int) | ||
nrows = HYPRE_Int(1) | ||
ncols = Ref{HYPRE_Int}(1) | ||
rows = Ref{HYPRE_BigInt}(i) | ||
cols = Ref{HYPRE_BigInt}(j) | ||
values = Ref{HYPRE_Complex}(v) | ||
HYPRE.@check HYPRE_IJMatrixAddToValues(A.ijmatrix, nrows, ncols, rows, cols, values) | ||
return A | ||
end | ||
|
||
function Ferrite.addindex!(b::HYPREVector, v, i::Int) | ||
nvalues = HYPRE_Int(1) | ||
indices = Ref{HYPRE_BigInt}(i) | ||
values = Ref{HYPRE_Complex}(v) | ||
HYPRE.@check HYPRE_IJVectorAddToValues(b.ijvector, nvalues, indices, values) | ||
return b | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Hypre to Ferrite vector | ||
function hypre_to_ferrite!(u::Vector{T}, uh::HYPREVector, dh::Ferrite.AbstractDofHandler) where {T} | ||
# Copy solution from HYPRE to Julia | ||
uj = Vector{Float64}(undef, num_local_true_dofs(dh)) | ||
copy!(uj, uh) | ||
|
||
my_rank = global_rank(getglobalgrid(dh)) | ||
|
||
# Helper to gather which global dof and values have to be send to which process | ||
gdof_value_send = [Dict{Int,Float64}() for i ∈ 1:MPI.Comm_size(MPI.COMM_WORLD)] | ||
# Helper to get the global dof to local dof mapping | ||
rank_recv_count = [0 for i∈1:MPI.Comm_size(MPI.COMM_WORLD)] | ||
gdof_to_ldof = Dict{Int,Int}() | ||
|
||
next_dof = 1 | ||
for (ldof,rank) ∈ enumerate(dh.ldof_to_rank) | ||
if rank == my_rank | ||
u[ldof] = uj[next_dof] | ||
next_dof += 1 | ||
else | ||
# We have to sync these later. | ||
gdof_to_ldof[dh.ldof_to_gdof[ldof]] = ldof | ||
rank_recv_count[rank] += 1 | ||
end | ||
end | ||
|
||
# TODO speed this up and better API | ||
dgrid = getglobalgrid(dh) | ||
for sv ∈ get_shared_vertices(dgrid) | ||
lvi = sv.local_idx | ||
my_rank != compute_owner(dgrid, sv) && continue | ||
for field_idx in 1:num_fields(dh) | ||
if Ferrite.has_vertex_dofs(dh, field_idx, lvi) | ||
local_dofs = Ferrite.vertex_dofs(dh, field_idx, lvi) | ||
global_dofs = dh.ldof_to_gdof[local_dofs] | ||
for receiver_rank ∈ keys(remote_entities(sv)) | ||
for i ∈ 1:length(global_dofs) | ||
# Note that u already has the correct values for all locally owned dofs due to the loop above! | ||
gdof_value_send[receiver_rank][global_dofs[i]] = u[local_dofs[i]] | ||
end | ||
end | ||
end | ||
end | ||
end | ||
|
||
for se ∈ get_shared_edges(dgrid) | ||
lei = se.local_idx | ||
my_rank != compute_owner(dgrid, se) && continue | ||
for field_idx in 1:num_fields(dh) | ||
if Ferrite.has_edge_dofs(dh, field_idx, lei) | ||
local_dofs = Ferrite.edge_dofs(dh, field_idx, lei) | ||
global_dofs = dh.ldof_to_gdof[local_dofs] | ||
for receiver_rank ∈ keys(remote_entities(se)) | ||
for i ∈ 1:length(global_dofs) | ||
# Note that u already has the correct values for all locally owned dofs due to the loop above! | ||
gdof_value_send[receiver_rank][global_dofs[i]] = u[local_dofs[i]] | ||
end | ||
end | ||
end | ||
end | ||
end | ||
|
||
for sf ∈ get_shared_faces(dgrid) | ||
lfi = sf.local_idx | ||
my_rank != compute_owner(dgrid, sf) && continue | ||
for field_idx in 1:num_fields(dh) | ||
if Ferrite.has_face_dofs(dh, field_idx, lfi) | ||
local_dofs = Ferrite.face_dofs(dh, field_idx, lfi) | ||
global_dofs = dh.ldof_to_gdof[local_dofs] | ||
for receiver_rank ∈ keys(remote_entities(sf)) | ||
for i ∈ 1:length(global_dofs) | ||
# Note that u already has the correct values for all locally owned dofs due to the loop above! | ||
gdof_value_send[receiver_rank][global_dofs[i]] = u[local_dofs[i]] | ||
end | ||
end | ||
end | ||
end | ||
end | ||
|
||
Ferrite.@debug println("preparing to distribute $gdof_value_send (R$my_rank)") | ||
|
||
# TODO precompute graph at it is static | ||
graph_source = Cint[my_rank-1] | ||
graph_dest = Cint[] | ||
for r ∈ 1:MPI.Comm_size(MPI.COMM_WORLD) | ||
!isempty(gdof_value_send[r]) && push!(graph_dest, r-1) | ||
end | ||
|
||
graph_degree = Cint[length(graph_dest)] | ||
graph_comm = MPI.Dist_graph_create(MPI.COMM_WORLD, graph_source, graph_degree, graph_dest) | ||
indegree, outdegree, _ = MPI.Dist_graph_neighbors_count(graph_comm) | ||
|
||
inranks = Vector{Cint}(undef, indegree) | ||
outranks = Vector{Cint}(undef, outdegree) | ||
MPI.Dist_graph_neighbors!(graph_comm, inranks, outranks) | ||
|
||
send_count = [length(gdof_value_send[outrank+1]) for outrank ∈ outranks] | ||
recv_count = [rank_recv_count[inrank+1] for inrank ∈ inranks] | ||
|
||
send_gdof = Cint[] | ||
for outrank ∈ outranks | ||
append!(send_gdof, Cint.(keys(gdof_value_send[outrank+1]))) | ||
end | ||
recv_gdof = Vector{Cint}(undef, sum(recv_count)) | ||
MPI.Neighbor_alltoallv!(VBuffer(send_gdof,send_count), VBuffer(recv_gdof,recv_count), graph_comm) | ||
|
||
send_val = Cdouble[] | ||
for outrank ∈ outranks | ||
append!(send_val, Cdouble.(values(gdof_value_send[outrank+1]))) | ||
end | ||
recv_val = Vector{Cdouble}(undef, sum(recv_count)) | ||
MPI.Neighbor_alltoallv!(VBuffer(send_val,send_count), VBuffer(recv_val,recv_count), graph_comm) | ||
|
||
for (gdof, val) ∈ zip(recv_gdof, recv_val) | ||
u[gdof_to_ldof[gdof]] = val | ||
end | ||
|
||
return u | ||
end |