From 314674cc599e8fd007800ec7d93d220ffdb921d2 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Fri, 16 Jun 2023 15:30:32 -0400 Subject: [PATCH 01/10] trait isconcrete --- src/SciMLOperators.jl | 1 + src/func.jl | 4 ++++ src/interface.jl | 35 ++++++++++++++++++++++++++++++++++- src/matrix.jl | 2 ++ src/scalar.jl | 1 + 5 files changed, 42 insertions(+), 1 deletion(-) diff --git a/src/SciMLOperators.jl b/src/SciMLOperators.jl index bb7fe173..aabb7a71 100644 --- a/src/SciMLOperators.jl +++ b/src/SciMLOperators.jl @@ -98,6 +98,7 @@ export update_coefficients!, issquare, islinear, + isconcrete, has_adjoint, has_expmv, diff --git a/src/func.jl b/src/func.jl index 0d9c0d49..3ddc9ae8 100644 --- a/src/func.jl +++ b/src/func.jl @@ -138,6 +138,7 @@ function FunctionOperator(op, has_mul5::Union{Nothing,Bool}=nothing, isconstant::Bool = false, islinear::Bool = false, + isconcrete::Bool = false, batch::Bool = false, ifcache::Bool = true, @@ -248,6 +249,7 @@ function FunctionOperator(op, traits = (; islinear = islinear, + isconcrete = isconcrete, isconstant = isconstant, opnorm = opnorm, @@ -468,6 +470,8 @@ function Base.inv(L::FunctionOperator) ) end +Base.convert(::Type{AbstractMatrix}, L::FunctionOperator) = convert(AbstractMatrix, L.op) + function Base.resize!(L::FunctionOperator, n::Integer) # input/output to `L` must be `AbstractVector`s diff --git a/src/interface.jl b/src/interface.jl index a428ac43..32cb0081 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -187,37 +187,53 @@ Base.oneunit(LType::Type{<:AbstractSciMLOperator}) = one(LType) Base.iszero(::AbstractSciMLOperator) = false # TODO """ +$SIGNATURES + Check if `adjoint(L)` is lazily defined. """ has_adjoint(L::AbstractSciMLOperator) = false # L', adjoint(L) """ +$SIGNATURES + Check if `expmv!(v, L, u, t)`, equivalent to `mul!(v, exp(t * A), u)`, is defined for `Number` `t`, and `AbstractArray`s `u, v` of appropriate sizes. """ has_expmv!(L::AbstractSciMLOperator) = false # expmv!(v, L, t, u) """ +$SIGNATURES + Check if `expmv(L, u, t)`, equivalent to `exp(t * A) * u`, is defined for `Number` `t`, and `AbstractArray` `u` of appropriate size. """ has_expmv(L::AbstractSciMLOperator) = false # v = exp(L, t, u) """ +$SIGNATURES + Check if `exp(L)` is defined lazily defined. """ has_exp(L::AbstractSciMLOperator) = islinear(L) """ +$SIGNATURES + Check if `L * u` is defined for `AbstractArray` `u` of appropriate size. """ has_mul(L::AbstractSciMLOperator) = true # du = L*u """ +$SIGNATURES + Check if `mul!(v, L, u)` is defined for `AbstractArray`s `u, v` of appropriate sizes. """ has_mul!(L::AbstractSciMLOperator) = true # mul!(du, L, u) """ +$SIGNATURES + Check if `L \\ u` is defined for `AbstractArray` `u` of appropriate size. """ has_ldiv(L::AbstractSciMLOperator) = false # du = L\u """ +$SIGNATURES + Check if `ldiv!(v, L, u)` is defined for `AbstractArray`s `u, v` of appropriate sizes. """ @@ -244,7 +260,24 @@ isconstant(::Union{ ) = true isconstant(L::AbstractSciMLOperator) = all(isconstant, getops(L)) -#islinear(L) = false +""" +$SIGNATURES + +Checks if `L` can be cheaply converted to an `AbstractMatrix` +""" +isconcrete(L::AbstractSciMLOperator) = all(isconcrete, getops(L)) + +isconcrete(::Union{ + # LinearAlgebra + AbstractMatrix, + UniformScaling, + Factorization, + + # Base + Number, + } + ) = true + """ $SIGNATURES diff --git a/src/matrix.jl b/src/matrix.jl index 992e5cca..4d92a755 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -103,6 +103,8 @@ end has_ldiv, has_ldiv!, ) + +isconcrete(::MatrixOperator) = true islinear(::MatrixOperator) = true function Base.show(io::IO, L::MatrixOperator) diff --git a/src/scalar.jl b/src/scalar.jl index dc477f79..ca43d54b 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -32,6 +32,7 @@ Base.adjoint(α::AbstractSciMLScalarOperator) = conj(α) Base.transpose(α::AbstractSciMLScalarOperator) = α has_mul!(::AbstractSciMLScalarOperator) = true +isconcrete(::AbstractSciMLScalarOperator) = true islinear(::AbstractSciMLScalarOperator) = true has_adjoint(::AbstractSciMLScalarOperator) = true From 118556d2e95ec13788b518b34b83664770a44547 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Fri, 16 Jun 2023 15:33:53 -0400 Subject: [PATCH 02/10] isconcrete method for functionop --- src/func.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/func.jl b/src/func.jl index 3ddc9ae8..71e5279b 100644 --- a/src/func.jl +++ b/src/func.jl @@ -518,6 +518,7 @@ function getops(L::FunctionOperator) end islinear(L::FunctionOperator) = L.traits.islinear +isconcrete(L::FunctionOperator) = L.traits.isconcrete isconstant(L::FunctionOperator) = L.traits.isconstant has_adjoint(L::FunctionOperator) = !(L.op_adjoint isa Nothing) has_mul(::FunctionOperator{iip}) where{iip} = true From 7bfd8bdb643882850eabf558cb758189dc31dc64 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sun, 25 Jun 2023 14:57:26 -0400 Subject: [PATCH 03/10] isconvertible, convert --- src/SciMLOperators.jl | 3 ++- src/basic.jl | 1 + src/batch.jl | 6 +++++ src/func.jl | 7 ++--- src/interface.jl | 60 ++++++++++++++++++++++++++++++++++--------- src/matrix.jl | 10 ++++++-- src/tensor.jl | 1 + 7 files changed, 70 insertions(+), 18 deletions(-) diff --git a/src/SciMLOperators.jl b/src/SciMLOperators.jl index aabb7a71..4409e920 100644 --- a/src/SciMLOperators.jl +++ b/src/SciMLOperators.jl @@ -98,7 +98,8 @@ export update_coefficients!, issquare, islinear, - isconcrete, + concretize, + isconvertible, has_adjoint, has_expmv, diff --git a/src/basic.jl b/src/basic.jl index a4ad0260..3e5dd74f 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -761,6 +761,7 @@ end getops(L::InvertedOperator) = (L.L,) islinear(L::InvertedOperator) = islinear(L.L) +isconvertible(::InvertedOperator) = false has_mul(L::InvertedOperator) = has_ldiv(L.L) has_mul!(L::InvertedOperator) = has_ldiv!(L.L) diff --git a/src/batch.jl b/src/batch.jl index a30d8634..53a990a3 100644 --- a/src/batch.jl +++ b/src/batch.jl @@ -67,6 +67,12 @@ function Base.conj(L::BatchedDiagonalOperator) # TODO - test this thoroughly ) end +function Base.convert(::Type{AbstractMatrix}, L::BatchedDiagonalOperator) + m, n = size(L) + msg = """$L cannot be represented by an $m × $n AbstractMatrix""" + throw(ArgumentError(msg)) +end + LinearAlgebra.issymmetric(L::BatchedDiagonalOperator) = true function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator) if isreal(L) diff --git a/src/func.jl b/src/func.jl index 71e5279b..0c98a15d 100644 --- a/src/func.jl +++ b/src/func.jl @@ -111,6 +111,7 @@ uniform across `op`, `op_adjoint`, `op_inverse`, `op_adjoint_inverse`. * `has_mul5` - `true` if the operator provides a five-argument `mul!` via the signature `op(v, u, p, t, α, β; )`. This trait is inferred if no value is provided. * `isconstant` - `true` if the operator is constant, and doesn't need to be updated via `update_coefficients[!]` during operator evaluation. * `islinear` - `true` if the operator is linear. Defaults to `false`. +* `isconvertible` - `true` a cheap `convert(AbstractMatrix, L.op)` method is available. Defaults to `false`. * `batch` - Boolean indicating if the input/output arrays comprise of batched column-vectors stacked in a matrix. If `true`, the input/output arrays must be `AbstractVecOrMat`s, and the length of the second dimension (the batch dimension) must be the same. The batch dimension is not involved in size computation. For example, with `batch = true`, and `size(output), size(input) = (M, K), (N, K)`, the `FunctionOperator` size is set to `(M, N)`. If `batch = false`, which is the default, the `input`/`output` arrays may of any size so long as `ndims(input) == ndims(output)`, and the `size` of `FunctionOperator` is set to `(length(input), length(output))`. * `ifcache` - Allocate cache arrays in constructor. Defaults to `true`. Cache can be generated afterwards by calling `cache_operator(L, input, output)` * `cache` - Pregenerated cache arrays for in-place evaluations. Expected to be of type and shape `(similar(input), similar(output),)`. The constructor generates cache if no values are provided. Cache generation by the constructor can be disabled by setting the kwarg `ifcache = false`. @@ -138,7 +139,7 @@ function FunctionOperator(op, has_mul5::Union{Nothing,Bool}=nothing, isconstant::Bool = false, islinear::Bool = false, - isconcrete::Bool = false, + isconvertible::Bool = false, batch::Bool = false, ifcache::Bool = true, @@ -249,7 +250,7 @@ function FunctionOperator(op, traits = (; islinear = islinear, - isconcrete = isconcrete, + isconvertible = isconvertible, isconstant = isconstant, opnorm = opnorm, @@ -518,7 +519,7 @@ function getops(L::FunctionOperator) end islinear(L::FunctionOperator) = L.traits.islinear -isconcrete(L::FunctionOperator) = L.traits.isconcrete +isconvertible(L::FunctionOperator) = L.traits.isconvertible isconstant(L::FunctionOperator) = L.traits.isconstant has_adjoint(L::FunctionOperator) = !(L.op_adjoint isa Nothing) has_mul(::FunctionOperator{iip}) where{iip} = true diff --git a/src/interface.jl b/src/interface.jl index 32cb0081..6f3ac1e8 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -261,22 +261,58 @@ isconstant(::Union{ isconstant(L::AbstractSciMLOperator) = all(isconstant, getops(L)) """ -$SIGNATURES + isconvertible(L) -> Bool -Checks if `L` can be cheaply converted to an `AbstractMatrix` +Checks if `L` can be cheaply converted to an `AbstractMatrix` via eager fusion. """ -isconcrete(L::AbstractSciMLOperator) = all(isconcrete, getops(L)) +isconvertible(L::AbstractSciMLOperator) = all(isconvertible, getops(L)) -isconcrete(::Union{ - # LinearAlgebra - AbstractMatrix, - UniformScaling, - Factorization, +isconvertible(::Union{ + # LinearAlgebra + AbstractMatrix, + UniformScaling, + Factorization, - # Base - Number, - } - ) = true + # Base + Number, + + # SciMLOperators + AbstractSciMLScalarOperator, + } + ) = true + +""" + concretize(L) -> AbstractMatrix + + concretize(L) -> Number + +Convert `SciMLOperator` to a concrete type via eager fusion. This method is a +no-op for types that are already concrete. +""" +concretize(L::Union{ + # LinearAlgebra + AbstractMatrix, + Factorization, + + # Base + Number, + + # SciMLOperators + AbstractSciMLOperator, + } + ) = convert(AbstractMatrix, L) + +concretize(L::Union{ + # LinearAlgebra + UniformScaling, + + # Base + Number, + + # SciMLOperators + AbstractSciMLScalarOperator, + } + ) = convert(Number, L) """ $SIGNATURES diff --git a/src/matrix.jl b/src/matrix.jl index 4d92a755..3902fe45 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -104,7 +104,7 @@ end has_ldiv!, ) -isconcrete(::MatrixOperator) = true +isconvertible(::MatrixOperator) = true islinear(::MatrixOperator) = true function Base.show(io::IO, L::MatrixOperator) @@ -164,7 +164,7 @@ SparseArrays.issparse(L::MatrixOperator) = issparse(L.A) # TODO - add tests for MatrixOperator indexing # propagate_inbounds here for the getindex fallback -Base.@propagate_inbounds Base.convert(::Type{AbstractMatrix}, L::MatrixOperator) = L.A +Base.@propagate_inbounds Base.convert(::Type{AbstractMatrix}, L::MatrixOperator) = convert(AbstractMatrix, L.A) Base.@propagate_inbounds Base.setindex!(L::MatrixOperator, v, i::Int) = (L.A[i] = v) Base.@propagate_inbounds Base.setindex!(L::MatrixOperator, v, I::Vararg{Int, N}) where{N} = (L.A[I...] = v) @@ -539,6 +539,12 @@ function Base.resize!(L::AffineOperator, n::Integer) L end +function Base.convert(::Type{AbstractMatrix}, L::AffineOperator) + m, n = size(L) + msg = """$L cannot be represented by an $m × $n AbstractMatrix""" + throw(ArgumentError(msg)) +end + has_adjoint(L::AffineOperator) = false has_mul(L::AffineOperator) = has_mul(L.A) has_mul!(L::AffineOperator) = has_mul!(L.A) diff --git a/src/tensor.jl b/src/tensor.jl index 6bf78543..6034bba1 100644 --- a/src/tensor.jl +++ b/src/tensor.jl @@ -121,6 +121,7 @@ end getops(L::TensorProductOperator) = L.ops islinear(L::TensorProductOperator) = reduce(&, islinear.(L.ops)) +isconvertible(::TensorProductOperator) = false Base.iszero(L::TensorProductOperator) = reduce(|, iszero.(L.ops)) has_adjoint(L::TensorProductOperator) = reduce(&, has_adjoint.(L.ops)) has_mul(L::TensorProductOperator) = reduce(&, has_mul.(L.ops)) From 641e0e661c51595543969bac20a9ff7cb532189c Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sun, 25 Jun 2023 15:14:42 -0400 Subject: [PATCH 04/10] isconvertible method for affine, invertible, batchdiag --- src/batch.jl | 1 + src/matrix.jl | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/batch.jl b/src/batch.jl index 53a990a3..64492f7c 100644 --- a/src/batch.jl +++ b/src/batch.jl @@ -97,6 +97,7 @@ function isconstant(L::BatchedDiagonalOperator) update_func_isconstant(L.update_func) & update_func_isconstant(L.update_func!) end islinear(::BatchedDiagonalOperator) = true +isconvertible(::BatchedDiagonalOperator) = false has_adjoint(L::BatchedDiagonalOperator) = true has_ldiv(L::BatchedDiagonalOperator) = all(x -> !iszero(x), L.diag) has_ldiv!(L::BatchedDiagonalOperator) = has_ldiv(L) diff --git a/src/matrix.jl b/src/matrix.jl index 3902fe45..30687f2c 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -324,6 +324,7 @@ end getops(L::InvertibleOperator) = (L.L, L.F,) islinear(L::InvertibleOperator) = islinear(L.L) +isconvertible(L::InvertibleOperator) = isconvertible(L.L) @forward InvertibleOperator.L ( # LinearAlgebra @@ -512,6 +513,7 @@ end getops(L::AffineOperator) = (L.A, L.B, L.b) islinear(::AffineOperator) = false +isconvertible(::AffineOperator) = false function Base.show(io::IO, L::AffineOperator) show(io, L.A) From 660169998d727fb4fca7cc3d9a142b9d9d7812ea Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sun, 2 Jul 2023 14:59:07 -0400 Subject: [PATCH 05/10] switch out convert calls for concretize --- src/interface.jl | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 9f93610c..cd4fcedd 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -294,9 +294,6 @@ concretize(L::Union{ AbstractMatrix, Factorization, - # Base - Number, - # SciMLOperators AbstractSciMLOperator, } @@ -418,22 +415,22 @@ expmv!(v,L::AbstractSciMLOperator,u,p,t) = mul!(v,exp(L,t),u) function Base.conj(L::AbstractSciMLOperator) isreal(L) && return L @warn """using convert-based fallback for Base.conj""" - convert(AbstractMatrix, L) |> conj + concretize(L) |> conj end function Base.:(==)(L1::AbstractSciMLOperator, L2::AbstractSciMLOperator) @warn """using convert-based fallback for Base.==""" size(L1) != size(L2) && return false - convert(AbstractMatrix, L1) == convert(AbstractMatrix, L1) + concretize(L1) == concretize(L2) end Base.@propagate_inbounds function Base.getindex(L::AbstractSciMLOperator, I::Vararg{Any,N}) where {N} @warn """using convert-based fallback for Base.getindex""" - convert(AbstractMatrix, L)[I...] + concretize(L)[I...] end function Base.getindex(L::AbstractSciMLOperator, I::Vararg{Int, N}) where {N} @warn """using convert-based fallback for Base.getindex""" - convert(AbstractMatrix, L)[I...] + concretize(L)[I...] end function Base.resize!(L::AbstractSciMLOperator, n::Integer) @@ -444,7 +441,7 @@ LinearAlgebra.exp(L::AbstractSciMLOperator) = exp(Matrix(L)) function LinearAlgebra.opnorm(L::AbstractSciMLOperator, p::Real=2) @warn """using convert-based fallback in LinearAlgebra.opnorm.""" - opnorm(convert(AbstractMatrix, L), p) + opnorm(concretize(L), p) end for pred in ( @@ -454,17 +451,17 @@ for pred in ( ) @eval function LinearAlgebra.$pred(L::AbstractSciMLOperator) @warn """using convert-based fallback in $($pred).""" - $pred(convert(AbstractMatrix, L)) + $pred(concretize(L)) end end function LinearAlgebra.mul!(v::AbstractArray, L::AbstractSciMLOperator, u::AbstractArray) @warn """using convert-based fallback in mul!.""" - mul!(v, convert(AbstractMatrix, L), u) + mul!(v, concretize(L), u) end function LinearAlgebra.mul!(v::AbstractArray, L::AbstractSciMLOperator, u::AbstractArray, α, β) @warn """using convert-based fallback in mul!.""" - mul!(v, convert(AbstractMatrix, L), u, α, β) + mul!(v, concretize(L), u, α, β) end # From a2367db54f872d61fff69381c7041b5e8acb9950 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sun, 2 Jul 2023 15:31:03 -0400 Subject: [PATCH 06/10] sum, prod fallback methods --- src/interface.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/interface.jl b/src/interface.jl index cd4fcedd..c4b57b51 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -444,6 +444,15 @@ function LinearAlgebra.opnorm(L::AbstractSciMLOperator, p::Real=2) opnorm(concretize(L), p) end +for op in ( + :sum, :prod, + ) + @eval function Base.$op(L::AbstractSciMLOperator; kwargs...) + @warn """using convert-based fallback in $($op).""" + $op(convert(AbstractMatrix, L); kwargs...) + end +end + for pred in ( :issymmetric, :ishermitian, From 54ef36071ce29dfdb92fb2b23343a51afac9447f Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 2 Jul 2023 15:57:35 -0400 Subject: [PATCH 07/10] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c35bf1f1..ce5eaeb9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLOperators" uuid = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" authors = ["Vedant Puri "] -version = "0.3.2" +version = "0.3.3" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" From ff96ee39f88d4d48c81d27ebe783fc1013e66698 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sun, 2 Jul 2023 16:21:05 -0400 Subject: [PATCH 08/10] concretize in sum/prod --- src/interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interface.jl b/src/interface.jl index c4b57b51..8437f0f3 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -449,7 +449,7 @@ for op in ( ) @eval function Base.$op(L::AbstractSciMLOperator; kwargs...) @warn """using convert-based fallback in $($op).""" - $op(convert(AbstractMatrix, L); kwargs...) + $op(concretize(L); kwargs...) end end From b01cd284fa879b6625e5e1d289d9fcedfafde8ac Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sun, 2 Jul 2023 17:09:33 -0400 Subject: [PATCH 09/10] bump Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ce5eaeb9..5008d940 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLOperators" uuid = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" authors = ["Vedant Puri "] -version = "0.3.3" +version = "0.3.4" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" From 5adda295a57583aa212705d0ecbad953618fd5c0 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Sun, 2 Jul 2023 14:59:07 -0400 Subject: [PATCH 10/10] switch out convert calls for concretize --- src/interface.jl | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 1259c532..8437f0f3 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -453,15 +453,6 @@ for op in ( end end -for op in ( - :sum, :prod, - ) - @eval function Base.$op(L::AbstractSciMLOperator; kwargs...) - @warn """using convert-based fallback in $($op).""" - $op(convert(AbstractMatrix, L); kwargs...) - end -end - for pred in ( :issymmetric, :ishermitian,