Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enzyme is dropping gradients when using custom rule and views #1856

Open
ptiede opened this issue Sep 18, 2024 · 17 comments
Open

Enzyme is dropping gradients when using custom rule and views #1856

ptiede opened this issue Sep 18, 2024 · 17 comments

Comments

@ptiede
Copy link
Contributor

ptiede commented Sep 18, 2024

I've managed to get into a situation where a custom rule seems incorrect when used within a loop. Here is a MWE

vl = [2, 2]
tot = sum(vl)
rg = [1:2, 3:4]
Nx = 4

iminds = reshape([CartesianIndex(i) for i in 1:2], :)
visinds = [collect(rg[i]) for i in eachindex(rg)]
Bs = Dict((iminds[i]=> ones(ComplexF64, vl[i], Nx*Nx) for i in eachindex(vl)))

function _mul!(out, A, b)
    mul!(out, A, b)
    return nothing 
end


function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(_mul!)}, ::Type{<:Const}, out::Duplicated,
                                      A::Const{<:Matrix}, b::Duplicated)
    _mul!(out.val, A.val, b.val)
    return EnzymeRules.AugmentedReturn(nothing, nothing, nothing)
end

@noinline function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
                                       ::Const{typeof(_mul!)},
                                       ::Type{<:Const}, tape, out::Duplicated, A::Const{<:Matrix},
                                       b::Duplicated)

    b.dval .+= real.(A.val' * out.dval)
    out.dval .= 0
    return (nothing, nothing, nothing)
end

@inline function f(Bs, visinds, iminds, tot, x)
    out = similar(x, Complex{eltype(x)}, tot)
    for i in eachindex(iminds, visinds)
        imind = iminds[i]
        visind = visinds[i]
        _mul!(@view(out[visind]), Bs[imind], reshape(@view(x[:, :, imind]), :))
    end
    return sum(abs2, out)
end

x = ones(Nx, Nx, 2)
dx = zero(x)

f(Bs, visinds, iminds, tot, x)
autodiff(set_runtime_activity(Reverse), f, Active, Const(Bs), Const(visinds), Const(iminds), Const(tot), Duplicated(x, fill!(dx, 0)))
@show dx
# dx = [0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0;;; 64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0]

So the gradients for x[:,:,1] are zero'd rather than also filled with 64. Funny enough if I change vl to be different e.g.,

vl = [2, 3]
tot = sum(vl)
rg = [1:2, 3:5]
Nx = 4

and keep everything else identical I get a out of bounds error

ERROR: DimensionMismatch: matrix A has dimensions (16,2), vector B has length 3
  [1] _generic_matvecmul!
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:697
  [2] generic_matvecmul!
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:687 [inlined]
  [3] mul!
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:66 [inlined]
  [4] mul!
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:237 [inlined]
  [5] *
    @ ~/.julia/juliaup/julia-1.10.5+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:57
  [6] reverse
    @ ~/Research/Enzyme/dft.jl:32 [inlined]
  [7] f
    @ ~/Research/Enzyme/dft.jl:42 [inlined]
  [8] diffejulia_f_7370wrap
    @ ~/Research/Enzyme/dft.jl:0
  [9] macro expansion
    @ ~/.julia/packages/Enzyme/0PGOL/src/compiler.jl:7045 [inlined]
 [10] enzyme_call
    @ ~/.julia/packages/Enzyme/0PGOL/src/compiler.jl:6648 [inlined]
 [11] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/0PGOL/src/compiler.jl:6525 [inlined]
 [12] autodiff
    @ ~/.julia/packages/Enzyme/0PGOL/src/Enzyme.jl:316 [inlined]
 [13] autodiff(::ReverseMode{…}, ::typeof(f), ::Type{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Const{…}, ::Duplicated{…})
    @ Enzyme ~/.julia/packages/Enzyme/0PGOL/src/Enzyme.jl:328
 [14] top-level scope
    @ ~/Research/Enzyme/dft.jl:51

so it looks like the shadow is incorrect.

Note that when acting on a single array the rule is correct, e.g.,

function fsimple(Bs, x)
    out = similar(Bs, Complex{eltype(x)}, size(Bs, 1))
    _mul!(out, Bs, reshape(x, :))
    return sum(abs2, out)
end

B = ones(ComplexF64, 2, 16)
xx = ones(4,4)
dxx = zero(xx)
autodiff(set_runtime_activity(Reverse), fsimple, Active, Const(B), Duplicated(xx, dxx))
@show dxx
# dxx = [64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0; 64.0 64.0 64.0 64.0]
@ptiede
Copy link
Contributor Author

ptiede commented Sep 18, 2024

Oh also this is on Enzyme#main

(Enzyme) pkg> st
Status `~/Research/Enzyme/Project.toml`
  [7da242da] Enzyme v0.13.0 `https://github.com/EnzymeAD/Enzyme.jl.git#main`
  [f151be2c] EnzymeCore v0.8.0 `https://github.com/EnzymeAD/Enzyme.jl.git:lib/EnzymeCore#main`

@wsmoses
Copy link
Member

wsmoses commented Sep 19, 2024

@ptiede can you simplify this? I tried to do so but I really don't understand what is supposed to be happening

@ptiede
Copy link
Contributor Author

ptiede commented Sep 20, 2024

I've reduced it a bit. It is essentially just a loop over a matvec where the vec is a view. Let me know if this isn't enough.

using Enzyme
using EnzymeCore: EnzymeRules
using LinearAlgebra

Nx = 2
Bs = [ones(ComplexF64, Nx, Nx*Nx) for i in 1:Nx]

function _mul!(out, A, b)
    mul!(out, A, b)
    return nothing 
end

function _mul(A, b)
    out = similar(A, size(A, 1))
    _mul!(out, A, b)
    return out
end


function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(_mul!)}, ::Type{<:Const}, out::Duplicated,
                                      A::Const{<:Matrix}, b::Duplicated)
    _mul!(out.val, A.val, b.val)
    return EnzymeRules.AugmentedReturn(nothing, nothing, nothing)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(_mul!)},
    ::Type{<:Const}, tape, out::Duplicated, A::Const{<:Matrix},
    b::Duplicated)

     b.dval .+= real.(A.val' * out.dval)
     out.dval .= 0
    return (nothing, nothing, nothing)
end


@inline function f2(Nx, Bs, x)
    s = zero(eltype(x))
    for i in 1:Nx
        @inbounds s += sum(abs2, _mul(Bs[i], reshape(@view(x[:, :, i]), :)))
    end
    return s
end



x = ones(Nx, Nx, 2)
dx = zero(x)

f2(Nx, Bs, x)
autodiff(set_runtime_activity(Reverse), f2, Active, Const(Nx), Const(Bs), Duplicated(x, fill!(dx, 0)))
@show dx
# dx = [0.0 0.0; 0.0 0.0;;; 32.0 32.0; 32.0 32.0]

# correct answer is fill(16.0, Nx, Nx, Nx)

@ptiede
Copy link
Contributor Author

ptiede commented Sep 20, 2024

Further reduction

using Enzyme, LinearAlgebra
Nx = 2
Bs = ones(ComplexF64, Nx, Nx*Nx)

function _mul!(out, A, b)
    mul!(out, A, b)
    return nothing 
end

function _mul(A, b)
    out = similar(A, size(A, 1))
    _mul!(out, A, b)
    return out
end


function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(_mul!)}, ::Type{<:Const}, out::Duplicated,
                                      A::Const{<:Matrix}, b::Duplicated)
    _mul!(out.val, A.val, b.val)
    return EnzymeRules.AugmentedReturn(nothing, nothing, nothing)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(_mul!)},
    ::Type{<:Const}, tape, out::Duplicated, A::Const{<:Matrix},
    b::Duplicated)

