From ccd9dcf99c7dc0cb9bf6501f15bbb3dc38c871b2 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 12 May 2021 14:05:24 -0400 Subject: [PATCH] adjoint test may work --- test/rulesets/Base/mapreduce.jl | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 3956a0a5c..a15b23054 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -48,16 +48,10 @@ # Adjoint -- like PermutedDimsArray this may actually be used xa = adjoint(rand(T,4,4)) test_rrule(prod, xa ⊢ rand(T,4,4)) - @test_skip test_rrule(prod, xa ⊢ rand(T,4,4), fkwargs=(dims=2,)) # seems broken? -#= -xa = randn(3,3)' -Zygote.gradient(x -> sum(prod(x,dims=1)), xa)[1] -ForwardDiff.gradient(x -> sum(prod(x,dims=1)), xa) -Zygote.gradient(x -> sum(prod(x,dims=1)), Matrix(xa))[1] -# These all agree, so the fault is in the testing, somehow. -=# + test_rrule(prod, xa ⊢ rand(T,4,4), fkwargs=(dims=2,)) @test unthunk(rrule(prod, adjoint(rand(T,3,3)))[2](1.0)[2]) isa Matrix @test unthunk(rrule(prod, adjoint(rand(T,3,3)), dims=1)[2](ones(1,3))[2]) isa Matrix + # Diagonal -- a stupid thing to do, product of zeros! Shouldn't be an error though: @test iszero(unthunk(rrule(prod, Diagonal(rand(T,3)))[2](1.0)[2])) @test iszero(unthunk(rrule(prod, Diagonal(rand(T,3)), dims=1)[2](ones(1,3))[2]))