Skip to content

Commit

Permalink
tests pass, only real
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed May 7, 2021
1 parent d7daed0 commit ad282d7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ end

function ∇prod_dims(dims, x, dy, y=prod(x; dims=dims))
T = promote_type(eltype(x), eltype(dy))
dx = fill!(similar(x, T, size(x)), zero(T))
dx = fill!(similar(x, T, axes(x)), zero(T))
∇prod_dims!(dx, dims, x, dy, y)
return dx
end
Expand All @@ -102,7 +102,7 @@ end

function ∇prod(x, dy::Number=1, y::Number=prod(x))
T = promote_type(eltype(x), eltype(dy))
dx = fill!(similar(x, T, size(x)), zero(T))
dx = fill!(similar(x, T, axes(x)), zero(T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices
∇prod!(dx, x, dy, y)
return dx
end
Expand Down
13 changes: 8 additions & 5 deletions test/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,18 @@
test_rrule(prod, xp xpbar; fkwargs=(dims=dims,), check_inferred=true)
end
end
@testset "structured" begin

@testset "structured wrappers" begin
# Adjoint -- like PermutedDimsArray this may actually be used
xa = adjoint(rand(T,4,4))
test_rrule(prod, xa rand(T,4,4))
@test_skip test_rrule(prod, xa rand(T,4,4), fkwargs=(dims=2,)) # seems broken?
#=
xa = rand(3,3)'
xa = randn(3,3)'
Zygote.gradient(x -> sum(prod(x,dims=1)), xa)[1]
ForwardDiff.gradient(x -> sum(prod(x,dims=1)), xa)
Zygote.gradient(x -> sum(prod(x,dims=1)), Matrix(xa))[1]
# These all agree, so the fault is in the testing, somehow. No doubt I'm holding it wrong.
# These all agree, so the fault is in the testing, somehow.
=#
@test unthunk(rrule(prod, adjoint(rand(T,3,3)))[2](1.0)[2]) isa Matrix
@test unthunk(rrule(prod, adjoint(rand(T,3,3)), dims=1)[2](ones(1,3))[2]) isa Matrix
Expand All @@ -71,10 +71,13 @@ Zygote.gradient(x -> sum(prod(x,dims=1)), Matrix(xa))[1]
@test_skip test_rrule(prod, xs rand(T,4,4))
@test_skip test_rrule(prod, xs rand(T,4,4), fkwargs=(dims=2,))
#=
xs = Symmetric(rand(3,3))
xs = Symmetric(100randn(3,3))
Zygote.gradient(x -> sum(prod(x,dims=1)), xs)[1]
ForwardDiff.gradient(x -> sum(prod(x,dims=1)), Matrix(xs))
Zygote.gradient(x -> sum(prod(x,dims=1)), Matrix(xs))[1]
# These all agree. This time test_rrule, besides an error, gives a complaint that aa ≈ bb fails, where:
aa, bb = [19520.328243416912 -10637.452753538959 10525.400561900045; -2510.9032814879456 -3998.8778331597046 -9169.556884964955; 635.6849617151897 -2346.1691173613817 1445.40000910585], [19520.32824346742 -13148.35603493111 11161.0855235898; -13148.35603493111 -3998.877833179008 -11515.726002337731; 11161.0855235898 -11515.726002337731 1445.4000090979655]
bb ≈ (aa .+ aa') .- Diagonal(aa) # not quite a projection. Is that right?
=#
@test unthunk(rrule(prod, Symmetric(rand(T,3,3)))[2](1.0)[2]) isa Matrix
end
Expand Down

0 comments on commit ad282d7

Please sign in to comment.