Skip to content

Commit

Permalink
Merge pull request #80 from ACEsuit/frules
Browse files Browse the repository at this point in the history
WIP: First set of Pushforwards
  • Loading branch information
cortner authored Dec 21, 2023
2 parents 5bff8f6 + 58abc1e commit 6a477ec
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 0 deletions.
9 changes: 9 additions & 0 deletions docs/src/experimental.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,12 @@ We implement custom pullbacks for most bases. These take the form
pb_evaluate!(∂X, basis, ∂B, X, args...)
```
and analogously for the `evaluate_***` variants. The `args...` can differ between different basis sets e.g. may rely on intermediate results in the evaluation of the basis. The `rrule` implementations are wrappers for these.

## Explicit Forward Mode Differentiation

We have started to implement custom pushforwards. These take the form
```julia
B, ∂B = pfwd_evaluate(basis, X, ΔX)
pfwd_evaluate!(B, ∂B, basis, X, ΔX)
```
and analogously for other functions. There are currently no `frule` wrappers, but we plan to provide these in due course.
57 changes: 57 additions & 0 deletions src/ace/sparseprodpool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,63 @@ end
# end


# --------------------- Pushforwards

# This implementation of the pushforward doesn't yet do batching
# this means that the output is a single vector A, the inputs
# BB[i] are nX x nBi -> with the nX pooled
# It is ASSUMED that ∂BB[i][j, :] / ∂X[j'] = 0 if j ≠ j'
# Therefore ΔBB[i] are also nX x nBi
# This is a simplification that may have to be revisited at some point.
#
# The output will be A, ∂A where
# A is size (nA,)
# ∂A is size (nA, nX)

_my_promote_type(args...) = promote_type(args...)
_my_promote_type(T1::Type{<: Number}, T2::Type{SVector{N, S}}, args...
) where {N, S} =
promote_type(SVector{N, T1}, T2, args...)



function pfwd_evaluate(basis::PooledSparseProduct, BB, ΔBB)
@assert length(size(BB[1])) == 2
@assert length(size(ΔBB[1])) == 2
@assert all(size(BB[t]) == size(ΔBB[t]) for t = 1:length(BB))

nX = size(ΔBB[1], 1)
nA = length(basis)

TA = promote_type(eltype.(BB)...)
A = acquire!(basis.pool, :A, (nA,), TA)
fill!(A, zero(TA))

T∂A = _my_promote_type(TA, eltype.(ΔBB)...)
∂A = acquire!(basis.pool, :∂A_pfwd, (nA, nX), T∂A)
fill!(∂A, zero(T∂A))

return pfwd_evaluate!(A, ∂A, basis, BB, ΔBB)
end

function pfwd_evaluate!(A, ∂A, basis::PooledSparseProduct{NB}, BB, ΔBB) where {NB}
nX = size(BB[1], 1)
for (i, ϕ) in enumerate(basis.spec)
for j = 1:nX
bb = ntuple(t -> BB[t][j, ϕ[t]], NB)
Δbb = ntuple(t -> ΔBB[t][j, ϕ[t]], NB)
∏bb, ∇∏bb = Polynomials4ML._static_prod_ed(bb)
A[i] += prod(bb)
@inbounds for t = 1:NB
∂A[i, j] += ∇∏bb[t] * Δbb[t]
end
end
end
return A, ∂A
end



# --------------------- connect with ChainRules
# todo ...

Expand Down
38 changes: 38 additions & 0 deletions src/ace/sparsesymmprod.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,44 @@ function _pb_pb_evaluate_AA!(spec::Vector{NTuple{N, Int}},
end


# -------------- Pushforwards / frules

function pfwd_evaluate(basis::SparseSymmProd,
A::AbstractVector{<: Number},
ΔA::AbstractMatrix)
nAA = length(basis)
TAA = eltype(A)
AA = acquire!(basis.pool, :AA, (nAA,), TAA)
T∂AA = _my_promote_type(TAA, eltype(ΔA))
∂AA = acquire!(basis.pool, :∂AA, (nAA, size(ΔA, 2)), T∂AA)
fill!(∂AA, zero(T∂AA))
pfwd_evaluate!(unwrap(AA), unwrap(∂AA), basis, A, ΔA)
return AA, ∂AA
end


@generated function pfwd_evaluate!(AA, ∂AA, basis::SparseSymmProd{NB}, A, ΔA) where {NB}
quote
if basis.hasconst; error("no implementation with hasconst"); end
Base.Cartesian.@nexprs $NB N -> _pfwd_AA_N!(AA, ∂AA, A, ΔA, basis.ranges[N], basis.specs[N])
return AA, ∂AA
end
end

function _pfwd_AA_N!(AA, ∂AA, A, ΔA,
rg_N, spec_N::Vector{NTuple{N, Int}}) where {N}
nX = size(ΔA, 2)
for (i, bb) in zip(rg_N, spec_N)
aa = ntuple(t -> A[bb[t]], N)
∏aa, ∇∏aa = Polynomials4ML._static_prod_ed(aa)
AA[i] = ∏aa
for t = 1:N, j = 1:nX
∂AA[i, j] += ∇∏aa[t] * ΔA[bb[t], j]
end
end
end


# -------------- Lux integration
# it needs an extra lux interface reason as in the case of the `basis`
# should it not be enough to just overload valtype?
Expand Down
24 changes: 24 additions & 0 deletions test/ace/test_prodbasis2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,27 @@ println_slim(@test l_AA2 ≈ basis2(bA))

println()
##


@info("Testing basic pushforward")

using ForwardDiff

for ntest = 1:10
local M, nX, spec, A, basis, AA1, AA2
M = rand(4:7)
BO = rand(2:5)
nX = rand(6:12)
spec = generate_SO2_spec(BO, 2*M)
A = randn(Float64, 2*M+1)
ΔA = randn(length(A), nX)

basis = SparseSymmProd(spec)
AA1 = basis(A)
∂AA1 = ForwardDiff.jacobian(basis, A) * ΔA
AA2, ∂AA2 = P4ML.pfwd_evaluate(basis, A, ΔA)
print_tf( @test AA1 AA2 )
print_tf( @test ∂AA1 ∂AA2 )
end

##
59 changes: 59 additions & 0 deletions test/ace/test_sparseprodpool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ function _generate_basis(; order=3, len = 50)
return PooledSparseProduct(spec)
end


function _rand_input1(basis::PooledSparseProduct{ORDER}) where {ORDER}
NN = [ maximum(b[i] for b in basis.spec) for i = 1:ORDER ]
BB = ntuple(i -> randn(NN[i]), ORDER)
Expand Down Expand Up @@ -138,3 +139,61 @@ for ntest = 1:20
end
println()


##
@info("Testing pushforward for PooledSparseProduct")

using ForwardDiff

function _rand_input1_pfwd(basis::PooledSparseProduct{ORDER};
nX = rand(7:12)) where {ORDER}
NN = [ maximum(b[i] for b in basis.spec) for i = 1:ORDER ]
BB = ntuple(i -> randn(nX, NN[i]), ORDER)
ΔBB = ntuple(i -> randn(nX, NN[i]), ORDER)
return BB, ΔBB
end

function fwddiff1_pfwd(basis::PooledSparseProduct{NB}, BB, ΔBB) where {NB}
A1 = basis(BB)
sub_i(t, ti, i) = ntuple(a -> a == i ? ti : t[a], length(t))
∂A1_i = [ ForwardDiff.jacobian(B -> basis(sub_i(BB, B, i)), BB[i])
for i = 1:NB ]
∂A1 = sum(∂A1_i[i] * ΔBB[i] for i = 1:NB)
return A1, ∂A1
end

function fwddiff_pfwd(basis::PooledSparseProduct{NB}, BB, ΔBB) where {NB}
nX = size(BB[1], 1)
Aj_∂Aj = [ fwddiff1_pfwd(basis,
ntuple(t -> BB[t][j,:], NB),
ntuple(t -> ΔBB[t][j,:], NB), )
for j = 1:nX ]
Aj = [ x[1] for x in Aj_∂Aj ]
∂Aj = [ x[2] for x in Aj_∂Aj ]
A = sum(Aj)
∂A = reduce(hcat, ∂Aj)
return A, ∂A
end


for ntest = 1:10
local order, basis, BB, ΔBB, A1, ∂A1, A2, ∂A2
order = rand(2:4)
basis = _generate_basis(; order=order)
BB, ΔBB = _rand_input1_pfwd(basis)
A1, ∂A1 = fwddiff_pfwd(basis, BB, ΔBB)
A2, ∂A2 = P4ML.pfwd_evaluate(basis, BB, ΔBB)
print_tf(@test A2 A1)
print_tf(@test ∂A2 ∂A1)
end

##

# # quick performance and allocation check
# using ObjectPools: unwrap
# order = 3
# basis = _generate_basis(; order=order)
# BB, ΔBB = _rand_input1_pfwd(basis)
# A, ∂A = P4ML.pfwd_evaluate(basis, BB, ΔBB)
# @btime Polynomials4ML.pfwd_evaluate!($(unwrap(A)), $(unwrap(∂A)), $basis, $BB, $ΔBB)

0 comments on commit 6a477ec

Please sign in to comment.