Skip to content

Commit

Permalink
Merge pull request #77 from ACEsuit/pool3
Browse files Browse the repository at this point in the history
Custom Sparse Prod Pool kernel for 3 Embeddings
  • Loading branch information
cortner authored Nov 15, 2023
2 parents 6c02075 + 70d16c2 commit 2df252c
Showing 1 changed file with 38 additions and 2 deletions.
40 changes: 38 additions & 2 deletions src/ace/sparseprodpool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ function _pullback_evaluate(∂A, basis::PooledSparseProduct{NB}, BB::TupMat) wh
return ∂BB
end

using Base.Cartesian: @nexprs

function _pullback_evaluate!(∂BB, ∂A, basis::PooledSparseProduct{NB}, BB::TupMat) where {NB}
nX = size(BB[1], 1)
Expand Down Expand Up @@ -301,9 +302,9 @@ function _pullback_evaluate!(∂BB, ∂A, basis::PooledSparseProduct{2}, BB::Tup

@inbounds for (iA, ϕ) in enumerate(basis.spec)
∂A_iA = ∂A[iA]
ϕ1 = ϕ[1]
ϕ2 = ϕ[2]
@simd ivdep for j = 1:nX
ϕ1 = ϕ[1]
ϕ2 = ϕ[2]
b1 = BB[1][j, ϕ1]
b2 = BB[2][j, ϕ2]
∂BB[1][j, ϕ1] = muladd(∂A_iA, b2, ∂BB[1][j, ϕ1])
Expand All @@ -313,6 +314,41 @@ function _pullback_evaluate!(∂BB, ∂A, basis::PooledSparseProduct{2}, BB::Tup
return nothing
end


function _pullback_evaluate!(∂BB, ∂A, basis::PooledSparseProduct{3}, BB::TupMat;
sizecheck = true)
nX = size(BB[1], 1)
NB = 3

if sizecheck
@assert all(nX <= size(BB[i], 1) for i = 1:NB)
@assert all(nX <= size(∂BB[i], 1) for i = 1:NB)
@assert all(size(∂BB[i], 2) >= size(BB[i], 2) for i = 1:NB)
@assert length(∂A) == length(basis)
@assert length(BB) == NB
@assert length(∂BB) == NB
end

B1 = BB[1]; B2 = BB[2]; B3 = BB[3]
∂B1 = ∂BB[1]; ∂B2 = ∂BB[2]; ∂B3 = ∂BB[3]

@inbounds for (iA, ϕ) in enumerate(basis.spec)
∂A_iA = ∂A[iA]
ϕ1 = ϕ[1]
ϕ2 = ϕ[2]
ϕ3 = ϕ[3]
@simd ivdep for j = 1:nX
b1 = B1[j, ϕ1]
b2 = B2[j, ϕ2]
b3 = B3[j, ϕ3]
∂B1[j, ϕ1] = muladd(∂A_iA, b2*b3, ∂B1[j, ϕ1])
∂B2[j, ϕ2] = muladd(∂A_iA, b1*b3, ∂B2[j, ϕ2])
∂B3[j, ϕ3] = muladd(∂A_iA, b1*b2, ∂B3[j, ϕ3])
end
end
return nothing
end

import ForwardDiff

function _pb_pb_evaluate(basis::PooledSparseProduct{NB}, ∂2,
Expand Down

0 comments on commit 2df252c

Please sign in to comment.