-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
StaticArrays adjoint constructor #570
Comments
To clarify, Zygote does not support mutation operation, but it does support mutable data structures ( |
The gradient of julia> similar(SA[1,2]) isa MArray
true
julia> gradient(x -> x[1], SA[1,2])[1]
2-element MArray{Tuple{2},Int64,1,2} with indices SOneTo(2):
1
0 In your example it's trying and failing to get back from that to the argument(s), the manual has a section on making this work for new structs... but StaticArrays has lots of more complicated constructors than the examples there. However the julia> gradient((x,y) -> sum(SA[x,y,x,y][1:3]), 1,2)
(2, 1) |
Alright, I figured as much :) It seems I've found some ways to go around the need for mutation (for now).
I get that this is the wrong repo to ask this, but why does
Okay I see; StaticArrays is converting internally in between
I've read this multiple times, but as I (tried to) explain in the OP I'm having problems with which exact adjoint Zygote needs.
Yes, this is unfortunate :(
Unfortunately, just on the surface. I'm getting errors here (though I suspect this is something different): julia> gradient(n->SA[n;n][1,1], 1)
ERROR: BoundsError: attempt to access (2,)
at index [2] Looking at the original error with fresh eyes, and with @mcabbott's comment in mind, it seems to say that we need the adjoint for the constructor @Zygote.adjoint StaticArrays.SArray{TS,Ty,C,D}(ma::MArray{TS,Ty,C,D}) where {TS, Ty, C, D} = SArray(ma), y->(y,) with all kinds of variations in the type arguments, but nothing I do seems to have any effect. Am I completely off the rails here? |
Because the adjoint of |
On the forward pass, it's indeed getting a scalar out of an array, as usual. It's clearer if we imagine this array being bigger than one element. But on the reverse pass, it's now asking "by how much do elements of that array influence the final answer?". And one implementation of that is to start with an array of zeros, and then precisely where
No, it needs to convert an array (any array) back into the individual arguments taken by the constructor.
These examples can be made to work by defining these gradients, although I may well have overlooked some other subtlety:
Here |
This is starting to become understandable! Thanks a lot! This adjoint seems to work well:
There's still one thing that's confusing regarding the I think it makes sense for these adjoints to be in either library; how are things like this usually resolved with Zygote? Do we want a bunch of adjoints from all kinds of projects here, the other way around, or is this seen as user responsibility to write themselves when they happen to use both libraries? |
I'm not really sure, Arrays may be special and weird, but FillArrays (next line) are more surprising... and perhaps one needs It would be nice to have all of these working. You may talk Zygote into depending on StaticArrays (to see these types), alternatively the gradients can be defined with only ZygoteRules (a small package for this purpose) which StaticArrays could depend on. |
then How to add adjoint for LinearAlgebra.diagm where Base.Pair is used? |
Those seem be easier, e.g. Zygote.@adjoint diagm(x::AbstractVector) = diagm(x), dy -> (diag(dy),)
Zygote.@adjoint diagm(pr::Pair) = diagm(pr), dy -> ((first=nothing, second=diag(dy, first(pr)),)
gradient(x -> sum(sin, diagm(x)), rand(4))
gradient(x -> sum(sin, diagm(1 => x)), rand(4)) But perhaps open a new issue (or better yet, a draft PR) for this. |
762: basic sparse handling r=DhairyaLGandhi a=DhairyaLGandhi Relates to #570, SciML/DifferentialEquations.jl#649, #742 Co-authored-by: Dhairya Gandhi <[email protected]> Co-authored-by: Dhairya Gandhi <[email protected]>
See also JuliaArrays/StaticArrays.jl#1068 for a PR to StaticArrays that fixes this. |
I'm trying to use Zygote.jl together with StaticArrays.jl, but am getting an error message I don't understand. Here's a very minimal example
I have tried to add adjoints, but I don't understand for which type the constructor should be for, take in, and whether the type of the delta matters. None of the adjoints I've tried to define helps, so there's obviously things I don't understand here. It is also very confusing that the error message mentions
MArray
when I've gone out of my way of avoiding mutable arrays, since Zygote doesn't support that.The text was updated successfully, but these errors were encountered: