From ad282d75fbdd1452205ee164f06bd06497d4b058 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 6 May 2021 21:01:42 -0400 Subject: [PATCH] tests pass, only real --- src/rulesets/Base/mapreduce.jl | 4 ++-- test/rulesets/Base/mapreduce.jl | 13 ++++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 5b3e713be..0a7330b15 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -82,7 +82,7 @@ end function ∇prod_dims(dims, x, dy, y=prod(x; dims=dims)) T = promote_type(eltype(x), eltype(dy)) - dx = fill!(similar(x, T, size(x)), zero(T)) + dx = fill!(similar(x, T, axes(x)), zero(T)) ∇prod_dims!(dx, dims, x, dy, y) return dx end @@ -102,7 +102,7 @@ end function ∇prod(x, dy::Number=1, y::Number=prod(x)) T = promote_type(eltype(x), eltype(dy)) - dx = fill!(similar(x, T, size(x)), zero(T)) + dx = fill!(similar(x, T, axes(x)), zero(T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices ∇prod!(dx, x, dy, y) return dx end diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 2a1ae87fe..ef34b3e0f 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -43,18 +43,18 @@ test_rrule(prod, xp ⊢ xpbar; fkwargs=(dims=dims,), check_inferred=true) end end - - @testset "structured" begin + + @testset "structured wrappers" begin # Adjoint -- like PermutedDimsArray this may actually be used xa = adjoint(rand(T,4,4)) test_rrule(prod, xa ⊢ rand(T,4,4)) @test_skip test_rrule(prod, xa ⊢ rand(T,4,4), fkwargs=(dims=2,)) # seems broken? #= -xa = rand(3,3)' +xa = randn(3,3)' Zygote.gradient(x -> sum(prod(x,dims=1)), xa)[1] ForwardDiff.gradient(x -> sum(prod(x,dims=1)), xa) Zygote.gradient(x -> sum(prod(x,dims=1)), Matrix(xa))[1] -# These all agree, so the fault is in the testing, somehow. No doubt I'm holding it wrong. +# These all agree, so the fault is in the testing, somehow. =# @test unthunk(rrule(prod, adjoint(rand(T,3,3)))[2](1.0)[2]) isa Matrix @test unthunk(rrule(prod, adjoint(rand(T,3,3)), dims=1)[2](ones(1,3))[2]) isa Matrix @@ -71,10 +71,13 @@ Zygote.gradient(x -> sum(prod(x,dims=1)), Matrix(xa))[1] @test_skip test_rrule(prod, xs ⊢ rand(T,4,4)) @test_skip test_rrule(prod, xs ⊢ rand(T,4,4), fkwargs=(dims=2,)) #= -xs = Symmetric(rand(3,3)) +xs = Symmetric(100randn(3,3)) Zygote.gradient(x -> sum(prod(x,dims=1)), xs)[1] ForwardDiff.gradient(x -> sum(prod(x,dims=1)), Matrix(xs)) Zygote.gradient(x -> sum(prod(x,dims=1)), Matrix(xs))[1] +# These all agree. This time test_rrule, besides an error, gives a complaint that aa ≈ bb fails, where: +aa, bb = [19520.328243416912 -10637.452753538959 10525.400561900045; -2510.9032814879456 -3998.8778331597046 -9169.556884964955; 635.6849617151897 -2346.1691173613817 1445.40000910585], [19520.32824346742 -13148.35603493111 11161.0855235898; -13148.35603493111 -3998.877833179008 -11515.726002337731; 11161.0855235898 -11515.726002337731 1445.4000090979655] +bb ≈ (aa .+ aa') .- Diagonal(aa) # not quite a projection. Is that right? =# @test unthunk(rrule(prod, Symmetric(rand(T,3,3)))[2](1.0)[2]) isa Matrix end