Skip to content

Commit

Permalink
Merge pull request #175 from JuliaDiff/ox/zygote_fixes
Browse files Browse the repository at this point in the history
Fix errors revealed by Zygote's tests
  • Loading branch information
oxinabox authored Apr 29, 2020
2 parents 02e7857 + 11d3610 commit df431b9
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 23 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.5.0"
version = "0.5.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -11,7 +11,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.7"
ChainRulesTestUtils = "0.2.1"
ChainRulesTestUtils = "0.2.2"
Compat = "3"
FiniteDifferences = "0.9"
Reexport = "0.2"
Expand Down
22 changes: 18 additions & 4 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

@scalar_rule(abs(x::Real), sign(x))
@scalar_rule(abs2(x), 2x)
@scalar_rule(exp(x), Ω)
@scalar_rule(exp(x::Real), Ω)
@scalar_rule(exp10(x), Ω * log(oftype(x, 10)))
@scalar_rule(exp2(x), Ω * log(oftype(x, 2)))
@scalar_rule(expm1(x), exp(x))
Expand Down Expand Up @@ -45,7 +45,8 @@
@scalar_rule(sinh(x), cosh(x))
@scalar_rule(tanh(x), 1-Ω^2)

@scalar_rule(acosh(x), inv(sqrt(x^2 - 1)))
# Can't multiply though sqrt in acosh because of negative complex case for x
@scalar_rule(acosh(x), inv(sqrt(x - 1) * sqrt(x + 1)))
@scalar_rule(acoth(x), inv(1 - x^2))
@scalar_rule(acsch(x), -inv(x^2 * sqrt(1 + x^-2)))
@scalar_rule(acsch(x::Real), -inv(abs(x) * sqrt(1 + x^2)))
Expand All @@ -66,7 +67,9 @@
@scalar_rule(-(x, y), (One(), -1))
@scalar_rule(/(x, y), (inv(y), -(x / y / y)))
@scalar_rule(\(x, y), (-(y / x / x), inv(x)))
@scalar_rule(^(x, y), (ifelse(iszero(y), zero(Ω), y * x^(y - 1)), Ω * log(x)))

#log(complex(x)) is require so it give correct complex answer for x<0
@scalar_rule(^(x, y), (ifelse(iszero(y), zero(Ω), y * x^(y - 1)), Ω * log(complex(x))))

@scalar_rule(cbrt(x), inv(3 * Ω^2))
@scalar_rule(inv(x), -Ω^2)
Expand Down Expand Up @@ -117,7 +120,7 @@ end

function rrule(::typeof(*), x::Number, y::Number)
function times_pullback(ΔΩ)
return (NO_FIELDS, @thunk(ΔΩ * y'), @thunk(x' * ΔΩ))
return (NO_FIELDS, @thunk(ΔΩ * y), @thunk(x * ΔΩ))
end
return x * y, times_pullback
end
Expand All @@ -132,3 +135,14 @@ function rrule(::typeof(identity), x)
end
return x, identity_pullback
end

function rrule(::typeof(identity), x::Tuple)
# `identity(::Tuple)` returns multiple outputs;because that is how we think of
# returning a tuple, so its pullback needs to accept multiple inputs.
# `identity(::Tuple)` has one input, so its pullback should return 1 matching output
# see https://github.com/JuliaDiff/ChainRulesCore.jl/issues/152
function identity_pullback(ȳs...)
return (NO_FIELDS, Composite{typeof(x)}(ȳs...))
end
return x, identity_pullback
end
23 changes: 11 additions & 12 deletions src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@ end
##### `det`
#####

function frule((_, ẋ), ::typeof(det), x)
function frule((_, ẋ), ::typeof(det), x::Union{Number, AbstractMatrix})
Ω = det(x)
# TODO Performance optimization: probably there is an efficent
# way to compute this trace without during the full compution within
return Ω, Ω * tr(inv(x) * ẋ)
end

