Skip to content

Commit

Permalink
Try #193:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Jun 3, 2020
2 parents 8ee1a6f + 5e35a9d commit bba4537
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 2 deletions.
60 changes: 58 additions & 2 deletions lib/cusolver/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,9 @@ end

using LinearAlgebra
using LinearAlgebra: BlasInt, checksquare
using LinearAlgebra.LAPACK: chkargsok
using LinearAlgebra.LAPACK: chkargsok, chklapackerror

using ..CUBLAS: cublasfill, cublasop, cublasside
using ..CUBLAS: cublasfill, cublasop, cublasside, unsafe_batch

function cusolverDnCreate()
handle = Ref{cusolverDnHandle_t}()
Expand Down Expand Up @@ -918,3 +918,59 @@ for (jname, bname, fname, elty, relty) in ((:syevjBatched!, :cusolverDnSsyevjBat
end
end
end

for (jname, fname, elty) in ((:potrsBatched!, :cusolverDnSpotrsBatched, :Float32),
(:potrsBatched!, :cusolverDnDpotrsBatched, :Float64),
(:potrsBatched!, :cusolverDnCpotrsBatched, :ComplexF32),
(:potrsBatched!, :cusolverDnZpotrsBatched, :ComplexF64)
)
@eval begin
# cusolverStatus_t
# cusolverDnSpotrsBatched(
# cusolverDnHandle_t handle,
# cublasFillMode_t uplo,
# int n,
# int nrhs,
# float *Aarray[],
# int lda,
# float *Barray[],
# int ldb,
# int *info,
# int batchSize);
function $jname(uplo::Char, A::Vector{<:CuMatrix{$elty}}, B::Vector{<:CuVecOrMat{$elty}})
if length(A) != length(B)
throw(DimensionMismatch(""))
end
# Set up information for the solver arguments
cuuplo = cublasfill(uplo)
n = checksquare(A[1])
if size(B[1], 1) != n
throw(DimensionMismatch("first dimension of B[i], $(size(B[1],1)), must match second dimension of A, $n"))
end
nrhs = size(B[1], 2)
# cuSOLVER's Remark 1: only nrhs=1 is supported.
if nrhs != 1
throw(ArgumentError("cuSOLVER only supports vectors for B"))
end
lda = max(1, stride(A[1], 2))
ldb = max(1, stride(B[1], 2))
batchSize = length(A)
devinfo = CuArray{Cint}(undef, 1)

Aptrs = unsafe_batch(A)
Bptrs = unsafe_batch(B)

# Run the solver
$fname(dense_handle(), cuuplo, n, nrhs, Aptrs, lda, Bptrs, ldb, devinfo, batchSize)

# Copy the solver info and delete the device memory
info = @allowscalar devinfo[1]
unsafe_free!(devinfo)
if info < 0
throw(ArgumentError("The $(info)th parameter is wrong"))
end

return B
end
end
end
46 changes: 46 additions & 0 deletions test/cusolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -426,4 +426,50 @@ k = 1
@test Array(h_r) Array(r)
end

@testset "potrsBatched!" begin
@testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
# Test lower
bA = [rand(elty, m, m) for i in 1:n]
bA = [bA[i]*bA[i]' for i in 1:n]
bB = [rand(elty, m) for i in 1:n]

# move to device
bd_A = CuArray{elty, 2}[]
bd_B = CuArray{elty, 1}[]
for i in 1:length(bA)
push!(bd_A, CuArray(bA[i]))
push!(bd_B, CuArray(bB[i]))
end

bd_X = CUSOLVER.potrsBatched!('L', bd_A, bd_B)
bh_X = [collect(bd_X[i]) for i in 1:n]

for i = 1:n
LinearAlgebra.LAPACK.potrs!('L', bA[i], bB[i])
@test bB[i] bh_X[i]
end

# Test upper
bA = [rand(elty, m, m) for i in 1:n]
bA = [bA[i]*bA[i]' for i in 1:n]
bB = [rand(elty, m) for i in 1:n]

# move to device
bd_A = CuArray{elty, 2}[]
bd_B = CuArray{elty, 1}[]
for i in 1:length(bA)
push!(bd_A, CuArray(bA[i]))
push!(bd_B, CuArray(bB[i]))
end

bd_X = CUSOLVER.potrsBatched!('U', bd_A, bd_B)
bh_X = [collect(bd_X[i]) for i in 1:n]

for i = 1:n
LinearAlgebra.LAPACK.potrs!('U', bA[i], bB[i])
@test bB[i] bh_X[i]
end
end
end

end

0 comments on commit bba4537

Please sign in to comment.