Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Performance) improvements on gathering field information from MixedDofHandler #621

Merged
merged 14 commits into from
Mar 21, 2023
Merged
7 changes: 4 additions & 3 deletions docs/src/reference/dofhandler.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ FieldHandler
close!(::MixedDofHandler)
```

## Dof renumbering
```@docs
renumber!
DofOrder.FieldWise
Expand All @@ -27,10 +28,10 @@ DofOrder.ComponentWise
## Common methods
```@docs
ndofs
ndofs_per_cell
dof_range
Ferrite.nfields(::MixedDofHandler)
Ferrite.getfieldnames(::MixedDofHandler)
Ferrite.getfielddim(::MixedDofHandler, ::Symbol)
celldofs
celldofs!
```

# CellIterator
Expand Down
194 changes: 165 additions & 29 deletions src/Dofs/MixedDofHandler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,38 @@ end
getfieldnames(fh::FieldHandler) = [field.name for field in fh.fields]
getfielddims(fh::FieldHandler) = [field.dim for field in fh.fields]
getfieldinterpolations(fh::FieldHandler) = [field.interpolation for field in fh.fields]

"""
ndofs_per_cell(dh::AbstractDofHandler[, cell::Int=1])

Return the number of degrees of freedom for the cell with index `cell`.

See also [`ndofs`](@ref).
"""
ndofs_per_cell(dh::MixedDofHandler, cell::Int=1) = dh.cell_dofs.length[cell]
nnodes_per_cell(dh::MixedDofHandler, cell::Int=1) = dh.cell_nodes.length[cell]

"""
celldofs!(global_dofs::Vector{Int}, dh::AbstractDofHandler, i::Int)

Store the degrees of freedom that belong to cell `i` in `global_dofs`.

See also [`celldofs`](@ref).
"""
function celldofs!(global_dofs::Vector{Int}, dh::MixedDofHandler, i::Int)
@assert isclosed(dh)
@assert length(global_dofs) == ndofs_per_cell(dh, i)
unsafe_copyto!(global_dofs, 1, dh.cell_dofs.values, dh.cell_dofs.offset[i], length(global_dofs))
return global_dofs
end

"""
celldofs(dh::AbstractDofHandler, i::Int)

Return a vector with the degrees of freedom that belong to cell `i`.

See also [`celldofs!`](@ref).
"""
function celldofs(dh::MixedDofHandler, i::Int)
@assert isclosed(dh)
return dh.cell_dofs[i]
Expand All @@ -97,8 +118,10 @@ end

"""
getfieldnames(dh::MixedDofHandler)
getfieldnames(fh::FieldHandler)

Returns the union of all the fields. Can be used as an iterable over all the fields in the problem.
Return a vector with the names of all fields. Can be used as an iterable over all the fields
in the problem.
"""
function getfieldnames(dh::MixedDofHandler)
fieldnames = Vector{Symbol}()
Expand All @@ -108,22 +131,24 @@ function getfieldnames(dh::MixedDofHandler)
return unique!(fieldnames)
end

"""
getfielddim(dh::MixedDofHandler, name::Symbol)
getfielddim(fh::FieldHandler, field_idx::Int) = fh.fields[field_idx].dim
getfielddim(fh::FieldHandler, field_name::Symbol) = getfielddim(fh, find_field(fh, field_name))

Returns the dimension of a specific field, given by name. Note that it will return the dimension of the first field found among the `FieldHandler`s.
"""
function getfielddim(dh::MixedDofHandler, name::Symbol)
getfielddim(dh::MixedDofHandler, field_idxs::NTuple{2,Int})
getfielddim(dh::MixedDofHandler, field_name::Symbol)
getfielddim(dh::FieldHandler, field_idx::Int)
getfielddim(dh::FieldHandler, field_name::Symbol)

