Skip to content

Commit

Permalink
Rework VecJac Operator
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 1, 2023
1 parent 7d23bec commit b764630
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 107 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseDiffTools"
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <[email protected]>"]
version = "2.9.2"
version = "2.10.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
50 changes: 20 additions & 30 deletions ext/SparseDiffToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,21 @@ import SparseDiffTools: __f̂, __jacobian!, __gradient, __gradient!
import ADTypes: AutoZygote, AutoSparseZygote

## Satisfying High-Level Interface for Sparse Jacobians
function __gradient(::Union{AutoSparseZygote, AutoZygote}, f, x, cols)
function __gradient(::Union{AutoSparseZygote, AutoZygote}, f::F, x, cols) where {F}
_, ∂x, _ = Zygote.gradient(__f̂, f, x, cols)
return vec(∂x)
end

function __gradient!(::Union{AutoSparseZygote, AutoZygote}, f!, fx, x, cols)
function __gradient!(::Union{AutoSparseZygote, AutoZygote}, f!::F, fx, x, cols) where {F}
return error("Zygote.jl cannot differentiate in-place (mutating) functions.")
end

# Zygote doesn't provide a way to accumulate directly into `J`. So we modify the code from
# https://github.com/FluxML/Zygote.jl/blob/82c7a000bae7fb0999275e62cc53ddb61aed94c7/src/lib/grad.jl#L140-L157C4
import Zygote: _jvec, _eyelike, _gradcopy!

@views function __jacobian!(J::AbstractMatrix, ::Union{AutoSparseZygote, AutoZygote}, f, x)
@views function __jacobian!(J::AbstractMatrix, ::Union{AutoSparseZygote, AutoZygote}, f::F,
x) where {F}
y, back = Zygote.pullback(_jvec f, x)
δ = _eyelike(y)
for k in LinearIndices(y)
Expand All @@ -36,13 +37,13 @@ import Zygote: _jvec, _eyelike, _gradcopy!
return J
end

function __jacobian!(J, ::Union{AutoSparseZygote, AutoZygote}, f!, fx, x)
function __jacobian!(_, ::Union{AutoSparseZygote, AutoZygote}, f!::F, fx, x) where {F}
return error("Zygote.jl cannot differentiate in-place (mutating) functions.")
end

### Jac, Hes products

function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v))
function numback_hesvec!(dy, f::F, x, v, cache1 = similar(v), cache2 = similar(v)) where {F}
g = let f = f
(dx, x) -> dx .= first(Zygote.gradient(f, x))
end
Expand All @@ -57,15 +58,14 @@ function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v))
@. dy = (cache1 - cache2) / (2ϵ)
end

function numback_hesvec(f, x, v)
g = x -> first(Zygote.gradient(f, x))
function numback_hesvec(f::F, x, v) where {F}
T = eltype(x)
# Should it be min? max? mean?
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
x += ϵ * v
gxp = g(x)
gxp = first(Zygote.gradient(f, x))
x -= 2ϵ * v
gxm = g(x)
gxm = first(Zygote.gradient(f, x))
(gxp - gxm) / (2ϵ)
end

Expand Down Expand Up @@ -94,61 +94,51 @@ end
## VecJac products

# VJP methods
function auto_vecjac!(du, f, x, v)
function auto_vecjac!(du, f::F, x, v) where {F}
!static_hasmethod(f, typeof((x,))) &&
error("For inplace function use autodiff = AutoFiniteDiff()")
du .= reshape(SparseDiffTools.auto_vecjac(f, x, v), size(du))
end

function auto_vecjac(f, x, v)
function auto_vecjac(f::F, x, v) where {F}
y, back = Zygote.pullback(f, x)
return vec(back(reshape(v, size(y)))[1])
return vec(only(back(reshape(v, size(y)))))
end

# overload operator interface
function SparseDiffTools._vecjac(f, u, autodiff::AutoZygote)
cache = ()
function SparseDiffTools._vecjac(f::F, _, u, autodiff::AutoZygote) where {F}
!static_hasmethod(f, typeof((u,))) &&
error("For inplace function use autodiff = AutoFiniteDiff()")
pullback = Zygote.pullback(f, u)

return AutoDiffVJP(f, u, cache, autodiff, pullback)
return AutoDiffVJP(f, u, (), autodiff, pullback)
end

function update_coefficients(L::AutoDiffVJP{<:AutoZygote}, u, p, t; VJP_input = nothing)
VJP_input !== nothing && (@set! L.u = VJP_input)

@set! L.f = update_coefficients(L.f, L.u, p, t)
@set! L.pullback = Zygote.pullback(L.f, L.u)
return L
end

function update_coefficients!(L::AutoDiffVJP{<:AutoZygote}, u, p, t; VJP_input = nothing)
VJP_input !== nothing && copy!(L.u, VJP_input)

update_coefficients!(L.f, L.u, p, t)
L.pullback = Zygote.pullback(L.f, L.u)

return L
end

# Interpret the call as df/du' * v
function (L::AutoDiffVJP{<:AutoZygote})(v, p, t; VJP_input = nothing)
# ignore VJP_input as pullback was computed in update_coefficients(...)
y, back = L.pullback
V = reshape(v, size(y))

return vec(first(back(V)))
return vec(only(back(reshape(v, size(y)))))
end

# prefer non in-place method
function (L::AutoDiffVJP{<:AutoZygote, IIP, true})(dv, v, p, t;
VJP_input = nothing) where {IIP}
function (L::AutoDiffVJP{<:AutoZygote})(dv, v, p, t; VJP_input = nothing)
# ignore VJP_input as pullback was computed in update_coefficients!(...)

_dv = L(v, p, t; VJP_input = VJP_input)
_dv = L(v, p, t; VJP_input)
copy!(dv, _dv)
end

function (L::AutoDiffVJP{<:AutoZygote, true, false})(args...; kwargs...)
error("Zygote requires an out of place method with signature f(u).")
end

end # module
Loading

0 comments on commit b764630

Please sign in to comment.