diff --git a/lib/cusolver/wrappers.jl b/lib/cusolver/wrappers.jl index 83d3afc3e3..925ecbcc99 100644 --- a/lib/cusolver/wrappers.jl +++ b/lib/cusolver/wrappers.jl @@ -256,9 +256,9 @@ end using LinearAlgebra using LinearAlgebra: BlasInt, checksquare -using LinearAlgebra.LAPACK: chkargsok +using LinearAlgebra.LAPACK: chkargsok, chklapackerror -using ..CUBLAS: cublasfill, cublasop, cublasside +using ..CUBLAS: cublasfill, cublasop, cublasside, unsafe_batch function cusolverDnCreate() handle = Ref{cusolverDnHandle_t}() @@ -918,3 +918,57 @@ for (jname, bname, fname, elty, relty) in ((:syevjBatched!, :cusolverDnSsyevjBat end end end + +for (jname, fname, elty) in ((:potrsBatched!, :cusolverDnSpotrsBatched, :Float32), + (:potrsBatched!, :cusolverDnDpotrsBatched, :Float64), + (:potrsBatched!, :cusolverDnCpotrsBatched, :ComplexF32), + (:potrsBatched!, :cusolverDnZpotrsBatched, :ComplexF64) + ) + @eval begin + # cusolverStatus_t + # cusolverDnSpotrsBatched( + # cusolverDnHandle_t handle, + # cublasFillMode_t uplo, + # int n, + # int nrhs, + # float *Aarray[], + # int lda, + # float *Barray[], + # int ldb, + # int *info, + # int batchSize); + function $jname(uplo::Char, A::Vector{<:CuMatrix{$elty}}, B::Vector{<:CuVecOrMat{$elty}}) + if length(A) != length(B) + throw(DimensionMismatch("")) + end + # Set up information for the solver arguments + cuuplo = cublasfill(uplo) + n = checksquare(A[1]) + if size(B[1], 1) != n + throw(DimensionMismatch("first dimension of B[i], $(size(B[1],1)), must match second dimension of A, $n")) + end + nrhs = size(B[1], 2) + # cuSOLVER's Remark 1: only nrhs=1 is supported. + if nrhs != 1 + throw(ArgumentError("cuSOLVER only supports vectors for B")) + end + lda = max(1, stride(A[1], 2)) + ldb = max(1, stride(B[1], 2)) + batchSize = length(A) + devinfo = CuArray{Cint}(undef, 1) + + Aptrs = unsafe_batch(A) + Bptrs = unsafe_batch(B) + + # Run the solver + $fname(dense_handle(), cuuplo, n, nrhs, Aptrs, lda, Bptrs, ldb, devinfo, batchSize) + + # Copy the solver info and delete the device memory + info = @allowscalar devinfo[1] + unsafe_free!(devinfo) + chklapackerror(BlasInt(info)) + + return B + end + end +end diff --git a/test/cusolver.jl b/test/cusolver.jl index af24a7fd02..b3d445f817 100644 --- a/test/cusolver.jl +++ b/test/cusolver.jl @@ -426,4 +426,50 @@ k = 1 @test Array(h_r) ≈ Array(r) end + @testset "potrsBatched!" begin + @testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64] + # Test lower + bA = [rand(elty, m, m) for i in 1:n] + bA = [bA[i]*bA[i]' for i in 1:n] + bB = [rand(elty, m) for i in 1:n] + + # move to device + bd_A = CuArray{elty, 2}[] + bd_B = CuArray{elty, 1}[] + for i in 1:length(bA) + push!(bd_A, CuArray(bA[i])) + push!(bd_B, CuArray(bB[i])) + end + + bd_X = CUSOLVER.potrsBatched!('L', bd_A, bd_B) + bh_X = [collect(bd_X[i]) for i in 1:n] + + for i = 1:n + LinearAlgebra.LAPACK.potrs!('L', bA[i], bB[i]) + @test bB[i] ≈ bh_X[i] + end + + # Test upper + bA = [rand(elty, m, m) for i in 1:n] + bA = [bA[i]*bA[i]' for i in 1:n] + bB = [rand(elty, m) for i in 1:n] + + # move to device + bd_A = CuArray{elty, 2}[] + bd_B = CuArray{elty, 1}[] + for i in 1:length(bA) + push!(bd_A, CuArray(bA[i])) + push!(bd_B, CuArray(bB[i])) + end + + bd_X = CUSOLVER.potrsBatched!('U', bd_A, bd_B) + bh_X = [collect(bd_X[i]) for i in 1:n] + + for i = 1:n + LinearAlgebra.LAPACK.potrs!('U', bA[i], bB[i]) + @test bB[i] ≈ bh_X[i] + end + end + end + end