for fh in dh.fieldhandlers
field_pos = findfirst(i->i == name, getfieldnames(fh))
if field_pos !== nothing
return fh.fields[field_pos].dim
end
end
error("did not find field $name")
Return the dimension of a given field. The field can be specified by its index (see
[`find_field`](@ref)) or its name.
"""
function getfielddim(dh::MixedDofHandler, field_idxs::NTuple{2, Int})
fh_idx, field_idx = field_idxs
fielddim = getfielddim(dh.fieldhandlers[fh_idx], field_idx)
return fielddim
end

getfielddim(dh::MixedDofHandler, name::Symbol) = getfielddim(dh, find_field(dh, name))

"""
nfields(dh::MixedDofHandler)
Expand All @@ -138,6 +163,7 @@ nfields(dh::MixedDofHandler) = length(getfieldnames(dh))
Add all fields of the [`FieldHandler`](@ref) `fh` to `dh`.
"""
function add!(dh::MixedDofHandler, fh::FieldHandler)
# TODO: perhaps check that a field with the same name is the same field?
@assert !isclosed(dh)
_check_same_celltype(dh.grid, collect(fh.cellset))
_check_cellset_intersections(dh, fh)
Expand Down Expand Up @@ -376,37 +402,147 @@ function add_cell_dofs(cell_dofs, cell, celldict, field_dim, ncelldofs, nextdof)
return nextdof
end

"""
find_field(dh::MixedDofHandler, field_name::Symbol)::NTuple{2,Int}

Return the index of the field with name `field_name` in a `MixedDofHandler`. The index is a
`NTuple{2,Int}`, where the 1st entry is the index of the `FieldHandler` within which the
field was found and the 2nd entry is the index of the field within the `FieldHandler`.

!!! note
Always finds the 1st occurence of a field within `MixedDofHandler`.

See also: [`find_field(fh::FieldHandler, field_name::Symbol)`](@ref),
[`_find_field(fh::FieldHandler, field_name::Symbol)`](@ref).
"""
function find_field(dh::MixedDofHandler, field_name::Symbol)
for (fh_idx, fh) in pairs(dh.fieldhandlers)
field_idx = _find_field(fh, field_name)
!isnothing(field_idx) && return (fh_idx, field_idx)
end
error("Did not find field :$field_name (existing fields: $(getfieldnames(dh))).")
end

"""
find_field(fh::FieldHandler, field_name::Symbol)::Int

Return the index of the field with name `field_name` in a `FieldHandler`. Throw an
error if the field is not found.

See also: [`find_field(dh::MixedDofHandler, field_name::Symbol)`](@ref), [`_find_field(fh::FieldHandler, field_name::Symbol)`](@ref).
"""
function find_field(fh::FieldHandler, field_name::Symbol)
j = findfirst(i->i == field_name, getfieldnames(fh))
j === nothing && error("could not find field :$field_name in FieldHandler (existing fields: $(getfieldnames(fh)))")
return j
field_idx = _find_field(fh, field_name)
if field_idx === nothing
error("Did not find field :$field_name in FieldHandler (existing fields: $(getfieldnames(fh)))")
end
return field_idx
end

# No error if field not found
"""
_find_field(fh::FieldHandler, field_name::Symbol)::Int

Return the index of the field with name `field_name` in the `FieldHandler` `fh`. Return
`nothing` if the field is not found.

See also: [`find_field(dh::MixedDofHandler, field_name::Symbol)`](@ref), [`find_field(fh::FieldHandler, field_name::Symbol)`](@ref).
"""
function _find_field(fh::FieldHandler, field_name::Symbol)
for (field_idx, field) in pairs(fh.fields)
if field.name == field_name
return field_idx
end
end
return nothing
end

# Calculate the offset to the first local dof of a field
function field_offset(fh::FieldHandler, field_name::Symbol)
function field_offset(fh::FieldHandler, field_idx::Int)
offset = 0
for i in 1:find_field(fh, field_name)-1
offset += getnbasefunctions(getfieldinterpolations(fh)[i])::Int * getfielddims(fh)[i]
for i in 1:(field_idx-1)
offset += getnbasefunctions(fh.fields[i].interpolation)::Int * fh.fields[i].dim
end
return offset
end
field_offset(fh::FieldHandler, field_name::Symbol) = field_offset(fh, find_field(fh, field_name))

field_offset(dh::MixedDofHandler, field_name::Symbol) = field_offset(dh, find_field(dh, field_name))
function field_offset(dh::MixedDofHandler, field_idxs::Tuple{Int, Int})
fh_idx, field_idx = field_idxs
field_offset(dh.fieldhandlers[fh_idx], field_idx)
end

"""
dof_range(fh::FieldHandler, field_idx::Int)
dof_range(fh::FieldHandler, field_name::Symbol)
dof_range(dh:MixedDofHandler, field_name::Symbol)

Return the local dof range for a given field. The field can be specified by its name or
index, where `field_idx` represents the index of a field within a `FieldHandler` and
`field_idxs` is a tuple of the `FieldHandler`-index within the `MixedDofHandler` and the
`field_idx`.