function rrule(::typeof(det), x)
function rrule(::typeof(det), x::Union{Number, AbstractMatrix})
Ω = det(x)
function det_pullback(ΔΩ)
return NO_FIELDS, @thunk(Ω * ΔΩ * inv(x)')
return NO_FIELDS, Ω * ΔΩ * transpose(inv(x))
end
return Ω, det_pullback
end
Expand All @@ -59,15 +59,15 @@ end
##### `logdet`
#####

function frule((_, Δx), ::typeof(logdet), x)
function frule((_, Δx), ::typeof(logdet), x::Union{Number, AbstractMatrix})
Ω = logdet(x)
return Ω, tr(inv(x) * Δx)
end

function rrule(::typeof(logdet), x)
function rrule(::typeof(logdet), x::Union{Number, AbstractMatrix})
Ω = logdet(x)
function logdet_pullback(ΔΩ)
return (NO_FIELDS, @thunk(ΔΩ * inv(x)'))
return (NO_FIELDS, ΔΩ * transpose(inv(x)))
end
return Ω, logdet_pullback
end
Expand All @@ -81,6 +81,8 @@ function frule((_, Δx), ::typeof(tr), x)
end

function rrule(::typeof(tr), x)
# This should really be a FillArray
# see https://github.com/JuliaDiff/ChainRules.jl/issues/46
function tr_pullback(ΔΩ)
return (NO_FIELDS, @thunk Diagonal(fill(ΔΩ, size(x, 1))))
end
Expand Down Expand Up @@ -121,14 +123,11 @@ function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
C, dC_pb = rrule(adjoint, Cᵀ)
function slash_pullback(Ȳ)
# Optimization note: dAᵀ, dBᵀ, dC are calculated no matter which partial you want
# this is not a problem if you want the 2nd or 3rd, but if you want the first, it
# is fairly wasteful
_, dC = dC_pb(Ȳ)
_, dBᵀ, dAᵀ = dS_pb(extern(dC))
_, dBᵀ, dAᵀ = dS_pb(unthunk(dC))

# need to extern as dAᵀ, dBᵀ are generally `Thunk`s, which don't support adjoint
∂A = @thunk last(dA_pb(extern(dAᵀ)))
∂B = @thunk last(dA_pb(extern(dBᵀ)))
∂A = last(dA_pb(unthunk(dAᵀ)))
∂B = last(dA_pb(unthunk(dBᵀ)))

(NO_FIELDS, ∂A, ∂B)
end
Expand Down
11 changes: 9 additions & 2 deletions src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@
#####

function rrule(::Type{<:Diagonal}, d::AbstractVector)
function Diagonal_pullback(ȳ)
return (NO_FIELDS, @thunk(diag(ȳ)))
function Diagonal_pullback(ȳ::AbstractMatrix)
return (NO_FIELDS, diag(ȳ))
end
function Diagonal_pullback(ȳ::Composite)
# TODO: Assert about the primal type in the Composite, It should be Diagonal
# infact it should be exactly the type of `Diagonal(d)`
# but right now Zygote loses primal type information so we can't use it.
# See https://github.com/FluxML/Zygote.jl/issues/603
return (NO_FIELDS, ȳ.diag)
end
return Diagonal(d), Diagonal_pullback
end
Expand Down
39 changes: 36 additions & 3 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@
test_scalar(acsc, 1/x)
test_scalar(acot, 1/x)
end
@testset "Inverse hyperbolic" for x = (0.5, Complex(0.5, 0.25))
@testset "Inverse hyperbolic" for x = (0.5, Complex(0.5, 0.25), Complex(-2.1 -3.1im))
test_scalar(asinh, x)
test_scalar(acosh, x + 1) # +1 accounts for domain
test_scalar(acosh, x + 1) # +1 accounts for domain for real
test_scalar(atanh, x)
test_scalar(asech, x)
test_scalar(acsch, x)
test_scalar(acoth, x + 1)
end

@testset "Inverse degrees" for x = (0.5, Complex(0.5, 0.25))
test_scalar(asind, x)
test_scalar(acosd, x)
Expand Down Expand Up @@ -100,7 +101,23 @@
end
end

@testset "*(x, y)" begin
@testset "*(x, y) (scalar)" begin
# This is pretty important so testing it fairly heavily
test_points = (0.0, -2.1, 3.2, 3.7+2.12im, 14.2-7.1im)
@testset "$x * $y; (perturbed by: $perturb)" for
x in test_points, y in test_points, perturb in test_points

# give small off-set so as can't slip in symmetry
== 0.5 + perturb
== 0.6 + perturb
Δz = perturb

frule_test(*, (x, ẋ), (y, ẏ))
rrule_test(*, Δz, (x, x̄), (y, ȳ))
end
end

@testset "matmul *(x, y)" begin
x, y = rand(3, 2), rand(2, 5)
z, pullback = rrule(*, x, y)

Expand All @@ -125,10 +142,26 @@
rrule_test(f, Δz, (x, x̄), (y, ȳ))
end

@testset "x^n for x<0" begin
rng = MersenneTwister(123456)
x = -15*rand(rng)
Δx, x̄ = 10rand(rng, 2)
y, Δy, ȳ = rand(rng, 3)
Δz = rand(rng)

frule_test(^, (-x, Δx), (y, Δy))
rrule_test(^, Δz, (-x, x̄), (y, ȳ))
end

@testset "identity" begin
rng = MersenneTwister(1)
rrule_test(identity, randn(rng), (randn(rng), randn(rng)))
rrule_test(identity, randn(rng, 4), (randn(rng, 4), randn(rng, 4)))

rrule_test(
identity, Tuple(randn(rng, 3)),
(Composite{Tuple}(randn(rng, 3)...), Composite{Tuple}(randn(rng, 3)...))
)
end

@testset "Constants" for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im, 0+200im)
Expand Down
8 changes: 8 additions & 0 deletions test/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@
rrule_test(Diagonal, D, (randn(rng, N), randn(rng, N)))
# Concrete type instead of UnionAll
rrule_test(typeof(D), D, (randn(rng, N), randn(rng, N)))

# TODO: replace this with a `rrule_test` once we have that working
# see https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/24
res, pb = rrule(Diagonal, [1, 4])
@test pb(10*res) == (NO_FIELDS, [10, 40])
comp = Composite{typeof(res)}(; diag=10*res.diag) # this is the structure of Diagonal
@test pb(comp) == (NO_FIELDS, [10, 40])
end

@testset "::Diagonal * ::AbstractVector" begin
rng, N = MersenneTwister(123456), 3
rrule_test(
Expand Down

2 comments on commit df431b9

@oxinabox
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/13840

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.1 -m "<description of version>" df431b9120b34717c13af0fbfb845f731c8e07ad
git push origin v0.5.1

Please sign in to comment.