Skip to content

Commit

Permalink
Reintroduce chkuplo
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Mar 16, 2023
1 parent 523682f commit 9b7aff3
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions lib/cusolver/dense.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

using LinearAlgebra
using LinearAlgebra: BlasInt, checksquare
using LinearAlgebra.BLAS: chkuplo
using LinearAlgebra.LAPACK: chkargsok, chklapackerror, chktrans, chkside, chkdiag

using ..CUBLAS: unsafe_batch
Expand All @@ -20,6 +21,7 @@ for (bname, fname,elty) in ((:cusolverDnSpotrf_bufferSize, :cusolverDnSpotrf, :F
@eval begin
function potrf!(uplo::Char,
A::StridedCuMatrix{$elty})
VERSION >= v"1.8" && chkuplo(uplo)
n = checksquare(A)
lda = max(1, stride(A, 2))

Expand Down Expand Up @@ -52,6 +54,7 @@ for (fname,elty) in ((:cusolverDnSpotrs, :Float32),
function potrs!(uplo::Char,
A::StridedCuMatrix{$elty},
B::StridedCuVecOrMat{$elty})
VERSION >= v"1.8" && chkuplo(uplo)
n = checksquare(A)
if size(B, 1) != n
throw(DimensionMismatch("first dimension of B, $(size(B,1)), must match second dimension of A, $n"))
Expand Down Expand Up @@ -79,6 +82,7 @@ for (bname, fname,elty) in ((:cusolverDnSpotri_bufferSize, :cusolverDnSpotri, :F
@eval begin
function potri!(uplo::Char,
A::StridedCuMatrix{$elty})
VERSION >= v"1.8" && chkuplo(uplo)
n = checksquare(A)
lda = max(1, stride(A, 2))

Expand Down Expand Up @@ -172,6 +176,7 @@ for (bname, fname,elty) in ((:cusolverDnSsytrf_bufferSize, :cusolverDnSsytrf, :F
@eval begin
function sytrf!(uplo::Char,
A::StridedCuMatrix{$elty})
VERSION >= v"1.8" && chkuplo(uplo)
n = checksquare(A)
lda = max(1, stride(A, 2))

Expand Down Expand Up @@ -487,6 +492,7 @@ for (jname, bname, fname, elty, relty) in ((:syevd!, :cusolverDnSsyevd_bufferSiz
function $jname(jobz::Char,
uplo::Char,
A::StridedCuMatrix{$elty})
VERSION >= v"1.8" && chkuplo(uplo)
n = checksquare(A)
lda = max(1, stride(A, 2))
W = CuArray{$relty}(undef, n)
Expand Down Expand Up @@ -526,6 +532,7 @@ for (jname, bname, fname, elty, relty) in ((:sygvd!, :cusolverDnSsygvd_bufferSiz
uplo::Char,
A::StridedCuMatrix{$elty},
B::StridedCuMatrix{$elty})
VERSION >= v"1.8" && chkuplo(uplo)
nA, nB = checksquare(A, B)
if nB != nA
throw(DimensionMismatch("Dimensions of A ($nA, $nA) and B ($nB, $nB) must match!"))
Expand Down Expand Up @@ -572,6 +579,7 @@ for (jname, bname, fname, elty, relty) in ((:sygvj!, :cusolverDnSsygvj_bufferSiz
B::StridedCuMatrix{$elty};
tol::$relty=eps($relty),
max_sweeps::Int=100)
VERSION >= v"1.8" && chkuplo(uplo)
nA, nB = checksquare(A, B)
if nB != nA
throw(DimensionMismatch("Dimensions of A ($nA, $nA) and B ($nB, $nB) must match!"))
Expand Down Expand Up @@ -625,6 +633,7 @@ for (jname, bname, fname, elty, relty) in ((:syevjBatched!, :cusolverDnSsyevjBat
max_sweeps::Int=100)

# Set up information for the solver arguments
VERSION >= v"1.8" && chkuplo(uplo)
n = checksquare(A)
lda = max(1, stride(A, 2))
batchSize = size(A,3)
Expand Down Expand Up @@ -681,6 +690,7 @@ for (fname, elty) in ((:cusolverDnSpotrsBatched, :Float32),
throw(DimensionMismatch(""))
end
# Set up information for the solver arguments
VERSION >= v"1.8" && chkuplo(uplo)
n = checksquare(A[1])
if size(B[1], 1) != n
throw(DimensionMismatch("first dimension of B[i], $(size(B[1],1)), must match second dimension of A, $n"))
Expand Down Expand Up @@ -719,6 +729,7 @@ for (fname, elty) in ((:cusolverDnSpotrfBatched, :Float32),
function potrfBatched!(uplo::Char, A::Vector{<:StridedCuMatrix{$elty}})

# Set up information for the solver arguments
VERSION >= v"1.8" && chkuplo(uplo)
n = checksquare(A[1])
lda = max(1, stride(A[1], 2))
batchSize = length(A)
Expand Down

0 comments on commit 9b7aff3

Please sign in to comment.