Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StaticArrays adjoint constructor #570

Open
martinhath opened this issue Mar 31, 2020 · 10 comments
Open

StaticArrays adjoint constructor #570

martinhath opened this issue Mar 31, 2020 · 10 comments

Comments

@martinhath
Copy link

I'm trying to use Zygote.jl together with StaticArrays.jl, but am getting an error message I don't understand. Here's a very minimal example

julia> gradient(n->SMatrix{1,1}(n)[1], 1)
ERROR: Need an adjoint for constructor SArray{Tuple{1,1},Int64,2,1}. Gradient is of type MArray{Tuple{1,1},Int64,2,1}

I have tried to add adjoints, but I don't understand for which type the constructor should be for, take in, and whether the type of the delta matters. None of the adjoints I've tried to define helps, so there's obviously things I don't understand here. It is also very confusing that the error message mentions MArray when I've gone out of my way of avoiding mutable arrays, since Zygote doesn't support that.

@AzamatB
Copy link
Contributor

AzamatB commented Mar 31, 2020

To clarify, Zygote does not support mutation operation, but it does support mutable data structures (Array is a prime example of that). So MArray should be fine as long as you don't mutate it in the code you want to differentiate.

@mcabbott
Copy link
Member

The gradient of getindex calls similar(x) to get a fresh array to write into, it needs mutation internally:

julia> similar(SA[1,2]) isa MArray
true

julia> gradient(x -> x[1], SA[1,2])[1]
2-element MArray{Tuple{2},Int64,1,2} with indices SOneTo(2):
 1
 0

In your example it's trying and failing to get back from that to the argument(s), the manual has a section on making this work for new structs... but StaticArrays has lots of more complicated constructors than the examples there. However the SA form seems to work:

julia> gradient((x,y) -> sum(SA[x,y,x,y][1:3]), 1,2)
(2, 1)

@martinhath
Copy link
Author

To clarify, Zygote does not support mutation operation, but it does support mutable data structures (Array is a prime example of that). So MArray should be fine as long as you don't mutate it in the code you want to differentiate.

Alright, I figured as much :) It seems I've found some ways to go around the need for mutation (for now).

The gradient of getindex calls similar(x) to get a fresh array to write into, it needs mutation internally:

I get that this is the wrong repo to ask this, but why does getindex without ! need mutation? What do you mean by "a fresh array to write into" when we're getting a scalar?

In your example it's trying and failing to get back from that to the argument(s)

Okay I see; StaticArrays is converting internally in between S and MArray

the manual has a section on making this work for new structs.

I've read this multiple times, but as I (tried to) explain in the OP I'm having problems with which exact adjoint Zygote needs.

but StaticArrays has lots of more complicated constructors than the examples there.

