From c9a09f9a3a08340202937aacac566756b42ec468 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 Nov 2023 21:11:46 -0500 Subject: [PATCH 1/3] Fix type stability and allow JacVec to be non-square jacobians --- Project.toml | 4 +- src/SparseDiffTools.jl | 1 + src/differentiation/common.jl | 68 ++++++++++++++++++++ src/differentiation/jaches_products.jl | 55 +++++++++++++--- src/differentiation/vecjac_products.jl | 87 +++----------------------- 5 files changed, 125 insertions(+), 90 deletions(-) create mode 100644 src/differentiation/common.jl diff --git a/Project.toml b/Project.toml index 96ee1e50..34aa591b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseDiffTools" uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804" authors = ["Pankaj Mishra ", "Chris Rackauckas "] -version = "2.11.0" +version = "2.12.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -49,7 +49,7 @@ LinearAlgebra = "1.6" PackageExtensionCompat = "1" Random = "1.6" Reexport = "1" -SciMLOperators = "0.2.11, 0.3" +SciMLOperators = "0.3.7" Setfield = "1" SparseArrays = "1.6" StaticArrayInterface = "1.3" diff --git a/src/SparseDiffTools.jl b/src/SparseDiffTools.jl index 5d0e7c73..0d971507 100644 --- a/src/SparseDiffTools.jl +++ b/src/SparseDiffTools.jl @@ -40,6 +40,7 @@ include("coloring/greedy_star1_coloring.jl") include("coloring/greedy_star2_coloring.jl") include("coloring/matrix2graph.jl") +include("differentiation/common.jl") include("differentiation/compute_jacobian_ad.jl") include("differentiation/compute_hessian_ad.jl") include("differentiation/jaches_products.jl") diff --git a/src/differentiation/common.jl b/src/differentiation/common.jl new file mode 100644 index 00000000..4a38178d --- /dev/null +++ b/src/differentiation/common.jl @@ -0,0 +1,68 @@ +mutable struct JacFunctionWrapper{iip, oop, mode, F, FU, P, T} <: Function + f::F + fu::FU + p::P + t::T +end + +function SciMLOperators.update_coefficients!(L::JacFunctionWrapper{iip, oop, mode}, _, + p, t) where {iip, oop, mode} + mode == 1 && (L.t = t) + mode == 2 && (L.p = p) + return L +end +function SciMLOperators.update_coefficients(L::JacFunctionWrapper{iip, oop, mode}, _, p, + t) where {iip, oop, mode} + return JacFunctionWrapper{iip, oop, mode, typeof(L.f), typeof(L.fu), typeof(p), + typeof(t)}(L.f, L.fu, p, + t) +end + +__internal_iip(::JacFunctionWrapper{iip}) where {iip} = iip +__internal_oop(::JacFunctionWrapper{iip, oop}) where {iip, oop} = oop + +(f::JacFunctionWrapper{true, oop, 1})(fu, u) where {oop} = f.f(fu, u, f.p, f.t) +(f::JacFunctionWrapper{true, oop, 2})(fu, u) where {oop} = f.f(fu, u, f.p) +(f::JacFunctionWrapper{true, oop, 3})(fu, u) where {oop} = f.f(fu, u) +(f::JacFunctionWrapper{true, true, 1})(u) = f.f(u, f.p, f.t) +(f::JacFunctionWrapper{true, true, 2})(u) = f.f(u, f.p) +(f::JacFunctionWrapper{true, true, 3})(u) = f.f(u) +(f::JacFunctionWrapper{true, false, 1})(u) = (f.f(f.fu, u, f.p, f.t); copy(f.fu)) +(f::JacFunctionWrapper{true, false, 2})(u) = (f.f(f.fu, u, f.p); copy(f.fu)) +(f::JacFunctionWrapper{true, false, 3})(u) = (f.f(f.fu, u); copy(f.fu)) + +(f::JacFunctionWrapper{false, true, 1})(fu, u) = (vec(fu) .= vec(f.f(u, f.p, f.t))) +(f::JacFunctionWrapper{false, true, 2})(fu, u) = (vec(fu) .= vec(f.f(u, f.p))) +(f::JacFunctionWrapper{false, true, 3})(fu, u) = (vec(fu) .= vec(f.f(u))) +(f::JacFunctionWrapper{false, true, 1})(u) = f.f(u, f.p, f.t) +(f::JacFunctionWrapper{false, true, 2})(u) = f.f(u, f.p) +(f::JacFunctionWrapper{false, true, 3})(u) = f.f(u) + +function JacFunctionWrapper(f::F, fu_, u, p, t) where {F} + fu = fu_ === nothing ? copy(u) : copy(fu_) + if t !== nothing + iip = static_hasmethod(f, typeof((fu, u, p, t))) + oop = static_hasmethod(f, typeof((u, p, t))) + if !iip && !oop + throw(ArgumentError("`f(u, p, t)` or `f(fu, u, p, t)` not defined for `f`")) + end + return JacFunctionWrapper{iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)}(f, + fu, p, t) + elseif p !== nothing + iip = static_hasmethod(f, typeof((fu, u, p))) + oop = static_hasmethod(f, typeof((u, p))) + if !iip && !oop + throw(ArgumentError("`f(u, p)` or `f(fu, u, p)` not defined for `f`")) + end + return JacFunctionWrapper{iip, oop, 2, F, typeof(fu), typeof(p), typeof(t)}(f, + fu, p, t) + else + iip = static_hasmethod(f, typeof((fu, u))) + oop = static_hasmethod(f, typeof((u,))) + if !iip && !oop + throw(ArgumentError("`f(u)` or `f(fu, u)` not defined for `f`")) + end + return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f, + fu, p, t) + end +end diff --git a/src/differentiation/jaches_products.jl b/src/differentiation/jaches_products.jl index 7c9d8d4b..57e9212c 100644 --- a/src/differentiation/jaches_products.jl +++ b/src/differentiation/jaches_products.jl @@ -223,7 +223,46 @@ function Base.resize!(L::FwdModeAutoDiffVecProd, n::Integer) end end -function JacVec(f, u::AbstractArray, p = nothing, t = nothing; +""" + JacVec(f, u, [p, t]; fu = nothing, autodiff = AutoForwardDiff(), tag = DeivVecTag(), + kwargs...) + +Returns SciMLOperators.FunctionOperator which computes jacobian-vector product `df/du * v`. + +!!! note + + For non-square jacobians with inplace `f`, `fu` must be specified, else `JacVec` assumes + a square jacobian. + +```julia +L = JacVec(f, u) + +L * v # = df/du * v +mul!(w, L, v) # = df/du * v + +L(v, p, t) # = df/dw * v +L(x, v, p, t) # = df/dw * v +``` + +## Allowed Function Signatures for `f` + +For Out of Place Functions: + +```julia +f(u, p, t) # t !== nothing +f(u, p) # p !== nothing +f(u) # Otherwise +``` + +For In Place Functions: + +```julia +f(du, u, p, t) # t !== nothing +f(du, u, p) # p !== nothing +f(du, u) # Otherwise +``` +""" +function JacVec(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing, autodiff = AutoForwardDiff(), tag = DeivVecTag(), kwargs...) cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff cache1 = similar(u) @@ -242,17 +281,13 @@ function JacVec(f, u::AbstractArray, p = nothing, t = nothing; error("Set autodiff to either AutoForwardDiff(), or AutoFiniteDiff()") end - outofplace = static_hasmethod(f, typeof((u,))) - isinplace = static_hasmethod(f, typeof((u, u))) - - if !(isinplace) & !(outofplace) - error("$f must have signature f(u), or f(du, u).") - end + ff = JacFunctionWrapper(f, fu, u, p, t) + fu === nothing && (fu = __internal_oop(ff) ? ff(u) : u) - L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!) + op = FwdModeAutoDiffVecProd(ff, u, cache, vecprod, vecprod!) - return FunctionOperator(L, u, u; isinplace, outofplace, p, t, islinear = true, - kwargs...) + return FunctionOperator(op, u, fu; isinplace = Val(true), outofplace = Val(true), p, t, + islinear = true, kwargs...) end function HesVec(f, u::AbstractArray, p = nothing, t = nothing; diff --git a/src/differentiation/vecjac_products.jl b/src/differentiation/vecjac_products.jl index 835f9e19..4836d958 100644 --- a/src/differentiation/vecjac_products.jl +++ b/src/differentiation/vecjac_products.jl @@ -35,7 +35,8 @@ end """ VecJac(f, u, [p, t]; fu = nothing, autodiff = AutoFiniteDiff()) -Returns SciMLOperators.FunctionOperator which computes vector-jacobian product `df/du * v`. +Returns SciMLOperators.FunctionOperator which computes vector-jacobian product +`(df/du)ᵀ * v`. !!! note @@ -45,11 +46,11 @@ Returns SciMLOperators.FunctionOperator which computes vector-jacobian product ` ```julia L = VecJac(f, u) -L * v # = df/du * v -mul!(w, L, v) # = df/du * v +L * v # = (df/du)ᵀ * v +mul!(w, L, v) # = (df/du)ᵀ * v -L(v, p, t; VJP_input = w) # = df/dw * v -L(x, v, p, t; VJP_input = w) # = df/dw * v +L(v, p, t; VJP_input = w) # = (df/du)ᵀ * v +L(x, v, p, t; VJP_input = w) # = (df/du)ᵀ * v ``` ## Allowed Function Signatures for `f` @@ -72,7 +73,7 @@ f(du, u) # Otherwise """ function VecJac(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing, autodiff = AutoFiniteDiff(), kwargs...) - ff = VecJacFunctionWrapper(f, fu, u, p, t) + ff = JacFunctionWrapper(f, fu, u, p, t) if !__internal_oop(ff) && autodiff isa AutoZygote msg = "Zygote requires an out of place method with signature f(u)." @@ -83,82 +84,12 @@ function VecJac(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing, op = _vecjac(ff, fu, u, autodiff) - # FIXME: FunctionOperator is terribly type unstable. It makes it `::Any` # NOTE: We pass `p`, `t` to Function Operator but we always use the cached version from - # VecJacFunctionWrapper - return FunctionOperator(op, fu, u; p, t, isinplace = true, outofplace = true, + # JacFunctionWrapper + return FunctionOperator(op, fu, u; p, t, isinplace = Val(true), outofplace = Val(true), islinear = true, accepted_kwargs = (:VJP_input,), kwargs...) end -mutable struct VecJacFunctionWrapper{iip, oop, mode, F, FU, P, T} <: Function - f::F - fu::FU - p::P - t::T -end - -function SciMLOperators.update_coefficients!(L::VecJacFunctionWrapper{iip, oop, mode}, _, - p, t) where {iip, oop, mode} - mode == 1 && (L.t = t) - mode == 2 && (L.p = p) - return L -end -function SciMLOperators.update_coefficients(L::VecJacFunctionWrapper{iip, oop, mode}, _, p, - t) where {iip, oop, mode} - return VecJacFunctionWrapper{iip, oop, mode, typeof(L.f), typeof(L.fu), typeof(p), - typeof(t)}(L.f, L.fu, p, - t) -end - -__internal_iip(::VecJacFunctionWrapper{iip}) where {iip} = iip -__internal_oop(::VecJacFunctionWrapper{iip, oop}) where {iip, oop} = oop - -(f::VecJacFunctionWrapper{true, oop, 1})(fu, u) where {oop} = f.f(fu, u, f.p, f.t) -(f::VecJacFunctionWrapper{true, oop, 2})(fu, u) where {oop} = f.f(fu, u, f.p) -(f::VecJacFunctionWrapper{true, oop, 3})(fu, u) where {oop} = f.f(fu, u) -(f::VecJacFunctionWrapper{true, true, 1})(u) = f.f(u, f.p, f.t) -(f::VecJacFunctionWrapper{true, true, 2})(u) = f.f(u, f.p) -(f::VecJacFunctionWrapper{true, true, 3})(u) = f.f(u) -(f::VecJacFunctionWrapper{true, false, 1})(u) = (f.f(f.fu, u, f.p, f.t); copy(f.fu)) -(f::VecJacFunctionWrapper{true, false, 2})(u) = (f.f(f.fu, u, f.p); copy(f.fu)) -(f::VecJacFunctionWrapper{true, false, 3})(u) = (f.f(f.fu, u); copy(f.fu)) - -(f::VecJacFunctionWrapper{false, true, 1})(fu, u) = (vec(fu) .= vec(f.f(u, f.p, f.t))) -(f::VecJacFunctionWrapper{false, true, 2})(fu, u) = (vec(fu) .= vec(f.f(u, f.p))) -(f::VecJacFunctionWrapper{false, true, 3})(fu, u) = (vec(fu) .= vec(f.f(u))) -(f::VecJacFunctionWrapper{false, true, 1})(u) = f.f(u, f.p, f.t) -(f::VecJacFunctionWrapper{false, true, 2})(u) = f.f(u, f.p) -(f::VecJacFunctionWrapper{false, true, 3})(u) = f.f(u) - -function VecJacFunctionWrapper(f::F, fu_, u, p, t) where {F} - fu = fu_ === nothing ? copy(u) : copy(fu_) - if t !== nothing - iip = static_hasmethod(f, typeof((fu, u, p, t))) - oop = static_hasmethod(f, typeof((u, p, t))) - if !iip && !oop - throw(ArgumentError("`f(u, p, t)` or `f(fu, u, p, t)` not defined for `f`")) - end - return VecJacFunctionWrapper{iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)}(f, - fu, p, t) - elseif p !== nothing - iip = static_hasmethod(f, typeof((fu, u, p))) - oop = static_hasmethod(f, typeof((u, p))) - if !iip && !oop - throw(ArgumentError("`f(u, p)` or `f(fu, u, p)` not defined for `f`")) - end - return VecJacFunctionWrapper{iip, oop, 2, F, typeof(fu), typeof(p), typeof(t)}(f, - fu, p, t) - else - iip = static_hasmethod(f, typeof((fu, u))) - oop = static_hasmethod(f, typeof((u,))) - if !iip && !oop - throw(ArgumentError("`f(u)` or `f(fu, u)` not defined for `f`")) - end - return VecJacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f, - fu, p, t) - end -end - function _vecjac(f::F, fu, u, autodiff::AutoFiniteDiff) where {F} cache = (similar(fu), similar(fu)) pullback = nothing From 0a5398c8f0f0cdeea849f2aae1841a3719dd5acc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 13 Nov 2023 11:40:10 -0500 Subject: [PATCH 2/3] Ensure non-breaking change --- src/differentiation/common.jl | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/differentiation/common.jl b/src/differentiation/common.jl index 4a38178d..2117986f 100644 --- a/src/differentiation/common.jl +++ b/src/differentiation/common.jl @@ -39,30 +39,33 @@ __internal_oop(::JacFunctionWrapper{iip, oop}) where {iip, oop} = oop (f::JacFunctionWrapper{false, true, 3})(u) = f.f(u) function JacFunctionWrapper(f::F, fu_, u, p, t) where {F} + # The warning instead of error ensures a non-breaking change for users relying on an + # undefined / undocumented feature fu = fu_ === nothing ? copy(u) : copy(fu_) if t !== nothing iip = static_hasmethod(f, typeof((fu, u, p, t))) oop = static_hasmethod(f, typeof((u, p, t))) if !iip && !oop - throw(ArgumentError("`f(u, p, t)` or `f(fu, u, p, t)` not defined for `f`")) + @warn """`p` and `t` provided but `f(u, p, t)` or `f(fu, u, p, t)` not defined + for `f`! Will fallback to `f(u)` or `f(fu, u)`.""" maxlog=1 + else + return JacFunctionWrapper{iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)}(f, + fu, p, t) end - return JacFunctionWrapper{iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)}(f, - fu, p, t) elseif p !== nothing iip = static_hasmethod(f, typeof((fu, u, p))) oop = static_hasmethod(f, typeof((u, p))) if !iip && !oop - throw(ArgumentError("`f(u, p)` or `f(fu, u, p)` not defined for `f`")) + @warn """`p` provided but `f(u, p)` or `f(fu, u, p)` not defined for `f`! Will + fallback to `f(u)` or `f(fu, u)`.""" maxlog=1 + else + return JacFunctionWrapper{iip, oop, 2, F, typeof(fu), typeof(p), typeof(t)}(f, + fu, p, t) end - return JacFunctionWrapper{iip, oop, 2, F, typeof(fu), typeof(p), typeof(t)}(f, - fu, p, t) - else - iip = static_hasmethod(f, typeof((fu, u))) - oop = static_hasmethod(f, typeof((u,))) - if !iip && !oop - throw(ArgumentError("`f(u)` or `f(fu, u)` not defined for `f`")) - end - return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f, - fu, p, t) end + iip = static_hasmethod(f, typeof((fu, u))) + oop = static_hasmethod(f, typeof((u,))) + !iip && !oop && throw(ArgumentError("`f(u)` or `f(fu, u)` not defined for `f`")) + return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f, + fu, p, t) end From 84c0dc308f7609ce0a780405ea3e0abe51d40712 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 13 Nov 2023 12:54:07 -0500 Subject: [PATCH 3/3] Fix cache sizes --- src/differentiation/jaches_products.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/differentiation/jaches_products.jl b/src/differentiation/jaches_products.jl index 57e9212c..2c4b0d79 100644 --- a/src/differentiation/jaches_products.jl +++ b/src/differentiation/jaches_products.jl @@ -264,8 +264,11 @@ f(du, u) # Otherwise """ function JacVec(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing, autodiff = AutoForwardDiff(), tag = DeivVecTag(), kwargs...) + ff = JacFunctionWrapper(f, fu, u, p, t) + fu === nothing && (fu = __internal_oop(ff) ? ff(u) : u) + cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff - cache1 = similar(u) + cache1 = similar(fu) cache2 = similar(u) (cache1, cache2), num_jacvec, num_jacvec! @@ -273,17 +276,15 @@ function JacVec(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing, cache1 = Dual{ typeof(ForwardDiff.Tag(tag, eltype(u))), eltype(u), 1, }.(u, ForwardDiff.Partials.(tuple.(u))) - - cache2 = copy(cache1) + cache2 = Dual{ + typeof(ForwardDiff.Tag(tag, eltype(fu))), eltype(fu), 1, + }.(fu, ForwardDiff.Partials.(tuple.(fu))) (cache1, cache2), auto_jacvec, auto_jacvec! else error("Set autodiff to either AutoForwardDiff(), or AutoFiniteDiff()") end - ff = JacFunctionWrapper(f, fu, u, p, t) - fu === nothing && (fu = __internal_oop(ff) ? ff(u) : u) - op = FwdModeAutoDiffVecProd(ff, u, cache, vecprod, vecprod!) return FunctionOperator(op, u, fu; isinplace = Val(true), outofplace = Val(true), p, t,