Skip to content

Commit

Permalink
Merge pull request #205 from vpuri3/isconcrete
Browse files Browse the repository at this point in the history
concretize methods
  • Loading branch information
ChrisRackauckas authored Jul 18, 2023
2 parents 303622a + 03d6d39 commit a5f1eba
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 11 deletions.
2 changes: 2 additions & 0 deletions src/SciMLOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ export update_coefficients!,

issquare,
islinear,
concretize,
isconvertible,

has_adjoint,
has_expmv,
Expand Down
1 change: 1 addition & 0 deletions src/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ end

getops(L::InvertedOperator) = (L.L,)
islinear(L::InvertedOperator) = islinear(L.L)
isconvertible(::InvertedOperator) = false

has_mul(L::InvertedOperator) = has_ldiv(L.L)
has_mul!(L::InvertedOperator) = has_ldiv!(L.L)
Expand Down
7 changes: 7 additions & 0 deletions src/batch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ function Base.conj(L::BatchedDiagonalOperator) # TODO - test this thoroughly
)
end

function Base.convert(::Type{AbstractMatrix}, L::BatchedDiagonalOperator)
m, n = size(L)
msg = """$L cannot be represented by an $m × $n AbstractMatrix"""
throw(ArgumentError(msg))
end

LinearAlgebra.issymmetric(L::BatchedDiagonalOperator) = true
function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator)
if isreal(L)
Expand All @@ -91,6 +97,7 @@ function isconstant(L::BatchedDiagonalOperator)
update_func_isconstant(L.update_func) & update_func_isconstant(L.update_func!)
end
islinear(::BatchedDiagonalOperator) = true
isconvertible(::BatchedDiagonalOperator) = false
has_adjoint(L::BatchedDiagonalOperator) = true
has_ldiv(L::BatchedDiagonalOperator) = all(x -> !iszero(x), L.diag)
has_ldiv!(L::BatchedDiagonalOperator) = has_ldiv(L)
Expand Down
6 changes: 6 additions & 0 deletions src/func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ uniform across `op`, `op_adjoint`, `op_inverse`, `op_adjoint_inverse`.
* `has_mul5` - `true` if the operator provides a five-argument `mul!` via the signature `op(v, u, p, t, α, β; <accepted_kwargs>)`. This trait is inferred if no value is provided.
* `isconstant` - `true` if the operator is constant, and doesn't need to be updated via `update_coefficients[!]` during operator evaluation.
* `islinear` - `true` if the operator is linear. Defaults to `false`.
* `isconvertible` - `true` a cheap `convert(AbstractMatrix, L.op)` method is available. Defaults to `false`.
* `batch` - Boolean indicating if the input/output arrays comprise of batched column-vectors stacked in a matrix. If `true`, the input/output arrays must be `AbstractVecOrMat`s, and the length of the second dimension (the batch dimension) must be the same. The batch dimension is not involved in size computation. For example, with `batch = true`, and `size(output), size(input) = (M, K), (N, K)`, the `FunctionOperator` size is set to `(M, N)`. If `batch = false`, which is the default, the `input`/`output` arrays may of any size so long as `ndims(input) == ndims(output)`, and the `size` of `FunctionOperator` is set to `(length(input), length(output))`.
* `ifcache` - Allocate cache arrays in constructor. Defaults to `true`. Cache can be generated afterwards by calling `cache_operator(L, input, output)`
* `cache` - Pregenerated cache arrays for in-place evaluations. Expected to be of type and shape `(similar(input), similar(output),)`. The constructor generates cache if no values are provided. Cache generation by the constructor can be disabled by setting the kwarg `ifcache = false`.
Expand Down Expand Up @@ -138,6 +139,7 @@ function FunctionOperator(op,
has_mul5::Union{Nothing,Bool}=nothing,
isconstant::Bool = false,
islinear::Bool = false,
isconvertible::Bool = false,

batch::Bool = false,
ifcache::Bool = true,
Expand Down Expand Up @@ -248,6 +250,7 @@ function FunctionOperator(op,

traits = (;
islinear = islinear,
isconvertible = isconvertible,
isconstant = isconstant,

opnorm = opnorm,
Expand Down Expand Up @@ -480,6 +483,8 @@ function Base.inv(L::FunctionOperator)
)
end

Base.convert(::Type{AbstractMatrix}, L::FunctionOperator) = convert(AbstractMatrix, L.op)

function Base.resize!(L::FunctionOperator, n::Integer)

# input/output to `L` must be `AbstractVector`s
Expand Down Expand Up @@ -526,6 +531,7 @@ function getops(L::FunctionOperator)
end

islinear(L::FunctionOperator) = L.traits.islinear
isconvertible(L::FunctionOperator) = L.traits.isconvertible
isconstant(L::FunctionOperator) = L.traits.isconstant
has_adjoint(L::FunctionOperator) = !(L.op_adjoint isa Nothing)
has_mul(::FunctionOperator{iip}) where{iip} = true
Expand Down
86 changes: 76 additions & 10 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,37 +187,53 @@ Base.oneunit(LType::Type{<:AbstractSciMLOperator}) = one(LType)
Base.iszero(::AbstractSciMLOperator) = false # TODO

"""
$SIGNATURES
Check if `adjoint(L)` is lazily defined.
"""
has_adjoint(L::AbstractSciMLOperator) = false # L', adjoint(L)
"""
$SIGNATURES
Check if `expmv!(v, L, u, t)`, equivalent to `mul!(v, exp(t * A), u)`, is
defined for `Number` `t`, and `AbstractArray`s `u, v` of appropriate sizes.
"""
has_expmv!(L::AbstractSciMLOperator) = false # expmv!(v, L, t, u)
"""
$SIGNATURES
Check if `expmv(L, u, t)`, equivalent to `exp(t * A) * u`, is defined for
`Number` `t`, and `AbstractArray` `u` of appropriate size.
"""
has_expmv(L::AbstractSciMLOperator) = false # v = exp(L, t, u)
"""
$SIGNATURES
Check if `exp(L)` is defined lazily defined.
"""
has_exp(L::AbstractSciMLOperator) = islinear(L)
"""
$SIGNATURES
Check if `L * u` is defined for `AbstractArray` `u` of appropriate size.
"""
has_mul(L::AbstractSciMLOperator) = true # du = L*u
"""
$SIGNATURES
Check if `mul!(v, L, u)` is defined for `AbstractArray`s `u, v` of
appropriate sizes.
"""
has_mul!(L::AbstractSciMLOperator) = true # mul!(du, L, u)
"""
$SIGNATURES
Check if `L \\ u` is defined for `AbstractArray` `u` of appropriate size.
"""
has_ldiv(L::AbstractSciMLOperator) = false # du = L\u
"""
$SIGNATURES
Check if `ldiv!(v, L, u)` is defined for `AbstractArray`s `u, v` of
appropriate sizes.
"""
Expand All @@ -244,7 +260,57 @@ isconstant(::Union{
) = true
isconstant(L::AbstractSciMLOperator) = all(isconstant, getops(L))

#islinear(L) = false
"""
isconvertible(L) -> Bool
Checks if `L` can be cheaply converted to an `AbstractMatrix` via eager fusion.
"""
isconvertible(L::AbstractSciMLOperator) = all(isconvertible, getops(L))

isconvertible(::Union{
# LinearAlgebra
AbstractMatrix,
UniformScaling,
Factorization,

# Base
Number,

# SciMLOperators
AbstractSciMLScalarOperator,
}
) = true

"""
concretize(L) -> AbstractMatrix
concretize(L) -> Number
Convert `SciMLOperator` to a concrete type via eager fusion. This method is a
no-op for types that are already concrete.
"""
concretize(L::Union{
# LinearAlgebra
AbstractMatrix,
Factorization,

# SciMLOperators
AbstractSciMLOperator,
}
) = convert(AbstractMatrix, L)

concretize(L::Union{
# LinearAlgebra
UniformScaling,

# Base
Number,

# SciMLOperators
AbstractSciMLScalarOperator,
}
) = convert(Number, L)

"""
$SIGNATURES
Expand Down Expand Up @@ -349,22 +415,22 @@ expmv!(v,L::AbstractSciMLOperator,u,p,t) = mul!(v,exp(L,t),u)
function Base.conj(L::AbstractSciMLOperator)
isreal(L) && return L
@warn """using convert-based fallback for Base.conj"""
convert(AbstractMatrix, L) |> conj
concretize(L) |> conj
end

function Base.:(==)(L1::AbstractSciMLOperator, L2::AbstractSciMLOperator)
@warn """using convert-based fallback for Base.=="""
size(L1) != size(L2) && return false
convert(AbstractMatrix, L1) == convert(AbstractMatrix, L1)
concretize(L1) == concretize(L2)
end

Base.@propagate_inbounds function Base.getindex(L::AbstractSciMLOperator, I::Vararg{Any,N}) where {N}
@warn """using convert-based fallback for Base.getindex"""
convert(AbstractMatrix, L)[I...]
concretize(L)[I...]
end
function Base.getindex(L::AbstractSciMLOperator, I::Vararg{Int, N}) where {N}
@warn """using convert-based fallback for Base.getindex"""
convert(AbstractMatrix, L)[I...]
concretize(L)[I...]
end

function Base.resize!(L::AbstractSciMLOperator, n::Integer)
Expand All @@ -375,15 +441,15 @@ LinearAlgebra.exp(L::AbstractSciMLOperator) = exp(Matrix(L))

function LinearAlgebra.opnorm(L::AbstractSciMLOperator, p::Real=2)
@warn """using convert-based fallback in LinearAlgebra.opnorm."""
opnorm(convert(AbstractMatrix, L), p)
opnorm(concretize(L), p)
end

for op in (
:sum, :prod,
)
@eval function Base.$op(L::AbstractSciMLOperator; kwargs...)
@warn """using convert-based fallback in $($op)."""
$op(convert(AbstractMatrix, L); kwargs...)
$op(concretize(L); kwargs...)
end
end

Expand All @@ -394,17 +460,17 @@ for pred in (
)
@eval function LinearAlgebra.$pred(L::AbstractSciMLOperator)
@warn """using convert-based fallback in $($pred)."""
$pred(convert(AbstractMatrix, L))
$pred(concretize(L))
end
end

function LinearAlgebra.mul!(v::AbstractArray, L::AbstractSciMLOperator, u::AbstractArray)
@warn """using convert-based fallback in mul!."""
mul!(v, convert(AbstractMatrix, L), u)
mul!(v, concretize(L), u)
end

function LinearAlgebra.mul!(v::AbstractArray, L::AbstractSciMLOperator, u::AbstractArray, α, β)
@warn """using convert-based fallback in mul!."""
mul!(v, convert(AbstractMatrix, L), u, α, β)
mul!(v, concretize(L), u, α, β)
end
#
12 changes: 11 additions & 1 deletion src/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ end
has_ldiv,
has_ldiv!,
)

isconvertible(::MatrixOperator) = true
islinear(::MatrixOperator) = true

function Base.show(io::IO, L::MatrixOperator)
Expand Down Expand Up @@ -162,7 +164,7 @@ SparseArrays.issparse(L::MatrixOperator) = issparse(L.A)

# TODO - add tests for MatrixOperator indexing
# propagate_inbounds here for the getindex fallback
Base.@propagate_inbounds Base.convert(::Type{AbstractMatrix}, L::MatrixOperator) = L.A
Base.@propagate_inbounds Base.convert(::Type{AbstractMatrix}, L::MatrixOperator) = convert(AbstractMatrix, L.A)
Base.@propagate_inbounds Base.setindex!(L::MatrixOperator, v, i::Int) = (L.A[i] = v)
Base.@propagate_inbounds Base.setindex!(L::MatrixOperator, v, I::Vararg{Int, N}) where{N} = (L.A[I...] = v)

Expand Down Expand Up @@ -322,6 +324,7 @@ end

getops(L::InvertibleOperator) = (L.L, L.F,)
islinear(L::InvertibleOperator) = islinear(L.L)
isconvertible(L::InvertibleOperator) = isconvertible(L.L)

@forward InvertibleOperator.L (
# LinearAlgebra
Expand Down Expand Up @@ -510,6 +513,7 @@ end
getops(L::AffineOperator) = (L.A, L.B, L.b)

islinear(::AffineOperator) = false
isconvertible(::AffineOperator) = false

function Base.show(io::IO, L::AffineOperator)
show(io, L.A)
Expand Down Expand Up @@ -537,6 +541,12 @@ function Base.resize!(L::AffineOperator, n::Integer)
L
end

function Base.convert(::Type{AbstractMatrix}, L::AffineOperator)
m, n = size(L)
msg = """$L cannot be represented by an $m × $n AbstractMatrix"""
throw(ArgumentError(msg))
end

has_adjoint(L::AffineOperator) = false
has_mul(L::AffineOperator) = has_mul(L.A)
has_mul!(L::AffineOperator) = has_mul!(L.A)
Expand Down
1 change: 1 addition & 0 deletions src/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Base.adjoint(α::AbstractSciMLScalarOperator) = conj(α)
Base.transpose::AbstractSciMLScalarOperator) = α

has_mul!(::AbstractSciMLScalarOperator) = true
isconcrete(::AbstractSciMLScalarOperator) = true
islinear(::AbstractSciMLScalarOperator) = true
has_adjoint(::AbstractSciMLScalarOperator) = true

Expand Down
1 change: 1 addition & 0 deletions src/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ end

getops(L::TensorProductOperator) = L.ops
islinear(L::TensorProductOperator) = reduce(&, islinear.(L.ops))
isconvertible(::TensorProductOperator) = false
Base.iszero(L::TensorProductOperator) = reduce(|, iszero.(L.ops))
has_adjoint(L::TensorProductOperator) = reduce(&, has_adjoint.(L.ops))
has_mul(L::TensorProductOperator) = reduce(&, has_mul.(L.ops))
Expand Down

0 comments on commit a5f1eba

Please sign in to comment.