Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: adjoints through observable functions #689

Merged
merged 26 commits into from
May 25, 2024

Conversation

DhairyaLGandhi
Copy link
Member

@DhairyaLGandhi DhairyaLGandhi commented May 6, 2024

Checklist

  • Appropriate tests were added
  • Any code changes were done in a way that does not break public API
  • All documentation related to code changes were updated
  • The new code follows the
    contributor guidelines, in particular the SciML Style Guide and
    COLPRAC.
  • Any new documentation only uses public API

Additional context

Currently, ADing through observables errors, however this allows us to AD through the observable function via symbolic indexing and accumulate and return grads against sol

julia> gs3 = gradient(sol) do sol
    sum(sol[sys.w])
end
((u = [[0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0]    [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0]], u_analytic = nothing, errors = nothing, t = nothing, k = nothing, prob = (f = nothing, u0 = nothing, tspan = nothing, p = ([0.0, 2990.0, 0.0],), kwargs = nothing, problem_type = nothing), alg = nothing, interp = nothing, dense = nothing, tslocation = nothing, stats = nothing, alg_choice = nothing, retcode = nothing, resid = nothing, original = nothing),)

This needs handling as part of when the observable symbol is in a collection (vector/ tuple/ ...), and also for various ADs like ReverseDiff and Enzyme.

Add any other context about the problem here.

Ideally, this would be handled by removing all the adjoints related to getindex and let AD do the heavy lifting for us. But this is faster to implement in its current form.

Copy link

codecov bot commented May 6, 2024

Codecov Report

Attention: Patch coverage is 0% with 33 lines in your changes are missing coverage. Please review.

Project coverage is 29.16%. Comparing base (a0fab7a) to head (f817b52).
Report is 28 commits behind head on master.

Files Patch % Lines
ext/SciMLBaseZygoteExt.jl 0.00% 33 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #689      +/-   ##
==========================================
- Coverage   31.79%   29.16%   -2.64%     
==========================================
  Files          55       55              
  Lines        4535     4574      +39     
==========================================
- Hits         1442     1334     -108     
- Misses       3093     3240     +147     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines 185 to 202
@adjoint function literal_getproperty(sol::AbstractTimeseriesSolution,
::Val{:u})
function solu_adjoint(Δ)
zerou = zero(sol.prob.u0)
_Δ = @. ifelse(Δ === nothing, (zerou,), Δ)
(build_solution(sol.prob, sol.alg, sol.t, _Δ),)
end
sol.u, solu_adjoint
end
# @adjoint function literal_getproperty(sol::AbstractTimeseriesSolution,
# ::Val{:u})
# function solu_adjoint(Δ)
# zerou = zero(sol.prob.u0)
# _Δ = @. ifelse(Δ === nothing, (zerou,), Δ)
# (build_solution(sol.prob, sol.alg, sol.t, _Δ),)
# end
# sol.u, solu_adjoint
# end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was returning the ODESolution as the adjoint. It is also an issue because it shortcuts the gradients through parameters and instead replaces it with the sol.prob, whereas we need to accumulate the gradients here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a unit test in the downstream set which shows this is fine?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to. In fact, that's why I asked if anything was relying on this behavior previously. Could you suggest what kind of test you have in mind?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be the root cause of many of the test failures? So that means it's caught by the tests already.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is what the error is referring to. I am missing a branch https://github.com/DhairyaLGandhi/RecursiveArrayTools.jl/tree/dg/noproj which removes an extra projection rule.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does refer to projecting to a VectorOfArray, and that rule wasn't defined for Tangent. Removing it gets us the expected results. If we want to project back to a VectorOfArray type, then that needs to be handled elsewhere.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now that we have restored the adjoint, I believe this can be resolved

@DhairyaLGandhi
Copy link
Member Author

Needs JuliaDiff/ChainRules.jl#793

@ChrisRackauckas
Copy link
Member

Add your unit tests as a new downstream testset.

@DhairyaLGandhi
Copy link
Member Author

Note that with SciMLSensitivity.jl#dg/ss (and SciML/SciMLStructures.jl#18) https://github.com/SciML/SciMLSensitivity.jl/blob/32f5ae7529a1957661b153f0ca9eff7e4caf0c5a/test/reversediff_output_types.jl#L14 looks like:

julia> gs = gradient(u0 -> loss(u0), u0)
([-0.7779831009550049, 0.40028226620020263],)

@DhairyaLGandhi
Copy link
Member Author

I've added a DAE example in the tests, but switched it off until we get SciMLSensitivity updated as well. The DC motor example fails to initialize currently. If there's a different test case, I can also hook that in.

Project.toml Outdated
@@ -68,6 +68,7 @@ Logging = "1.10"
Makie = "0.20"
Markdown = "1.10"
ModelingToolkit = "8.75, 9"
ModelingToolkitStandardLibrary = "2.7"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be gated int Downstream

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added bounds to test/downstream/Project.toml in 940ea78, should I remove anything from the regular test environment or do i need to declare these in both places?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove it from the regular

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

d061ce4 does that

@DhairyaLGandhi
Copy link
Member Author

@ChrisRackauckas SciMLSensitivity test pass with d061ce4 (latest commit), but the Core (Downstream) tests get cancelled before anything runs. Is that because the Core (Python) tests fail for unrelated reasons?

@gdalle
Copy link
Contributor

gdalle commented May 22, 2024

So what happens here is:

  • The latest ADTypes (v1.2) is installed in the test environment
  • But for the Python test group, the environment resolution forces the previous ADTypes (v0.2)
  • I think the following message refers to ADTypes
1 dependency precompiled but a different version is currently loaded. Restart julia to access the new version
  • So the latest ADTypes is provided to the Python test group even though compatibility forbids it at the moment, because the previously active environment leaks

@DhairyaLGandhi
Copy link
Member Author

Both CI/ Python and CI/ Downgrade seem to be failing on master as well.

@gdalle
Copy link
Contributor

gdalle commented May 23, 2024

The problem I mentioned has not been fixed. It's not a problem with ADTypes per se, it's a problem with environment stacking

@DhairyaLGandhi
Copy link
Member Author

Is there anything left to be done in this PR?

@ChrisRackauckas ChrisRackauckas merged commit 3811745 into SciML:master May 25, 2024
29 of 42 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants