Skip to content

Commit

Permalink
Merge pull request #274 from avik-pal/ap/rework_jacvec
Browse files Browse the repository at this point in the history
Fix type stability and allow JacVec to be non-square jacobians
  • Loading branch information
ChrisRackauckas authored Nov 15, 2023
2 parents 9a713c6 + 84c0dc3 commit 9a42b50
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 94 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseDiffTools"
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <[email protected]>"]
version = "2.11.0"
version = "2.12.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -49,7 +49,7 @@ LinearAlgebra = "1.6"
PackageExtensionCompat = "1"
Random = "1.6"
Reexport = "1"
SciMLOperators = "0.2.11, 0.3"
SciMLOperators = "0.3.7"
Setfield = "1"
SparseArrays = "1.6"
StaticArrayInterface = "1.3"
Expand Down
1 change: 1 addition & 0 deletions src/SparseDiffTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ include("coloring/greedy_star1_coloring.jl")
include("coloring/greedy_star2_coloring.jl")
include("coloring/matrix2graph.jl")

include("differentiation/common.jl")
include("differentiation/compute_jacobian_ad.jl")
include("differentiation/compute_hessian_ad.jl")
include("differentiation/jaches_products.jl")
Expand Down
71 changes: 71 additions & 0 deletions src/differentiation/common.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
mutable struct JacFunctionWrapper{iip, oop, mode, F, FU, P, T} <: Function
f::F
fu::FU
p::P
t::T
end

function SciMLOperators.update_coefficients!(L::JacFunctionWrapper{iip, oop, mode}, _,
p, t) where {iip, oop, mode}
mode == 1 && (L.t = t)
mode == 2 && (L.p = p)
return L
end
function SciMLOperators.update_coefficients(L::JacFunctionWrapper{iip, oop, mode}, _, p,
t) where {iip, oop, mode}
return JacFunctionWrapper{iip, oop, mode, typeof(L.f), typeof(L.fu), typeof(p),
typeof(t)}(L.f, L.fu, p,
t)
end

__internal_iip(::JacFunctionWrapper{iip}) where {iip} = iip
__internal_oop(::JacFunctionWrapper{iip, oop}) where {iip, oop} = oop

(f::JacFunctionWrapper{true, oop, 1})(fu, u) where {oop} = f.f(fu, u, f.p, f.t)
(f::JacFunctionWrapper{true, oop, 2})(fu, u) where {oop} = f.f(fu, u, f.p)
(f::JacFunctionWrapper{true, oop, 3})(fu, u) where {oop} = f.f(fu, u)
(f::JacFunctionWrapper{true, true, 1})(u) = f.f(u, f.p, f.t)
(f::JacFunctionWrapper{true, true, 2})(u) = f.f(u, f.p)
(f::JacFunctionWrapper{true, true, 3})(u) = f.f(u)
(f::JacFunctionWrapper{true, false, 1})(u) = (f.f(f.fu, u, f.p, f.t); copy(f.fu))
(f::JacFunctionWrapper{true, false, 2})(u) = (f.f(f.fu, u, f.p); copy(f.fu))
(f::JacFunctionWrapper{true, false, 3})(u) = (f.f(f.fu, u); copy(f.fu))

(f::JacFunctionWrapper{false, true, 1})(fu, u) = (vec(fu) .= vec(f.f(u, f.p, f.t)))
(f::JacFunctionWrapper{false, true, 2})(fu, u) = (vec(fu) .= vec(f.f(u, f.p)))
(f::JacFunctionWrapper{false, true, 3})(fu, u) = (vec(fu) .= vec(f.f(u)))
(f::JacFunctionWrapper{false, true, 1})(u) = f.f(u, f.p, f.t)
(f::JacFunctionWrapper{false, true, 2})(u) = f.f(u, f.p)
(f::JacFunctionWrapper{false, true, 3})(u) = f.f(u)

function JacFunctionWrapper(f::F, fu_, u, p, t) where {F}
# The warning instead of error ensures a non-breaking change for users relying on an
# undefined / undocumented feature
fu = fu_ === nothing ? copy(u) : copy(fu_)
if t !== nothing
iip = static_hasmethod(f, typeof((fu, u, p, t)))
oop = static_hasmethod(f, typeof((u, p, t)))
if !iip && !oop
@warn """`p` and `t` provided but `f(u, p, t)` or `f(fu, u, p, t)` not defined
for `f`! Will fallback to `f(u)` or `f(fu, u)`.""" maxlog=1
else
return JacFunctionWrapper{iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)}(f,
fu, p, t)
end
elseif p !== nothing
iip = static_hasmethod(f, typeof((fu, u, p)))
oop = static_hasmethod(f, typeof((u, p)))
if !iip && !oop
@warn """`p` provided but `f(u, p)` or `f(fu, u, p)` not defined for `f`! Will
fallback to `f(u)` or `f(fu, u)`.""" maxlog=1
else
return JacFunctionWrapper{iip, oop, 2, F, typeof(fu), typeof(p), typeof(t)}(f,
fu, p, t)
end
end
iip = static_hasmethod(f, typeof((fu, u)))
oop = static_hasmethod(f, typeof((u,)))
!iip && !oop && throw(ArgumentError("`f(u)` or `f(fu, u)` not defined for `f`"))
return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f,
fu, p, t)
end
64 changes: 50 additions & 14 deletions src/differentiation/jaches_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,36 +223,72 @@ function Base.resize!(L::FwdModeAutoDiffVecProd, n::Integer)
end
end

