diff --git a/src/func.jl b/src/func.jl index 08798076..b4e60f1f 100644 --- a/src/func.jl +++ b/src/func.jl @@ -4,20 +4,8 @@ Matrix free operator given by a function $(FIELDS) """ -mutable struct FunctionOperator{ - iip, - oop, - mul5, - T <: Number, - F, - Fa, - Fi, - Fai, - Tr, - P, - Tt, - C, -} <: AbstractSciMLOperator{T} +mutable struct FunctionOperator{iip, oop, mul5, T <: Number, F, Fa, Fi, Fai, Tr, P, Tt, + C, iType, oType} <: AbstractSciMLOperator{T} """ Function with signature op(u, p, t) and (if isinplace) op(v, u, p, t) """ op::F """ Adjoint operator""" @@ -34,42 +22,82 @@ mutable struct FunctionOperator{ t::Tt """ Cache """ cache::C +end - function FunctionOperator(op, - op_adjoint, - op_inverse, - op_adjoint_inverse, - traits, - p, - t, - cache) - iip = traits.isinplace - oop = traits.outofplace - mul5 = traits.has_mul5 - T = traits.T - - new{ - iip, - oop, - mul5, - T, - typeof(op), - typeof(op_adjoint), - typeof(op_inverse), - typeof(op_adjoint_inverse), - typeof(traits), - typeof(p), - typeof(t), - typeof(cache), - }(op, - op_adjoint, - op_inverse, - op_adjoint_inverse, - traits, - p, - t, - cache) - end +function FunctionOperator(op, op_adjoint, op_inverse, op_adjoint_inverse, traits, p, t, + cache, ::Type{iType}, ::Type{oType}) where {iType, oType} + iip = traits.isinplace + oop = traits.outofplace + mul5 = traits.has_mul5 + T = traits.T + + return FunctionOperator{iip, oop, mul5, T, typeof(op), typeof(op_adjoint), + typeof(op_inverse), typeof(op_adjoint_inverse), typeof(traits), typeof(p), + typeof(t), typeof(cache), iType, oType}(op, op_adjoint, op_inverse, + op_adjoint_inverse, traits, p, t, cache) +end + +function set_op(f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, iType, + oType}, op) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, iType, oType} + return FunctionOperator{iip, oop, mul5, T, typeof(op), Fa, Fi, Fai, Tr, P, Tt, C, iType, + oType}(op, f.op_adjoint, f.op_inverse, f.op_adjoint_inverse, f.traits, f.p, f.t, + f.cache) +end + +function set_op_adjoint(f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, + iType, oType}, op_adjoint) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, + C, iType, oType} + return FunctionOperator{iip, oop, mul5, T, F, typeof(op_adjoint), Fi, Fai, Tr, P, Tt, + C, iType, oType}(f.op, op_adjoint, f.op_inverse, f.op_adjoint_inverse, f.traits, + f.p, f.t, f.cache) +end + +function set_op_inverse(f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, + iType, oType}, op_inverse) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, + C, iType, oType} + return FunctionOperator{iip, oop, mul5, T, F, Fa, typeof(op_inverse), Fai, Tr, P, Tt, + C, iType, oType}(f.op, f.op_adjoint, op_inverse, f.op_adjoint_inverse, f.traits, + f.p, f.t, f.cache) +end + +function set_op_adjoint_inverse(f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, + P, Tt, C, iType, oType}, + op_adjoint_inverse) where {iip, oop, mul5, T, F, Fa, + Fi, Fai, Tr, P, Tt, C, iType, oType} + return FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, typeof(op_adjoint_inverse), Tr, + P, Tt, C, iType, oType}(f.op, f.op_adjoint, f.op_inverse, op_adjoint_inverse, + f.traits, f.p, f.t, f.cache) +end + +function set_traits(f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, + iType, oType}, traits) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, + iType, oType} + return FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, typeof(traits), P, Tt, + C, iType, oType}(f.op, f.op_adjoint, f.op_inverse, f.op_adjoint_inverse, traits, + f.p, f.t, f.cache) +end + +function set_p(f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, + iType, oType}, p) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, iType, + oType} + return FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, typeof(p), Tt, C, iType, + oType}(f.op, f.op_adjoint, f.op_inverse, f.op_adjoint_inverse, f.traits, p, f.t, + f.cache) +end + +function set_t(f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, iType, + oType}, t) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, iType, oType} + return FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, typeof(t), C, iType, + oType}(f.op, f.op_adjoint, f.op_inverse, f.op_adjoint_inverse, f.traits, f.p, t, + f.cache) +end + +function set_cache(f::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, + iType, oType}, cache) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, + iType, oType} + return FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, typeof(cache), + iType, oType}(f.op, f.op_adjoint, f.op_inverse, f.op_adjoint_inverse, f.traits, + f.p, f.t, cache) end """ @@ -138,29 +166,34 @@ function FunctionOperator(op, # traits T::Union{Type{<:Number}, Nothing} = nothing, - isinplace::Union{Nothing, Bool} = nothing, - outofplace::Union{Nothing, Bool} = nothing, - has_mul5::Union{Nothing, Bool} = nothing, + isinplace::Union{Nothing, Bool, Val} = nothing, + outofplace::Union{Nothing, Bool, Val} = nothing, + has_mul5::Union{Nothing, Bool, Val} = nothing, isconstant::Bool = false, islinear::Bool = false, isconvertible::Bool = false, batch::Bool = false, - ifcache::Bool = true, + ifcache::Union{Bool, Val} = Val(true), cache::Union{Nothing, NTuple{2}} = nothing, # LinearAlgebra traits opnorm = nothing, - issymmetric::Bool = false, - ishermitian::Bool = false, + issymmetric::Union{Bool, Val} = Val(false), + ishermitian::Union{Bool, Val} = Val(false), isposdef::Bool = false,) where {N} # establish types # store eltype of input/output for caching with ComposedOperator. - eltypes = eltype(input), eltype(output) - T = isnothing(T) ? promote_type(eltypes...) : T - t = isnothing(t) ? zero(real(T)) : t + _T = T === nothing ? promote_type(eltype(input), eltype(output)) : T + _t = t === nothing ? zero(real(_T)) : t + + isinplace isa Val && (@assert _unwrap_val(isinplace) isa Bool) + outofplace isa Val && (@assert _unwrap_val(outofplace) isa Bool) + has_mul5 isa Val && (@assert _unwrap_val(has_mul5) isa Bool) + issymmetric isa Val && (@assert _unwrap_val(issymmetric) isa Bool) + ishermitian isa Val && (@assert _unwrap_val(ishermitian) isa Bool) - @assert T<:Number """The `eltype` of `FunctionOperator`, as well as + @assert _T<:Number """The `eltype` of `FunctionOperator`, as well as the `input`/`output` arrays must be `<:Number`.""" # establish sizes @@ -203,106 +236,109 @@ function FunctionOperator(op, # evaluation signatures - isinplace = if isnothing(isinplace) - static_hasmethod(op, typeof((output, input, p, t))) + _isinplace = if isinplace === nothing + Val(static_hasmethod(op, typeof((output, input, p, _t)))) + elseif isinplace isa Bool + Val(isinplace) else isinplace end - outofplace = if isnothing(outofplace) - static_hasmethod(op, typeof((input, p, t))) + _outofplace = if outofplace === nothing + Val(static_hasmethod(op, typeof((input, p, _t)))) + elseif outofplace isa Bool + Val(outofplace) else outofplace end - if !isinplace & !outofplace + if !_unwrap_val(_isinplace) & !_unwrap_val(_outofplace) @error """Please provide a funciton with signatures `op(u, p, t)` for applying the operator out-of-place, and/or the signature is `op(v, u, p, t)` for in-place application.""" end - has_mul5 = if isnothing(has_mul5) - has_mul5 = true - for f in (op, op_adjoint, op_inverse, op_adjoint_inverse) - if !isnothing(f) - has_mul5 *= static_hasmethod(f, typeof((output, input, p, t, t, t))) - end - end - + _has_mul5 = if has_mul5 === nothing + __and_val(__has_mul5(op, output, input, p, _t), + __has_mul5(op_adjoint, input, output, p, _t), + __has_mul5(op_inverse, output, input, p, _t), + __has_mul5(op_adjoint_inverse, input, output, p, _t)) + elseif has_mul5 isa Bool + Val(has_mul5) + else has_mul5 end # traits - isreal = T <: Real - selfadjoint = ishermitian | (isreal & issymmetric) - adjointable = !(op_adjoint isa Nothing) | selfadjoint + isreal = _T <: Real + selfadjoint = _unwrap_val(ishermitian) | (isreal & _unwrap_val(issymmetric)) + adjointable = !(op_adjoint isa Nothing) | _unwrap_val(selfadjoint) invertible = !(op_inverse isa Nothing) if selfadjoint & (op_adjoint isa Nothing) - op_adjoint = op + _op_adjoint = op + else + _op_adjoint = op_adjoint end if selfadjoint & invertible & (op_adjoint_inverse isa Nothing) - op_adjoint_inverse = op_inverse + _op_adjoint_inverse = op_inverse + else + _op_adjoint_inverse = op_adjoint_inverse end - traits = (; - islinear = islinear, - isconvertible = isconvertible, - isconstant = isconstant, opnorm = opnorm, - issymmetric = issymmetric, - ishermitian = ishermitian, - isposdef = isposdef, isinplace = isinplace, - outofplace = outofplace, - has_mul5 = has_mul5, - ifcache = ifcache, - T = T, - batch = batch, - size = _size, - sizes = sizes, - eltypes = eltypes, - accepted_kwargs = accepted_kwargs, - kwargs = Dict{Symbol, Any}(),) - - L = FunctionOperator(op, - op_adjoint, - op_inverse, - op_adjoint_inverse, - traits, - p, - t, - cache) + traits = (; islinear, isconvertible, isconstant, opnorm, + issymmetric = _unwrap_val(issymmetric), ishermitian = _unwrap_val(ishermitian), + isposdef, isinplace = _unwrap_val(_isinplace), + outofplace = _unwrap_val(_outofplace), has_mul5 = _unwrap_val(_has_mul5), + ifcache = _unwrap_val(ifcache), T = _T, batch, size = _size, sizes, + accepted_kwargs, kwargs = Dict{Symbol, Any}()) + + L = FunctionOperator{_unwrap_val(_isinplace), _unwrap_val(_outofplace), + _unwrap_val(_has_mul5), _T, typeof(op), typeof(_op_adjoint), typeof(op_inverse), + typeof(_op_adjoint_inverse), typeof(traits), typeof(p), typeof(_t), typeof(cache), + eltype(input), eltype(output)}(op, + _op_adjoint, op_inverse, _op_adjoint_inverse, traits, p, _t, cache) + + # L = FunctionOperator(op, _op_adjoint, op_inverse, _op_adjoint_inverse, traits, p, _t, + # cache, eltype(input), eltype(output)) # create cache - if ifcache & isnothing(L.cache) - L = cache_operator(L, input) + if _unwrap_val(ifcache) & (L.cache === nothing) + L_cached = cache_operator(L, input) + else + L_cached = L end - L + return L_cached +end + +@inline __has_mul5(::Nothing, y, x, p, t) = Val(true) +@inline function __has_mul5(f::F, y, x, p, t) where {F} + return Val(static_hasmethod(f, typeof((y, x, p, t, t, t)))) end +@inline __and_val(vs...) = mapreduce(_unwrap_val, *, vs) function update_coefficients(L::FunctionOperator, u, p, t; kwargs...) # update p, t - @set! L.p = p - @set! L.t = t + L = set_p(L, p) + L = set_t(L, t) # filter and update kwargs filtered_kwargs = get_filtered_kwargs(kwargs, L.traits.accepted_kwargs) - @set! L.traits.kwargs = Dict{Symbol, Any}(filtered_kwargs) + + L = set_traits(L, merge(L.traits, (; kwargs = Dict{Symbol, Any}(filtered_kwargs)))) isconstant(L) && return L - @set! L.op = update_coefficients(L.op, u, p, t; filtered_kwargs...) - @set! L.op_adjoint = update_coefficients(L.op_adjoint, u, p, t; filtered_kwargs...) - @set! L.op_inverse = update_coefficients(L.op_inverse, u, p, t; filtered_kwargs...) - @set! L.op_adjoint_inverse = update_coefficients(L.op_adjoint_inverse, - u, - p, - t; - filtered_kwargs...) + L = set_op(L, update_coefficients(L.op, u, p, t; filtered_kwargs...)) + L = set_op_adjoint(L, update_coefficients(L.op_adjoint, u, p, t; filtered_kwargs...)) + L = set_op_inverse(L, update_coefficients(L.op_inverse, u, p, t; filtered_kwargs...)) + L = set_op_adjoint_inverse(L, + update_coefficients(L.op_adjoint_inverse, u, p, t; filtered_kwargs...)) end function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...) @@ -325,8 +361,8 @@ function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...) end function iscached(L::FunctionOperator) - L.traits.ifcache ? !isnothing(L.cache) : !L.traits.ifcache - !isnothing(L.cache) + # L.traits.ifcache ? !isnothing(L.cache) : !L.traits.ifcache + L.cache !== nothing end # fix method amg bw AbstractArray, AbstractVecOrMat @@ -351,7 +387,9 @@ function _cache_operator(L::FunctionOperator, u::AbstractArray) M = size(L, 1) K = size(u, 2) size_out = u isa AbstractVector ? (M,) : (M, K) - @set! L.traits.sizes = size(u), size_out + + new_traits = merge(L.traits, (; sizes = (size(u), size_out))) + L = set_traits(L, new_traits) u else @@ -374,22 +412,32 @@ end cache_self(L::FunctionOperator, u::AbstractArray) = _cache_self(L, u) cache_self(L::FunctionOperator, u::AbstractVecOrMat) = _cache_self(L, u) -function _cache_self(L::FunctionOperator, u::AbstractArray) - _u = similar(u, L.traits.eltypes[1], L.traits.sizes[1]) - _v = similar(u, L.traits.eltypes[2], L.traits.sizes[2]) +function _cache_self(L::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, + iType, oType}, + u::AbstractArray) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, + Tt, C, iType, oType} + _u = similar(u, iType, L.traits.sizes[1]) + _v = similar(u, oType, L.traits.sizes[2]) - @set! L.cache = (_u, _v) + return set_cache(L, (_u, _v)) end # fix method amg bw AbstractArray, AbstractVecOrMat cache_internals(L::FunctionOperator, u::AbstractArray) = _cache_internals(L, u) cache_internals(L::FunctionOperator, u::AbstractVecOrMat) = _cache_internals(L, u) -function _cache_internals(L::FunctionOperator, u::AbstractArray) - @set! L.op = cache_operator(L.op, u) - @set! L.op_adjoint = cache_operator(L.op_adjoint, u) - @set! L.op_inverse = cache_operator(L.op_inverse, u) - @set! L.op_adjoint_inverse = cache_operator(L.op_adjoint_inverse, u) +function _cache_internals(L::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, + C, iType, oType}, + u::AbstractArray) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, + P, Tt, C, iType, oType} + newop = cache_operator(L.op, u) + newop_adjoint = cache_operator(L.op_adjoint, u) + newop_inverse = cache_operator(L.op_inverse, u) + newop_adjoint_inverse = cache_operator(L.op_adjoint_inverse, u) + + return FunctionOperator{iip, oop, mul5, T, typeof(newop), typeof(newop_adjoint), + typeof(newop_inverse), typeof(newop_adjoint_inverse), Tr, P, Tt, C, iType, oType}(newop, newop_adjoint, newop_inverse, newop_adjoint_inverse, L.traits, L.p, L.t, + L.cache) end function Base.show(io::IO, L::FunctionOperator) @@ -397,14 +445,13 @@ function Base.show(io::IO, L::FunctionOperator) print(io, "FunctionOperator($M × $N)") end Base.size(L::FunctionOperator) = L.traits.size -function Base.adjoint(L::FunctionOperator) - if ishermitian(L) | (isreal(L) & issymmetric(L)) - return L - end - if !(has_adjoint(L)) - return AdjointOperator(L) - end +function Base.adjoint(L::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, + C, iType, oType, + }) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, iType, oType} + (ishermitian(L) | (isreal(L) & issymmetric(L))) && return L + + has_adjoint(L) || return AdjointOperator(L) op = L.op_adjoint op_adjoint = L.op @@ -412,31 +459,20 @@ function Base.adjoint(L::FunctionOperator) op_inverse = L.op_adjoint_inverse op_adjoint_inverse = L.op_inverse - traits = L.traits - @set! traits.size = reverse(size(L)) - @set! traits.sizes = reverse(traits.sizes) - @set! traits.eltypes = reverse(traits.eltypes) + traits = merge(L.traits, (; size = reverse(size(L)), sizes = reverse(L.traits.sizes))) - cache = if iscached(L) - cache = reverse(L.cache) - else - nothing - end + cache = iscached(L) ? reverse(L.cache) : nothing - FunctionOperator(op, - op_adjoint, - op_inverse, - op_adjoint_inverse, - traits, - L.p, - L.t, - cache) + return FunctionOperator{iip, oop, mul5, T, typeof(op), typeof(op_adjoint), + typeof(op_inverse), typeof(op_adjoint_inverse), typeof(traits), P, Tt, + typeof(cache), iType, oType}(op, op_adjoint, op_inverse, op_adjoint_inverse, traits, + L.p, L.t, cache) end -function Base.inv(L::FunctionOperator) - if !(has_ldiv(L)) - return InvertedOperator(L) - end +function Base.inv(L::FunctionOperator{iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, + C, iType, oType, + }) where {iip, oop, mul5, T, F, Fa, Fi, Fai, Tr, P, Tt, C, iType, oType} + has_ldiv(L) || return InvertedOperator(L) op = L.op_inverse op_inverse = L.op @@ -444,33 +480,22 @@ function Base.inv(L::FunctionOperator) op_adjoint = L.op_adjoint_inverse op_adjoint_inverse = L.op_adjoint - traits = L.traits - @set! traits.size = reverse(size(L)) - @set! traits.sizes = reverse(traits.sizes) - @set! traits.eltypes = reverse(traits.eltypes) - - @set! traits.opnorm = if traits.opnorm isa Number - 1 / traits.opnorm - elseif traits.opnorm isa Nothing + opnorm = if L.traits.opnorm isa Number + 1 / L.traits.opnorm + elseif L.traits.opnorm isa Nothing nothing else - (p::Real) -> 1 / traits.opnorm(p) + (p::Real) -> 1 / L.traits.opnorm(p) end + traits = merge(L.traits, + (; size = reverse(size(L)), sizes = reverse(L.traits.sizes), opnorm)) - cache = if iscached(L) - cache = reverse(L.cache) - else - nothing - end + cache = iscached(L) ? reverse(L.cache) : nothing - FunctionOperator(op, - op_adjoint, - op_inverse, - op_adjoint_inverse, - traits, - L.p, - L.t, - cache) + return FunctionOperator{iip, oop, mul5, T, typeof(op), typeof(op_adjoint), + typeof(op_inverse), typeof(op_adjoint_inverse), typeof(traits), P, Tt, + typeof(cache), iType, oType}(op, op_adjoint, op_inverse, op_adjoint_inverse, traits, + L.p, L.t, cache) end Base.convert(::Type{AbstractMatrix}, L::FunctionOperator) = convert(AbstractMatrix, L.op) diff --git a/src/utils.jl b/src/utils.jl index 8b36ae28..1297f64d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -16,8 +16,8 @@ struct NoKwargFilter end function preprocess_update_func(update_func, accepted_kwargs) _update_func = (update_func === nothing) ? DEFAULT_UPDATE_FUNC : update_func _accepted_kwargs = (accepted_kwargs === nothing) ? () : accepted_kwargs - # accepted_kwargs can be passed as nothing to indicate that we should not filter - # (e.g. if the function already accepts all kwargs...). + # accepted_kwargs can be passed as nothing to indicate that we should not filter + # (e.g. if the function already accepts all kwargs...). return (_accepted_kwargs isa NoKwargFilter) ? _update_func : FilterKwargs(_update_func, _accepted_kwargs) end @@ -48,3 +48,6 @@ function (f::FilterKwargs)(args...; kwargs...) f.f(args...; filtered_kwargs...) end # + +_unwrap_val(x) = x +_unwrap_val(::Val{X}) where {X} = X