Skip to content

Commit

Permalink
Make LOBPCG GPU-compatible (#711)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael F. Herbst <[email protected]>
  • Loading branch information
GVigne and mfherbst authored Sep 28, 2022
1 parent a96f551 commit d28391e
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 36 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ version = "0.5.8"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Brillouin = "23470ee3-d0df-4052-8b1a-8cbd6363e7f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DftFunctionals = "6bd331d2-b28d-4fd3-880e-1a1c7f37947f"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
InteratomicPotentials = "a9efe35a-c65d-452d-b8a8-82646cd5cb04"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
Expand Down Expand Up @@ -48,13 +49,14 @@ spglib_jll = "ac4a9f1e-bdb2-5204-990c-47c8b2f70d4e"
[compat]
AbstractFFTs = "1"
AtomsBase = "0.2.2"
BlockArrays = "0.16.2"
Brillouin = "0.5 - 0.5.8" # Upper bound temporary until memory bug resolved.
ChainRulesCore = "1.15"
Conda = "1"
CUDA = "3"
DftFunctionals = "0.2"
FFTW = "1"
ForwardDiff = "0.10"
GPUArraysCore = "0.1"
InteratomicPotentials = "0.2"
Interpolations = "0.12, 0.13, 0.14"
IterTools = "1"
Expand Down
1 change: 1 addition & 0 deletions src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ include("common/mpi.jl")
include("common/threading.jl")
include("common/printing.jl")
include("common/cis2pi.jl")
include("common/zeros_like.jl")

export PspHgh
include("pseudo/NormConservingPsp.jl")
Expand Down
9 changes: 9 additions & 0 deletions src/common/zeros_like.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Create an array of same type as X filled with zeros, minimizing the number
# of allocations.
function zeros_like(X::AbstractArray, T::Type=eltype(X), dims::Integer...=size(X)...)
Z = similar(X, T, dims...)
Z .= 0
Z
end
zeros_like(X::AbstractArray, dims::Integer...) = zeros_like(X, eltype(X), dims...)
zeros_like(X::Array, T::Type=eltype(X), dims::Integer...=size(X)...) = zeros(T, dims...)
134 changes: 100 additions & 34 deletions src/eigen/lobpcg_hyper_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,80 @@
vprintln(args...) = nothing

using LinearAlgebra
using BlockArrays # used for the `mortar` command which makes block matrices
import Base: *
import Base.size, Base.adjoint, Base.Array
include("../workarounds/gpu_arrays.jl")

"""
Simple wrapper to represent a matrix formed by the concatenation of column blocks:
it is mostly equivalent to hcat, but doesn't allocate the full matrix.
LazyHcat only supports a few multiplication routines: furthermore, a multiplication
involving this structure will always yield a plain array (and not a LazyHcat structure).
LazyHcat is a lightweight subset of BlockArrays.jl's functionalities, but has the
advantage to be able to store GPU Arrays (BlockArrays is heavily built on Julia's CPU Array).
"""
struct LazyHcat{T <: Number, D <: Tuple} <: AbstractMatrix{T}
blocks::D
end

function LazyHcat(arrays::AbstractArray...)
@assert length(arrays) != 0
n_ref = size(arrays[1], 1)
@assert all(size.(arrays, 1) .== n_ref)

T = promote_type(map(eltype, arrays)...)

LazyHcat{T, typeof(arrays)}(arrays)
end

function Base.size(A::LazyHcat)
n = size(A.blocks[1], 1)
m = sum(size(block, 2) for block in A.blocks)
(n,m)
end

Base.Array(A::LazyHcat) = hcat(A.blocks...)

Base.adjoint(A::LazyHcat) = Adjoint(A)

@views function Base.:*(Aadj::Adjoint{T, <: LazyHcat}, B::LazyHcat) where {T}
A = Aadj.parent
rows = size(A)[2]
cols = size(B)[2]
ret = similar(A.blocks[1], rows, cols)

orow = 0 # row offset
for (iA, blA) in enumerate(A.blocks)
ocol = 0 # column offset
for (iB, blB) in enumerate(B.blocks)
ret[orow .+ (1:size(blA, 2)), ocol .+ (1:size(blB, 2))] .= blA' * blB
ocol += size(blB, 2)
end
orow += size(blA, 2)
end
ret
end

Base.:*(Aadj::Adjoint{T, <: LazyHcat}, B::AbstractMatrix) where {T} = Aadj * LazyHcat(B)

@views function *(Ablock::LazyHcat, B::AbstractMatrix)
res = Ablock.blocks[1] * B[1:size(Ablock.blocks[1], 2), :] # First multiplication
offset = size(Ablock.blocks[1], 2)
for block in Ablock.blocks[2:end]
mul!(res, block, B[offset .+ (1:size(block, 2)), :], 1, 1)
offset += size(block, 2)
end
res
end

# when X or Y are BlockArrays, this makes the return value be a proper array (not a BlockArray)
function array_mul(X::AbstractArray{T}, Y) where {T}
Z = Array{T}(undef, size(X, 1), size(Y, 2))
mul!(Z, X, Y)
function LinearAlgebra.mul!(res::AbstractMatrix, Ablock::LazyHcat,
B::AbstractVecOrMat, α::Number, β::Number)
mul!(res, Ablock*B, I, α, β)
end

# Perform a Rayleigh-Ritz for the N first eigenvectors.
@timing function rayleigh_ritz(X, AX, N)
XAX = array_mul(X', AX)
XAX = X' * AX
@assert all(!isnan, XAX)
F = eigen(Hermitian(XAX))
F.vectors[:,1:N], F.values[1:N]
Expand Down Expand Up @@ -174,16 +237,16 @@ end
niter = 1
ninners = zeros(Int,0)
while true
BYX = BY'X
# XXX the one(T) instead of plain old 1 is because of https://github.com/JuliaArrays/BlockArrays.jl/issues/176
mul!(X, Y, BYX, -one(T), one(T)) # X -= Y*BY'X
BYX = BY' * X
mul!(X, Y, BYX, -1, 1) # X -= Y*BY'X
# If the orthogonalization has produced results below 2eps, we drop them
# This is to be able to orthogonalize eg [1;0] against [e^iθ;0],
# as can happen in extreme cases in the ortho!(cP, cX)
dropped = drop!(X)
if dropped != []
@views mul!(X[:, dropped], Y, BY' * (X[:, dropped]), -one(T), one(T)) # X -= Y*BY'X
X[:, dropped] .-= Y * (BY' * X[:, dropped])
end

if norm(BYX) < tol && niter > 1
push!(ninners, 0)
break
Expand Down Expand Up @@ -219,11 +282,9 @@ function final_retval(X, AX, resid_history, niter, n_matvec)
residuals = AX .- X*Diagonal(λ)
=λ, X=X,
residual_norms=[norm(residuals[:, i]) for i in 1:size(residuals, 2)],
residual_history=resid_history[:, 1:niter+1],
n_matvec=n_matvec)
residual_history=resid_history[:, 1:niter+1], n_matvec=n_matvec)
end


### The algorithm is Xn+1 = rayleigh_ritz(hcat(Xn, A*Xn, Xn-Xn-1))
### We follow the strategy of Hetmaniuk and Lehoucq, and maintain a B-orthonormal basis Y = (X,R,P)
### After each rayleigh_ritz step, the B-orthonormal X and P are deduced by an orthogonal rotation from Y
Expand All @@ -234,6 +295,7 @@ end
miniter=1, ortho_tol=2eps(real(eltype(X))),
n_conv_check=nothing, display_progress=false)
N, M = size(X)

# If N is too small, we will likely get in trouble
error_message(verb) = "The eigenproblem is too small, and the iterative " *
"eigensolver $verb fail; increase the number of " *
Expand All @@ -252,7 +314,7 @@ end
B_ortho!(X, BX)
end

n_matvec = M # Count number of matrix-vector products
n_matvec = M # Count number of matrix-vector products
AX = similar(X)
AX = mul!(AX, A, X)
@assert all(!isnan, AX)
Expand All @@ -274,7 +336,8 @@ end
end
nlocked = 0
niter = 0 # the first iteration is fake
λs = @views [(X[:,n]'*AX[:,n]) / (X[:,n]'BX[:,n]) for n=1:M]
λs = @views [(X[:, n]'*AX[:, n]) / (X[:, n]'BX[:, n]) for n=1:M]
λs = oftype(X[:, 1], λs) # Offload to GPU if needed
new_X = X
new_AX = AX
new_BX = BX
Expand All @@ -290,23 +353,23 @@ end

# Form Rayleigh-Ritz subspace
if niter > 1
Y = mortar((X, R, P))
AY = mortar((AX, AR, AP))
BY = mortar((BX, BR, BP)) # data shared with (X, R, P) in non-general case
Y = LazyHcat(X, R, P)
AY = LazyHcat(AX, AR, AP)
BY = LazyHcat(BX, BR, BP) # data shared with (X, R, P) in non-general case
else
Y = mortar((X, R))
AY = mortar((AX, AR))
BY = mortar((BX, BR)) # data shared with (X, R) in non-general case
Y = LazyHcat(X, R)
AY = LazyHcat(AX, AR)
BY = LazyHcat(BX, BR) # data shared with (X, R) in non-general case
end
cX, λs = rayleigh_ritz(Y, AY, M-nlocked)

# Update X. By contrast to some other implementations, we
# wait on updating P because we have to know which vectors
# to lock (and therefore the residuals) before computing P
# only for the unlocked vectors. This results in better convergence.
new_X = array_mul(Y, cX)
new_AX = array_mul(AY, cX) # no accuracy loss, since cX orthogonal
new_BX = (B == I) ? new_X : array_mul(BY, cX)
new_X = Y * cX
new_AX = AY * cX # no accuracy loss, since cX orthogonal
new_BX = (B == I) ? new_X : BY * cX
end

### Compute new residuals
Expand All @@ -320,7 +383,7 @@ end
vprintln(niter, " ", resid_history[:, niter+1])
if precon !== I
@timing "preconditioning" begin
precondprep!(precon, X) # update preconditioner if needed; defaults to noop
precondprep!(precon, X) # update preconditioner if needed; defaults to noop
ldiv!(precon, new_R)
end
end
Expand Down Expand Up @@ -360,20 +423,23 @@ end
# orthogonalization, see Hetmaniuk & Lehoucq, and Duersch et. al.
# cP = copy(cX)
# cP[Xn_indices,:] .= 0
e = zeros(eltype(X), size(cX, 1), M - prev_nlocked)
for i in 1:length(Xn_indices)
e[Xn_indices[i], i] = 1
end

lenXn = length(Xn_indices)
e = zeros_like(X, size(cX, 1), M - prev_nlocked)
lower_diag = one(similar(X, lenXn, lenXn))
# e has zeros everywhere except on one of its lower diagonal
e[Xn_indices[1]:last(Xn_indices), 1:lenXn] = lower_diag

cP = cX .- e
cP = cP[:, Xn_indices]
# orthogonalize against all Xn (including newly locked)
ortho!(cP, cX, cX, tol=ortho_tol)

# Get new P
new_P = array_mul( Y, cP)
new_AP = array_mul(AY, cP)
new_P = Y * cP
new_AP = AY * cP
if B != I
new_BP = array_mul(BY, cP)
new_BP = BY * cP
else
new_BP = new_P
end
Expand Down Expand Up @@ -418,8 +484,8 @@ end

# Orthogonalize R wrt all X, newly active P
if niter > 0
Z = mortar((full_X, P))
BZ = mortar((full_BX, BP)) # data shared with (full_X, P) in non-general case
Z = LazyHcat(full_X, P)
BZ = LazyHcat(full_BX, BP) # data shared with (full_X, P) in non-general case
else
Z = full_X
BZ = full_BX
Expand Down
19 changes: 19 additions & 0 deletions src/workarounds/gpu_arrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# TODO: remove this when it is implemented in GPUArrays and CUDA
import LinearAlgebra.dot, LinearAlgebra.eigen
using LinearAlgebra
using GPUArraysCore
using CUDA

# https://github.com/JuliaGPU/CUDA.jl/issues/1565
LinearAlgebra.dot(x::AbstractGPUArray, D::Diagonal,y::AbstractGPUArray) = x'*(D*y)

# https://github.com/JuliaGPU/CUDA.jl/issues/1572
function LinearAlgebra.eigen(A::Hermitian{T,AT}) where {T <: Complex,AT <: CuArray}
vals, vects = CUDA.CUSOLVER.heevd!('V','U', A.data)
(vectors = vects, values = vals)
end

function LinearAlgebra.eigen(A::Hermitian{T,AT}) where {T <: Real,AT <: CuArray}
vals, vects = CUDA.CUSOLVER.syevd!('V','U', A.data)
(vectors = vects, values = vals)
end
20 changes: 20 additions & 0 deletions test/lobpcg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,23 @@ end
@test res1.λ[ik] res2.λ[ik] atol=1e-6
end
end

@testset "LOBPCG Internal data structures" begin
a1 = rand(10, 5)
a2 = rand(10, 2)
a3 = rand(10, 7)
b1 = rand(10, 6)
b2 = rand(10, 2)
A = hcat(a1,a2,a3)
B = hcat(b1,b2)
Ablock = DFTK.LazyHcat(a1, a2, a3)
Bblock = DFTK.LazyHcat(b1, b2)
@test Ablock'*Bblock A'*B
@test Ablock'*B A'*B

C = rand(14, 4)
@test Ablock*C A*C

D = rand(10, 4)
@test mul!(D,Ablock, C, 1, 0) A*C
end

0 comments on commit d28391e

Please sign in to comment.