b.dval .+= real.(A.val' * out.dval)
out.dval .= 0
return (nothing, nothing, nothing)
end


@inline function f2(Nx, Bs, x)
    s = zero(eltype(x))
    for i in 1:Nx
        @inbounds s += sum(abs2, _mul(Bs, reshape(@view(x[:, :, i]), :)))
    end
    return s
end



x = ones(Nx, Nx, 2)
dx = zero(x)

f2(Nx, Bs, x)
autodiff(set_runtime_activity(Reverse), f2, Active, Const(Nx), Const(Bs), Duplicated(x, fill!(dx, 0)))
@show dx

@wsmoses
Copy link
Member

wsmoses commented Sep 24, 2024

what is this supposed to give/why is it wrong?

For one thing you're storing into constant data (Bs) which means that it has a zero derivative guaranteed. see https://enzyme.mit.edu/index.fcgi/julia/stable/faq/#Activity-of-temporary-storage

@wsmoses wsmoses closed this as completed Sep 24, 2024
@wsmoses wsmoses reopened this Sep 25, 2024
@wsmoses
Copy link
Member

wsmoses commented Sep 25, 2024

using Enzyme, LinearAlgebra
using EnzymeCore: EnzymeRules

function _mul!(b)
    return b[1]*b[1] 
end

function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(_mul!)}, ::Type{<:Active},
                                      b::Duplicated)
    return EnzymeRules.AugmentedReturn(EnzymeRules.needs_primal(config) ? _mul!(b.val) : nothing, nothing, nothing)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(_mul!)},
    out, tape,  b::Duplicated)
    @show b.val
    @show b.val.indices, b.val.offset1, b.val.parent, b.val.stride1
    @show b.dval.indices, b.dval.offset1, b.dval.parent, b.dval.stride1
    b.dval[1] += 2 * b.val[1] * out.val
    @show b.dval
    @show out
    return (nothing,)
end


@inline function f2(b)
    s = zero(eltype(b))
    for i in 1:size(b, 2)
        out = _mul!(@view(b[:, i]))
        s += out
    end
    return s
end


Nx = 1
Ny = 2
b = [1.0 4.0]
db = zero(b)

f2(b)
autodiff(Reverse, f2, Active, Duplicated(b, fill!(db, 0)))
# What Enzyme returns db = [0  16]
# What the correct answer is db = [2 8]

@ptiede
Copy link
Contributor Author

ptiede commented Sep 25, 2024

