From 1e2b3783d26760490af12adcb651d4928b05742b Mon Sep 17 00:00:00 2001 From: cortner Date: Tue, 11 Jun 2024 22:50:22 -0700 Subject: [PATCH] fix pb and pb2 for sparsepoolprod --- src/Polynomials4ML.jl | 5 +- src/ace/sparseprodpool.jl | 307 ++++++++++++-------------------- test/ace/test_sparseprodpool.jl | 140 +-------------- 3 files changed, 120 insertions(+), 332 deletions(-) diff --git a/src/Polynomials4ML.jl b/src/Polynomials4ML.jl index ddc81dd..d420840 100644 --- a/src/Polynomials4ML.jl +++ b/src/Polynomials4ML.jl @@ -13,6 +13,7 @@ import WithAlloc: whatalloc using LuxCore, Random, StaticArrays import ChainRulesCore: rrule, frule, NoTangent, ZeroTangent using HyperDualNumbers: Hyper +using ForwardDiff: Dual, extract_derivative import LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer, initialparameters, initialstates @@ -28,10 +29,10 @@ function index end function orthpolybasis end function degree end -function pullback end function pullback! end function pullback end -function pullback! end +function pullback2! end +function pullback2 end function pushforward end function pushforward! end diff --git a/src/ace/sparseprodpool.jl b/src/ace/sparseprodpool.jl index c3ccaeb..21e1454 100644 --- a/src/ace/sparseprodpool.jl +++ b/src/ace/sparseprodpool.jl @@ -134,6 +134,27 @@ function evaluate!(A, basis::PooledSparseProduct{NB}, BB::TupMat, return A end +# special-casing NB = 1 for correctness +function evaluate!(A, basis::PooledSparseProduct{1}, + BB::Tuple{<: AbstractMatrix}, + nX = size(BB[1], 1)) + @assert size(BB[1], 1) >= nX + BB1 = BB[1] + spec = basis.spec + fill!(A, zero(eltype(A))) + @inbounds for (iA, ϕ) in enumerate(spec) + ϕ1 = ϕ[1] + a = zero(eltype(A)) + @simd ivdep for j = 1:nX + b1 = BB1[j, ϕ1] + a += b1 + end + A[iA] = a + end + return A +end + + # special-casing NB = 2 for performance reasons function evaluate!(A, basis::PooledSparseProduct{2}, BB::Tuple{<: AbstractMatrix, <: AbstractMatrix}, @@ -174,7 +195,7 @@ end # TODO: this should probably be replaced with a loop that generates # the code up to a large-ish NB. -pullback!(∂B1, ∂A, basis::PooledSparseProduct{1}, BB::TupMat) = +pullback!(∂B1::AbstractMatrix, ∂A, basis::PooledSparseProduct{1}, BB::TupMat) = pullback!((∂B1,), ∂A, basis, BB) pullback!(∂B1, ∂B2, ∂A, basis::PooledSparseProduct{2}, BB::TupMat) = @@ -187,9 +208,9 @@ pullback!(∂B1, ∂B2, ∂B3, ∂B4, ∂A, basis::PooledSparseProduct{4}, BB::T pullback!((∂B1, ∂B2, ∂B3, ∂B4,), ∂A, basis, BB) -function pullback!(∂BB, # output - ∂A, basis::PooledSparseProduct{NB}, BB::TupMat # inputs - ) where {NB} +function pullback!(∂BB::Tuple, # output + ∂A, basis::PooledSparseProduct{NB}, BB::TupMat # inputs + ) where {NB} nX = size(BB[1], 1) @assert all(nX <= size(BB[i], 1) for i = 1:NB) @assert all(nX <= size(∂BB[i], 1) for i = 1:NB) @@ -206,7 +227,8 @@ function pullback!(∂BB, # output end a, g = _static_prod_ed(b) for i = 1:NB - ∂BB[i][j, ϕ[i]] = muladd(∂A_iA, g[i], ∂BB[i][j, ϕ[i]]) + ϕi = ϕ[i] + ∂BB[i][j, ϕi] = muladd(∂A_iA, g[i], ∂BB[i][j, ϕi]) end end end @@ -218,7 +240,32 @@ end # a cruder code generation strategy. This specialized code # confirms this. -function pullback!(∂BB, ∂A, basis::PooledSparseProduct{2}, BB::TupMat) +# NB = 1 for correctness +function pullback!(∂BB::Tuple, ∂A, basis::PooledSparseProduct{1}, BB::TupMat) + nX = size(BB[1], 1) + NB = 1 + @assert length(∂A) == length(basis) + @assert length(BB) == length(∂BB) == NB + @assert all(nX <= size(BB[i], 1) for i = 1:NB) + @assert all(nX <= size(∂BB[i], 1) for i = 1:NB) + @assert all(size(∂BB[i], 2) >= size(BB[i], 2) for i = 1:NB) + BB1 = BB[1] + ∂BB1 = ∂BB[1] + + fill!(∂BB1, zero(eltype(∂BB1))) + + @inbounds for (iA, ϕ) in enumerate(basis.spec) + ∂A_iA = ∂A[iA] + ϕ1 = ϕ[1] + @simd ivdep for j = 1:nX + # A[iA] += b1 + ∂BB1[j, ϕ1] += ∂A_iA + end + end + return ∂BB +end + +function pullback!(∂BB::Tuple, ∂A, basis::PooledSparseProduct{2}, BB::TupMat) nX = size(BB[1], 1) NB = 2 @assert length(∂A) == length(basis) @@ -247,7 +294,7 @@ function pullback!(∂BB, ∂A, basis::PooledSparseProduct{2}, BB::TupMat) return ∂BB end -function pullback!(∂BB, ∂A, basis::PooledSparseProduct{3}, BB::TupMat; +function pullback!(∂BB::Tuple, ∂A, basis::PooledSparseProduct{3}, BB::TupMat; sizecheck = true) nX = size(BB[1], 1) NB = 3 @@ -285,60 +332,10 @@ function pullback!(∂BB, ∂A, basis::PooledSparseProduct{3}, BB::TupMat; return ∂BB end - -# -------- - -function pullback_x! end - -function whatalloc(::typeof(pullback_x!), - ∂A, basis::PooledSparseProduct{NB}, BB::TupMat) where {NB} - TA = promote_type(eltype.(BB)..., eltype(∂A)) - return ((TA, length(basis)), ntuple(i -> (TA, size(BB[i])...), NB)...) -end - -pullback_x!(A, ∂B1, ∂B2, ∂A, basis::PooledSparseProduct{2}, BB::TupMat) = - pullback_x!(A, (∂B1, ∂B2), ∂A, basis, BB) - -pullback_x!(A, ∂B1, ∂B2, ∂B3, ∂A, basis::PooledSparseProduct{3}, BB::TupMat) = - pullback_x!(A, (∂B1, ∂B2, ∂B3), ∂A, basis, BB) - - -# experimental version that also computes the original object -function pullback_x!(A, ∂BB, ∂A, basis::PooledSparseProduct{2}, BB::TupMat) - nX = size(BB[1], 1) - NB = 2 - @assert length(∂A) == length(basis) - @assert length(BB) == length(∂BB) == 2 - @assert all(nX <= size(BB[i], 1) for i = 1:NB) - @assert all(nX <= size(∂BB[i], 1) for i = 1:NB) - @assert all(size(∂BB[i], 2) >= size(BB[i], 2) for i = 1:NB) - BB1, BB2 = BB - ∂BB1, ∂BB2 = ∂BB - - for i = 1:length(∂BB) - fill!(∂BB[i], zero(eltype(∂BB[i]))) - end - fill!(A, zero(eltype(A))) - - @inbounds for (iA, ϕ) in enumerate(basis.spec) - ∂A_iA = ∂A[iA] - ϕ1 = ϕ[1] - ϕ2 = ϕ[2] - @simd ivdep for j = 1:nX - b1 = BB1[j, ϕ1] - b2 = BB2[j, ϕ2] - ∂BB1[j, ϕ1] = muladd(∂A_iA, b2, ∂BB1[j, ϕ1]) - ∂BB2[j, ϕ2] = muladd(∂A_iA, b1, ∂BB2[j, ϕ2]) - A[j] += b1 * b2 - end - end - return A, ∂BB -end - -function pullback_x!(A, ∂BB, ∂A, basis::PooledSparseProduct{3}, BB::TupMat; +function pullback!(∂BB::Tuple, ∂A, basis::PooledSparseProduct{4}, BB::TupMat; sizecheck = true) nX = size(BB[1], 1) - NB = 3 + NB = 4 if sizecheck @assert all(nX <= size(BB[i], 1) for i = 1:NB) @@ -353,177 +350,93 @@ function pullback_x!(A, ∂BB, ∂A, basis::PooledSparseProduct{3}, BB::TupMat; fill!(∂BB[i], zero(eltype(∂BB[i]))) end - B1 = BB[1]; B2 = BB[2]; B3 = BB[3] - ∂B1 = ∂BB[1]; ∂B2 = ∂BB[2]; ∂B3 = ∂BB[3] + B1 = BB[1]; B2 = BB[2]; B3 = BB[3]; B4 = BB[4] + ∂B1 = ∂BB[1]; ∂B2 = ∂BB[2]; ∂B3 = ∂BB[3]; ∂B4 = ∂BB[4] @inbounds for (iA, ϕ) in enumerate(basis.spec) ∂A_iA = ∂A[iA] ϕ1 = ϕ[1] ϕ2 = ϕ[2] ϕ3 = ϕ[3] + ϕ4 = ϕ[4] @simd ivdep for j = 1:nX b1 = B1[j, ϕ1] b2 = B2[j, ϕ2] b3 = B3[j, ϕ3] - ∂B1[j, ϕ1] = muladd(∂A_iA, b2*b3, ∂B1[j, ϕ1]) - ∂B2[j, ϕ2] = muladd(∂A_iA, b1*b3, ∂B2[j, ϕ2]) - ∂B3[j, ϕ3] = muladd(∂A_iA, b1*b2, ∂B3[j, ϕ3]) - A[j] += b1 * b2 * b3 + b4 = B4[j, ϕ4] + ∂B1[j, ϕ1] = muladd(∂A_iA, b2*b3*b4, ∂B1[j, ϕ1]) + ∂B2[j, ϕ2] = muladd(∂A_iA, b1*b3*b4, ∂B2[j, ϕ2]) + ∂B3[j, ϕ3] = muladd(∂A_iA, b1*b2*b4, ∂B3[j, ϕ3]) + ∂B4[j, ϕ4] = muladd(∂A_iA, b1*b2*b3, ∂B4[j, ϕ4]) end end - return A, ∂BB + return ∂BB end -# -------------------------------------------------------- +# --------------------------------------------------------------- # reverse over reverse -function pullback2(∂2, ∂A, basis::PooledSparseProduct{NB}, BB::TupMat - ) where {NB} +# A = evaluate(basis, BB) +# ∂BB = pullback(∂A, basis, BB) +# ∂∂BB is the perturbation to ∂BB - # ∂2 should be a tuple of length 2 - @assert ∂2 isa NTuple{NB, <: AbstractMatrix} - @assert BB isa NTuple{NB, <: AbstractMatrix} - @assert ∂A isa AbstractVector - - nX = size(BB[1], 1) - @assert all(nX == size(BB[i], 1) for i = 1:NB) - - ∂2_∂A = zeros(length(∂A)) - ∂2_BB = ntuple(i -> zeros(size(BB[i])...), NB) - - for (iA, ϕ) in enumerate(basis.spec) - @simd ivdep for j = 1:nX - b = ntuple(Val(NB)) do i - @inbounds BB[i][j, ϕ[i]] - end - ∂g = ntuple(Val(NB)) do i - @inbounds ∂2[i][j, ϕ[i]] - end - _, g, ∂g_b = _pb_grad_static_prod(∂g, b) - for i = 1:NB - # ∂BB[i][j, ϕ[i]] += ∂A[iA] * g[i] - ∂2_∂A[iA] += ∂2[i][j, ϕ[i]] * g[i] - ∂2_BB[i][j, ϕ[i]] += ∂A[iA] * ∂g_b[i] - end - end - end - return ∂2_∂A, ∂2_BB +function whatalloc(::typeof(pullback2!), ∂∂BB, ∂A, + basis::PooledSparseProduct{NB}, BB) where {NB} + TA = promote_type(eltype.(BB)..., eltype(∂A), eltype.(∂∂BB)...) + return ( (TA, size(∂A)...), + ntuple(i -> (TA, size(BB[i])...), NB)...) end +pullback2!(∇_∂A, ∇_BB1::AbstractMatrix, ∂∂BB, ∂A, basis::PooledSparseProduct{1}, BB) = + pullback2!(∇_∂A, (∇_BB1,), ∂∂BB, ∂A, basis, BB) -function pullback2(∂2, ∂A, basis::PooledSparseProduct{1}, BB::TupMat) - - # ∂2 should be a tuple of length 2 - @assert ∂2 isa Tuple{<: AbstractMatrix} - @assert BB isa Tuple{<: AbstractMatrix} - @assert ∂A isa AbstractVector - - nX = size(BB[1], 1) - - ∂2_∂A = zeros(length(∂A)) - ∂2_BB = (zeros(size(BB[1])...), ) - - for (iA, ϕ) in enumerate(basis.spec) - @simd ivdep for j = 1:nX - ϕ1 = ϕ[1] - b1 = BB[1][j, ϕ1] - # A[iA] += b1 - # ∂BB[1][j, ϕ1] += ∂A[iA] - ∂2_∂A[iA] += ∂2[1][j, ϕ1] - end - end - return ∂2_∂A, ∂2_BB -end +pullback2!(∇_∂A, ∇_BB1, ∇_BB2, ∂∂BB, ∂A, basis::PooledSparseProduct{2}, BB) = + pullback2!(∇_∂A, (∇_BB1, ∇_BB2), ∂∂BB, ∂A, basis, BB) +pullback2!(∇_∂A, ∇_BB1, ∇_BB2, ∇_BB3, ∂∂BB, ∂A, basis::PooledSparseProduct{3}, BB) = + pullback2!(∇_∂A, (∇_BB1, ∇_BB2, ∇_BB3), ∂∂BB, ∂A, basis, BB) -function pullback2(∂2, ∂A, basis::PooledSparseProduct{2}, BB::TupMat) +pullback2!(∇_∂A, ∇_BB1, ∇_BB2, ∇_BB3, ∇_BB4, ∂∂BB, ∂A, basis::PooledSparseProduct{4}, BB) = + pullback2!(∇_∂A, (∇_BB1, ∇_BB2, ∇_BB3, ∇_BB4), ∂∂BB, ∂A, basis, BB) - # ∂2 should be a tuple of length 2 - @assert ∂2 isa Tuple{<: AbstractMatrix, <: AbstractMatrix} - @assert BB isa Tuple{<: AbstractMatrix, <: AbstractMatrix} - @assert ∂A isa AbstractVector - - nX = size(BB[1], 1) +function pullback2!(∇_∂A, ∇_BB::Tuple, # outputs + ∂∂BB, # perturbation + ∂A, basis::PooledSparseProduct{NB}, BB) where {NB} - ∂2_∂A = zeros(length(∂A)) - ∂2_BB = ntuple(i -> zeros(size(BB[i])...), 2) - - for (iA, ϕ) in enumerate(basis.spec) - @simd ivdep for j = 1:nX - ϕ1 = ϕ[1] - ϕ2 = ϕ[2] - b1 = BB[1][j, ϕ1] - b2 = BB[2][j, ϕ2] - # A[iA] += b1 * b2 - # ∂BB[1][j, ϕ1] += ∂A[iA] * b2 - # ∂BB[2][j, ϕ2] += ∂A[iA] * b1 - ∂2_∂A[iA] += ∂2[1][j, ϕ1] * b2 + ∂2[2][j, ϕ2] * b1 - ∂2_BB[1][j, ϕ1] += ∂2[2][j, ϕ2] * ∂A[iA] - ∂2_BB[2][j, ϕ2] += ∂2[1][j, ϕ1] * ∂A[iA] - end + function _dual(i) + T = promote_type(eltype(BB[i]), eltype(∂∂BB[i])) + return Dual{T}(zero(T), one(T)) end - return ∂2_∂A, ∂2_BB -end - - -# --------------------- Pushforwards - -# This implementation of the pushforward doesn't yet do batching -# this means that the output is a single vector A, the inputs -# BB[i] are nX x nBi -> with the nX pooled -# It is ASSUMED that ∂BB[i][j, :] / ∂X[j'] = 0 if j ≠ j' -# Therefore ΔBB[i] are also nX x nBi -# This is a simplification that may have to be revisited at some point. -# -# The output will be A, ∂A where -# A is size (nA,) -# ∂A is size (nA, nX) - -_my_promote_type(args...) = promote_type(args...) - -_my_promote_type(T1::Type{<: Number}, T2::Type{SVector{N, S}}, args... - ) where {N, S} = - promote_type(SVector{N, T1}, T2, args...) - - -function pushforward(basis::PooledSparseProduct, BB, ΔBB) - @assert length(size(BB[1])) == 2 - @assert length(size(ΔBB[1])) == 2 - @assert all(size(BB[t]) == size(ΔBB[t]) for t = 1:length(BB)) - - nX = size(ΔBB[1], 1) - nA = length(basis) - - TA = promote_type(eltype.(BB)...) - A = zeros(TA, nA) - - T∂A = _my_promote_type(TA, eltype.(ΔBB)...) - ∂A = zeros(T∂A, (nA, nX)) - fill!(∂A, zero(T∂A)) - - return pushforward!(A, ∂A, basis, BB, ΔBB) -end - - -function pushforward!(A, ∂A, basis::PooledSparseProduct{NB}, BB, ΔBB) where {NB} - nX = size(BB[1], 1) - for (i, ϕ) in enumerate(basis.spec) - for j = 1:nX - bb = ntuple(t -> BB[t][j, ϕ[t]], NB) - Δbb = ntuple(t -> ΔBB[t][j, ϕ[t]], NB) - ∏bb, ∇∏bb = Polynomials4ML._static_prod_ed(bb) - A[i] += prod(bb) - @inbounds for t = 1:NB - ∂A[i, j] += ∇∏bb[t] * Δbb[t] + @no_escape begin + dd = ntuple(_dual, NB) + BB_d = ntuple(i -> @alloc(typeof(dd[i]), size(BB[i])...), NB) + for i = 1:NB + @inbounds for t = 1:length(BB[i]) + BB_d[i][t] = BB[i][t] + dd[i] * ∂∂BB[i][t] + end + end + A_d = @withalloc evaluate!(basis, BB_d) + ∂BB_d = @withalloc pullback!(∂A, basis, BB_d) + @inbounds for t = 1:length(A_d) + ∇_∂A[t] = extract_derivative(eltype(∇_∂A), A_d[t]) + end + @inbounds for i = 1:NB + for t = 1:length(∂BB_d[i]) + Ti = eltype(∇_BB[i]) + ∇_BB[i][t] = extract_derivative(Ti, ∂BB_d[i][t]) end end end - return A, ∂A + return ∇_∂A, ∇_BB end +# --------------------------------------------------------------- +# Pushforward + + # --------------------- connect with ChainRules # can this be generalized again? diff --git a/test/ace/test_sparseprodpool.jl b/test/ace/test_sparseprodpool.jl index 7f0dfa4..00e1efa 100644 --- a/test/ace/test_sparseprodpool.jl +++ b/test/ace/test_sparseprodpool.jl @@ -43,7 +43,7 @@ println() ## @info("Test pooling of multiple inputs") -nX = 64 +nX = 17 for ntest = 1:30 local bBB, bA1, bA2, bA3, basis @@ -76,11 +76,9 @@ test_withalloc(basis; batch=false) @info("Testing rrule") using LinearAlgebra: dot -@warn("order = 1 tests currently fail in an unexplained way") - for ntest = 1:30 local bBB, bA2, u, basis, nX - order = mod1(ntest, 3)+1 + order = mod1(ntest, 4) basis = _generate_basis(; order=order) bBB = _generate_input(basis) nX = size(bBB[1], 1) @@ -103,9 +101,9 @@ println() @info("Testing pullback2 for PooledSparseProduct") import ChainRulesCore: rrule, NoTangent -for ntest = 1:20 +for ntest = 1:20 local basis, val, pb, bBB, A - ORDER = mod1(ntest, 3)+1 + ORDER = mod1(ntest, 4) basis = _generate_basis(;order = ORDER) bBB = _generate_input(basis) ∂A = randn(length(basis)) @@ -138,136 +136,12 @@ for ntest = 1:20 return dot(∂_∂A, bV) + sum(dot(bUU[i], ∂2_BB[i]) for i = 1:ORDER) end - print_tf(@test all( fdtest(F, dF, 0.0; verbose=false) )) + print_tf(@test fdtest(F, dF, 0.0; verbose=false) ) + # print_tf(@test all( fdtest(F, dF, 0.0; verbose=false) )) + # fdtest(F, dF, 0.0; verbose=true) end println() -## - -ORDER = 3 -basis = _generate_basis(;order = ORDER, len=300) -bBB = _generate_input(basis) -∂A = randn(length(basis)) -∂2 = ntuple(i -> randn(size(bBB[i])), length(bBB)) - - -## - -using ForwardDiff: Dual, extract_derivative -using Bumper, WithAlloc - -function auto_pb_pb(∂BB, ∂A, basis, BB) - # φ = ∂BB ⋅ pullback(∂A, basis, BB) - # = (∂bBB ⋅ ∇_BB) (∂A ⋅ evaluate(basis, BB)) - # ∇_∂A φ = (∂BB ⋅ ∇_BB) evaluate(basis, BB) - # ∇_BB φ = (∂BB ⋅ ∇_BB) ∇_BB (∂A ⋅ evaluate(basis, BB)) - # = (∂BB ⋅ ∇_BB) pullback(∂A, basis, BB) - d = Dual{Float64}(0.0, 1.0) - BB_d = ntuple(i -> BB[i] .+ d .* ∂BB[i], length(BB)) - @no_escape begin - A_d, ∂BB_d = @withalloc Polynomials4ML.pullback_x!(∂A, basis, BB_d) - # A_d = @withalloc evaluate!(basis, BB_d) - # ∂BB_d = @withalloc pullback!(∂A, basis, BB_d) - ∇_∂A = extract_derivative.(Float64, A_d) - ∇_BB = ntuple(i -> extract_derivative.(Float64, ∂BB_d[i]), length(∂BB_d)) - end - return ∇_∂A, ∇_BB -end - - - -function auto_pb_pb!(∇_∂A, ∇_BB, ∂BB, ∂A, basis::PooledSparseProduct{2}, BB) - @assert all(eltype(BB[i]) == eltype(BB[1]) for i = 2:length(BB)) - @no_escape begin - T = eltype(BB[1]) - d = Dual{T}(zero(T), one(T)) - TD = typeof(d) - B1 = BB[1] - B2 = BB[2] - B1_d = @alloc(TD, size(B1)...) - B2_d = @alloc(TD, size(B2)...) - @inbounds for t = 1:length(B1) - B1_d[t] = B1[t] + d * ∂BB[1][t] - end - @inbounds for t = 1:length(B2) - B2_d[t] = B2[t] + d * ∂BB[2][t] - end - BB_d = (B1_d, B2_d) - # A_d = @withalloc evaluate!(basis, BB_d) - # ∂BB_d = @withalloc pullback!(∂A, basis, BB_d) - A_d, ∂BB_d = @withalloc Polynomials4ML.pullback_x!(∂A, basis, BB_d) - @inbounds for i = 1:length(A_d) - ∇_∂A[i] = extract_derivative(T, A_d[i]) - end - @inbounds for i = 1:length(∂BB_d) - for j = 1:length(∂BB_d[i]) - ∇_BB[i][j] = extract_derivative(T, ∂BB_d[i][j]) - end - end - end - return ∇_∂A, ∇_BB -end - -function auto_pb_pb!(∇_∂A, ∇_BB, ∂BB, ∂A, basis::PooledSparseProduct{3}, BB) - @assert all(eltype(BB[i]) == eltype(BB[1]) for i = 2:length(BB)) - @no_escape begin - T = eltype(BB[1]) - d = Dual{T}(zero(T), one(T)) - TD = typeof(d) - B1 = BB[1] - B2 = BB[2] - B3 = BB[3] - B1_d = @alloc(TD, size(B1)...) - B2_d = @alloc(TD, size(B2)...) - B3_d = @alloc(TD, size(B3)...) - @inbounds for t = 1:length(B1) - B1_d[t] = B1[t] + d * ∂BB[1][t] - end - @inbounds for t = 1:length(B2) - B2_d[t] = B2[t] + d * ∂BB[2][t] - end - @inbounds for t = 1:length(B3) - B3_d[t] = B3[t] + d * ∂BB[3][t] - end - BB_d = (B1_d, B2_d, B3_d) - # A_d = @withalloc evaluate!(basis, BB_d) - # ∂BB_d = @withalloc pullback!(∂A, basis, BB_d) - A_d, ∂BB_d = @withalloc Polynomials4ML.pullback_x!(∂A, basis, BB_d) - @inbounds for i = 1:length(A_d) - ∇_∂A[i] = extract_derivative(T, A_d[i]) - end - @inbounds for i = 1:length(∂BB_d) - for j = 1:length(∂BB_d[i]) - ∇_BB[i][j] = extract_derivative(T, ∂BB_d[i][j]) - end - end - end - return ∇_∂A, ∇_BB -end - - -∇_∂A, ∇_BB = auto_pb_pb(∂2, ∂A, basis, bBB); -∇_∂A2 = deepcopy(∇_∂A); ∇_BB2 = deepcopy(∇_BB) -auto_pb_pb!(∇_∂A2, ∇_BB2, ∂2, ∂A, basis, bBB) - -# @info("pb² for PooledSparseProduct, order = $ORDER, len = $(length(basis))") -# print("pullback : ") -# @btime Polynomials4ML.pullback($∂A, $basis, $bBB) -# print(" pullback2 : ") -# @btime Polynomials4ML.pullback2($∂2, $∂A, $basis, $bBB); -# print(" auto_pb_pb : ") -# @btime auto_pb_pb($∂2, $∂A, $basis, $bBB); -# print(" auto_pb_pb! : ") -# @btime auto_pb_pb!($∇_∂A2, $∇_BB2, $∂2, $∂A, $basis, $bBB); - -## - -∇_∂A1, ∇_BB1 = auto_pb_pb(∂2, ∂A, basis, bBB) -∇_∂A2, ∇_BB2 = P4ML.pullback2(∂2, ∂A, basis, bBB) -∇_∂A3 = deepcopy(∇_∂A2); ∇_BB3 = deepcopy(∇_BB2) -auto_pb_pb!(∇_∂A3, ∇_BB3, ∂2, ∂A, basis, bBB) -∇_∂A1 ≈ ∇_∂A2 ≈ ∇_∂A3 -all(∇_BB1 .≈ ∇_BB2 .≈ ∇_BB3) ## @info("Testing pushforward for PooledSparseProduct")