!!! note
The `dof_range` of a field can vary between different `FieldHandler`s. Therefore, it is
advised to use the `field_idxs` or refer to a given `FieldHandler` directly in case
several `FieldHandler`s exist. Using the `field_name` will always refer to the first
occurence of `field` within the `MixedDofHandler`.

function dof_range(fh::FieldHandler, field_name::Symbol)
f = find_field(fh, field_name)
offset = field_offset(fh, field_name)
field_interpolation = fh.fields[f].interpolation
field_dim = fh.fields[f].dim
Example:
```jldoctest
julia> grid = generate_grid(Triangle, (3, 3))
Grid{2, Triangle, Float64} with 18 Triangle cells and 16 nodes

julia> dh = MixedDofHandler(grid); add!(dh, :u, 3); add!(dh, :p, 1); close!(dh);

julia> dof_range(dh, :u)
1:9

julia> dof_range(dh, :p)
10:12

julia> dof_range(dh, (1,1)) # field :u
1:9

julia> dof_range(dh.fieldhandlers[1], 2) # field :p
10:12
```
"""
function dof_range(fh::FieldHandler, field_idx::Int)
offset = field_offset(fh, field_idx)
field_interpolation = fh.fields[field_idx].interpolation
field_dim = fh.fields[field_idx].dim
n_field_dofs = getnbasefunctions(field_interpolation)::Int * field_dim
return (offset+1):(offset+n_field_dofs)
end
dof_range(fh::FieldHandler, field_name::Symbol) = dof_range(fh, find_field(fh, field_name))

function dof_range(dh::MixedDofHandler, field_name::Symbol)
if length(dh.fieldhandlers) > 1
error("The given MixedDofHandler has $(length(dh.fieldhandlers)) FieldHandlers.
Extracting the dof range based on the fieldname might not be a unique problem
in this case. Use `dof_range(fh::FieldHandler, field_name)` instead.")
end
fh_idx, field_idx = find_field(dh, field_name)
return dof_range(dh.fieldhandlers[fh_idx], field_idx)
end

"""
getfieldinterpolation(dh::MixedDofHandler, field_idxs::NTuple{2,Int})
getfieldinterpolation(dh::FieldHandler, field_idx::Int)
getfieldinterpolation(dh::FieldHandler, field_name::Symbol)

find_field(dh::MixedDofHandler, field_name::Symbol) = find_field(first(dh.fieldhandlers), field_name)
field_offset(dh::MixedDofHandler, field_name::Symbol) = field_offset(first(dh.fieldhandlers), field_name)
Return the interpolation of a given field. The field can be specified by its index (see
[`find_field`](@ref) or its name.
"""
function getfieldinterpolation(dh::MixedDofHandler, field_idxs::NTuple{2,Int})
fh_idx, field_idx = field_idxs
ip = dh.fieldhandlers[fh_idx].fields[field_idx].interpolation
return ip
end
getfieldinterpolation(fh::FieldHandler, field_idx::Int) = fh.fields[field_idx].interpolation
getfieldinterpolation(dh::MixedDofHandler, field_idx::Int) = dh.fieldhandlers[1].fields[field_idx].interpolation
getfielddim(fh::FieldHandler, field_idx::Int) = fh.fields[field_idx].dim
getfielddim(dh::MixedDofHandler, field_idx::Int) = dh.fieldhandlers[1].fields[field_idx].dim
getfieldinterpolation(fh::FieldHandler, field_name::Symbol) = getfieldinterpolation(fh, find_field(fh, field_name))

function reshape_to_nodes(dh::MixedDofHandler, u::Vector{T}, fieldname::Symbol) where T
# make sure the field exists
Expand Down
4 changes: 2 additions & 2 deletions test/test_mixeddofhandler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ function test_subparametric_quad()
add!(ch, dbc1)
close!(ch)
update!(ch, 1.0)
@test getnbasefunctions(Ferrite.getfieldinterpolation(dh,1)) == 9 # algebraic nbasefunctions
@test getnbasefunctions(Ferrite.getfieldinterpolation(fh,1)) == 9 # algebraic nbasefunctions
@test celldofs(dh, 1) == [i for i in 1:18]
end

Expand All @@ -530,7 +530,7 @@ function test_subparametric_triangle()
add!(ch, dbc1)
close!(ch)
update!(ch, 1.0)
@test getnbasefunctions(Ferrite.getfieldinterpolation(dh,1)) == 6 # algebraic nbasefunctions
@test getnbasefunctions(Ferrite.getfieldinterpolation(fh,1)) == 6 # algebraic nbasefunctions
@test celldofs(dh, 1) == [i for i in 1:12]
end

Expand Down