-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add package extension for BlockArrays.jl.
This patch adds support for global block matrices. A global block matrix can be obtained by `create_sparsity_pattern(BlockMatrix, dh, ch)` or by `create_sparsity_pattern(A::BlockMatrix, dh, ch)`. The former infers the block sizes by looking at the fields of the DofHandler, and the latters uses the blocksizes of the (partially) initilized BlockMatrix (i.e. by `undef_blocks`). For global assembly, `start_assemble(::BlockMatrix, ...)` returns a `BlockAssembler` which pre allocates a vector for pre-computing the (block number, local index) pairs for each DoF. BlockArrays only support local condensation of constraints currently and you have to use `apply_assemble!` instead of `assemble!` + `apply!`.
- Loading branch information
1 parent
eba00a4
commit 4dfccd9
Showing
4 changed files
with
239 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
module BlockArraysExt | ||
|
||
using BlockArrays: BlockArray, BlockIndex, BlockMatrix, BlockVector, block, blockaxes, | ||
blockindex, blocks, findblockindex | ||
using Ferrite | ||
using Ferrite: addindex!, fillzero! | ||
|
||
# TODO: Move into Ferrite | ||
function global_dof_range(dh::DofHandler, f::Symbol) | ||
set = Set{Int}() | ||
frange = dof_range(dh, f) | ||
for cc in CellIterator(dh) | ||
union!(set, @view celldofs(cc)[frange]) | ||
end | ||
dofmin, dofmax = extrema(set) | ||
r = dofmin:dofmax | ||
if length(set) != length(r) | ||
error("renumber by blocks you donkey") | ||
end | ||
return r | ||
end | ||
|
||
################################### | ||
## Creating the sparsity pattern ## | ||
################################### | ||
|
||
# Note: | ||
# Creating the full unblocked matrix and then splitting into blocks inside the BlockArray | ||
# constructor (i.e. by `getindex(::SparseMatrixCSC, ::UnitRange, ::UnitRange)`) is | ||
# consistently faster than creating individual blocks directly. However, the latter approach | ||
# uses less than half of the memory (measured for a 2x2 block system and various problem | ||
# sizes), so might be useful in the future to provide an option on what algorithm to use. | ||
|
||
# TODO: Could potentially extract the element type and matrix type for the individual blocks | ||
# by allowing e.g. create_sparsity_pattern(BlockMatrix{Float32}, ...) but that is not | ||
# even supported by regular pattern right now. | ||
function Ferrite.create_sparsity_pattern(::Type{<:BlockMatrix}, dh, ch; kwargs...) | ||
K = create_sparsity_pattern(dh, ch; kwargs...) | ||
# Infer block sizes from the fields in the DofHandler | ||
global_ranges = [global_dof_range(dh, f) for f in dh.field_names] | ||
block_sizes = length.(global_ranges) | ||
return BlockArray(K, block_sizes, block_sizes) | ||
end | ||
|
||
function Ferrite.create_sparsity_pattern(B::BlockMatrix, dh, ch; kwargs...) | ||
if !(size(B, 1) == size(B, 2) == ndofs(dh)) | ||
error("size of input matrix ($(size(B))) does not match number of dofs ($(ndofs(dh)))") | ||
end | ||
K = create_sparsity_pattern(dh, ch; kwargs...) | ||
ax = axes(B) | ||
for block_j in blockaxes(B, 2), block_i in blockaxes(B, 1) | ||
range_j = ax[2][block_j] | ||
range_i = ax[1][block_i] | ||
B[block_i, block_j] = K[range_i, range_j] | ||
end | ||
return B | ||
end | ||
|
||
########################################### | ||
## BlockAssembler and associated methods ## | ||
########################################### | ||
|
||
struct BlockAssembler{BM, Bv} <: Ferrite.AbstractSparseAssembler | ||
K::BM | ||
f::Bv | ||
blockindices::Vector{BlockIndex{1}} | ||
end | ||
|
||
Ferrite.matrix_handle(ba::BlockAssembler) = ba.K | ||
Ferrite.vector_handle(ba::BlockAssembler) = ba.f | ||
|
||
function Ferrite.start_assemble(K::BlockMatrix, f; fillzero::Bool=true) | ||
fillzero && (fillzero!(K); fillzero!(f)) | ||
return BlockAssembler(K, f, BlockIndex{1}[]) | ||
end | ||
|
||
# Split into the block and the local index | ||
splindex(idx::BlockIndex{1}) = (block(idx), blockindex(idx)) | ||
|
||
function Ferrite.assemble!(assembler::BlockAssembler, dofs::AbstractVector{<:Integer}, ke::AbstractMatrix, fe::AbstractVector) | ||
K = assembler.K | ||
f = assembler.f | ||
blockindices = assembler.blockindices | ||
|
||
@assert blockaxes(K, 1) == blockaxes(K, 2) | ||
@assert axes(K, 1) == axes(K, 2) == axes(f, 1) | ||
@boundscheck checkbounds(K, dofs, dofs) | ||
@boundscheck checkbounds(f, dofs) | ||
|
||
# Update the cached the block indices | ||
resize!(blockindices, length(dofs)) | ||
@inbounds for (i, I) in pairs(dofs) | ||
blockindices[i] = findblockindex(axes(K, 1), I) | ||
end | ||
|
||
# Assemble matrix entries | ||
@inbounds for (j, blockindex_j) in pairs(blockindices) | ||
Bj, lj = splindex(blockindex_j) | ||
for (i, blockindex_i) in pairs(blockindices) | ||
Bi, li = splindex(blockindex_i) | ||
KB = @view K[Bi, Bj] | ||
addindex!(KB, ke[i, j], li, lj) | ||
end | ||
end | ||
|
||
# Assemble vector entries | ||
if blockaxes(f) == blockaxes(K, 1) | ||
# If f::BlockVector with the same axes the same blockindex cache can be used... | ||
@inbounds for (i, blockindex_i) in pairs(blockindices) | ||
Bi, li = splindex(blockindex_i) | ||
fB = @view f[Bi] | ||
fB[li] += fe[i] | ||
end | ||
else | ||
# ... otherwise, use regular indexing in fallback assemble! | ||
@inbounds assemble!(f, dofs, fe) | ||
end | ||
return | ||
end | ||
|
||
function Ferrite.apply!(::BlockMatrix, ::AbstractVector, ::ConstraintHandler) | ||
error( | ||
"Condensation of constraints with `apply!` after assembling not supported yet " * | ||
"for BlockMatrix, use local condensation with `apply_assemble!` instead." | ||
) | ||
end | ||
|
||
|
||
####################################################### | ||
## Overloaded assembly pieces from src/arrayutils.jl ## | ||
####################################################### | ||
|
||
function Ferrite.addindex!(B::BlockMatrix{Tv}, v::Tv, i::Int, j::Int) where Tv | ||
@boundscheck checkbounds(B, i, j) | ||
Bi, li = splindex(findblockindex(axes(B, 1), i)) | ||
Bj, lj = splindex(findblockindex(axes(B, 2), j)) | ||
BB = @view B[Bi, Bj] | ||
@inbounds addindex!(BB, v, li, lj) | ||
return B | ||
end | ||
|
||
function Ferrite.fillzero!(B::Union{BlockVector,BlockMatrix}) | ||
for blk in blocks(B) | ||
fillzero!(blk) | ||
end | ||
return B | ||
end | ||
|
||
end # module BlockArraysExt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
using Ferrite, BlockArrays, Test | ||
|
||
@testset "BlockArrays.jl extension" begin | ||
grid = generate_grid(Triangle, (10, 10)) | ||
|
||
dh = DofHandler(grid) | ||
push!(dh, :u, 2) | ||
push!(dh, :p, 1) | ||
close!(dh) | ||
renumber!(dh, DofOrder.FieldWise()) | ||
nd = ndofs(dh) ÷ 3 | ||
|
||
ch = ConstraintHandler(dh) | ||
periodic_faces = collect_periodic_faces(grid, "top", "bottom") | ||
add!(ch, PeriodicDirichlet(:u, periodic_faces)) | ||
add!(ch, Dirichlet(:u, union(getfaceset(grid, "left"), getfaceset(grid, "top")), (x, t) -> [0, 0])) | ||
add!(ch, Dirichlet(:p, getfaceset(grid, "left"), (x, t) -> 0)) | ||
close!(ch) | ||
update!(ch, 0) | ||
|
||
K = create_sparsity_pattern(dh, ch) | ||
f = zeros(axes(K, 1)) | ||
KB = create_sparsity_pattern(BlockMatrix, dh, ch) | ||
@test KB isa BlockMatrix | ||
@test blocksize(KB) == (2, 2) | ||
@test size(KB[Block(1), Block(1)]) == (2nd, 2nd) | ||
@test size(KB[Block(2), Block(1)]) == (1nd, 2nd) | ||
@test size(KB[Block(1), Block(2)]) == (2nd, 1nd) | ||
@test size(KB[Block(2), Block(2)]) == (1nd, 1nd) | ||
fB = similar(KB, axes(KB, 1)) | ||
|
||
# Test the pattern | ||
fill!(K.nzval, 1) | ||
foreach(x -> fill!(x.nzval, 1), blocks(KB)) | ||
@test K == KB | ||
|
||
# Zeroing out in start_assemble | ||
assembler = start_assemble(K, f) | ||
@test iszero(K) | ||
@test iszero(f) | ||
block_assembler = start_assemble(KB, fB) | ||
@test iszero(KB) | ||
@test iszero(fB) | ||
|
||
# Assembly procedure | ||
npc = ndofs_per_cell(dh) | ||
for cc in CellIterator(dh) | ||
ke = rand(npc, npc) | ||
fe = rand(npc) | ||
dofs = celldofs(cc) | ||
# Standard assemble | ||
assemble!(assembler, dofs, ke, fe) | ||
assemble!(block_assembler, dofs, ke, fe) | ||
# Assemble with local condensation of constraints | ||
let ke = copy(ke), fe = copy(fe) | ||
apply_assemble!(assembler, ch, dofs, ke, fe) | ||
end | ||
let ke = copy(ke), fe = copy(fe) | ||
apply_assemble!(block_assembler, ch, dofs, ke, fe) | ||
end | ||
end | ||
@test K ≈ KB | ||
@test f ≈ fB | ||
|
||
# Global application of BC not supported yet | ||
@test_throws ErrorException apply!(KB, fB, ch) | ||
|
||
# Custom blocking by passing a partially initialized matrix | ||
perm = invperm([ch.free_dofs; ch.prescribed_dofs]) | ||
renumber!(dh, ch, perm) | ||
nfree = length(ch.free_dofs) | ||
npres = length(ch.prescribed_dofs) | ||
K = create_sparsity_pattern(dh, ch) | ||
block_sizes = [nfree, npres] | ||
KBtmp = BlockArray(undef_blocks, SparseMatrixCSC{Float64, Int}, block_sizes, block_sizes) | ||
KB = create_sparsity_pattern(KBtmp, dh, ch) | ||
@test KBtmp === KB | ||
@test blocksize(KB) == (2, 2) | ||
@test size(KB[Block(1), Block(1)]) == (nfree, nfree) | ||
@test size(KB[Block(2), Block(1)]) == (npres, nfree) | ||
@test size(KB[Block(1), Block(2)]) == (nfree, npres) | ||
@test size(KB[Block(2), Block(2)]) == (npres, npres) | ||
# Test the pattern | ||
fill!(K.nzval, 1) | ||
foreach(x -> fill!(x.nzval, 1), blocks(KB)) | ||
@test K == KB | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters