Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

adds wrappers for syevjBatched/heevjBatched family of CUSOLVER functions #695

Merged
merged 1 commit into from
Apr 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions src/solver/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -823,3 +823,61 @@ for (jname, bname, fname, elty, relty) in ((:sygvj!, :cusolverDnSsygvj_bufferSiz
end
end
end

for (jname, bname, fname, elty, relty) in ((:syevjBatched!, :cusolverDnSsyevjBatched_bufferSize, :cusolverDnSsyevjBatched, :Float32, :Float32),
(:syevjBatched!, :cusolverDnDsyevjBatched_bufferSize, :cusolverDnDsyevjBatched, :Float64, :Float64),
(:heevjBatched!, :cusolverDnCheevjBatched_bufferSize, :cusolverDnCheevjBatched, :ComplexF32, :Float32),
(:heevjBatched!, :cusolverDnZheevjBatched_bufferSize, :cusolverDnZheevjBatched, :ComplexF64, :Float64)
)
@eval begin
function $jname(jobz::Char,
uplo::Char,
A::CuArray{$elty};
tol::$relty=eps($relty),
max_sweeps::Int=100)

# Set up information for the solver arguments
cuuplo = cublasfill(uplo)
cujobz = cusolverjob(jobz)
n = checksquare(A)
lda = max(1, stride(A, 2))
batchSize = size(A,3)
W = CuArray{$relty}(undef, n,batchSize)
params = Ref{syevjInfo_t}(C_NULL)
devinfo = CuArray{Cint}(undef, batchSize)

# Initialize the solver parameters
cusolverDnCreateSyevjInfo(params)
cusolverDnXsyevjSetTolerance(params[], tol)
cusolverDnXsyevjSetMaxSweeps(params[], max_sweeps)

# Calculate the workspace size
lwork = @argout(CUSOLVER.$bname(dense_handle(), cujobz, cuuplo, n,
A, lda, W, out(Ref{Cint}(0)), params, batchSize))[]

# Run the solver
@workspace eltyp=$elty size=lwork work->begin
$fname(dense_handle(), cujobz, cuuplo, n, A, lda, W, work,
lwork, devinfo, params[], batchSize)
end

# Copy the solver info and delete the device memory
info = @allowscalar collect(devinfo)
unsafe_free!(devinfo)

# Double check the solver's exit status
for i = 1:batchSize
if info[i] < 0
throw(ArgumentError("The $(info)th parameter of the $(i)th solver is wrong"))
end
end

# Return eigenvalues (in W) and possibly eigenvectors (in A)
if jobz == 'N'
return W
elseif jobz == 'V'
return W, A
end
end
end
end
54 changes: 54 additions & 0 deletions test/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,60 @@ k = 1
@test Eig.values ≈ h_W
end

@testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
@testset "syevjBatched!" begin
# Generate a random symmetric/hermitian matrix
A = rand(elty, m,m,n)
A += permutedims(A, (2,1,3))
d_A = CuArray(A)

# Run the solver
local d_W, d_V
if( elty <: Complex )
d_W, d_V = CUSOLVER.heevjBatched!('V','U', d_A)
else
d_W, d_V = CUSOLVER.syevjBatched!('V','U', d_A)
end

# Pull it back to hardware
h_W = collect(d_W)
h_V = collect(d_V)

# Use non-GPU blas to estimate the eigenvalues as well
for i = 1:n
# Get our eigenvalues
Eig = eigen(LinearAlgebra.Hermitian(A[:,:,i]))

# Compare to the actual ones
@test Eig.values ≈ h_W[:,i]
@test abs.(Eig.vectors'*h_V[:,:,i]) ≈ I
end

# Do it all again, but with the option to not compute eigenvectors
d_A = CuArray(A)

# Run the solver
local d_W
if( elty <: Complex )
d_W = CUSOLVER.heevjBatched!('N','U', d_A)
else
d_W = CUSOLVER.syevjBatched!('N','U', d_A)
end

# Pull it back to hardware
h_W = collect(d_W)

# Use non-GPU blas to estimate the eigenvalues as well
for i = 1:n
# Get the reference results
Eig = eigen(LinearAlgebra.Hermitian(A[:,:,i]))

# Compare to the actual ones
@test Eig.values ≈ h_W[:,i]
end
end
end

@testset "svd with $method method" for
method in (CUSOLVER.QRAlgorithm, CUSOLVER.JacobiAlgorithm),
(_m, _n) in ((m, n), (n, m))
Expand Down