Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mixedduplicated support in enzyme ext #120

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions ext/QuadGKEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,37 @@
return b*a
end

function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f::Union{Const, Active}, segs::Annotation{T}...; kws...) where {T}
struct MixedClosureVector{F}
f::Base.RefValue{F}
end

function Base.:+(a::CV, b::CV) where {CV <: MixedClosureVector}
res = deepcopy(a)::CV
Enzyme.Compiler.recursive_accumulate(res, b, identity)
res

Check warning on line 94 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L91-L94

Added lines #L91 - L94 were not covered by tests
end

function Base.:-(a::CV, b::CV) where {CV <: MixedClosureVector}
res = deepcopy(a)::CV
Enzyme.Compiler.recursive_accumulate(res, b, x->-x)
res

Check warning on line 100 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L97-L100

Added lines #L97 - L100 were not covered by tests
end

function Base.:*(a::Number, b::CV) where {CV <: MixedClosureVector}

Check warning on line 103 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L103

Added line #L103 was not covered by tests
# b + (a-1) * b = a * b
res = deepcopy(b)::CV
Enzyme.Compiler.recursive_accumulate(res, b, x->(a-1)*x)
Comment on lines +105 to +106
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
res = deepcopy(b)::CV
Enzyme.Compiler.recursive_accumulate(res, b, x->(a-1)*x)
res = Enzyme.make_zero(b)::CV
Enzyme.Compiler.recursive_accumulate(res, b, x->a*x)

This addresses @stevengj's concern in #110 (comment) and can perhaps be ported to the ClosureVector equivalent too

res

Check warning on line 107 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L105-L107

Added lines #L105 - L107 were not covered by tests
end

function Base.:*(a::MixedClosureVector, b::Number)
return b*a

Check warning on line 111 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L110-L111

Added lines #L110 - L111 were not covered by tests
end

function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f::Union{Const, Active, MixedDuplicated}, segs::Annotation{T}...; kws...) where {T}

Check warning on line 114 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L114

Added line #L114 was not covered by tests
df = if f isa Const
nothing
else
elseif f isa Active

Check warning on line 117 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L117

Added line #L117 was not covered by tests
segbuf = cache[1]
fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T})
_df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x
Expand All @@ -96,6 +123,17 @@
return ClosureVector(drev[1][1])
end
_df.f
elseif f isa MixedDuplicated
segbuf = cache[1]
fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T})
_df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x
fshadow = Ref(Enzyme.make_zero(f.val))
tape, prim, shad = fwd(Const(call), MixedDuplicated(f.val, fshadow), Const(x))
drev = rev(Const(call), f, Const(x), dres.val[1], tape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why it's OK to use f and not MixedDuplicated(f.val, fshadow) in the reverse pass here? Both alternatives give correct results, but I can't wrap my head around why. When adapting the code for Duplicated I have to use Duplicated(f.val, fshadow) in both fwd and rev, otherwise I get incorrect results.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A typo lol

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry this was hacked on my phone during a cubs game

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's nuts ⚾ 🤯
...and leaves me even more confused about getting correct results on every test case I've tried so far, but glad to know my intuition was right

return MixedClosureVector(fshadow)

Check warning on line 133 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L126-L133

Added lines #L126 - L133 were not covered by tests
end
Enzyme.Compiler.recursive_accumulate(f.dval, _df.f)
nothing

Check warning on line 136 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L135-L136

Added lines #L135 - L136 were not covered by tests
end
dsegs1 = segs[1] isa Const ? nothing : -LinearAlgebra.dot(f.val(segs[1].val), dres.val[1])
dsegsn = segs[end] isa Const ? nothing : LinearAlgebra.dot(f.val(segs[end].val), dres.val[1])
Expand Down
Loading