From b6f6178ed91c3760da372fd1a3101ab0c0d1ccd1 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Tue, 17 Sep 2024 15:20:33 -0500 Subject: [PATCH 1/4] Add mixedduplicated support in enzyme ext --- ext/QuadGKEnzymeExt.jl | 35 +++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/ext/QuadGKEnzymeExt.jl b/ext/QuadGKEnzymeExt.jl index d10d21a..8074cf3 100644 --- a/ext/QuadGKEnzymeExt.jl +++ b/ext/QuadGKEnzymeExt.jl @@ -84,10 +84,31 @@ function Base.:*(a::ClosureVector, b::Number) return b*a end -function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f::Union{Const, Active}, segs::Annotation{T}...; kws...) where {T} +struct MixedClosureVector{F} + f::Base.RefValue{F} +end + +function Base.:+(a::CV, b::CV) where {CV <: MixedClosureVector} + Enzyme.Compiler.recursive_accumulate(a, b, identity)::CV +end + +function Base.:-(a::CV, b::CV) where {CV <: MixedClosureVector} + Enzyme.Compiler.recursive_accumulate(a, b, x->-x)::CV +end + +function Base.:*(a::Number, b::CV) where {CV <: MixedClosureVector} + # b + (a-1) * b = a * b + Enzyme.Compiler.recursive_accumulate(b, b, x->(a-1)*x)::CV +end + +function Base.:*(a::MixedClosureVector, b::Number) + return b*a +end + +function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f::Union{Const, Active, MixedDuplicated}, segs::Annotation{T}...; kws...) where {T} df = if f isa Const nothing - else + elseif f isa Active segbuf = cache[1] fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T}) _df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x @@ -96,6 +117,16 @@ function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres:: return ClosureVector(drev[1][1]) end _df.f + elseif f isa MixedDuplicated + segbuf = cache[1] + fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T}) + _df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x + fshadow = Rev(Enzyme.make_zero(f.val)) + tape, prim, shad = fwd(Const(call), MixedDuplicated(f.val, fshadow), Const(x)) + drev = rev(Const(call), f, Const(x), dres.val[1], tape) + return MixedClosureVector(fshadow) + end + _df.f[] end dsegs1 = segs[1] isa Const ? nothing : -LinearAlgebra.dot(f.val(segs[1].val), dres.val[1]) dsegsn = segs[end] isa Const ? nothing : LinearAlgebra.dot(f.val(segs[end].val), dres.val[1]) From 5bd4dbf108d2c81ad82cc9332db9488b8d53a6e8 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Tue, 17 Sep 2024 15:29:14 -0500 Subject: [PATCH 2/4] fixup --- ext/QuadGKEnzymeExt.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ext/QuadGKEnzymeExt.jl b/ext/QuadGKEnzymeExt.jl index 8074cf3..2c24fca 100644 --- a/ext/QuadGKEnzymeExt.jl +++ b/ext/QuadGKEnzymeExt.jl @@ -89,16 +89,19 @@ struct MixedClosureVector{F} end function Base.:+(a::CV, b::CV) where {CV <: MixedClosureVector} - Enzyme.Compiler.recursive_accumulate(a, b, identity)::CV + res = deepcopy(a)::CV + Enzyme.Compiler.recursive_accumulate(res, b, identity)::CV end function Base.:-(a::CV, b::CV) where {CV <: MixedClosureVector} - Enzyme.Compiler.recursive_accumulate(a, b, x->-x)::CV + res = deepcopy(a)::CV + Enzyme.Compiler.recursive_accumulate(res, b, x->-x)::CV end function Base.:*(a::Number, b::CV) where {CV <: MixedClosureVector} # b + (a-1) * b = a * b - Enzyme.Compiler.recursive_accumulate(b, b, x->(a-1)*x)::CV + res = deepcopy(b)::CV + Enzyme.Compiler.recursive_accumulate(res, b, x->(a-1)*x)::CV end function Base.:*(a::MixedClosureVector, b::Number) @@ -126,7 +129,8 @@ function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres:: drev = rev(Const(call), f, Const(x), dres.val[1], tape) return MixedClosureVector(fshadow) end - _df.f[] + Enzyme.Compiler.recursive_accumulate(f.dval, _df.f) + nothing end dsegs1 = segs[1] isa Const ? nothing : -LinearAlgebra.dot(f.val(segs[1].val), dres.val[1]) dsegsn = segs[end] isa Const ? nothing : LinearAlgebra.dot(f.val(segs[end].val), dres.val[1]) From e87bd848d8c20325f63060bcf4042853681e3db0 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 17 Sep 2024 16:47:56 -0500 Subject: [PATCH 3/4] Update QuadGKEnzymeExt.jl --- ext/QuadGKEnzymeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/QuadGKEnzymeExt.jl b/ext/QuadGKEnzymeExt.jl index 2c24fca..a59ccea 100644 --- a/ext/QuadGKEnzymeExt.jl +++ b/ext/QuadGKEnzymeExt.jl @@ -124,7 +124,7 @@ function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres:: segbuf = cache[1] fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T}) _df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x - fshadow = Rev(Enzyme.make_zero(f.val)) + fshadow = Ref(Enzyme.make_zero(f.val)) tape, prim, shad = fwd(Const(call), MixedDuplicated(f.val, fshadow), Const(x)) drev = rev(Const(call), f, Const(x), dres.val[1], tape) return MixedClosureVector(fshadow) From ab12c5bd2a5878f453dec06b697560eb580ce38f Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 17 Sep 2024 16:48:40 -0500 Subject: [PATCH 4/4] Update QuadGKEnzymeExt.jl --- ext/QuadGKEnzymeExt.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ext/QuadGKEnzymeExt.jl b/ext/QuadGKEnzymeExt.jl index a59ccea..9d4c9b7 100644 --- a/ext/QuadGKEnzymeExt.jl +++ b/ext/QuadGKEnzymeExt.jl @@ -90,18 +90,21 @@ end function Base.:+(a::CV, b::CV) where {CV <: MixedClosureVector} res = deepcopy(a)::CV - Enzyme.Compiler.recursive_accumulate(res, b, identity)::CV + Enzyme.Compiler.recursive_accumulate(res, b, identity) + res end function Base.:-(a::CV, b::CV) where {CV <: MixedClosureVector} res = deepcopy(a)::CV - Enzyme.Compiler.recursive_accumulate(res, b, x->-x)::CV + Enzyme.Compiler.recursive_accumulate(res, b, x->-x) + res end function Base.:*(a::Number, b::CV) where {CV <: MixedClosureVector} # b + (a-1) * b = a * b res = deepcopy(b)::CV - Enzyme.Compiler.recursive_accumulate(res, b, x->(a-1)*x)::CV + Enzyme.Compiler.recursive_accumulate(res, b, x->(a-1)*x) + res end function Base.:*(a::MixedClosureVector, b::Number)