From 4afdfdd4751420bfc1ce813033ad5cef8ac36dd6 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Sat, 21 Sep 2019 01:38:35 -0700 Subject: [PATCH] Add rrule for Diagonal * AbstractVector (#108) * Add rrule for Diagonal * AbstractVector * Simplify rrule definitions for Diagonal * AbstractVector Co-Authored-By: Lyndon White * Use consistent indentation --- src/rulesets/LinearAlgebra/structured.jl | 7 +++++++ test/rulesets/LinearAlgebra/structured.jl | 9 +++++++++ 2 files changed, 16 insertions(+) diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 3e8934afc..45f026edb 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -19,6 +19,13 @@ function rrule(::typeof(diag), A::AbstractMatrix) return diag(A), diag_pullback end +function rrule(::typeof(*), D::Diagonal{<:Real}, V::AbstractVector{<:Real}) + function times_pullback(Ȳ) + return (NO_FIELDS, @thunk(Diagonal(Ȳ .* V)), @thunk(D * Ȳ)) + end + return D * V, times_pullback +end + ##### ##### `Symmetric` ##### diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index 2d82a2d3f..1731aad32 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -7,6 +7,15 @@ # Concrete type instead of UnionAll rrule_test(typeof(D), D, (randn(rng, N), randn(rng, N))) end + @testset "::Diagonal * ::AbstractVector" begin + rng, N = MersenneTwister(123456), 3 + rrule_test( + *, + randn(rng, N), + (Diagonal(randn(rng, N)), Diagonal(randn(rng, N))), + (randn(rng, N), randn(rng, N)), + ) + end @testset "diag" begin rng, N = MersenneTwister(123456), 7 rrule_test(diag, randn(rng, N), (randn(rng, N, N), randn(rng, N, N)))