Yes, this is unfortunate :(

However the SA form seems to work:

Unfortunately, just on the surface. I'm getting errors here (though I suspect this is something different):

julia> gradient(n->SA[n;n][1,1], 1)
ERROR: BoundsError: attempt to access (2,)
  at index [2]

Looking at the original error with fresh eyes, and with @mcabbott's comment in mind, it seems to say that we need the adjoint for the constructor SArray{Tuple{1,1},Int64,2,1}(MArray{Tuple{1,1},Int64,2,1}), since Zygote can't know that the dual isn't supposed to change with this transformation. I've tried this

@Zygote.adjoint StaticArrays.SArray{TS,Ty,C,D}(ma::MArray{TS,Ty,C,D}) where {TS, Ty, C, D} = SArray(ma), y->(y,)

with all kinds of variations in the type arguments, but nothing I do seems to have any effect. Am I completely off the rails here?

@AzamatB
Copy link
Contributor

AzamatB commented Mar 31, 2020

why does getindex without ! need mutation

Because the adjoint of getindex is setindex!

@mcabbott
Copy link
Member

mcabbott commented Mar 31, 2020

why does getindex without ! need mutation? What do you mean by "a fresh array to write into" when we're getting a scalar?

On the forward pass, it's indeed getting a scalar out of an array, as usual. It's clearer if we imagine this array being bigger than one element. But on the reverse pass, it's now asking "by how much do elements of that array influence the final answer?". And one implementation of that is to start with an array of zeros, and then precisely where getindex read out an element, writes into this new array via setindex! with the same indices. That's where it wants an MArray, and that part seems to work fine.

In your example it's trying and failing to get back from that to the argument(s)

Okay I see; StaticArrays is converting internally in between S and MArray

No, it needs to convert an array (any array) back into the individual arguments taken by the constructor.

julia> Zygote.gradient(v -> v[1], SVector(5,5))[1] # ∇getindex returns one MArray...
2-element MArray{Tuple{2},Int64,1,2} with indices SOneTo(2):
 1
 0

julia> Zygote.gradient((a,b) -> SVector(a,b)[1], 5,5) # ... and then after that, "∇SVector" should return two scalars
ERROR: Need an adjoint for constructor SArray{Tuple{2},Int64,1,2}. Gradient is of type MArray{Tuple{2},Int64,1,2}

julia> gradient(v -> SVector{2}(v)[1], [5,5]) # ... or in this case, "∇SVector" should return a vector
ERROR: Need an adjoint for constructor SArray{Tuple{2},Int64,1,2}. Gradient is of type MArray

julia> Zygote.gradient((a,b) -> sum(SVector(a,b)), 5,5) # here "∇SVector" gets a FillArray instead, must still make two scalars
ERROR: Need an adjoint for constructor SArray{Tuple{2},Int64,1,2}. Gradient is of type FillArrays.Fill

These examples can be made to work by defining these gradients, although I may well have overlooked some other subtlety:

@Zygote.adjoint (T::Type{<:SVector})(xs::Number...) = T(xs...), dv -> (nothing, dv...)
@Zygote.adjoint (T::Type{<:SVector})(x::AbstractVector) = T(x), dv -> (nothing, dv)

Here nothing is needed as Zygote views the type itself as the first argument, because this could be some container type with its own parameters, although here it isn't. And nothing is (again by Zygote convention) a generalised zero which indicates such things.

@martinhath
Copy link
Author

This is starting to become understandable! Thanks a lot! This adjoint seems to work well:

@Zygote.adjoint (T::Type{<:SArray})(x::Number...) = T(x...), y->(nothing, y...)

There's still one thing that's confusing regarding the nothing in the end; Looking at the array constructors in Zygote there doesn't seem to be any such thing, but without it I'm getting out-of-bounds errors. How come?


I think it makes sense for these adjoints to be in either library; how are things like this usually resolved with Zygote? Do we want a bunch of adjoints from all kinds of projects here, the other way around, or is this seen as user responsibility to write themselves when they happen to use both libraries?

@mcabbott
Copy link
Member

I'm not really sure, Arrays may be special and weird, but FillArrays (next line) are more surprising... and perhaps one needs Δ::NamedTuple here too...

It would be nice to have all of these working. You may talk Zygote into depending on StaticArrays (to see these types), alternatively the gradients can be defined with only ZygoteRules (a small package for this purpose) which StaticArrays could depend on.

@Yansf677
Copy link

This is starting to become understandable! Thanks a lot! This adjoint seems to work well:

@Zygote.adjoint (T::Type{<:SArray})(x::Number...) = T(x...), y->(nothing, y...)

There's still one thing that's confusing regarding the nothing in the end; Looking at the array constructors in Zygote there doesn't seem to be any such thing, but without it I'm getting out-of-bounds errors. How come?

I think it makes sense for these adjoints to be in either library; how are things like this usually resolved with Zygote? Do we want a bunch of adjoints from all kinds of projects here, the other way around, or is this seen as user responsibility to write themselves when they happen to use both libraries?

then How to add adjoint for LinearAlgebra.diagm where Base.Pair is used?

@mcabbott
Copy link
Member

How to add adjoint for LinearAlgebra.diagm where Base.Pair is used?

Those seem be easier, e.g.

Zygote.@adjoint diagm(x::AbstractVector) = diagm(x), dy -> (diag(dy),)
Zygote.@adjoint diagm(pr::Pair) = diagm(pr), dy -> ((first=nothing, second=diag(dy, first(pr)),)

gradient(x -> sum(sin, diagm(x)), rand(4))
gradient(x -> sum(sin, diagm(1 => x)), rand(4))

But perhaps open a new issue (or better yet, a draft PR) for this.

@stevengj
Copy link

stevengj commented Sep 9, 2022

See also JuliaArrays/StaticArrays.jl#1068 for a PR to StaticArrays that fixes this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants