diff --git a/Project.toml b/Project.toml index 7732e7d1d..5d88267fe 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.66.0" +version = "1.67.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 6c66d19ee..8c616345a 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -11,7 +11,7 @@ function frule((_, _), ::typeof(zero), x) end function rrule(::typeof(zero), x) - zero_pullback(_) = (NoTangent(), ZeroTangent()) + zero_pullback = Returns((NoTangent(), ZeroTangent())) return (zero(x), zero_pullback) end @@ -22,7 +22,7 @@ function frule((_, _), ::typeof(one), x) end function rrule(::typeof(one), x) - one_pullback(_) = (NoTangent(), ZeroTangent()) + one_pullback = Returns((NoTangent(), ZeroTangent())) return (one(x), one_pullback) end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 25c755f55..0b9e0a721 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -4,6 +4,7 @@ end @testset "base.jl" begin @testset "zero/one" begin + @test last(rrule(zero, 0.1)) === last(rrule(one, 0.2f0)) for f in [zero, one] for x in [1.0, 1.0im, [10.0+im 11.0-im; 12.0+2im 13.0-3im]] test_frule(f, x)