-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make some functions more AD friendly #91
base: master
Are you sure you want to change the base?
Conversation
Seems okay to me. Is there a Zygote issue discussing why changing broadcast to map is required here? It might be worth referencing that in a code comment otherwise this could get changed back in future. Beyond this PR we could think about adding a Zygote test if we want to make sure we don't break AD compat. |
Zygote hasn't changed here, what is required is overloading |
I like the idea of adding tests here |
Definitely agree about not depending on Zygote and the same implementation with/without AD. I'm just wondering why that overload is required at all, sounds like something that could be tracked/improved in Zygote? |
Lgtm modulo adding a test. Would it make sense to make this part of AtomsBaseTesting to test also in downstream codes? |
It could be an optional extra or emit a warning in AtomsBaseTesting, but I don't think we should make Zygote compat a required part of the interface. Tests to check that the systems in AtomsBase are Zygote-compatible would be useful to avoid regression though, I guess taking on Zygote as a test dependency is fine. |
I was working on some tests to add to this and am running into a missing adjoint issue... box = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]u"m"
bcs = [Periodic(), Periodic(), DirichletZero()]
elements = [:C, :C]
atoms = [Atom(elements[i], positions[i]) for i in 1:2]
# distance between first two particles
function dist(sys::AbstractSystem)
sepvec = diff(position(sys))[1]
sqrt(dot(sepvec, sepvec))
end
gradient(0) do x
positions = [[0, 0, 0], [x, 0.5, 0.5]]u"m"
atoms = [Atom(elements[i], positions[i]) for i in 1:2]
flexible = FlexibleSystem(atoms, box, bcs)
dist(flexible)
end And I get a super long stacktrace that starts with: ERROR: Need an adjoint for constructor StaticArrays.SVector{3, Quantity{Float64, 𝐋, Unitful.FreeUnits{(m,), 𝐋, nothing}}}. Gradient is of type StaticArrays.SVector{3, Float64} I found a whole chain of discussions across several PR's on various packages (1 -> 2 -> 3 -> 4 -> 5 -> 6), and I don't follow enough of the nitty-gritty details to know for sure if that last one will fix this or not when merged (it also seems like it might depend on the Julia version? I was doing this on 1.9.2), but hopefully @DhairyaLGandhi can lend some insight? |
I get a different issue, related to mutation with using AtomsBase, Zygote, Unitful, LinearAlgebra
box = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]u"m"
bcs = [Periodic(), Periodic(), DirichletZero()]
elements = [:C, :C]
function dist(sys::AbstractSystem)
sepvec = diff(position(sys))[1]
sqrt(dot(sepvec, sepvec))
end
gradient(0) do x
positions = [[0, 0, 0], [x, 0.5, 0.5]]u"m"
atoms = [Atom(elements[i], positions[i]) for i in 1:2]
flexible = FlexibleSystem(atoms, box, bcs)
dist(flexible)
end
|
So far, I've looked at |
I have always found it hard to get units to play well with AD, and don't use them when taking gradients in my own code. There is https://github.com/SBuercklin/UnitfulChainRules.jl which may be useful. |
Update on this PR? |
Zygote has a property called
literal_indexed_iterate
which types with some iteration can implement to allow for cleaner accumulation of gradients when working with AD. However, this adds a dependency on Zygote, which might be costly for a base package.Package extensions also cannot be used since it would basically overwrite methods causing an amount of piracy. It is also disallowed as of Julia 1.10. This therefore is a simple way to still benefit from AD-able code gen while not having to introduce (any) complexity.