Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mixedduplicated support in enzyme ext #120

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

wsmoses
Copy link
Contributor

@wsmoses wsmoses commented Sep 17, 2024

No description provided.

Copy link

codecov bot commented Sep 17, 2024

Codecov Report

Attention: Patch coverage is 0% with 26 lines in your changes missing coverage. Please review.

Project coverage is 86.83%. Comparing base (ce727e1) to head (ab12c5b).

Files with missing lines Patch % Lines
ext/QuadGKEnzymeExt.jl 0.00% 26 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (ce727e1) and HEAD (ab12c5b). Click for more details.

HEAD has 6 uploads less than BASE
Flag BASE (ce727e1) HEAD (ab12c5b)
9 3
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #120      +/-   ##
==========================================
- Coverage   92.36%   86.83%   -5.54%     
==========================================
  Files           8        8              
  Lines         786      767      -19     
==========================================
- Hits          726      666      -60     
- Misses         60      101      +41     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@danielwe
Copy link
Contributor

Thanks looking into this! The MWE in #119 now errors with the following:

julia> include("quadgkmixed.jl");
chebyshevintegral(domain, coeffs) = 2.0
ERROR: LoadError: UndefVarError: `Rev` not defined
Stacktrace:
  [1] (::QuadGKEnzymeExt.var"#18#25"{Active{Tuple{Float64, Float64}}, MixedDuplicated{Fun{Chebyshev{IntervalSets.ClosedInterval{Float64}, Float64}, Float64, Vector{Float64}}}})(x::Float64)
    @ QuadGKEnzymeExt ~/src/upstream-packages/QuadGK.jl/ext/QuadGKEnzymeExt.jl:127
  [2] evalrule(f::QuadGKEnzymeExt.var"#18#25"{Active{…}, MixedDuplicated{…}}, a::Float64, b::Float64, x::Vector{Float64}, w::Vector{Float64}, wg::Vector{Float64}, nrm::QuadGKEnzymeExt.var"#20#27")
    @ QuadGK ~/src/upstream-packages/QuadGK.jl/src/evalrule.jl:0
  [3] (::QuadGK.var"#6#9"{QuadGKEnzymeExt.var"#18#25"{…}, QuadGKEnzymeExt.var"#20#27", Vector{…}, Vector{…}, Vector{…}})(seg::QuadGK.Segment{Float64, Float64, Float64})
    @ QuadGK ~/src/upstream-packages/QuadGK.jl/src/adapt.jl:36
  [4] iterate
    @ ./generator.jl:47 [inlined]
  [5] _collect
    @ ./array.jl:854 [inlined]
  [6] collect_similar
    @ ./array.jl:763 [inlined]
  [7] map
    @ ./abstractarray.jl:3285 [inlined]
  [8] do_quadgk(f::QuadGKEnzymeExt.var"#18#25"{…}, s::Tuple{…}, n::Int64, atol::Nothing, rtol::Nothing, maxevals::Int64, nrm::QuadGKEnzymeExt.var"#20#27", _segbuf::Nothing, eval_segbuf::Vector{…})
    @ QuadGK ~/src/upstream-packages/QuadGK.jl/src/adapt.jl:35
  [9] (::QuadGK.var"#50#51"{Nothing, Nothing, Int64, Int64, QuadGKEnzymeExt.var"#20#27", Nothing, Vector{QuadGK.Segment{…}}})(f::Function, s::Tuple{Float64, Float64}, ::Function)
    @ QuadGK ~/src/upstream-packages/QuadGK.jl/src/api.jl:83
 [10] handle_infinities(workfunc::QuadGK.var"#50#51"{…}, f::QuadGKEnzymeExt.var"#18#25"{…}, s::Tuple{…})
    @ QuadGK ~/src/upstream-packages/QuadGK.jl/src/adapt.jl:189
 [11] quadgk(::QuadGKEnzymeExt.var"#18#25"{…}, ::Float64, ::Vararg{…}; atol::Nothing, rtol::Nothing, maxevals::Int64, order::Int64, norm::Function, segbuf::Nothing, eval_segbuf::Vector{…})
    @ QuadGK ~/src/upstream-packages/QuadGK.jl/src/api.jl:82
 [12] reverse(::EnzymeCore.EnzymeRules.ConfigWidth{…}, ::Const{…}, ::Active{…}, ::Tuple{…}, ::MixedDuplicated{…}, ::Active{…}, ::Vararg{…}; kws::@Kwargs{})
    @ QuadGKEnzymeExt ~/src/upstream-packages/QuadGK.jl/ext/QuadGKEnzymeExt.jl:126
 [13] macro expansion
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:7187 [inlined]
 [14] enzyme_call
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6794 [inlined]
 [15] AdjointThunk
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6677 [inlined]
 [16] runtime_generic_rev(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, tape::Enzyme.Compiler.Tape{…}, f::typeof(quadgk), df::Nothing, primal_1::Fun{…}, shadow_1_1::Base.RefValue{…}, primal_2::Float64, shadow_2_1::Base.RefValue{…}, primal_3::Float64, shadow_3_1::Base.RefValue{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/TiboG/src/rules/jitrules.jl:468
 [17] chebyshevintegral
    @ ~/issues/quadgkmixed.jl:5 [inlined]
 [18] chebyshevintegral
    @ ~/issues/quadgkmixed.jl:0 [inlined]
 [19] diffejulia_chebyshevintegral_3229_inner_1wrap
    @ ~/issues/quadgkmixed.jl:0
 [20] macro expansion
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:7187 [inlined]
 [21] enzyme_call
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6794 [inlined]
 [22] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6671 [inlined]
 [23] autodiff
    @ ~/.julia/packages/Enzyme/TiboG/src/Enzyme.jl:320 [inlined]
 [24] autodiff(::ReverseMode{false, FFIABI, false, false}, ::typeof(chebyshevintegral), ::Type{Active}, ::Const{IntervalSets.ClosedInterval{Float64}}, ::Duplicated{Vector{Float64}})
    @ Enzyme ~/.julia/packages/Enzyme/TiboG/src/Enzyme.jl:332
 [25] top-level scope
    @ ~/issues/quadgkmixed.jl:13
 [26] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [27] top-level scope
    @ REPL[4]:1
in expression starting at /home/daniel/issues/quadgkmixed.jl:13
Some type information was truncated. Use `show(err)` to see complete types.

@danielwe
Copy link
Contributor

Also, regular Duplicated is probably at least as important as MixedDuplicated, but I suppose that's pretty straightforward to add once the latter is in place?

@danielwe
Copy link
Contributor

danielwe commented Sep 17, 2024

Fixing the Rev -> Ref typo, I now get a different error:

julia> include("quadgkmixed.jl");
chebyshevintegral(domain, coeffs) = 2.0
ERROR: LoadError: TypeError: in typeassert, expected QuadGKEnzymeExt.MixedClosureVector{var"#7#8"{Fun{Chebyshev{IntervalSets.ClosedInterval{Float64}, Float64}, Float64, Vector{Float64}}}}, got a value of type Nothing
Stacktrace:
  [1] *(a::Float64, b::QuadGKEnzymeExt.MixedClosureVector{var"#7#8"{Fun{Chebyshev{IntervalSets.ClosedInterval{Float64}, Float64}, Float64, Vector{Float64}}}})
    @ QuadGKEnzymeExt ~/src/upstream-packages/QuadGK.jl/ext/QuadGKEnzymeExt.jl:104
  [2] *(a::QuadGKEnzymeExt.MixedClosureVector{var"#7#8"{Fun{Chebyshev{IntervalSets.ClosedInterval{Float64}, Float64}, Float64, Vector{Float64}}}}, b::Float64)
    @ QuadGKEnzymeExt ~/src/upstream-packages/QuadGK.jl/ext/QuadGKEnzymeExt.jl:108
  [3] evalrule(f::QuadGKEnzymeExt.var"#94#101"{…}, a::Float64, b::Float64, x::Vector{…}, w::Vector{…}, wg::Vector{…}, nrm::QuadGKEnzymeExt.var"#96#103")
    @ QuadGK ~/src/upstream-packages/QuadGK.jl/src/evalrule.jl:26
  [4] (::QuadGK.var"#6#9"{QuadGKEnzymeExt.var"#94#101"{…}, QuadGKEnzymeExt.var"#96#103", Vector{…}, Vector{…}, Vector{…}})(seg::QuadGK.Segment{Float64, Float64, Float64})
    @ QuadGK ~/src/upstream-packages/QuadGK.jl/src/adapt.jl:36
  [5] iterate
    @ ./generator.jl:47 [inlined]
  [6] _collect
    @ ./array.jl:854 [inlined]
  [7] collect_similar
    @ ./array.jl:763 [inlined]
  [8] map
    @ ./abstractarray.jl:3285 [inlined]
  [9] do_quadgk(f::QuadGKEnzymeExt.var"#94#101"{…}, s::Tuple{…}, n::Int64, atol::Nothing, rtol::Nothing, maxevals::Int64, nrm::QuadGKEnzymeExt.var"#96#103", _segbuf::Nothing, eval_segbuf::Vector{…})
    @ QuadGK ~/src/upstream-packages/QuadGK.jl/src/adapt.jl:35
 [10] (::QuadGK.var"#50#51"{Nothing, Nothing, Int64, Int64, QuadGKEnzymeExt.var"#96#103", Nothing, Vector{QuadGK.Segment{…}}})(f::Function, s::Tuple{Float64, Float64}, ::Function)
    @ QuadGK ~/src/upstream-packages/QuadGK.jl/src/api.jl:83
 [11] handle_infinities(workfunc::QuadGK.var"#50#51"{…}, f::QuadGKEnzymeExt.var"#94#101"{…}, s::Tuple{…})
    @ QuadGK ~/src/upstream-packages/QuadGK.jl/src/adapt.jl:189
 [12] quadgk(::QuadGKEnzymeExt.var"#94#101"{…}, ::Float64, ::Vararg{…}; atol::Nothing, rtol::Nothing, maxevals::Int64, order::Int64, norm::Function, segbuf::Nothing, eval_segbuf::Vector{…})
    @ QuadGK ~/src/upstream-packages/QuadGK.jl/src/api.jl:82
 [13] reverse(::EnzymeCore.EnzymeRules.ConfigWidth{…}, ::Const{…}, ::Active{…}, ::Tuple{…}, ::MixedDuplicated{…}, ::Active{…}, ::Vararg{…}; kws::@Kwargs{})
    @ QuadGKEnzymeExt ~/src/upstream-packages/QuadGK.jl/ext/QuadGKEnzymeExt.jl:126
 [14] macro expansion
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:7187 [inlined]
 [15] enzyme_call
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6794 [inlined]
 [16] AdjointThunk
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6677 [inlined]
 [17] runtime_generic_rev(activity::Type{…}, width::Val{…}, ModifiedBetween::Val{…}, tape::Enzyme.Compiler.Tape{…}, f::typeof(quadgk), df::Nothing, primal_1::var"#7#8"{…}, shadow_1_1::Base.RefValue{…}, primal_2::Float64, shadow_2_1::Base.RefValue{…}, primal_3::Float64, shadow_3_1::Base.RefValue{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/TiboG/src/rules/jitrules.jl:468
 [18] chebyshevintegral
    @ ~/issues/quadgkmixed.jl:5 [inlined]
 [19] chebyshevintegral
    @ ~/issues/quadgkmixed.jl:0 [inlined]
 [20] diffejulia_chebyshevintegral_9738_inner_1wrap
    @ ~/issues/quadgkmixed.jl:0
 [21] macro expansion
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:7187 [inlined]
 [22] enzyme_call
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6794 [inlined]
 [23] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/TiboG/src/compiler.jl:6671 [inlined]
 [24] autodiff
    @ ~/.julia/packages/Enzyme/TiboG/src/Enzyme.jl:320 [inlined]
 [25] autodiff(::ReverseMode{false, FFIABI, false, false}, ::typeof(chebyshevintegral), ::Type{Active}, ::Const{IntervalSets.ClosedInterval{Float64}}, ::Duplicated{Vector{Float64}})
    @ Enzyme ~/.julia/packages/Enzyme/TiboG/src/Enzyme.jl:332
 [26] top-level scope
    @ ~/issues/quadgkmixed.jl:13
 [27] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [28] top-level scope
    @ REPL[5]:1
in expression starting at /home/daniel/issues/quadgkmixed.jl:13
Some type information was truncated. Use `show(err)` to see complete types.

EDIT: My first attempt was Rev -> rev, resulting in a gibberish comment, sorry about that

@wsmoses
Copy link
Contributor Author

wsmoses commented Sep 17, 2024

Well that was a dumb mistake, should've been Ref

@danielwe
Copy link
Contributor

With your latest change fixing the returns from the MixedClosureVector operations I get an answer, but it's incorrect:

julia> include("quadgkmixed.jl");
chebyshevintegral(domain, coeffs) = 2.0
dcoeffs = [0.0]

Should be dcoeffs = [2.0], as the integral is equal to 2coeffs[1]

@danielwe
Copy link
Contributor

danielwe commented Sep 17, 2024

Oh, I think I see what's happening. When you make_zero(f.val), you're also zeroing out the captured variables that are Const. In the case of ApproxFun.Fun, that means you're creating a Fun with a trivial domain 0.0..0.0, and currently the result of evaluating a Fun outside it's domain is that it returns zero (even though this is left undefined in the package docs).

So it seems like something more context-sensitive than make_zero is needed. Actually, does the MixedDuplicated type separate between Active and Const variables in its subparts at all?

EDIT: Separating between Active and Const probably isn't relevant here, neither should be zeroed out in this context.

@danielwe
Copy link
Contributor

Btw. the non-erroring but incorrect example used quadgk(x -> f(x), ...) rather than quadgk(f, ...) as in #119. Otherwise you get a MethodError for (::AdjointThunk{...})(::Fun{...}), but that's a separate issue---unless an appropriate rule for Fun would fix the issue with zeroing the domain.

I'll try to make an example that doesn't use ApproxFun.

@danielwe
Copy link
Contributor

danielwe commented Sep 17, 2024

The new ApproxFun-free MWE in #119 has the same issue explained above: make_zero zeroing out non-duplicated parts of the closure leading to incorrect/vanishing derivatives for the Duplicated parts.

MWE repeated for convenience:

using Enzyme, QuadGK

function polyintegral(coeffs, scale)
    f(x) = scale * evalpoly(x, coeffs)
    return first(quadgk(f, -1.0, 1.0))
end

coeffs = [1.0]
scale = 1.0
@show polyintegral(coeffs, scale)

dcoeffs = make_zero(coeffs)
autodiff(Reverse, polyintegral, Active, Duplicated(coeffs, dcoeffs), Const(scale))
@show dcoeffs

Output with the current state of the PR:

julia> include("quadgkmixed.jl");
polyintegral(coeffs, scale) = 2.0
dcoeffs = [0.0]

Should be dcoeffs = [2.0].

@danielwe
Copy link
Contributor

Trying to adapt your example to the Duplicated case, I'm realizing that Enzyme.Compiler.recursive_accumulate isn't as recursive as it perhaps should be. It doesn't seem to recurse into nested mutable containers. Maybe this is the real issue and the zeroing of Const/Active memory is a red herring.

Example:

julia> a, b = Ref([1.0]), Ref([1.0])
(Base.RefValue{Vector{Float64}}([1.0]), Base.RefValue{Vector{Float64}}([1.0]))

julia> Enzyme.Compiler.recursive_accumulate(a, b)

julia> a, b
(Base.RefValue{Vector{Float64}}([1.0]), Base.RefValue{Vector{Float64}}([1.0]))

@wsmoses
Copy link
Contributor Author

wsmoses commented Sep 17, 2024

Hm yeah I was hoping we had a utility that did the latter already, but it appears not.

I suppose the solution here is to extend it to do so, if you want to give it a whirl

@wsmoses
Copy link
Contributor Author

wsmoses commented Sep 17, 2024

Okay you may be able to use the same trick we do here: https://github.com/EnzymeAD/Enzyme.jl/blob/786a998f0dc5343703c5420eae40cb790575e218/src/Enzyme.jl#L297

make_zero of the existing shadow to fill the iddict of mutable locations, then in place accumulate all the leaf values

@danielwe
Copy link
Contributor

danielwe commented Sep 18, 2024

OK, here's a proof of concept that's correct in simple cases like the MWE:

function accumulate!(a::T, b::T, args...) where {T}
    anodes, bnodes = IdDict(), IdDict()
    Enzyme.make_zero(T, anodes, a)
    Enzyme.make_zero(T, bnodes, b)
    anodes_vector = sort!(collect(keys(anodes)); by=nameof  typeof)
    bnodes_vector = sort!(collect(keys(bnodes)); by=nameof  typeof)
    for (anode, bnode) in zip(anodes_vector, bnodes_vector)
        if ismutable(anode) && ismutable(bnode)
            Enzyme.Compiler.recursive_accumulate(anode, bnode, args...)
        end
    end
    return nothing
end

The problem is ensuring commensurate iteration orders for the iddicts. Here I'm simply sorting by type name, which will obviously break as soon as the captured variables contain two or more variables of the same type. If this hack using make_zero is the way we're going we'll need to cook up an IdDict that preserves insertion order, which I suppose could just be a thin wrapper around OrderedCollections.OrderedDict, so not a big hassle. OrderedCollections is already a transitive dependency of QuadGK.

It's unclear to me whether this will always be correct or if you can sometimes get double accumulation in deeply nested structures, since make_zero adds all values in the tree to the iddict, not just the leaves.

Finally, there's a question of how wasteful the extra allocations from make_zero are. The accumulation within a quadrature can be a pretty hot part of the code, but I haven't profiled to see if this becomes a bottleneck.

EDIT: moved parenthetical comment to review

Comment on lines +105 to +106
res = deepcopy(b)::CV
Enzyme.Compiler.recursive_accumulate(res, b, x->(a-1)*x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
res = deepcopy(b)::CV
Enzyme.Compiler.recursive_accumulate(res, b, x->(a-1)*x)
res = Enzyme.make_zero(b)::CV
Enzyme.Compiler.recursive_accumulate(res, b, x->a*x)

This addresses @stevengj's concern in #110 (comment) and can perhaps be ported to the ClosureVector equivalent too

_df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x
fshadow = Ref(Enzyme.make_zero(f.val))
tape, prim, shad = fwd(Const(call), MixedDuplicated(f.val, fshadow), Const(x))
drev = rev(Const(call), f, Const(x), dres.val[1], tape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why it's OK to use f and not MixedDuplicated(f.val, fshadow) in the reverse pass here? Both alternatives give correct results, but I can't wrap my head around why. When adapting the code for Duplicated I have to use Duplicated(f.val, fshadow) in both fwd and rev, otherwise I get incorrect results.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A typo lol

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry this was hacked on my phone during a cubs game

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's nuts ⚾ 🤯
...and leaves me even more confused about getting correct results on every test case I've tried so far, but glad to know my intuition was right

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants