Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding A New AD #1105

Open
willtebbutt opened this issue Sep 5, 2024 · 4 comments
Open

Adding A New AD #1105

willtebbutt opened this issue Sep 5, 2024 · 4 comments
Labels

Comments

@willtebbutt
Copy link

Question❓

I'm interested in getting (the soon-to-be-renamed) Tapir.jl to work nicely with SciMLSensitivity.jl, as I understand this to be the way to get it to play nicely with the SciML ecosystem more broadly (please correct me if I'm wrong on this point!)

I can't see from the docs, or from a quick dive into some of the internals of this package, what the right way to go about this is. It looks like there's a few functions that I need to add methods for, but I'm not at all sure.

Could someone point me in the right direction?

@ChrisRackauckas
Copy link
Member

You'd just define a _vecjacobian! overload: https://github.com/SciML/SciMLSensitivity.jl/blob/v7.66.2/src/derivative_wrappers.jl#L656-L752

along with the caches: https://github.com/SciML/SciMLSensitivity.jl/blob/v7.66.2/src/adjoint_common.jl#L212-L215

where that's used to build whatever caches are required to get a non-allocating vecjacobian.

You would probably want to opt it into callbacks https://github.com/SciML/SciMLSensitivity.jl/blob/v7.66.2/src/callback_tracking.jl because it supports mutation.

Similarly, GaussAdjoint needs a special cache

elseif sensealg.autojacvec isa EnzymeVJP
pf = let f = unwrappedf
if DiffEqBase.isinplace(prob)
function (out, u, _p, t)
f(out, u, _p, t)
nothing
end
else
!DiffEqBase.isinplace(prob)
function (out, u, _p, t)
out .= f(u, _p, t)
nothing
end
end
end
paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf)
pJ = nothing
and overload
elseif sensealg.autojacvec isa EnzymeVJP
tmp3, tmp4, tmp6 = paramjac_config
vtmp4 = vec(tmp4)
vtmp4 .= λ
out .= 0
Enzyme.autodiff(
Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const,
Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t))

and QuadratureAdjoint needs a fairly similar one

elseif sensealg.autojacvec isa EnzymeVJP
pf = let f = unwrappedf
if DiffEqBase.isinplace(prob) && prob isa RODEProblem
function (out, u, _p, t, W)
f(out, u, _p, t, W)
nothing
end
elseif DiffEqBase.isinplace(prob)
function (out, u, _p, t)
f(out, u, _p, t)
nothing
end
elseif !DiffEqBase.isinplace(prob) && prob isa RODEProblem
function (out, u, _p, t, W)
out .= f(u, _p, t, W)
nothing
end
else
!DiffEqBase.isinplace(prob)
function (out, u, _p, t)
out .= f(u, _p, t)
nothing
end
end
end
paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf)
pJ = nothing
and
elseif sensealg.autojacvec isa EnzymeVJP
tmp3, tmp4, tmp6 = paramjac_config
tmp4 .= λ
out .= 0
Enzyme.autodiff(
Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const,
Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t))

@willtebbutt
Copy link
Author

Lovely, thanks for this. Is there a recommended way to test that my overloads are implemented correctly?

@ChrisRackauckas
Copy link
Member

You'll see that this file runs through lots of combinations:

https://github.com/SciML/SciMLSensitivity.jl/blob/master/test/adjoint.jl#L46-L196

If you add a vjp option for Tapir then you should be able to just add an autojacvec = AutoTapir() and add a test that it's matching. Then there's like 5 sets of them.

Because this grew organically over time it's a bit verbose, but we'll make that into a loop some day 😅

@willtebbutt
Copy link
Author

Haha excellent -- thanks for the info!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants