diff --git a/src/solver/CUSOLVER.jl b/src/solver/CUSOLVER.jl index 6c2ccb32..a34e5eb4 100644 --- a/src/solver/CUSOLVER.jl +++ b/src/solver/CUSOLVER.jl @@ -4,7 +4,7 @@ import CUDAdrv: CUDAdrv, CuContext, CuStream_t, CuPtr, PtrOrCuPtr, CU_NULL import CUDAapi using ..CuArrays -using ..CuArrays: libcusolver, active_context, _getindex +using ..CuArrays: libcusolver, active_context, _getindex, unsafe_free! using LinearAlgebra using SparseArrays diff --git a/src/solver/dense.jl b/src/solver/dense.jl index abf92fb3..be575612 100644 --- a/src/solver/dense.jl +++ b/src/solver/dense.jl @@ -34,15 +34,19 @@ for (bname, fname,elty) in ((:cusolverDnSpotrf_bufferSize, :cusolverDnSpotrf, :F CuPtr{$elty}, Cint, Ptr{Cint}), dense_handle(), cuuplo, n, A, lda, bufSize) - buffer = CuArray{$elty}(undef, bufSize[]) + buffer = CuArray{$elty}(undef, bufSize[]) devinfo = CuArray{Cint}(undef, 1) @check ccall(($(string(fname)), libcusolver), cusolverStatus_t, (cusolverDnHandle_t, cublasFillMode_t, Cint, CuPtr{$elty}, Cint, CuPtr{$elty}, Cint, CuPtr{Cint}), dense_handle(), cuuplo, n, A, lda, buffer, bufSize[], devinfo) + unsafe_free!(buffer) + info = BlasInt(_getindex(devinfo, 1)) + unsafe_free!(devinfo) chkargsok(info) + A, info end end @@ -72,8 +76,11 @@ for (fname,elty) in ((:cusolverDnSpotrs, :Float32), CuPtr{$elty}, Cint, CuPtr{$elty}, Cint, CuPtr{Cint}), dense_handle(), cuuplo, n, nrhs, A, lda, B, ldb, devinfo) + info = _getindex(devinfo, 1) + unsafe_free!(devinfo) chkargsok(BlasInt(info)) + B end end @@ -102,12 +109,16 @@ for (bname, fname,elty) in ((:cusolverDnSgetrf_bufferSize, :cusolverDnSgetrf, :F Cint, CuPtr{$elty}, CuPtr{Cint}, CuPtr{Cint}), dense_handle(), m, n, A, lda, buffer, devipiv, devinfo) + unsafe_free!(buffer) + info = _getindex(devinfo, 1) + unsafe_free!(devinfo) if info < 0 throw(ArgumentError("The $(info)th parameter is wrong")) elseif info > 0 throw(LinearAlgebra.SingularException(info)) end + A, devipiv end end @@ -127,6 +138,7 @@ for (bname, fname,elty) in ((:cusolverDnSgeqrf_bufferSize, :cusolverDnSgeqrf, :F (cusolverDnHandle_t, Cint, Cint, CuPtr{$elty}, Cint, Ptr{Cint}), dense_handle(), m, n, A, lda, bufSize) + buffer = CuArray{$elty}(undef, bufSize[]) tau = CuArray{$elty}(undef, min(m, n)) devinfo = CuArray{Cint}(undef, 1) @@ -135,10 +147,14 @@ for (bname, fname,elty) in ((:cusolverDnSgeqrf_bufferSize, :cusolverDnSgeqrf, :F Cint, CuPtr{$elty}, CuPtr{$elty}, Cint, CuPtr{Cint}), dense_handle(), m, n, A, lda, tau, buffer, bufSize[], devinfo) + unsafe_free!(buffer) + info = _getindex(devinfo, 1) + unsafe_free!(devinfo) if info < 0 throw(ArgumentError("The $(info)th parameter is wrong")) end + A, tau end end @@ -169,12 +185,16 @@ for (bname, fname,elty) in ((:cusolverDnSsytrf_bufferSize, :cusolverDnSsytrf, :F CuPtr{$elty}, Cint, CuPtr{Cint}, CuPtr{$elty}, Cint, CuPtr{Cint}), dense_handle(), cuuplo, n, A, lda, devipiv, buffer, bufSize[], devinfo) + unsafe_free!(buffer) + info = _getindex(devinfo, 1) + unsafe_free!(devinfo) if info < 0 throw(ArgumentError("The $(info)th parameter is wrong")) elseif info > 0 throw(LinearAlgebra.SingularException(info)) end + A, devipiv end end @@ -208,7 +228,9 @@ for (fname,elty) in ((:cusolverDnSgetrs, :Float32), CuPtr{$elty}, Cint, CuPtr{Cint}, CuPtr{$elty}, Cint, CuPtr{Cint}), dense_handle(), cutrans, n, nrhs, A, lda, ipiv, B, ldb, devinfo) + info = _getindex(devinfo, 1) + unsafe_free!(devinfo) if info < 0 throw(ArgumentError("The $(info)th parameter is wrong")) end @@ -246,6 +268,7 @@ for (bname, fname, elty) in ((:cusolverDnSormqr_bufferSize, :cusolverDnSormqr, : lda = n end k = length(tau) + bufSize = Ref{Cint}(0) @check ccall(($(string(bname)),libcusolver), cusolverStatus_t, (cusolverDnHandle_t, cublasSideMode_t, @@ -254,6 +277,7 @@ for (bname, fname, elty) in ((:cusolverDnSormqr_bufferSize, :cusolverDnSormqr, : dense_handle(), cuside, cutrans, m, n, k, A, lda, tau, C, ldc, bufSize) + buffer = CuArray{$elty}(undef, bufSize[]) devinfo = CuArray{Cint}(undef, 1) @check ccall(($(string(fname)),libcusolver), cusolverStatus_t, @@ -264,10 +288,14 @@ for (bname, fname, elty) in ((:cusolverDnSormqr_bufferSize, :cusolverDnSormqr, : dense_handle(), cuside, cutrans, m, n, k, A, lda, tau, C, ldc, buffer, bufSize[], devinfo) + unsafe_free!(buffer) + info = _getindex(devinfo, 1) + unsafe_free!(devinfo) if info < 0 throw(ArgumentError("The $(info)th parameter is wrong")) end + side == 'L' ? C : C[:, 1:minimum(size(A))] end end @@ -284,11 +312,13 @@ for (bname, fname, elty) in ((:cusolverDnSorgqr_bufferSize, :cusolverDnSorgqr, : n = min(m, size(A, 2)) lda = max(1, stride(A, 2)) k = length(tau) + bufSize = Ref{Cint}(0) @check ccall(($(string(bname)), libcusolver), cusolverStatus_t, (cusolverDnHandle_t, Cint, Cint, Cint, CuPtr{$elty}, Cint, CuPtr{$elty}, Ptr{Cint}), dense_handle(), m, n, k, A, lda, tau, bufSize) + buffer = CuArray{$elty}(undef, bufSize[]) devinfo = CuArray{Cint}(undef, 1) @check ccall(($(string(fname)), libcusolver), cusolverStatus_t, @@ -296,10 +326,14 @@ for (bname, fname, elty) in ((:cusolverDnSorgqr_bufferSize, :cusolverDnSorgqr, : Cint, CuPtr{$elty}, CuPtr{$elty}, Cint, CuPtr{Cint}), dense_handle(), m, n, k, A, lda, tau, buffer, bufSize[], devinfo) + unsafe_free!(buffer) + info = _getindex(devinfo, 1) + unsafe_free!(devinfo) if info < 0 throw(ArgumentError("The $(info)th parameter is wrong")) end + if n < size(A, 2) A[:, 1:n] else @@ -318,10 +352,12 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgebrd_bufferSize, :cusolverDnSg function gebrd!(A::CuMatrix{$elty}) m, n = size(A) lda = max(1, stride(A, 2)) + bufSize = Ref{Cint}(0) @check ccall(($(string(bname)), libcusolver), cusolverStatus_t, (cusolverDnHandle_t, Cint, Cint, Ptr{Cint}), dense_handle(), m, n, bufSize) + buffer = CuArray{$elty}(undef, bufSize[]) devinfo = CuArray{Cint}(undef, 1) k = min(m, n) @@ -335,10 +371,14 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgebrd_bufferSize, :cusolverDnSg CuPtr{$elty}, CuPtr{$elty}, Cint, CuPtr{Cint}), dense_handle(), m, n, A, lda, D, E, TAUQ, TAUP, buffer, bufSize[], devinfo) + unsafe_free!(buffer) + info = _getindex(devinfo, 1) + unsafe_free!(devinfo) if info < 0 throw(ArgumentError("The $(info)th parameter is wrong")) end + A, D, E, TAUQ, TAUP end end @@ -357,15 +397,12 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvd_bufferSize, :cusolverDnSg throw(ArgumentError("CUSOLVER's gesvd currently requires m >= n")) end lda = max(1, stride(A, 2)) + lwork = Ref{Cint}(0) @check ccall(($(string(bname)), libcusolver), cusolverStatus_t, (cusolverDnHandle_t, Cint, Cint, Ptr{Cint}), dense_handle(), m, n, lwork) - work = CuArray{$elty}(undef, lwork[]) - rwork = CuArray{$relty}(undef, min(m, n) - 1) - devinfo = CuArray{Cint}(undef, 1) - if jobu === 'S' && m > n U = CuArray{$elty}(undef, m, n) elseif jobu === 'N' @@ -386,6 +423,9 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvd_bufferSize, :cusolverDnSg end ldvt = max(1, stride(Vt, 2)) + work = CuArray{$elty}(undef, lwork[]) + rwork = CuArray{$relty}(undef, min(m, n) - 1) + devinfo = CuArray{Cint}(undef, 1) @check ccall(($(string(fname)), libcusolver), cusolverStatus_t, (cusolverDnHandle_t, Cuchar, Cuchar, Cint, Cint, CuPtr{$elty}, Cint, CuPtr{$relty}, @@ -395,7 +435,11 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvd_bufferSize, :cusolverDnSg n, A, lda, S, U, ldu, Vt, ldvt, work, lwork[], rwork, devinfo) + unsafe_free!(work) + unsafe_free!(rwork) + info = _getindex(devinfo, 1) + unsafe_free!(devinfo) if info < 0 throw(ArgumentError("The $(info)th parameter is wrong")) end @@ -437,12 +481,12 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdj_bufferSize, :cusolverDnS end ldv = max(1, stride(V, 2)) - lwork = Ref{Cint}(0) params = Ref{gesvdjInfo_t}(C_NULL) cusolverDnCreateGesvdjInfo(params) cusolverDnXgesvdjSetTolerance(params[], tol) cusolverDnXgesvdjSetMaxSweeps(params[], max_sweeps) + lwork = Ref{Cint}(0) @check ccall(($(string(bname)), libcusolver), cusolverStatus_t, (cusolverDnHandle_t, cusolverEigMode_t, Cint, Cint, Cint, CuPtr{$elty}, Cint, CuPtr{$relty}, @@ -455,7 +499,6 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdj_bufferSize, :cusolverDnS work = CuArray{$elty}(undef, lwork[]) devinfo = CuArray{Cint}(undef, 1) - @check ccall(($(string(fname)), libcusolver), cusolverStatus_t, (cusolverDnHandle_t, cusolverEigMode_t, Cint, Cint, Cint, CuPtr{$elty}, Cint, CuPtr{$relty}, @@ -465,11 +508,16 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdj_bufferSize, :cusolverDnS n, A, lda, S, U, ldu, V, ldv, work, lwork[], devinfo, params[]) + unsafe_free!(work) + info = _getindex(devinfo, 1) + unsafe_free!(devinfo) if info < 0 throw(ArgumentError("The $(info)th parameter is wrong")) end + cusolverDnDestroyGesvdjInfo(params[]) + U, S, V end end @@ -488,6 +536,7 @@ for (jname, bname, fname, elty, relty) in ((:syevd!, :cusolverDnSsyevd_bufferSiz n = checksquare(A) lda = max(1, stride(A, 2)) W = CuArray{$relty}(undef, n) + bufSize = Ref{Cint}(0) @check ccall(($(string(bname)), libcusolver), cusolverStatus_t, (cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, @@ -501,10 +550,14 @@ for (jname, bname, fname, elty, relty) in ((:syevd!, :cusolverDnSsyevd_bufferSiz Cint, CuPtr{$elty}, Cint, CuPtr{$relty}, CuPtr{$elty}, Cint, CuPtr{Cint}), dense_handle(), cujobz, cuuplo, n, A, lda, W, buffer, bufSize[], devinfo) + unsafe_free!(buffer) + info = _getindex(devinfo, 1) + unsafe_free!(devinfo) if info < 0 throw(ArgumentError("The $(info)th parameter is wrong")) end + if jobz == 'N' return W elseif jobz == 'V' @@ -534,6 +587,7 @@ for (jname, bname, fname, elty, relty) in ((:sygvd!, :cusolverDnSsygvd_bufferSiz lda = max(1, stride(A, 2)) ldb = max(1, stride(B, 2)) W = CuArray{$relty}(undef, n) + bufSize = Ref{Cint}(0) cuitype = cusolverEigType_t(itype) @check ccall(($(string(bname)), libcusolver), cusolverStatus_t, @@ -550,10 +604,14 @@ for (jname, bname, fname, elty, relty) in ((:sygvd!, :cusolverDnSsygvd_bufferSiz CuPtr{$elty}, Cint, CuPtr{Cint}), dense_handle(), cuitype, cujobz, cuuplo, n, A, lda, B, ldb, W, buffer, bufSize[], devinfo) + unsafe_free!(buffer) + info = _getindex(devinfo, 1) + unsafe_free!(devinfo) if info < 0 throw(ArgumentError("The $(info)th parameter is wrong")) end + if jobz == 'N' return W elseif jobz == 'V' @@ -585,17 +643,19 @@ for (jname, bname, fname, elty, relty) in ((:sygvj!, :cusolverDnSsygvj_bufferSiz lda = max(1, stride(A, 2)) ldb = max(1, stride(B, 2)) W = CuArray{$relty}(undef, n) - bufSize = Ref{Cint}(0) params = Ref{syevjInfo_t}(C_NULL) cusolverDnCreateSyevjInfo(params) cusolverDnXsyevjSetTolerance(params[], tol) cusolverDnXsyevjSetMaxSweeps(params[], max_sweeps) + + bufSize = Ref{Cint}(0) @check ccall(($(string(bname)), libcusolver), cusolverStatus_t, (cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, cublasFillMode_t, Cint, CuPtr{$elty}, Cint, CuPtr{$elty}, Cint, CuPtr{$relty}, Ptr{Cint}, syevjInfo_t), dense_handle(), Cint(itype), cujobz, cuuplo, n, A, lda, B, ldb, W, bufSize, params[]) + buffer = CuArray{$elty}(undef, bufSize[]) devinfo = CuArray{Cint}(undef, 1) @check ccall(($(string(fname)), libcusolver), cusolverStatus_t, @@ -605,11 +665,16 @@ for (jname, bname, fname, elty, relty) in ((:sygvj!, :cusolverDnSsygvj_bufferSiz syevjInfo_t), dense_handle(), Cint(itype), cujobz, cuuplo, n, A, lda, B, ldb, W, buffer, bufSize[], devinfo, params[]) + unsafe_free!(buffer) + info = _getindex(devinfo, 1) + unsafe_free!(devinfo) if info < 0 throw(ArgumentError("The $(info)th parameter is wrong")) end + cusolverDnDestroySyevjInfo(params[]) + if jobz == 'N' return W elseif jobz == 'V'