-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rule for prod
#335
Add rule for prod
#335
Conversation
Codecov Report
@@ Coverage Diff @@
## master #335 +/- ##
==========================================
+ Coverage 97.64% 97.73% +0.09%
==========================================
Files 18 18
Lines 1018 1061 +43
==========================================
+ Hits 994 1037 +43
Misses 24 24
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, just some comments.
##### `prod` | ||
##### | ||
|
||
function rrule(::typeof(prod), x::AbstractArray{T}; dims=:) where {T<:CommutativeMulNumber} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for diving in with my usual request, but is there a sensible way that we could restrict the type here, since the tests currently only look at Array
s? e.g. I imagine that a at least one of a Fill
, Diagonal
, StaticArray
etc will do something weird here. Would StridedArray
suffice for your use case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a test with PermutedDimsArray, which isn't a StridedArray, and I think it ought to work fine with StaticArrays, although I have not tested that in any depth. Diagonal seems to work although I struggle to imagine why calling prod
on one would be a good idea, but weird things happen:
julia> unthunk(rrule(prod, Diagonal(SA[1,2,3,4]))[2](1.0)[2])
4×4 Diagonal{Float64, MVector{4, Float64}}:
0.0 ⋅ ⋅ ⋅
⋅ 0.0 ⋅ ⋅
⋅ ⋅ 0.0 ⋅
⋅ ⋅ ⋅ 0.0
julia> unthunk(rrule(prod, Fill(2,3))[2](1.0)[2])
3-element Vector{Float64}:
4.0
4.0
4.0
Fill makes a Vector gradient. Somehow rrule(sum, Fill(2,3))
makes a Fill, because it simply broadcasts rather than calling similar
. Is this something the package aims to guarantee? I don't see a test for it. Elsewhere it chooses similar
over broadcasting to void other issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a test with PermutedDimsArray, which isn't a StridedArray
Apologies, it's late here. They are quite clearly in the tests...
Fill makes a Vector gradient
We would definitely want the output of rrule
w.r.t. a Fill
argument to be either another Fill
or an appropriate Composite
. This is the kind of thing that e.g. Zygote
can probably get right without a rule in ChainRules
, so I think the ideal solution here is just not to implement a rule that covers Fill
.
Also, should the result with Diagonal
have zeros on the diagonal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, should the result with Diagonal have zeros on the diagonal?
Something is broken on the last commit, but not this, this one you can do in your head: It's a product of mostly zeros, so the gradient with respect to nonzero entries still vanishes.
Fill makes a Vector gradient
We would definitely want
This could be arranged at the cost of more complexity... although possibly Fill ought to define similar
more like that of Diagonal if it wishes to be preserved under such things?
Although clearly not all gradients are going to perserve this structure:
julia> gradient(sum∘cumsum, Fill(1,3))[1]
3-element Vector{Int64}:
3
2
1
julia> gradient(x -> sum(cumsum(x, dims=1)), Diagonal([1,2,3]))[1] # what should this produce?
And here's how much Zygote can figure out right now:
julia> function myprod(xs)
out = 1.0
for x in xs
out *= x
end
out
end
myprod (generic function with 1 method)
julia> gradient(myprod, Fill(1,3))[1]
3-element Vector{Float64}:
1.0
1.0
1.0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something is broken on the last commit, but not this, this one you can do in your head: It's a product of mostly zeros, so the gradient with respect to nonzero entries still vanishes.
Yes, good point.
julia> gradient(sum∘cumsum, Fill(1,3))[1]
This should produce a Composite
, with the value
with an appropriate value
field.
julia> gradient(x -> sum(cumsum(x, dims=1)), Diagonal([1,2,3]))[1] # what should this produce?
This should also produce either a Diagonal
or an appropriate Composite
.
But with all of these types, I'm not saying that your PR needs to cover them. I'm purely suggesting it addresses the minimal set of types that you're confident are done correctly, and assumes that AD can do a reasonable job of deriving the others.
julia> gradient(myprod, Fill(1,3))[1]
The answer to this is the result of a bug in Zygote
that I should fix -- it looks like an example of what I'm commenting on here, where getindex
has been implemented for too broad a set of types. Zygote
really should be able to derive the rule for this properly. i.e. getindex
only ever returns the value
field of a Fill
, so you shouldn't even need a rule for getindex
for Fill
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we can get this merged, shall we change this to StridedArray
and then we can make a follow up later?
bump |
Codecov Report
@@ Coverage Diff @@
## master #335 +/- ##
==========================================
- Coverage 98.49% 98.48% -0.02%
==========================================
Files 23 23
Lines 1929 1980 +51
==========================================
+ Hits 1900 1950 +50
- Misses 29 30 +1
Continue to review full report at Codecov.
|
@test_skip test_rrule(prod, xs ⊢ rand(T,4,4)) | ||
@test_skip test_rrule(prod, xs ⊢ rand(T,4,4), fkwargs=(dims=2,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a bug in how FiniteDifferences does this, or am I thinking incorrectly about what it should produce?
using ForwardDiff, FiniteDifferences
xs = Symmetric(reshape(1:16,4,4)./10)
xm = Matrix(xs)
g1 = ForwardDiff.gradient(prod, xm) # symmetric, but not Symmetric
g2 = grad(central_fdm(5, 1), prod, xm)[1]
g3 = grad(central_fdm(5, 1), prod, xs)[1] # is Symmetric, and differs
g1 ≈ g2
diag(g1) ≈ diag(g3)
UnitUpperTriangular(g1) ≈ UnitUpperTriangular(g3 ./ 2) # this seems weird
With dims
:
g4 = ForwardDiff.gradient(x -> sum(prod(x,dims=1)), xm) # no longer symmetric
g5 = grad(central_fdm(5, 1), x -> sum(prod(x,dims=1)), xm)[1]
g6 = grad(central_fdm(5, 1), x -> sum(prod(x,dims=1)), xs)[1]
g4 ≈ g5
proj(m) = (m .+ m')./2;
proj(g4) ≈ proj(proj(g4)) # it's a projection
fold(m) = m .+ m' .- Diagonal(m)
fold(g4) ≈ g6
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest checking what FiniteDifferences.to_vec
is output ting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now I've left here a simpler test that it does run without error on a Symmetric: unthunk(rrule(prod, Symmetric(ones(T,2,2)))[2](1.0)[2]) == [1 1; 1 1]
. I very much doubt this case is going to see use, but it shouldn't give an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be better to have a test with Symmetric([2.0 3.0; 3.0 2.0])
ones are hard to trust
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I've switched it. It's mostly to check this doesn't give an error, the computation here does not care at all what kind of matrix it gets. Although ones
is in fact sufficient to see this weirdness:
julia> grad(central_fdm(5, 1), prod, Symmetric(ones(2,2)))[1]
2×2 Symmetric{Float64, Matrix{Float64}}:
1.0 2.0
2.0 1.0
julia> ForwardDiff.gradient(prod, ones(2,2))
2×2 Matrix{Float64}:
1.0 1.0
1.0 1.0
Status here is that:
julia> x=rand(10,100); x[1:21:end].=0; # half the columns have a zero
julia> @btime gradient(x -> sum(prod(x, dims=1)), $x)[1] # Zygote, with NaN
1.292 μs (3 allocations: 8.83 KiB)
10×100 Matrix{Float64}:
NaN 1.64248e-6 0.0 2.85863e-5 0.0 … 0.0 0.00412328 0.0 0.000262674
julia> @btime gradient(x -> sum(prod(x, dims=1)), $x)[1] # this PR, original, with a a type instability
56.000 μs (1706 allocations: 51.06 KiB)
10×100 Matrix{Float64}:
1.499e-5 2.54822e-5 0.0 0.000129398 … 0.000634891 0.0 0.00343482
julia> @btime gradient(x -> sum(prod(x, dims=1)), $x)[1] # with Val(Int(dims)), same as dims=1 hard coded
1.850 μs (5 allocations: 8.86 KiB)
10×100 Matrix{Float64}:
1.12526e-6 0.000483517 0.0 5.28479e-6 … 0.000555819 0.0 0.00126222
julia> @btime gradient(x -> sum(prod(x, dims=1)), $x)[1] # version with map, not indexing
2.704 μs (6 allocations: 8.89 KiB)
10×100 Matrix{Float64}:
5.36396e-7 0.0370314 0.0 0.022377 0.0 … 8.56615e-6 0.0 1.3721e-5
0.0 0.0327162 4.17844e-5 0.00593767 0.0 7.05661e-7 0.0 0.0244155 Current PR is the 2nd-last variant, code for last variant reads:
But my vote is to leave it; someone else can switch to this later if they see a need. |
Co-authored-by: Simeon Schaub <[email protected]>
Bump? Failure on nightly looks unrelated, LinearAlgebra/factorization.jl:182 |
Array concerns can be addressed in follow-up if and when it occurs |
Thanks, sorry about the long delay |
988: delete rule for prod to use ChainRules' r=oxinabox a=oxinabox @mcabbott added a rule for prod into ChainRules JuliaDiff/ChainRules.jl#335 It's better than the one in Zygote as it gets the right answer even if one of the elements is zero. So we can delete the old one here. But leaving the tests in place per our policy, as a double check against regressions in ChainRules Co-authored-by: Lyndon White <[email protected]>
This adds a reverse-mode gradient for
prod(x; dims)
, which should correctly treat zero entries.It ends up a little more complicated than seems ideal. In particular this won't work on CuArrays (at least when there are zeros, I think it will give a scalar access warning, which might be better than
NaN
s; when there aren't, it should work). Is there a mechanism worked out for where a Cu version of∇prod_dims!
should live, if someone were to write one?It also probably won't work well for second derivatives.