-
-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve mul!
, AddedOperator
, and update_coefficients!
to remove memory allocations
#249
base: master
Are you sure you want to change the base?
Changes from all commits
206c4c7
92e0d00
1383b7b
1aab2cc
f02cac6
dbd3590
a1283f4
9190d18
3fb8a9e
934ae5d
8280c24
4cafaa0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,11 @@ | ||
module SciMLOperatorsStaticArraysCoreExt | ||
import SciMLOperators | ||
import StaticArraysCore | ||
function Base.copyto!(L::SciMLOperators.MatrixOperator, | ||
rhs::Base.Broadcast.Broadcasted{<:StaticArraysCore.StaticArrayStyle}) | ||
(copyto!(L.A, rhs); L) | ||
end | ||
end #module | ||
module SciMLOperatorsStaticArraysCoreExt | ||
|
||
import SciMLOperators | ||
import StaticArraysCore | ||
|
||
function Base.copyto!(L::SciMLOperators.MatrixOperator, | ||
rhs::Base.Broadcast.Broadcasted{<:StaticArraysCore.StaticArrayStyle}) | ||
(copyto!(L.A, rhs); L) | ||
end | ||
|
||
end #module |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -201,7 +201,7 @@ end | |
|
||
for T in SCALINGNUMBERTYPES | ||
@eval function ScaledOperator(λ::$T, L::ScaledOperator) | ||
λ = ScalarOperator(λ) * L.λ | ||
λ = λ * L.λ | ||
ScaledOperator(λ, L.L) | ||
end | ||
|
||
|
@@ -250,7 +250,7 @@ function update_coefficients!(L::ScaledOperator, u, p, t) | |
update_coefficients!(L.L, u, p, t) | ||
update_coefficients!(L.λ, u, p, t) | ||
|
||
L | ||
nothing | ||
end | ||
|
||
getops(L::ScaledOperator) = (L.λ, L.L) | ||
|
@@ -288,13 +288,14 @@ end | |
Base.:*(L::ScaledOperator, u::AbstractVecOrMat) = L.λ * (L.L * u) | ||
Base.:\(L::ScaledOperator, u::AbstractVecOrMat) = L.λ \ (L.L \ u) | ||
|
||
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::ScaledOperator, u::AbstractVecOrMat) | ||
@inline function LinearAlgebra.mul!( | ||
v::AbstractVecOrMat, L::ScaledOperator, u::AbstractVecOrMat) | ||
iszero(L.λ) && return lmul!(false, v) | ||
a = convert(Number, L.λ) | ||
mul!(v, L.L, u, a, false) | ||
end | ||
|
||
function LinearAlgebra.mul!(v::AbstractVecOrMat, | ||
@inline function LinearAlgebra.mul!(v::AbstractVecOrMat, | ||
L::ScaledOperator, | ||
u::AbstractVecOrMat, | ||
α, | ||
|
@@ -326,22 +327,34 @@ struct AddedOperator{T, | |
|
||
function AddedOperator(ops) | ||
@assert !isempty(ops) | ||
_check_AddedOperator_sizes(ops) | ||
T = promote_type(eltype.(ops)...) | ||
new{T, typeof(ops)}(ops) | ||
end | ||
end | ||
|
||
function AddedOperator(ops::AbstractSciMLOperator...) | ||
sz = size(first(ops)) | ||
for op in ops[2:end] | ||
@assert size(op)==sz "Dimension mismatch: cannot add operators of | ||
sizes $(sz), and $(size(op))." | ||
end | ||
AddedOperator(ops) | ||
end | ||
|
||
AddedOperator(L::AbstractSciMLOperator) = L | ||
|
||
@generated function _check_AddedOperator_sizes(ops::Tuple) | ||
ops_types = ops.parameters | ||
N = length(ops_types) | ||
sz_expr_list = () | ||
sz_expr = :(sz = size(first(ops))) | ||
for i in 2:N | ||
sz_expr_list = (sz_expr_list..., :(size(ops[$i]) == sz)) | ||
end | ||
|
||
quote | ||
$sz_expr | ||
@assert all(tuple($(sz_expr_list...))) "Dimension mismatch: cannot add operators of different sizes." | ||
nothing | ||
end | ||
end | ||
|
||
# constructors | ||
Base.:+(A::AbstractSciMLOperator, B::AbstractMatrix) = A + MatrixOperator(B) | ||
Base.:+(A::AbstractMatrix, B::AbstractSciMLOperator) = MatrixOperator(A) + B | ||
|
@@ -371,13 +384,15 @@ for op in (:+, :-) | |
for LT in SCALINGCOMBINETYPES | ||
@eval function Base.$op(L::$LT, λ::$T) | ||
@assert issquare(L) | ||
iszero(λ) && return L | ||
N = size(L, 1) | ||
Id = IdentityOperator(N) | ||
AddedOperator(L, $op(λ) * Id) | ||
end | ||
|
||
@eval function Base.$op(λ::$T, L::$LT) | ||
@assert issquare(L) | ||
iszero(λ) && return $op(L) | ||
N = size(L, 1) | ||
Id = IdentityOperator(N) | ||
AddedOperator(λ * Id, $op(L)) | ||
|
@@ -386,6 +401,23 @@ for op in (:+, :-) | |
end | ||
end | ||
|
||
for T in SCALINGNUMBERTYPES[2:end] | ||
@eval function Base.:*(λ::$T, L::AddedOperator) | ||
ops = map(op -> λ * op, L.ops) | ||
AddedOperator(ops) | ||
end | ||
|
||
@eval function Base.:*(L::AddedOperator, λ::$T) | ||
ops = map(op -> λ * op, L.ops) | ||
AddedOperator(ops) | ||
end | ||
|
||
@eval function Base.:/(L::AddedOperator, λ::$T) | ||
ops = map(op -> op / λ, L.ops) | ||
AddedOperator(ops) | ||
end | ||
end | ||
|
||
function Base.convert(::Type{AbstractMatrix}, L::AddedOperator) | ||
sum(op -> convert(AbstractMatrix, op), L.ops) | ||
end | ||
|
@@ -422,16 +454,32 @@ function update_coefficients(L::AddedOperator, u, p, t) | |
@reset L.ops = ops | ||
end | ||
|
||
@generated function update_coefficients!(L::AddedOperator, u, p, t) | ||
ops_types = L.parameters[2].parameters | ||
N = length(ops_types) | ||
quote | ||
Base.@nexprs $N i->begin | ||
update_coefficients!(L.ops[i], u, p, t) | ||
end | ||
|
||
nothing | ||
end | ||
end | ||
|
||
getops(L::AddedOperator) = L.ops | ||
islinear(L::AddedOperator) = all(islinear, getops(L)) | ||
Base.iszero(L::AddedOperator) = all(iszero, getops(L)) | ||
has_adjoint(L::AddedOperator) = all(has_adjoint, L.ops) | ||
|
||
function cache_internals(L::AddedOperator, u::AbstractVecOrMat) | ||
for i in 1:length(L.ops) | ||
@reset L.ops[i] = cache_operator(L.ops[i], u) | ||
@generated function cache_internals(L::AddedOperator, u::AbstractVecOrMat) | ||
ops_types = L.parameters[2].parameters | ||
N = length(ops_types) | ||
quote | ||
Base.@nexprs $N i->begin | ||
@reset L.ops[i] = cache_operator(L.ops[i], u) | ||
end | ||
L | ||
end | ||
L | ||
end | ||
|
||
getindex(L::AddedOperator, i::Int) = sum(op -> op[i], L.ops) | ||
|
@@ -441,26 +489,33 @@ function Base.:*(L::AddedOperator, u::AbstractVecOrMat) | |
sum(op -> iszero(op) ? zero(u) : op * u, L.ops) | ||
end | ||
|
||
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::AddedOperator, u::AbstractVecOrMat) | ||
mul!(v, first(L.ops), u) | ||
for op in L.ops[2:end] | ||
iszero(op) && continue | ||
mul!(v, op, u, true, true) | ||
@generated function LinearAlgebra.mul!( | ||
v::AbstractVecOrMat, L::AddedOperator, u::AbstractVecOrMat) | ||
ops_types = L.parameters[2].parameters | ||
N = length(ops_types) | ||
quote | ||
mul!(v, L.ops[1], u) | ||
Base.@nexprs $(N - 1) i->begin | ||
mul!(v, L.ops[i + 1], u, true, true) | ||
end | ||
v | ||
end | ||
v | ||
end | ||
|
||
function LinearAlgebra.mul!(v::AbstractVecOrMat, | ||
@generated function LinearAlgebra.mul!(v::AbstractVecOrMat, | ||
L::AddedOperator, | ||
u::AbstractVecOrMat, | ||
α, | ||
β) | ||
lmul!(β, v) | ||
for op in L.ops | ||
iszero(op) && continue | ||
mul!(v, op, u, α, true) | ||
ops_types = L.parameters[2].parameters | ||
N = length(ops_types) | ||
Comment on lines
-459
to
+511
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this actually required? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A recursive implmeentation is probably cleaner and would get more compilation reuse? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How exactly? |
||
quote | ||
lmul!(β, v) | ||
Base.@nexprs $(N) i->begin | ||
mul!(v, L.ops[i], u, α, true) | ||
end | ||
v | ||
end | ||
v | ||
end | ||
|
||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is related to this issue in SparseArrays.jl. We currently need it to avoid extra allocations.