Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test_rrule failed when pullback return array buffer allocated in rrule. #264

Open
chengchingwen opened this issue Oct 21, 2022 · 1 comment

Comments

@chengchingwen
Copy link

MWE:

julia> using ChainRulesCore, ChainRulesTestUtils

julia> f(x) = x .* fill(2, size(x))
f (generic function with 1 method)

julia> f1(x) = f(x)
f1 (generic function with 1 method)

julia> f2(x) = f(x)
f2 (generic function with 1 method)

julia> function ChainRulesCore.rrule(::typeof(f1), x)
    tmp = similar(x)
    fill!(tmp, 2)
    y = x .* tmp
    function pullback(Ȳ)
        tmp .*= Ȳ
        ∂ = tmp
        return (NoTangent(), ∂)
    end
    return y, pullback
end

julia> function ChainRulesCore.rrule(::typeof(f2), x)
    tmp = similar(x)
    fill!(tmp, 2)
    y = x .* tmp
    function pullback(Ȳ)
        ∂ = tmp .*return (NoTangent(), ∂)
    end
    return y, pullback
end

julia> test_rrule(f2, [1.0,2.0,3.0])
Test.DefaultTestSet("test_rrule: f2 on Vector{Float64}", Any[], 7, false, false, true, 1.666336561747233e9, 1.666336562127431e9)

julia> test_rrule(f1, [1.0,2.0,3.0])
test_rrule: f1 on Vector{Float64}: Test Failed at /home/peter/.julia/packages/ChainRulesTestUtils/YbVdW/src/check_result.jl:24
Expression: isapprox(actual, expected; kwargs...)
Evaluated: isapprox([32.0, 132.5192, 18.7272], [-7.999999999999777, 16.27999999999986, -6.120000000000349]; rtol = 1.0e-9, atol = 1.0e-9)
Stacktrace:
[...]

It seems a single pullback is called multiple times without being recreated by rrule and thus the assign the value to the same place multiple times.

@mzgubic
Copy link
Member

mzgubic commented Oct 21, 2022

Key point from discussion in Julia Slack:

f1 will run into problems at higher order AD, because there is array mutation in the pullback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants