Skip to content

Commit

Permalink
Merge pull request #110 from ACEsuit/co/xscal
Browse files Browse the repository at this point in the history
General Adaptive Scalar Basis + Sparsification
  • Loading branch information
cortner authored Feb 2, 2022
2 parents 63f6243 + b6f85ad commit bb2f9dd
Show file tree
Hide file tree
Showing 17 changed files with 772 additions and 55 deletions.
2 changes: 2 additions & 0 deletions src/ACE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ include("Rn1pbasis.jl")
include("scal1pbasis.jl")
include("discrete1pbasis.jl")

include("xscal1pbasis.jl")

include("product_1pbasis.jl")

# basis selectors used to specify finite subsets of basis functions
Expand Down
26 changes: 14 additions & 12 deletions src/Rn1pbasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,6 @@ function Rn1pBasis{T, TT, TJ, VSYM, NSYM}(R::TransformedPolys{T, TT, TJ},
end



# # -------- temporary hack for 1.6, should not be needed from 1.7 onwards

# function acquire_B!(basis::Rn1pBasis, args...)
# VT = valtype(basis, args...)
# return acquire!(basis.B_pool, length(basis), VT)
# end

# function release_B!(basis::Rn1pBasis, B)
# return release!(basis.B_pool, B)
# end

# ---------------------- Implementation of Rn1pBasis

Base.length(basis::Rn1pBasis) = length(basis.R)
Expand All @@ -72,10 +60,24 @@ function Base.show(io::IO, basis::Rn1pBasis)
end


# ------------- specs and sparsification

get_spec(basis::Rn1pBasis, n::Integer) = NamedTuple{(_nsym(basis),)}((n,))

get_spec(basis::Rn1pBasis) = get_spec.(Ref(basis), 1:length(basis))


function sparsify!(basis::Rn1pBasis, spec)
maxn = maximum(_n(b, basis) for b in spec)
# generate a new radial basis
if maxn > length(basis)
basis.P = ACE.OrthPolys.TransformedPolys(maxn, basis.P)
end
return basis
end

# ---------------

==(P1::Rn1pBasis, P2::Rn1pBasis) =
( (P1.R == P2.R) && (typeof(P1) == typeof(P2)) )

Expand Down
17 changes: 13 additions & 4 deletions src/Ylm1pbasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ end

get_spec(basis::Ylm1pBasis) = get_spec.(Ref(basis), 1:length(basis))

function Base.show(io::IO, basis::Ylm1pBasis)
print(io, "Ylm1pBasis{$(_varsym(basis)), $(_lsym(basis)), $(_msym(basis))}($(basis.SH.alp.L), \"$(basis.label)\")")
end


function sparsify!(basis::Ylm1pBasis{T}, keep) where {T}
maxL = maximum(_l(b, basis) for b in keep)
basis.SH = SHBasis(maxL, T)
return basis
end

# function get_spec(basis::Ylm1pBasis{T, VS, L, M}) where {T, VS, L, M}
# @assert length(basis) == (_maxL(basis) + 1)^2
Expand All @@ -87,6 +88,14 @@ end
# return spec
# end

#------


function Base.show(io::IO, basis::Ylm1pBasis)
print(io, "Ylm1pBasis{$(_varsym(basis)), $(_lsym(basis)), $(_msym(basis))}($(basis.SH.alp.L), \"$(basis.label)\")")
end


==(P1::Ylm1pBasis, P2::Ylm1pBasis) =
( (P1.SH == P2.SH) && (typeof(P1) == typeof(P2)) )

Expand Down
28 changes: 28 additions & 0 deletions src/auxiliary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,31 @@ function _read_dict_1pspec(D::Dict)
NTPROTO = namedtuple(D["SYMS"]...)
return [ NTPROTO(binds) for binds in D["inds"] ]
end



#
# This is a useful utility function used e.g. in scal1pbasis.jl
#

"""
returns an `SVector{N}` of the form `x * e_I` where `e_I` is the Ith canonical basis vector.
"""
@generated function __e(::SVector{N}, ::Val{I}, x::T) where {N, I, T}
code = "SA["
for i = 1:N
if i == I
code *= "x,"
else
code *= "0,"
end
end
code *= "]"
quote
$( Meta.parse(code) )
end
end

__e(xx::SVector{N, T}, valI::Val{I}) where {N, T, I} = __e(xx, valI, one(T))

__e(::Number, ::Any, x) = x
41 changes: 41 additions & 0 deletions src/pibasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,18 @@ sparsify(spec::PIBasisSpec, Ikeep::AbstractVector{<: Integer}) =



function _fix_A_indices!(spec::PIBasisSpec, new_inds::AbstractVector{<: Integer})
for iAA = 1:size(spec.iAA2iA, 1)
for α = 1:spec.orders[iAA]
= spec.iAA2iA[iAA, α]
new_vα = new_inds[vα]
@assert new_vα > 0
spec.iAA2iA[iAA, α] = new_vα
end
end
return nothing
end

# --------------------------------- PIBasis implementation


Expand Down Expand Up @@ -193,6 +205,35 @@ function sparsify!(basis::PIBasis, Ikeep::AbstractVector{<: Integer})
return basis
end

"""
This should allow the 1p basis to sparsify itself, then feed back to the
pibasis what the correct indices are.
"""
function clean_1pbasis!(basis::PIBasis)
spec = get_spec(basis)
B1p = basis.basis1p
spec1p = NamedTuple[]
for bb in spec
append!(spec1p, bb)
end
identity.(unique!(spec1p))
# sparsify the product 1p basis
_, new_inds = sparsify!(basis.basis1p, spec1p)
# now fix the indexing of the PIBasis specification
_fix_A_indices!(basis.spec, new_inds)
return basis
end

# syms = symbols(B1p)
# rgs = Dict{Symbol, Any}([sym => [] for sym in syms]...)
# for bb in spec, b in bb, sym in keys(b)
# push!(rgs[sym], getproperty(b, sym))
# end
# for sym in syms
# rgs[sym] = identity.(unique(rgs[sym]))
# end


# -------------------


Expand Down
18 changes: 13 additions & 5 deletions src/polynomials/orthpolys.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ OrthPolyBasis(pl, tl::T, pr, tr::T, A::Vector{T}, B::Vector{T}, C::Vector{T},
OrthPolyBasis(pl, tl, pr, tr, A, B, C, tdf, ww,
VectorPool{T}(), VectorPool{T}())


valtype(P::OrthPolyBasis{T}, x::TX = one(T)) where {T, TX <: Number} =
promote_type(T, TX)

Expand All @@ -131,14 +130,16 @@ write_dict(J::OrthPolyBasis{T}) where {T} = Dict(
"tl" => J.tl,
"A" => J.A,
"B" => J.B,
"C" => J.C
"C" => J.C,
"tdf" => J.tdf,
"ww" => J.ww
)

OrthPolyBasis(D::Dict, T=read_dict(D["T"])) =
OrthPolyBasis(
D["pl"], D["tl"], D["pr"], D["tr"],
Vector{T}(D["A"]), Vector{T}(D["B"]), Vector{T}(D["C"]),
T[], T[]
T.(D["tdf"]), T.(D["ww"])
)

read_dict(::Val{:ACE_OrthPolyBasis}, D::Dict) = OrthPolyBasis(D)
Expand All @@ -151,6 +152,9 @@ function ACE.rand_radial(J::OrthPolyBasis)
return rand(J.tdf)
end

OrthPolyBasis(N::Integer, J::OrthPolyBasis) =
OrthPolyBasis(N, J.pcut, J.tcut, J.pin, J.tin, J.tdf, J.ww)

function OrthPolyBasis(N::Integer,
pcut::Integer,
tcut::T,
Expand All @@ -171,7 +175,7 @@ function OrthPolyBasis(N::Integer,
end

if minimum(tdf) < tl || maximum(tdf) > tr
@warn("OrthoPolyBasis: t range outside [tl, tr]")
@warn("OrthPolyBasis: t range outside [tl, tr]")
end

A = zeros(T, N)
Expand Down Expand Up @@ -288,7 +292,7 @@ end
A utility function to generate a jacobi-type basis
"""
function discrete_jacobi(N; pcut=0, tcut=1.0, pin=0, tin=-1.0, Nquad = 1000)
function discrete_jacobi(N; pcut=0, tcut=1.0, pin=0, tin=-1.0, Nquad = 3 * N)
tl, tr = minmax(tin, tcut)
dt = (tr - tl) / Nquad
tdf = range(tl + dt/2, tr - dt/2, length=Nquad)
Expand All @@ -315,6 +319,10 @@ function TransformedPolys(J::OrthPolyBasis{T}, trans, rl, ru) where {T}
return TransformedPolys(J, trans, T(rl), T(ru), B_pool, B_pool)
end

TransformedPolys(maxN::Integer, P::TransformedPolys) =
TransformedPolys(OrthPolyBasis(maxN, P.J), P.trans, P.rl, P.ru)


==(J1::TransformedPolys, J2::TransformedPolys) = (
(J1.J == J2.J) &&
(J1.trans == J2.trans) &&
Expand Down
40 changes: 40 additions & 0 deletions src/polynomials/snap.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#
# WORK IN PROGRESS - CURRENTLY NOT INCLUDED
#
#

module SNAP

"""
`SNAPRnlBasis:` define radial basis for hyperspherical harmonics.
```math
R_{nl}(r) = (-i)^l 2^{l+1/2} l! \cdot
\bigg( \frac{(n+1) (n-l)!}{\pi (n+l+1)!} \bigg)^{1/2} \cdot
\sin^l\Big( \frac{\pi r}{2 r_0}\Big) \cdot
C_{n-l}^{l+1}\Big( \cos\Big(\frac{\pi r}{2 r_0}\Big)\Big).
```
"""
struct SNAPRnlBasis{T, VSYM, LSYM, MSYM}
maxn::Int
maxl::Int
spec
B_pool::VectorPool{Complex{T}}
dB_pool::VectorPool{T}
end

_varsym(::SNAPRnlBasis{T, VSYM, LSYM, MSYM}) where {T, VSYM, LSYM, MSYM} = VSYM
_nsym(::SNAPRnlBasis{T, VSYM, LSYM, MSYM}) where {T, VSYM, LSYM, MSYM} = NSYM
_n(b, basis::SNAPRnlBasis) = getproperty(b, _nsym(basis))
_rr(X, Rn::SNAPRnlBasis) = getproperty(X, _varsym(Rn))




function evaluate!(Rnl, basis::SNAPRnlBasis, X::AbstractState)
rr = _rr(X)
r = norm(rr)


end

end
49 changes: 49 additions & 0 deletions src/product_1pbasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,57 @@ function rand_radial(basis::Product1pBasis)
return nothing
end

# -------------- sparsification

function sparsify!(basis1p::Product1pBasis, keep::AbstractVector{<: NamedTuple})
# spec, keep, new_spec will be lists of named tuples,
# e.g. [ (n = , l = , m = ), ... ]
spec = get_spec(basis1p)
new_spec = eltype(spec)[]
new_inds = Vector{Int}(undef, length(spec))
for (ib, b) in enumerate(spec)
if b in keep
push!(new_spec, b)
new_inds[ib] = length(new_spec)
end
end

# now we need to recompute the indices array, this can be easily done via
# set_spec!(basis::Product1pBasis{NB}, spec), but before we do that
# we should sparsify the basis components as well
for bas_i in basis1p.bases
_sparsify_component!(bas_i, new_spec)
end

# finally fix the basis1pspec internally:
set_spec!(basis1p, new_spec)

# return the old to new index mapping so that the pibasis can fix itself.
return basis1p, new_inds
end


using NamedTupleTools: select

"""
this performs some generic work to sparsify a 1p-basis component.
but the actual sparsificatin happens in the individual basis implementations
"""
function _sparsify_component!(basis1p, keep)
# get rid of all info we don't need
syms = symbols(basis1p)
keep1 = unique( select.(keep, Ref(syms)) )
# double-check that keep1 is compatible
spec = get_spec(basis1p)
@assert all(b in spec for b in keep1)
# now get the basis spec and get the list of indices to keep
if length(keep1) < length(spec)
# Ikeep = findall( [b in keep1 for b in spec] )
# sparsify!(basis1p, Ikeep)
sparsify!(basis1p, keep1)
end
return basis1p
end


# --------------- AD codes
Expand Down
Loading

0 comments on commit bb2f9dd

Please sign in to comment.