function JacVec(f, u::AbstractArray, p = nothing, t = nothing;
"""
JacVec(f, u, [p, t]; fu = nothing, autodiff = AutoForwardDiff(), tag = DeivVecTag(),
kwargs...)
Returns SciMLOperators.FunctionOperator which computes jacobian-vector product `df/du * v`.
!!! note
For non-square jacobians with inplace `f`, `fu` must be specified, else `JacVec` assumes
a square jacobian.
```julia
L = JacVec(f, u)
L * v # = df/du * v
mul!(w, L, v) # = df/du * v
L(v, p, t) # = df/dw * v
L(x, v, p, t) # = df/dw * v
```
## Allowed Function Signatures for `f`
For Out of Place Functions:
```julia
f(u, p, t) # t !== nothing
f(u, p) # p !== nothing
f(u) # Otherwise
```
For In Place Functions:
```julia
f(du, u, p, t) # t !== nothing
f(du, u, p) # p !== nothing
f(du, u) # Otherwise
```
"""
function JacVec(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing,
autodiff = AutoForwardDiff(), tag = DeivVecTag(), kwargs...)
ff = JacFunctionWrapper(f, fu, u, p, t)
fu === nothing && (fu = __internal_oop(ff) ? ff(u) : u)

cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
cache1 = similar(u)
cache1 = similar(fu)
cache2 = similar(u)

(cache1, cache2), num_jacvec, num_jacvec!
elseif autodiff isa AutoForwardDiff
cache1 = Dual{
typeof(ForwardDiff.Tag(tag, eltype(u))), eltype(u), 1,
}.(u, ForwardDiff.Partials.(tuple.(u)))

cache2 = copy(cache1)
cache2 = Dual{
typeof(ForwardDiff.Tag(tag, eltype(fu))), eltype(fu), 1,
}.(fu, ForwardDiff.Partials.(tuple.(fu)))

(cache1, cache2), auto_jacvec, auto_jacvec!
else
error("Set autodiff to either AutoForwardDiff(), or AutoFiniteDiff()")
end

outofplace = static_hasmethod(f, typeof((u,)))
isinplace = static_hasmethod(f, typeof((u, u)))
op = FwdModeAutoDiffVecProd(ff, u, cache, vecprod, vecprod!)

if !(isinplace) & !(outofplace)
error("$f must have signature f(u), or f(du, u).")
end

L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)

return FunctionOperator(L, u, u; isinplace, outofplace, p, t, islinear = true,
kwargs...)
return FunctionOperator(op, u, fu; isinplace = Val(true), outofplace = Val(true), p, t,
islinear = true, kwargs...)
end