after simplification :
; Function Attrs: mustprogress willreturn
define "enzyme_type"="{[-1]:Float@double}" double @preprocess_julia_f2_11891({} addrspace(10)* noundef nonnull align 16 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@double, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="125879341910352" "enzymejl_parmtype_ref"="2" %0) local_unnamed_addr #11 !dbg !154 {
top:
  %newstruct = alloca { [1 x [1 x i64]], i64 }, align 8
  %1 = alloca { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, align 8
  %2 = call {}*** @julia.get_pgcstack() #12
  %ptls_field31 = getelementptr inbounds {}**, {}*** %2, i64 2
  %3 = bitcast {}*** %ptls_field31 to i64***
  %ptls_load3233 = load i64**, i64*** %3, align 8, !tbaa !11
  %4 = getelementptr inbounds i64*, i64** %ptls_load3233, i64 2
  %safepoint = load i64*, i64** %4, align 8, !tbaa !15
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #12, !dbg !155
  fence syncscope("singlethread") seq_cst
  %5 = addrspacecast {} addrspace(10)* %0 to {} addrspace(10)* addrspace(11)*, !dbg !156
  %arraysize_ptr = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %5, i64 4, !dbg !156
  %6 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr to i64 addrspace(11)*, !dbg !156
  %arraysize = load i64, i64 addrspace(11)* %6, align 16, !dbg !156, !tbaa !15, !range !22, !alias.scope !23, !noalias !26, !enzyme_inactive !10
  %.not = icmp eq i64 %arraysize, 0, !dbg !158
  br i1 %.not, label %L67, label %L18.preheader, !dbg !157

L18.preheader:                                    ; preds = %top
  %arraysize_ptr8 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %5, i64 3
  %7 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr8 to i64 addrspace(11)*
  %memcpy_refined_dst = getelementptr inbounds { [1 x [1 x i64]], i64 }, { [1 x [1 x i64]], i64 }* %newstruct, i64 0, i32 0, i64 0, i64 0
  %8 = getelementptr inbounds { [1 x [1 x i64]], i64 }, { [1 x [1 x i64]], i64 }* %newstruct, i64 0, i32 1
  %.fca.0.gep = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %1, i64 0, i32 0
  %.fca.1.0.0.0.gep = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %1, i64 0, i32 1, i32 0, i64 0, i64 0
  %.fca.1.1.gep = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %1, i64 0, i32 1, i32 1
  %.fca.2.gep = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %1, i64 0, i32 2
  %.fca.3.gep = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %1, i64 0, i32 3
  %9 = addrspacecast { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %1 to { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)*
  br label %L18, !dbg !162

L18:                                              ; preds = %L42, %L18.preheader
  %iv = phi i64 [ %iv.next, %L42 ], [ 0, %L18.preheader ]
  %value_phi7 = phi double [ %14, %L42 ], [ 0.000000e+00, %L18.preheader ]
  %iv.next = add nuw nsw i64 %iv, 1, !dbg !165
  %arraysize9 = load i64, i64 addrspace(11)* %7, align 8, !dbg !165, !tbaa !15, !range !22, !alias.scope !23, !noalias !26, !enzyme_inactive !10
  %arraysize12 = load i64, i64 addrspace(11)* %6, align 16, !dbg !169, !tbaa !15, !range !22, !alias.scope !23, !noalias !26, !enzyme_inactive !10
  %10 = add nsw i64 %iv.next, -1, !dbg !172
  %.not34 = icmp ult i64 %10, %arraysize12, !dbg !176
  br i1 %.not34, label %L42, label %L39, !dbg !162

L39:                                              ; preds = %L18
  store i64 %arraysize9, i64* %memcpy_refined_dst, align 8, !dbg !177, !tbaa !68, !alias.scope !70, !noalias !178
  store i64 %iv.next, i64* %8, align 8, !dbg !177, !tbaa !68, !alias.scope !70, !noalias !178
  %11 = addrspacecast { [1 x [1 x i64]], i64 }* %newstruct to { [1 x [1 x i64]], i64 } addrspace(11)*, !dbg !162
  call fastcc void @julia_throw_boundserror_11895({} addrspace(10)* nofree noundef nonnull align 16 dereferenceable(40) %0, { [1 x [1 x i64]], i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(16) %11) #13, !dbg !162
  unreachable, !dbg !162

L42:                                              ; preds = %L18
  %12 = mul i64 %arraysize9, %10, !dbg !181
  store {} addrspace(10)* %0, {} addrspace(10)** %.fca.0.gep, align 8, !dbg !164, !noalias !191
  store i64 %arraysize9, i64* %.fca.1.0.0.0.gep, align 8, !dbg !164, !noalias !191
  store i64 %iv.next, i64* %.fca.1.1.gep, align 8, !dbg !164, !noalias !191
  store i64 %12, i64* %.fca.2.gep, align 8, !dbg !164, !noalias !191
  store i64 1, i64* %.fca.3.gep, align 8, !dbg !164, !noalias !191
  %13 = call double @julia__mul__11897({ {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(40) %9) #12, !dbg !164
  %14 = fadd double %value_phi7, %13, !dbg !192
  %.not35 = icmp eq i64 %iv.next, %arraysize, !dbg !194
  %15 = add nuw nsw i64 %iv.next, 1, !dbg !195
  br i1 %.not35, label %L67.loopexit, label %L18, !dbg !196

L67.loopexit:                                     ; preds = %L42
  store i64 %arraysize9, i64* %memcpy_refined_dst, align 8, !dbg !177, !tbaa !68, !alias.scope !70, !noalias !178
  store i64 %arraysize, i64* %8, align 8, !dbg !177, !tbaa !68, !alias.scope !70, !noalias !178
  br label %L67, !dbg !197

L67:                                              ; preds = %L67.loopexit, %top
  %value_phi23 = phi double [ 0.000000e+00, %top ], [ %14, %L67.loopexit ]
  ret double %value_phi23, !dbg !197
}

; Function Attrs: mustprogress willreturn
define internal void @diffejulia_f2_11891({} addrspace(10)* noundef nonnull align 16 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@double, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="125879341910352" "enzymejl_parmtype_ref"="2" %0, {} addrspace(10)* align 16 "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@double, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="125879341910352" "enzymejl_parmtype_ref"="2" %"'", double %differeturn) local_unnamed_addr #11 !dbg !205 {
top:
  %1 = alloca { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, align 8
  %2 = alloca { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, align 8
  %newstruct.i = alloca [1 x i64], align 8
  %newstruct12.i = alloca [1 x i64], align 8
  %"iv'ac" = alloca i64, align 8
  %"value_phi23'de" = alloca double, align 8
  %3 = getelementptr double, double* %"value_phi23'de", i64 0
  store double 0.000000e+00, double* %3, align 8
  %"value_phi7'de" = alloca double, align 8
  %4 = getelementptr double, double* %"value_phi7'de", i64 0
  store double 0.000000e+00, double* %4, align 8
  %"'de" = alloca double, align 8
  %5 = getelementptr double, double* %"'de", i64 0
  store double 0.000000e+00, double* %5, align 8
  %"'de2" = alloca double, align 8
  %6 = getelementptr double, double* %"'de2", i64 0
  store double 0.000000e+00, double* %6, align 8
  %7 = call {}*** @julia.get_pgcstack()
  %8 = call {}*** @julia.get_pgcstack()
  %9 = call {}*** @julia.get_pgcstack()
  %newstruct = alloca { [1 x [1 x i64]], i64 }, align 8
  %"'ipa" = alloca { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, align 8
  store { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } zeroinitializer, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %"'ipa", align 8
  %10 = alloca { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, align 8
  %11 = call {}*** @julia.get_pgcstack() #20
  %ptls_field31 = getelementptr inbounds {}**, {}*** %11, i64 2
  %12 = bitcast {}*** %ptls_field31 to i64***
  %ptls_load3233 = load i64**, i64*** %12, align 8, !tbaa !18, !alias.scope !206, !noalias !209
  %13 = getelementptr inbounds i64*, i64** %ptls_load3233, i64 2
  %safepoint = load i64*, i64** %13, align 8, !tbaa !22, !alias.scope !211, !noalias !214
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #20, !dbg !216
  fence syncscope("singlethread") seq_cst
  %14 = addrspacecast {} addrspace(10)* %0 to {} addrspace(10)* addrspace(11)*, !dbg !217
  %arraysize_ptr = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %14, i64 4, !dbg !217
  %15 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr to i64 addrspace(11)*, !dbg !217
  %arraysize = load i64, i64 addrspace(11)* %15, align 16, !dbg !217, !tbaa !22, !range !29, !alias.scope !219, !noalias !222, !enzyme_inactive !17
  %.not = icmp eq i64 %arraysize, 0, !dbg !224
  br i1 %.not, label %L67, label %L18.preheader, !dbg !218

L18.preheader:                                    ; preds = %top
  %arraysize_ptr8 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %14, i64 3
  %16 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr8 to i64 addrspace(11)*
  %memcpy_refined_dst = getelementptr inbounds { [1 x [1 x i64]], i64 }, { [1 x [1 x i64]], i64 }* %newstruct, i64 0, i32 0, i64 0, i64 0
  %17 = getelementptr inbounds { [1 x [1 x i64]], i64 }, { [1 x [1 x i64]], i64 }* %newstruct, i64 0, i32 1
  %".fca.0.gep'ipg" = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %"'ipa", i64 0, i32 0
  %.fca.0.gep = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %10, i64 0, i32 0
  %".fca.1.0.0.0.gep'ipg" = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %"'ipa", i64 0, i32 1, i32 0, i64 0, i64 0
  %.fca.1.0.0.0.gep = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %10, i64 0, i32 1, i32 0, i64 0, i64 0
  %".fca.1.1.gep'ipg" = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %"'ipa", i64 0, i32 1, i32 1
  %.fca.1.1.gep = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %10, i64 0, i32 1, i32 1
  %".fca.2.gep'ipg" = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %"'ipa", i64 0, i32 2
  %.fca.2.gep = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %10, i64 0, i32 2
  %".fca.3.gep'ipg" = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %"'ipa", i64 0, i32 3
  %.fca.3.gep = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %10, i64 0, i32 3
  %"'ipc" = addrspacecast { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %"'ipa" to { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)*
  %18 = addrspacecast { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %10 to { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)*
  %19 = add nsw i64 %arraysize, -1, !dbg !228
  br label %L18, !dbg !228

L18:                                              ; preds = %L42, %L18.preheader
  %iv = phi i64 [ %iv.next, %L42 ], [ 0, %L18.preheader ]
  %iv.next = add nuw nsw i64 %iv, 1, !dbg !231
  %arraysize9 = load i64, i64 addrspace(11)* %16, align 8, !dbg !231, !tbaa !22, !range !29, !alias.scope !219, !noalias !222, !enzyme_inactive !17
  %arraysize12 = load i64, i64 addrspace(11)* %15, align 16, !dbg !235, !tbaa !22, !range !29, !alias.scope !219, !noalias !222, !enzyme_inactive !17
  %20 = add nsw i64 %iv.next, -1, !dbg !238
  %.not34 = icmp ult i64 %20, %arraysize12, !dbg !242
  br i1 %.not34, label %L42, label %L39, !dbg !228

L39:                                              ; preds = %L18
  store i64 %arraysize9, i64* %memcpy_refined_dst, align 8, !dbg !243, !tbaa !75, !alias.scope !77, !noalias !244
  store i64 %iv.next, i64* %17, align 8, !dbg !243, !tbaa !75, !alias.scope !77, !noalias !244
  %21 = addrspacecast { [1 x [1 x i64]], i64 }* %newstruct to { [1 x [1 x i64]], i64 } addrspace(11)*, !dbg !228
  call fastcc void @julia_throw_boundserror_11895({} addrspace(10)* nofree noundef nonnull align 16 dereferenceable(40) %0, { [1 x [1 x i64]], i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(16) %21) #21, !dbg !228
  unreachable, !dbg !228

L42:                                              ; preds = %L18
  %22 = mul i64 %arraysize9, %20, !dbg !247
  store {} addrspace(10)* %"'", {} addrspace(10)** %".fca.0.gep'ipg", align 8, !dbg !230, !alias.scope !257, !noalias !260
  store {} addrspace(10)* %0, {} addrspace(10)** %.fca.0.gep, align 8, !dbg !230, !alias.scope !262, !noalias !263
  store i64 %arraysize9, i64* %".fca.1.0.0.0.gep'ipg", align 8, !dbg !230, !alias.scope !257, !noalias !260
  store i64 %arraysize9, i64* %.fca.1.0.0.0.gep, align 8, !dbg !230, !alias.scope !262, !noalias !263
  store i64 %iv.next, i64* %".fca.1.1.gep'ipg", align 8, !dbg !230, !alias.scope !257, !noalias !260
  store i64 %iv.next, i64* %.fca.1.1.gep, align 8, !dbg !230, !alias.scope !262, !noalias !263
  store i64 %22, i64* %".fca.2.gep'ipg", align 8, !dbg !230, !alias.scope !257, !noalias !260
  store i64 %22, i64* %.fca.2.gep, align 8, !dbg !230, !alias.scope !262, !noalias !263
  store i64 1, i64* %".fca.3.gep'ipg", align 8, !dbg !230, !alias.scope !257, !noalias !260
  store i64 1, i64* %.fca.3.gep, align 8, !dbg !230, !alias.scope !262, !noalias !263
  %23 = bitcast {}*** %9 to {}**, !dbg !230
  %24 = getelementptr inbounds {}*, {}** %23, i64 -14, !dbg !230
  %25 = getelementptr inbounds {}*, {}** %24, i64 16, !dbg !230
  %26 = bitcast {}** %25 to i8**, !dbg !230
  %27 = load i8*, i8** %26, align 8, !dbg !230
  %28 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %24, i64 80, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 125882009936336 to {}*) to {} addrspace(10)*)), !dbg !230
  %29 = bitcast {} addrspace(10)* %28 to [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(10)*, !dbg !230
  %30 = addrspacecast [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(10)* %29 to [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)*, !dbg !230
  %31 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %30, i64 0, i32 0, !dbg !230
  %32 = load { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* %18, align 8, !dbg !230
  %33 = load { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* %"'ipc", align 8, !dbg !230
  store { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } %32, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* %31, align 8, !dbg !230
  %34 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %30, i64 0, i32 1, !dbg !230
  store { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } %33, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* %34, align 8, !dbg !230
  %35 = extractvalue { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } %32, 0, !dbg !230
  %36 = extractvalue { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } %33, 0, !dbg !230
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %28, {} addrspace(10)* %35, {} addrspace(10)* %36), !dbg !230
  %37 = call {}*** @julia.get_pgcstack()
  %ptls_field3.i = getelementptr inbounds {}**, {}*** %37, i64 2
  %38 = bitcast {}*** %ptls_field3.i to i64***
  %ptls_load45.i = load i64**, i64*** %38, align 8, !tbaa !18
  %39 = getelementptr inbounds i64*, i64** %ptls_load45.i, i64 2
  %safepoint.i = load i64*, i64** %39, align 8, !tbaa !22
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint.i), !dbg !264
  fence syncscope("singlethread") seq_cst
  %.not35 = icmp eq i64 %iv.next, %arraysize, !dbg !267
  br i1 %.not35, label %L67.loopexit, label %L18, !dbg !269

L67.loopexit:                                     ; preds = %L42
  store i64 %arraysize9, i64* %memcpy_refined_dst, align 8, !dbg !243, !tbaa !75, !alias.scope !77, !noalias !244
  store i64 %arraysize, i64* %17, align 8, !dbg !243, !tbaa !75, !alias.scope !77, !noalias !244
  br label %L67, !dbg !270

L67:                                              ; preds = %L67.loopexit, %top
  br label %invertL67, !dbg !270

inverttop:                                        ; preds = %invertL67, %invertL18.preheader
  fence syncscope("singlethread") seq_cst
  fence syncscope("singlethread") seq_cst
  ret void

invertL18.preheader:                              ; preds = %invertL18
  br label %inverttop

invertL18:                                        ; preds = %julia_reverse_11904.exit
  %40 = load double, double* %"value_phi7'de", align 8
  store double 0.000000e+00, double* %"value_phi7'de", align 8
  %41 = load i64, i64* %"iv'ac", align 8
  %42 = icmp eq i64 %41, 0
  %43 = xor i1 %42, true
  %44 = select fast i1 %43, double %40, double 0.000000e+00
  %45 = load double, double* %"'de", align 8
  %46 = fadd fast double %45, %40
  %47 = select fast i1 %42, double %45, double %46
  store double %47, double* %"'de", align 8
  br i1 %42, label %invertL18.preheader, label %incinvertL18

incinvertL18:                                     ; preds = %invertL18
  %48 = load i64, i64* %"iv'ac", align 8
  %49 = add nsw i64 %48, -1
  store i64 %49, i64* %"iv'ac", align 8
  br label %invertL42

invertL42:                                        ; preds = %mergeinvertL18_L67.loopexit, %incinvertL18
  %50 = load double, double* %"'de", align 8, !dbg !271
  store double 0.000000e+00, double* %"'de", align 8, !dbg !271
  %51 = load double, double* %"value_phi7'de", align 8, !dbg !271
  %52 = fadd fast double %51, %50, !dbg !271
  store double %52, double* %"value_phi7'de", align 8, !dbg !271
  %53 = load double, double* %"'de2", align 8, !dbg !271
  %54 = fadd fast double %53, %50, !dbg !271
  store double %54, double* %"'de2", align 8, !dbg !271
  %55 = load i64, i64* %"iv'ac", align 8, !dbg !230
  %_unwrap = addrspacecast { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %10 to { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)*, !dbg !230
  %56 = load i64, i64* %"iv'ac", align 8, !dbg !230
  %"'ipc_unwrap" = addrspacecast { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }* %"'ipa" to { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)*, !dbg !230
  %57 = bitcast {}*** %8 to {}**, !dbg !230
  %58 = getelementptr inbounds {}*, {}** %57, i64 -14, !dbg !230
  %59 = getelementptr inbounds {}*, {}** %58, i64 16, !dbg !230
  %60 = bitcast {}** %59 to i8**, !dbg !230
  %61 = load i8*, i8** %60, align 8, !dbg !230
  %62 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %58, i64 80, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 125882009936336 to {}*) to {} addrspace(10)*)), !dbg !230
  %63 = bitcast {} addrspace(10)* %62 to [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(10)*, !dbg !230
  %64 = addrspacecast [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(10)* %63 to [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)*, !dbg !230
  %65 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i32 0, !dbg !230
  %66 = load { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* %_unwrap, align 8, !dbg !230
  %67 = load { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* %"'ipc_unwrap", align 8, !dbg !230
  store { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } %66, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* %65, align 8, !dbg !230
  %68 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i32 1, !dbg !230
  store { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } %67, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* %68, align 8, !dbg !230
  %69 = extractvalue { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } %66, 0, !dbg !230
  %70 = extractvalue { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } %67, 0, !dbg !230
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %62, {} addrspace(10)* %69, {} addrspace(10)* %70), !dbg !230
  %71 = load double, double* %"'de2", align 8, !dbg !230
  store double 0.000000e+00, double* %"'de2", align 8, !dbg !230
  %72 = bitcast {}*** %7 to {}**, !dbg !230
  %73 = getelementptr inbounds {}*, {}** %72, i64 -14, !dbg !230
  %74 = getelementptr inbounds {}*, {}** %73, i64 16, !dbg !230
  %75 = bitcast {}** %74 to i8**, !dbg !230
  %76 = load i8*, i8** %75, align 8, !dbg !230
  %77 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %73, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 125879275533520 to {}*) to {} addrspace(10)*)), !dbg !230
  %78 = bitcast {} addrspace(10)* %77 to [1 x double] addrspace(10)*, !dbg !230
  %79 = addrspacecast [1 x double] addrspace(10)* %78 to [1 x double] addrspace(11)*, !dbg !230
  %80 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %79, i64 0, i32 0, !dbg !230
  store double %71, double addrspace(11)* %80, align 8, !dbg !230
  %81 = bitcast { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %1 to i8*
  call void @llvm.lifetime.start.p0i8(i64 40, i8* %81)
  %82 = bitcast { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %2 to i8*
  call void @llvm.lifetime.start.p0i8(i64 40, i8* %82)
  %83 = bitcast [1 x i64]* %newstruct.i to i8*
  call void @llvm.lifetime.start.p0i8(i64 8, i8* %83)
  %84 = bitcast [1 x i64]* %newstruct12.i to i8*
  call void @llvm.lifetime.start.p0i8(i64 8, i8* %84)
  %85 = call {}*** @julia.get_pgcstack()
  %ptls_field43.i = getelementptr inbounds {}**, {}*** %85, i64 2
  %86 = bitcast {}*** %ptls_field43.i to i64***
  %ptls_load4445.i = load i64**, i64*** %86, align 8, !tbaa !18
  %87 = getelementptr inbounds i64*, i64** %ptls_load4445.i, i64 2
  %safepoint.i4 = load i64*, i64** %87, align 8, !tbaa !22
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint.i4), !dbg !273
  fence syncscope("singlethread") seq_cst
  %88 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 0, !dbg !276
  %89 = call fastcc nonnull {} addrspace(10)* @julia_repr_11919({ {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(80) %88), !dbg !282
  call fastcc void @julia_println_11917({} addrspace(10)* noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 125879058330480 to {}*) to {} addrspace(10)*), {} addrspace(10)* noundef nonnull %89), !dbg !282
  %90 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 0, i32 2, !dbg !283
  %getfield_addr.i = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 0, i32 0, !dbg !283
  %getfield.i = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %getfield_addr.i unordered, align 8, !dbg !283, !tbaa !22, !alias.scope !30, !noalias !33, !nonnull !17, !dereferenceable !131, !align !132
  %91 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 0, i32 3, !dbg !283
  %92 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 0, i32 1, i32 0, i64 0, i64 0, !dbg !286
  %unbox.unpack.unpack.unpack.i = load i64, i64 addrspace(11)* %92, align 8, !dbg !286, !tbaa !22, !alias.scope !30, !noalias !33
  %unbox.elt46.i = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 0, i32 1, i32 1, !dbg !286
  %unbox.unpack47.i = load i64, i64 addrspace(11)* %unbox.elt46.i, align 8, !dbg !286, !tbaa !22, !alias.scope !30, !noalias !33
  %unbox2.i = load i64, i64 addrspace(11)* %90, align 8, !dbg !286, !tbaa !22, !alias.scope !30, !noalias !33
  %unbox3.i = load i64, i64 addrspace(11)* %91, align 8, !dbg !286, !tbaa !22, !alias.scope !30, !noalias !33
  %.fca.0.0.0.0.gep33.i = getelementptr inbounds { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %1, i64 0, i32 0, i32 0, i64 0, i64 0, !dbg !287
  store i64 %unbox.unpack.unpack.unpack.i, i64* %.fca.0.0.0.0.gep33.i, align 8, !dbg !287, !noalias !288
  %.fca.0.1.gep35.i = getelementptr inbounds { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %1, i64 0, i32 0, i32 1, !dbg !287
  store i64 %unbox.unpack47.i, i64* %.fca.0.1.gep35.i, align 8, !dbg !287, !noalias !288
  %.fca.1.gep37.i = getelementptr inbounds { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %1, i64 0, i32 1, !dbg !287
  store i64 %unbox2.i, i64* %.fca.1.gep37.i, align 8, !dbg !287, !noalias !288
  %.fca.2.gep39.i = getelementptr inbounds { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %1, i64 0, i32 2, !dbg !287
  store {} addrspace(10)* %getfield.i, {} addrspace(10)** %.fca.2.gep39.i, align 8, !dbg !287, !noalias !288
  %.fca.3.gep41.i = getelementptr inbounds { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %1, i64 0, i32 3, !dbg !287
  store i64 %unbox3.i, i64* %.fca.3.gep41.i, align 8, !dbg !287, !noalias !288
  %93 = addrspacecast { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %1 to { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 } addrspace(11)*, !dbg !287
  %94 = call fastcc nonnull {} addrspace(10)* @julia_repr_11915({ { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(40) %93), !dbg !287
  call fastcc void @julia_println_11917({} addrspace(10)* noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 125879269903824 to {}*) to {} addrspace(10)*), {} addrspace(10)* noundef nonnull %94), !dbg !287
  %95 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 1, !dbg !291
  %96 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 1, i32 2, !dbg !291
  %getfield_addr4.i = getelementptr inbounds { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }, { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* %95, i64 0, i32 0, !dbg !291
  %getfield5.i = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %getfield_addr4.i unordered, align 8, !dbg !291, !tbaa !22, !alias.scope !30, !noalias !33, !nonnull !17, !dereferenceable !131, !align !132
  %97 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 1, i32 3, !dbg !291
  %98 = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 1, i32 1, i32 0, i64 0, i64 0, !dbg !294
  %unbox6.unpack.unpack.unpack.i = load i64, i64 addrspace(11)* %98, align 8, !dbg !294, !tbaa !22, !alias.scope !30, !noalias !33
  %unbox6.elt51.i = getelementptr inbounds [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }], [2 x { {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 }] addrspace(11)* %64, i64 0, i64 1, i32 1, i32 1, !dbg !294
  %unbox6.unpack52.i = load i64, i64 addrspace(11)* %unbox6.elt51.i, align 8, !dbg !294, !tbaa !22, !alias.scope !30, !noalias !33
  %unbox7.i = load i64, i64 addrspace(11)* %96, align 8, !dbg !294, !tbaa !22, !alias.scope !30, !noalias !33
  %unbox8.i = load i64, i64 addrspace(11)* %97, align 8, !dbg !294, !tbaa !22, !alias.scope !30, !noalias !33
  %.fca.0.0.0.0.gep.i = getelementptr inbounds { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %2, i64 0, i32 0, i32 0, i64 0, i64 0, !dbg !295
  store i64 %unbox6.unpack.unpack.unpack.i, i64* %.fca.0.0.0.0.gep.i, align 8, !dbg !295, !noalias !288
  %.fca.0.1.gep.i = getelementptr inbounds { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %2, i64 0, i32 0, i32 1, !dbg !295
  store i64 %unbox6.unpack52.i, i64* %.fca.0.1.gep.i, align 8, !dbg !295, !noalias !288
  %.fca.1.gep.i = getelementptr inbounds { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %2, i64 0, i32 1, !dbg !295
  store i64 %unbox7.i, i64* %.fca.1.gep.i, align 8, !dbg !295, !noalias !288
  %.fca.2.gep.i = getelementptr inbounds { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %2, i64 0, i32 2, !dbg !295
  store {} addrspace(10)* %getfield5.i, {} addrspace(10)** %.fca.2.gep.i, align 8, !dbg !295, !noalias !288
  %.fca.3.gep.i = getelementptr inbounds { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }, { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %2, i64 0, i32 3, !dbg !295
  store i64 %unbox8.i, i64* %.fca.3.gep.i, align 8, !dbg !295, !noalias !288
  %99 = addrspacecast { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %2 to { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 } addrspace(11)*, !dbg !295
  %100 = call fastcc nonnull {} addrspace(10)* @julia_repr_11915({ { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(40) %99), !dbg !295
  call fastcc void @julia_println_11917({} addrspace(10)* noundef nonnull align 8 addrspacecast ({}* inttoptr (i64 125881971620392 to {}*) to {} addrspace(10)*), {} addrspace(10)* noundef nonnull %100), !dbg !295
  %memcpy_refined_dst.i = getelementptr inbounds [1 x i64], [1 x i64]* %newstruct.i, i64 0, i64 0, !dbg !296
  store i64 1, i64* %memcpy_refined_dst.i, align 8, !dbg !296, !tbaa !75, !alias.scope !77, !noalias !301
  %bitcast.i = load i64, i64 addrspace(11)* %98, align 8, !dbg !302, !tbaa !22, !alias.scope !30, !noalias !33
  %.not.i = icmp eq i64 %bitcast.i, 0, !dbg !310
  br i1 %.not.i, label %L39.i, label %L42.i, !dbg !312

L39.i:                                            ; preds = %invertL42
  %101 = addrspacecast [1 x i64]* %newstruct.i to [1 x i64] addrspace(11)*, !dbg !312
  call fastcc void @julia_throw_boundserror_11911({ {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(40) %95, [1 x i64] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(8) %101) #22, !dbg !312
  unreachable, !dbg !312

L42.i:                                            ; preds = %invertL42
  %getfield10.i = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %getfield_addr4.i unordered, align 8, !dbg !313, !tbaa !22, !alias.scope !30, !noalias !33, !nonnull !17, !dereferenceable !131, !align !132
  %unbox11.i = load i64, i64 addrspace(11)* %96, align 8, !dbg !315, !tbaa !22, !alias.scope !30, !noalias !33
  %102 = addrspacecast {} addrspace(10)* %getfield10.i to double addrspace(13)* addrspace(11)*, !dbg !317
  %arrayptr56.i = load double addrspace(13)*, double addrspace(13)* addrspace(11)* %102, align 16, !dbg !317, !tbaa !22, !alias.scope !319, !noalias !33, !nonnull !17
  %103 = getelementptr inbounds double, double addrspace(13)* %arrayptr56.i, i64 %unbox11.i, !dbg !317
  %arrayref.i = load double, double addrspace(13)* %103, align 8, !dbg !317, !tbaa !138, !alias.scope !141, !noalias !142
  %memcpy_refined_dst13.i = getelementptr inbounds [1 x i64], [1 x i64]* %newstruct12.i, i64 0, i64 0, !dbg !296
  store i64 1, i64* %memcpy_refined_dst13.i, align 8, !dbg !296, !tbaa !75, !alias.scope !77, !noalias !301
  %bitcast14.i = load i64, i64 addrspace(11)* %92, align 8, !dbg !302, !tbaa !22, !alias.scope !30, !noalias !33
  %.not57.i = icmp eq i64 %bitcast14.i, 0, !dbg !310
  br i1 %.not57.i, label %L60.i, label %julia_reverse_11904.exit, !dbg !312

L60.i:                                            ; preds = %L42.i
  %104 = addrspacecast [1 x i64]* %newstruct12.i to [1 x i64] addrspace(11)*, !dbg !312
  call fastcc void @julia_throw_boundserror_11911({ {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(80) %88, [1 x i64] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(8) %104) #22, !dbg !312
  unreachable, !dbg !312

julia_reverse_11904.exit:                         ; preds = %L42.i
  %getfield16.i = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %getfield_addr.i unordered, align 8, !dbg !313, !tbaa !22, !alias.scope !30, !noalias !33, !nonnull !17, !dereferenceable !131, !align !132
  %105 = addrspacecast {} addrspace(10)* %getfield16.i to double addrspace(13)* addrspace(11)*, !dbg !317
  %arrayptr1958.i = load double addrspace(13)*, double addrspace(13)* addrspace(11)* %105, align 16, !dbg !317, !tbaa !22, !alias.scope !319, !noalias !33, !nonnull !17
  %unbox17.i = load i64, i64 addrspace(11)* %90, align 8, !dbg !315, !tbaa !22, !alias.scope !30, !noalias !33
  %106 = getelementptr inbounds double, double addrspace(13)* %arrayptr1958.i, i64 %unbox17.i, !dbg !317
  %arrayref20.i = load double, double addrspace(13)* %106, align 8, !dbg !317, !tbaa !138, !alias.scope !141, !noalias !142
  %107 = fmul double %arrayref20.i, 2.000000e+00, !dbg !320
  %108 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %79, i64 0, i64 0, !dbg !326
  %unbox21.i = load double, double addrspace(11)* %108, align 8, !dbg !327, !tbaa !22, !alias.scope !30, !noalias !33
  %109 = fmul double %107, %unbox21.i, !dbg !327
  %110 = fadd double %arrayref.i, %109, !dbg !328
  store double %110, double addrspace(13)* %103, align 8, !dbg !330, !tbaa !138, !alias.scope !141, !noalias !334
  %111 = call fastcc nonnull {} addrspace(10)* @julia_repr_11919({ {} addrspace(10)*, { [1 x [1 x i64]], i64 }, i64, i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(40) %95), !dbg !335
  call fastcc void @julia_println_11917({} addrspace(10)* noundef nonnull align 16 addrspacecast ({}* inttoptr (i64 125879058432336 to {}*) to {} addrspace(10)*), {} addrspace(10)* noundef nonnull %111), !dbg !335
  %112 = call fastcc nonnull {} addrspace(10)* @julia_repr_11913([1 x double] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(8) %79), !dbg !336
  call fastcc void @julia_println_11917({} addrspace(10)* noundef nonnull align 128 addrspacecast ({}* inttoptr (i64 125878243333248 to {}*) to {} addrspace(10)*), {} addrspace(10)* noundef nonnull %112), !dbg !336
  %113 = bitcast { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %1 to i8*, !dbg !337
  call void @llvm.lifetime.end.p0i8(i64 40, i8* %113), !dbg !337
  %114 = bitcast { { [1 x [1 x i64]], i64 }, i64, {} addrspace(10)*, i64 }* %2 to i8*, !dbg !337
  call void @llvm.lifetime.end.p0i8(i64 40, i8* %114), !dbg !337
  %115 = bitcast [1 x i64]* %newstruct.i to i8*, !dbg !337
  call void @llvm.lifetime.end.p0i8(i64 8, i8* %115), !dbg !337
  %116 = bitcast [1 x i64]* %newstruct12.i to i8*, !dbg !337
  call void @llvm.lifetime.end.p0i8(i64 8, i8* %116), !dbg !337
  br label %invertL18

invertL67.loopexit:                               ; preds = %invertL67
  %_unwrap3 = add nsw i64 %arraysize, -1
  br label %mergeinvertL18_L67.loopexit

mergeinvertL18_L67.loopexit:                      ; preds = %invertL67.loopexit
  store i64 %_unwrap3, i64* %"iv'ac", align 8
  br label %invertL42

invertL67:                                        ; preds = %L67
  store double %differeturn, double* %"value_phi23'de", align 8
  %117 = load double, double* %"value_phi23'de", align 8
  store double 0.000000e+00, double* %"value_phi23'de", align 8
  %118 = xor i1 %.not, true
  %119 = select fast i1 %118, double %117, double 0.000000e+00
  %120 = load double, double* %"'de", align 8
  %121 = fadd fast double %120, %117
  %122 = select fast i1 %.not, double %120, double %121
  store double %122, double* %"'de", align 8
  br i1 %.not, label %inverttop, label %invertL67.loopexit
}

b.val = [4.0]
(b.val.indices, b.val.offset1, b.val.parent, b.val.stride1) = ((Base.Slice(Base.OneTo(1)), 2), 1, [1.0 4.0], 1)
(b.dval.indices, b.dval.offset1, b.dval.parent, b.dval.stride1) = ((Base.Slice(Base.OneTo(1)), 2), 1, [0.0 0.0], 1)
b.dval = [8.0]
out = Active{Float64}(1.0)
b.val = [4.0]
(b.val.indices, b.val.offset1, b.val.parent, b.val.stride1) = ((Base.Slice(Base.OneTo(1)), 2), 1, [1.0 4.0], 1)
(b.dval.indices, b.dval.offset1, b.dval.parent, b.dval.stride1) = ((Base.Slice(Base.OneTo(1)), 2), 1, [0.0 8.0], 1)
b.dval = [16.0]
out = Active{Float64}(1.0)

@wsmoses
Copy link
Member

wsmoses commented Sep 25, 2024

b is overwritten from fwd to reverse, as specified in the config.

Saving the original b fixes it.

EnzymeRules.overwritten(config) = (false, true)
EnzymeRules.overwritten(config) = (false, true)
b.val = [4.0]
(b.val.indices, b.val.offset1, b.val.parent, b.val.stride1) = ((Base.Slice(Base.OneTo(1)), 2), 1, [1.0 4.0], 1)
(b.dval.indices, b.dval.offset1, b.dval.parent, b.dval.stride1) = ((Base.Slice(Base.OneTo(1)), 2), 1, [0.0 0.0], 1)
b.dval = [8.0]
out = Active{Float64}(1.0)
b.val = [1.0]
(b.val.indices, b.val.offset1, b.val.parent, b.val.stride1) = ((Base.Slice(Base.OneTo(1)), 1), 0, [1.0 4.0], 1)
(b.dval.indices, b.dval.offset1, b.dval.parent, b.dval.stride1) = ((Base.Slice(Base.OneTo(1)), 1), 0, [0.0 8.0], 1)
b.dval = [2.0]
out = Active{Float64}(1.0)
db = [2.0 8.0]
wmoses@beast:~/git/Enzyme.jl ((HEAD detached at origin/main)) $ cat la.jl 
using Enzyme, LinearAlgebra
using EnzymeCore: EnzymeRules

Enzyme.API.printall!(true)

function _mul!(b)
    return b[1]*b[1] 
end

function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(_mul!)}, ::Type{<:Active},
                                      b::Duplicated)
    @show EnzymeRules.overwritten(config)
    tape = b
    return EnzymeRules.AugmentedReturn(EnzymeRules.needs_primal(config) ? _mul!(b.val) : nothing, nothing, tape)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(_mul!)},
    out, tape,  b::Duplicated)
    b = tape
    @show b.val
    @show b.val.indices, b.val.offset1, b.val.parent, b.val.stride1
    @show b.dval.indices, b.dval.offset1, b.dval.parent, b.dval.stride1
    b.dval[1] += 2 * b.val[1] * out.val
    @show b.dval
    @show out
    return (nothing,)
end


@inline function f2(b)
    s = zero(eltype(b))
    for i in 1:size(b, 2)
        out = _mul!(@view(b[:, i]))
        s += out
    end
    return s
end


Nx = 1
Ny = 2
b = [1.0 4.0]
db = zero(b)

f2(b)
autodiff(Reverse, f2, Active, Duplicated(b, fill!(db, 0)))

@show db
# What Enzyme returns db = [0  16]
# What the correct answer is db = [2 8]

@wsmoses wsmoses closed this as completed Sep 25, 2024
@ptiede
Copy link
Contributor Author

ptiede commented Sep 25, 2024

Wait isn't the original code still a problem? This still gives me a zero in the first element

function _mul!(A, b)
    return A[1]*b[1] 
end

function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(_mul!)}, ::Type{<:Active},
                                      A::Const{<:Matrix}, b::Duplicated)
    tape = (A, b)
    return EnzymeRules.AugmentedReturn(EnzymeRules.needs_primal(config) ? _mul!(A.val, b.val) : nothing, nothing, tape)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(_mul!)},
    out, tape, A::Const{<:Matrix}, b::Duplicated)

    @show EnzymeRules.overwritten(config)
    @show b.val
    @show b.val.indices, b.val.offset1, b.val.parent, b.val.stride1
    @show b.dval.indices, b.dval.offset1, b.dval.parent, b.dval.stride1
    b.dval[1] += real(A.val'[1] * out.val)
    @show b.dval
    @show out
    @show A.val
    return (nothing, nothing)
end


@inline function f2(A, b)
    s = zero(eltype(b))
    for i in 1:size(b, 2)
        out = similar(A, size(A, 1))
        s = _mul!(A, @view(b[:, i]))
        @inbounds s += sum(abs, out)
    end
    return s
end


Nx = 1
Ny = 2
b = [1.0 4.0]
db = zero(b)

A = ones(Float64, Nx, Nx)


f2(A, b)
autodiff(Reverse, f2, Active, Const(A), Duplicated(b, fill!(db, 0)))
# db = [0.0, 1.0]

@wsmoses
Copy link
Member

wsmoses commented Sep 25, 2024

there's some slight bugs in that (specifically you dont unwrap the tape in the reverse pass, and also you say s = mul!)

with those fixed, it works

using Enzyme, LinearAlgebra

function _mul!(A, b)
    return A[1]*b[1] 
end

function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(_mul!)}, ::Type{<:Active},
                                      A::Const{<:Matrix}, b::Duplicated)
    tape = (A, b)
    return EnzymeRules.AugmentedReturn(EnzymeRules.needs_primal(config) ? _mul!(A.val, b.val) : nothing, nothing, tape)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(_mul!)},
    out, tape, A::Const{<:Matrix}, b::Duplicated)
    # You need this here
    (A, b) = tape

    @show EnzymeRules.overwritten(config)
    @show b.val
    @show b.val.indices, b.val.offset1, b.val.parent, b.val.stride1
    @show b.dval.indices, b.dval.offset1, b.dval.parent, b.dval.stride1
    b.dval[1] += real(A.val'[1] * out.val)
    @show b.dval
    @show out
    @show A.val
    return (nothing, nothing)
end


@inline function f2(A, b)
    s = zero(eltype(b))
    for i in 1:size(b, 2)
        out = similar(A, size(A, 1))
        s += _mul!(A, @view(b[:, i]))
    end
    return s
end


Nx = 1
Ny = 2
b = [1.0 4.0]
db = zero(b)

A = ones(Float64, Nx, Nx)


f2(A, b)
autodiff(Reverse, f2, Active, Const(A), Duplicated(b, fill!(db, 0)))
# db = [0.0, 1.0]

@ptiede
Copy link
Contributor Author

ptiede commented Sep 25, 2024

For instance this still gives me something weird

function _mul!(A, b)
    return A*b[1] 
end

function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(_mul!)}, ::Type{<:Active},
                                      A::Const, b::Duplicated)
    tape = (copy(A.val), copy(b.val))
    return EnzymeRules.AugmentedReturn(EnzymeRules.needs_primal(config) ? _mul!(A.val, b.val) : nothing, nothing, tape)
end

function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
    ::Const{typeof(_mul!)},
    out, tape, A::Const, b::Duplicated)

    b1 = tape[2]
    A = tape[1]
    @show EnzymeRules.overwritten(config)
    @show b.val
    @show b.val.indices, b.val.offset1, b.val.parent, b.val.stride1
    @show b.dval.indices, b.dval.offset1, b.dval.parent, b.dval.stride1
    b.dval[1] += (A * out.val)
    @show b.dval
    @show out
    @show A
    return (nothing, nothing)
end


@inline function f2(A, b)
    s = zero(eltype(b))
    for i in 1:size(b, 2)
        out = _mul!(A, @view(b[:, i]))
        @inbounds s += out
    end
    return s
end


Nx = 1
Ny = 2
b = [1.0 4.0]
db = zero(b)



f2(1.0, b)
autodiff(Reverse, f2, Active, Const(1.0), Duplicated(b, fill!(db, 0)))

@wsmoses
Copy link
Member

wsmoses commented Sep 25, 2024

you can't do copy A.val, which creates a different copy of b and thus it won't point to the same reference as b to be updated in place

@ptiede
Copy link
Contributor Author

ptiede commented Sep 25, 2024

Ok I need help understanding the overwritten stuff. So, in fwd pass I should check if b is overwritten.
If it is, I need to store the forward pass b in the tape and use that in the reverse rule instead of the b
passed to the reverse rule? In other rules e.g., here I see you copy the value in the forward pass. When should I copy and not?

@wsmoses
Copy link
Member

wsmoses commented Sep 25, 2024

using Enzyme, LinearAlgebra

vl = [2, 2]
tot = sum(vl)
rg = [1:2, 3:4]
Nx = 4

iminds = reshape([CartesianIndex(i) for i in 1:2], :)
visinds = [collect(rg[i]) for i in eachindex(rg)]
Bs = Dict((iminds[i]=> ones(ComplexF64, vl[i], Nx*Nx) for i in eachindex(vl)))

function _mul!(out, A, b)
    mul!(out, A, b)
    return nothing 
end


function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfigWidth{1}, ::Const{typeof(_mul!)}, ::Type{<:Const}, out::Duplicated,
                                      A::Const{<:Matrix}, b::Duplicated)
    _mul!(out.val, A.val, b.val)
    tape = (out, b)
    return EnzymeRules.AugmentedReturn(nothing, nothing, tape)
end

@noinline function EnzymeRules.reverse(config::EnzymeRules.RevConfigWidth{1},
                                       ::Const{typeof(_mul!)},
                                       ::Type{<:Const}, tape, out::Duplicated, A::Const{<:Matrix},
                                       b::Duplicated)

    (out, b) = tape

    b.dval .+= real.(A.val' * out.dval)
    out.dval .= 0
    return (nothing, nothing, nothing)
end

@inline function f(Bs, visinds, iminds, tot, x)
    out = similar(x, Complex{eltype(x)}, tot)
    for i in eachindex(iminds, visinds)
        imind = iminds[i]
        visind = visinds[i]
        _mul!(@view(out[visind]), Bs[imind], reshape(@view(x[:, :, imind]), :))
    end
    return sum(abs2, out)
end


@inline function f2(Bs, visinds, iminds, tot, x)
    out = similar(x, Complex{eltype(x)}, tot)
    for i in eachindex(iminds, visinds)
        imind = iminds[i]
        visind = visinds[i]
        mul!(@view(out[visind]), Bs[imind], reshape(@view(x[:, :, imind]), :))
    end
    return sum(abs2, out)
end

x = ones(Nx, Nx, 2)
dx = zero(x)

f(Bs, visinds, iminds, tot, x)
autodiff(set_runtime_activity(Reverse), f, Active, Const(Bs), Const(visinds), Const(iminds), Const(tot), Duplicated(x, fill!(dx, 0)))
@show dx


x = ones(Nx, Nx, 2)
dx = zero(x)

f2(Bs, visinds, iminds, tot, x)
autodiff(set_runtime_activity(Reverse), f2, Active, Const(Bs), Const(visinds), Const(iminds), Const(tot), Duplicated(x, fill!(dx, 0)))
@show dx

This version of your original code works successfully^

@ptiede
Copy link
Contributor Author

ptiede commented Sep 25, 2024

Thanks! I think I am confused when to store a copy in the tape and when to store the overwritten arguments in augmented_forward

@wsmoses
Copy link
Member

wsmoses commented Sep 25, 2024

yeah okay the more I think about this, we should probably force enzyme to automatically do the nice thing here for top level pointers.

Nevertheless, the above should be a good workaround.

@wsmoses wsmoses reopened this Sep 25, 2024
@ptiede
Copy link
Contributor Author

ptiede commented Sep 25, 2024

Ok no problem! I'll implement the fix in my code. Thanks for all the help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants