Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make some functions more AD friendly #91

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

DhairyaLGandhi
Copy link

Zygote has a property called literal_indexed_iterate which types with some iteration can implement to allow for cleaner accumulation of gradients when working with AD. However, this adds a dependency on Zygote, which might be costly for a base package.

Package extensions also cannot be used since it would basically overwrite methods causing an amount of piracy. It is also disallowed as of Julia 1.10. This therefore is a simple way to still benefit from AD-able code gen while not having to introduce (any) complexity.

@jgreener64
Copy link
Collaborator

Seems okay to me. Is there a Zygote issue discussing why changing broadcast to map is required here? It might be worth referencing that in a code comment otherwise this could get changed back in future.

Beyond this PR we could think about adding a Zygote test if we want to make sure we don't break AD compat.

@DhairyaLGandhi
Copy link
Author

Zygote hasn't changed here, what is required is overloading Zygote.literal_indexed_iterate. I was trying to avoid the dependency on Zygote. I also feel it would be better if the implementation of the function didn't change during AD or otherwise. That can make it harder to debug gradient issues.

@DhairyaLGandhi
Copy link
Author

I like the idea of adding tests here

@jgreener64
Copy link
Collaborator

Definitely agree about not depending on Zygote and the same implementation with/without AD. I'm just wondering why that overload is required at all, sounds like something that could be tracked/improved in Zygote?

@mfherbst
Copy link
Member

Lgtm modulo adding a test. Would it make sense to make this part of AtomsBaseTesting to test also in downstream codes?

@jgreener64
Copy link
Collaborator

It could be an optional extra or emit a warning in AtomsBaseTesting, but I don't think we should make Zygote compat a required part of the interface.

Tests to check that the systems in AtomsBase are Zygote-compatible would be useful to avoid regression though, I guess taking on Zygote as a test dependency is fine.

@rkurchin
Copy link
Collaborator

I was working on some tests to add to this and am running into a missing adjoint issue...

box = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]u"m"
bcs = [Periodic(), Periodic(), DirichletZero()]
elements = [:C, :C]
atoms = [Atom(elements[i], positions[i]) for i in 1:2]

# distance between first two particles
function dist(sys::AbstractSystem)
    sepvec = diff(position(sys))[1]
    sqrt(dot(sepvec, sepvec))
end

gradient(0) do x
    positions = [[0, 0, 0], [x, 0.5, 0.5]]u"m"
    atoms = [Atom(elements[i], positions[i]) for i in 1:2]
    flexible = FlexibleSystem(atoms, box, bcs)
    dist(flexible)
end

And I get a super long stacktrace that starts with:

ERROR: Need an adjoint for constructor StaticArrays.SVector{3, Quantity{Float64, 𝐋, Unitful.FreeUnits{(m,), 𝐋, nothing}}}. Gradient is of type StaticArrays.SVector{3, Float64}

I found a whole chain of discussions across several PR's on various packages (1 -> 2 -> 3 -> 4 -> 5 -> 6), and I don't follow enough of the nitty-gritty details to know for sure if that last one will fix this or not when merged (it also seems like it might depend on the Julia version? I was doing this on 1.9.2), but hopefully @DhairyaLGandhi can lend some insight?

@jgreener64
Copy link
Collaborator

I get a different issue, related to mutation with position.(sys). I am on Julia 1.10.0 and the latest StaticArrays, ChainRules etc.

using AtomsBase, Zygote, Unitful, LinearAlgebra

box = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]u"m"
bcs = [Periodic(), Periodic(), DirichletZero()]
elements = [:C, :C]

function dist(sys::AbstractSystem)
    sepvec = diff(position(sys))[1]
    sqrt(dot(sepvec, sepvec))
end

gradient(0) do x
    positions = [[0, 0, 0], [x, 0.5, 0.5]]u"m"
    atoms = [Atom(elements[i], positions[i]) for i in 1:2]
    flexible = FlexibleSystem(atoms, box, bcs)
    dist(flexible)
end
ERROR: Mutating arrays is not supported -- called copyto!(Vector{Atom{3, Quantity{Float64, 𝐋, Unitful.FreeUnits{(m,), 𝐋, nothing}}, Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Vector{Atom{3, Quantity{…}, Quantity{…}, Quantity{…}}})
    @ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:70
  [3] (::Zygote.var"#543#544"{Vector{…}})(::Vector{@NamedTuple{…}})
    @ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:85
  [4] (::Zygote.var"#2633#back#545"{Zygote.var"#543#544"{…}})(Δ::Vector{@NamedTuple{…}})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
  [5] _collect(c::Any, itr::Any, ::Base.EltypeUnknown, isz::Union{Base.HasLength, Base.HasShape})
    @ Base ./array.jl:765 [inlined]
  [6] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{@NamedTuple{…}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
  [7] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{Vector{…}, Vector{…}}}, Base.var"#4#5"{Zygote.var"#1366#1372"}})
    @ Base ./array.jl:759 [inlined]
  [8] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{@NamedTuple{…}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
  [9] broadcastable
    @ ./broadcast.jl:743 [inlined]
 [10] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{@NamedTuple{…}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [11] broadcasted
    @ ./broadcast.jl:1339 [inlined]
 [12] position
    @ ~/.julia/dev/AtomsBase/src/interface.jl:139 [inlined]
 [13] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{StaticArraysCore.SVector{…}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [14] dist
    @ ./REPL[5]:2 [inlined]
 [15] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [16] #1
    @ ./REPL[6]:5 [inlined]
 [17] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [18] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:45
 [19] gradient(::Function, ::Int64, ::Vararg{Int64})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:97
 [20] top-level scope
    @ REPL[6]:1
Some type information was truncated. Use `show(err)` to see complete types.

@DhairyaLGandhi
Copy link
Author

So far, I've looked at build_graph from Chemellia/AtomGraphs.jl#11 as the test case for a code path to AD. I don't think unitful is fully supported to AD

@jgreener64
Copy link
Collaborator

I have always found it hard to get units to play well with AD, and don't use them when taking gradients in my own code.

There is https://github.com/SBuercklin/UnitfulChainRules.jl which may be useful.

@cortner
Copy link
Member

cortner commented Sep 23, 2024

Update on this PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants