From 912663f6de3cbf2eb13b01e3a0d7d1c50e70a7e4 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 9 Mar 2021 18:29:41 +0000 Subject: [PATCH] add prod --- src/rulesets/Base/mapreduce.jl | 14 ++++++++++++++ test/rulesets/Base/mapreduce.jl | 7 +++++++ 2 files changed, 21 insertions(+) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 33eae96d1..9a8e84834 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -48,3 +48,17 @@ function rrule( end return y, sum_abs2_pullback end + +#### +#### prod +#### + +function rrule(::typeof(prod), x::AbstractArray{T}; dims=:) where {T<:CommutativeMulNumber} + y = prod(x; dims=dims) + function prod_pullback(ȳ) + # broadcasting the two works out the size no-matter `dims` + x̄ = y .* ȳ ./ x + return (NO_FIELDS, x̄) + end + return y, prod_pullback +end diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 8e53ffdec..abecbc05d 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -20,4 +20,11 @@ end end end # sum abs2 + + @testset "prod" begin + test_rrule(prod, randn(5)) + test_rrule(prod, randn(5, 6)) + test_rrule(prod, randn(5, 6); fkwargs=(;dims=2)) + test_rrule(prod, randn(5, 6); fkwargs=(;dims=1)) + end end