function HesVec(f, u::AbstractArray, p = nothing, t = nothing;
Expand Down
87 changes: 9 additions & 78 deletions src/differentiation/vecjac_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ end
"""
VecJac(f, u, [p, t]; fu = nothing, autodiff = AutoFiniteDiff())
Returns SciMLOperators.FunctionOperator which computes vector-jacobian product `df/du * v`.
Returns SciMLOperators.FunctionOperator which computes vector-jacobian product
`(df/du)ᵀ * v`.
!!! note
Expand All @@ -45,11 +46,11 @@ Returns SciMLOperators.FunctionOperator which computes vector-jacobian product `
```julia
L = VecJac(f, u)
L * v # = df/du * v
mul!(w, L, v) # = df/du * v
L * v # = (df/du)ᵀ * v
mul!(w, L, v) # = (df/du)ᵀ * v
L(v, p, t; VJP_input = w) # = df/dw * v
L(x, v, p, t; VJP_input = w) # = df/dw * v
L(v, p, t; VJP_input = w) # = (df/du)ᵀ * v
L(x, v, p, t; VJP_input = w) # = (df/du)ᵀ * v
```
## Allowed Function Signatures for `f`
Expand All @@ -72,7 +73,7 @@ f(du, u) # Otherwise
"""
function VecJac(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing,
autodiff = AutoFiniteDiff(), kwargs...)
ff = VecJacFunctionWrapper(f, fu, u, p, t)
ff = JacFunctionWrapper(f, fu, u, p, t)

if !__internal_oop(ff) && autodiff isa AutoZygote
msg = "Zygote requires an out of place method with signature f(u)."
Expand All @@ -83,82 +84,12 @@ function VecJac(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing,

op = _vecjac(ff, fu, u, autodiff)

# FIXME: FunctionOperator is terribly type unstable. It makes it `::Any`
# NOTE: We pass `p`, `t` to Function Operator but we always use the cached version from
# VecJacFunctionWrapper
return FunctionOperator(op, fu, u; p, t, isinplace = true, outofplace = true,
# JacFunctionWrapper
return FunctionOperator(op, fu, u; p, t, isinplace = Val(true), outofplace = Val(true),
islinear = true, accepted_kwargs = (:VJP_input,), kwargs...)
end

mutable struct VecJacFunctionWrapper{iip, oop, mode, F, FU, P, T} <: Function
f::F
fu::FU
p::P
t::T
end

function SciMLOperators.update_coefficients!(L::VecJacFunctionWrapper{iip, oop, mode}, _,
p, t) where {iip, oop, mode}
mode == 1 && (L.t = t)
mode == 2 && (L.p = p)
return L
end
function SciMLOperators.update_coefficients(L::VecJacFunctionWrapper{iip, oop, mode}, _, p,
t) where {iip, oop, mode}
return VecJacFunctionWrapper{iip, oop, mode, typeof(L.f), typeof(L.fu), typeof(p),
typeof(t)}(L.f, L.fu, p,
t)
end

__internal_iip(::VecJacFunctionWrapper{iip}) where {iip} = iip
__internal_oop(::VecJacFunctionWrapper{iip, oop}) where {iip, oop} = oop

(f::VecJacFunctionWrapper{true, oop, 1})(fu, u) where {oop} = f.f(fu, u, f.p, f.t)
(f::VecJacFunctionWrapper{true, oop, 2})(fu, u) where {oop} = f.f(fu, u, f.p)
(f::VecJacFunctionWrapper{true, oop, 3})(fu, u) where {oop} = f.f(fu, u)
(f::VecJacFunctionWrapper{true, true, 1})(u) = f.f(u, f.p, f.t)
(f::VecJacFunctionWrapper{true, true, 2})(u) = f.f(u, f.p)
(f::VecJacFunctionWrapper{true, true, 3})(u) = f.f(u)
(f::VecJacFunctionWrapper{true, false, 1})(u) = (f.f(f.fu, u, f.p, f.t); copy(f.fu))
(f::VecJacFunctionWrapper{true, false, 2})(u) = (f.f(f.fu, u, f.p); copy(f.fu))
(f::VecJacFunctionWrapper{true, false, 3})(u) = (f.f(f.fu, u); copy(f.fu))

(f::VecJacFunctionWrapper{false, true, 1})(fu, u) = (vec(fu) .= vec(f.f(u, f.p, f.t)))
(f::VecJacFunctionWrapper{false, true, 2})(fu, u) = (vec(fu) .= vec(f.f(u, f.p)))
(f::VecJacFunctionWrapper{false, true, 3})(fu, u) = (vec(fu) .= vec(f.f(u)))
(f::VecJacFunctionWrapper{false, true, 1})(u) = f.f(u, f.p, f.t)
(f::VecJacFunctionWrapper{false, true, 2})(u) = f.f(u, f.p)
(f::VecJacFunctionWrapper{false, true, 3})(u) = f.f(u)

function VecJacFunctionWrapper(f::F, fu_, u, p, t) where {F}
fu = fu_ === nothing ? copy(u) : copy(fu_)
if t !== nothing
iip = static_hasmethod(f, typeof((fu, u, p, t)))
oop = static_hasmethod(f, typeof((u, p, t)))
if !iip && !oop
throw(ArgumentError("`f(u, p, t)` or `f(fu, u, p, t)` not defined for `f`"))
end
return VecJacFunctionWrapper{iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)}(f,
fu, p, t)
elseif p !== nothing
iip = static_hasmethod(f, typeof((fu, u, p)))
oop = static_hasmethod(f, typeof((u, p)))
if !iip && !oop
throw(ArgumentError("`f(u, p)` or `f(fu, u, p)` not defined for `f`"))
end
return VecJacFunctionWrapper{iip, oop, 2, F, typeof(fu), typeof(p), typeof(t)}(f,
fu, p, t)
else
iip = static_hasmethod(f, typeof((fu, u)))
oop = static_hasmethod(f, typeof((u,)))
if !iip && !oop
throw(ArgumentError("`f(u)` or `f(fu, u)` not defined for `f`"))
end
return VecJacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f,
fu, p, t)
end
end

function _vecjac(f::F, fu, u, autodiff::AutoFiniteDiff) where {F}
cache = (similar(fu), similar(fu))
pullback = nothing
Expand Down

0 comments on commit 9a42b50

Please sign in to comment.