Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Commit

Permalink
Add early frees to CUSOLVER wrappers.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Feb 7, 2019
1 parent d40b1ba commit b0b4db9
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/solver/CUSOLVER.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import CUDAdrv: CUDAdrv, CuContext, CuStream_t, CuPtr, PtrOrCuPtr, CU_NULL
import CUDAapi

using ..CuArrays
using ..CuArrays: libcusolver, active_context, _getindex
using ..CuArrays: libcusolver, active_context, _getindex, unsafe_free!

using LinearAlgebra
using SparseArrays
Expand Down
81 changes: 73 additions & 8 deletions src/solver/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,19 @@ for (bname, fname,elty) in ((:cusolverDnSpotrf_bufferSize, :cusolverDnSpotrf, :F
CuPtr{$elty}, Cint, Ptr{Cint}),
dense_handle(), cuuplo, n, A, lda, bufSize)

buffer = CuArray{$elty}(undef, bufSize[])
buffer = CuArray{$elty}(undef, bufSize[])
devinfo = CuArray{Cint}(undef, 1)
@check ccall(($(string(fname)), libcusolver), cusolverStatus_t,
(cusolverDnHandle_t, cublasFillMode_t, Cint,
CuPtr{$elty}, Cint, CuPtr{$elty}, Cint, CuPtr{Cint}),
dense_handle(), cuuplo, n, A, lda, buffer,
bufSize[], devinfo)
unsafe_free!(buffer)

info = BlasInt(_getindex(devinfo, 1))
unsafe_free!(devinfo)
chkargsok(info)

A, info
end
end
Expand Down Expand Up @@ -72,8 +76,11 @@ for (fname,elty) in ((:cusolverDnSpotrs, :Float32),
CuPtr{$elty}, Cint, CuPtr{$elty}, Cint, CuPtr{Cint}),
dense_handle(), cuuplo, n, nrhs, A, lda, B,
ldb, devinfo)

info = _getindex(devinfo, 1)
unsafe_free!(devinfo)
chkargsok(BlasInt(info))

B
end
end
Expand Down Expand Up @@ -102,12 +109,16 @@ for (bname, fname,elty) in ((:cusolverDnSgetrf_bufferSize, :cusolverDnSgetrf, :F
Cint, CuPtr{$elty}, CuPtr{Cint}, CuPtr{Cint}),
dense_handle(), m, n, A, lda, buffer,
devipiv, devinfo)
unsafe_free!(buffer)

info = _getindex(devinfo, 1)
unsafe_free!(devinfo)
if info < 0
throw(ArgumentError("The $(info)th parameter is wrong"))
elseif info > 0
throw(LinearAlgebra.SingularException(info))
end

A, devipiv
end
end
Expand All @@ -127,6 +138,7 @@ for (bname, fname,elty) in ((:cusolverDnSgeqrf_bufferSize, :cusolverDnSgeqrf, :F
(cusolverDnHandle_t, Cint, Cint, CuPtr{$elty}, Cint,
Ptr{Cint}), dense_handle(), m, n, A,
lda, bufSize)

buffer = CuArray{$elty}(undef, bufSize[])
tau = CuArray{$elty}(undef, min(m, n))
devinfo = CuArray{Cint}(undef, 1)
Expand All @@ -135,10 +147,14 @@ for (bname, fname,elty) in ((:cusolverDnSgeqrf_bufferSize, :cusolverDnSgeqrf, :F
Cint, CuPtr{$elty}, CuPtr{$elty}, Cint, CuPtr{Cint}),
dense_handle(), m, n, A, lda, tau, buffer,
bufSize[], devinfo)
unsafe_free!(buffer)

info = _getindex(devinfo, 1)
unsafe_free!(devinfo)
if info < 0
throw(ArgumentError("The $(info)th parameter is wrong"))
end

A, tau
end
end
Expand Down Expand Up @@ -169,12 +185,16 @@ for (bname, fname,elty) in ((:cusolverDnSsytrf_bufferSize, :cusolverDnSsytrf, :F
CuPtr{$elty}, Cint, CuPtr{Cint}, CuPtr{$elty}, Cint,
CuPtr{Cint}), dense_handle(), cuuplo, n, A,
lda, devipiv, buffer, bufSize[], devinfo)
unsafe_free!(buffer)

info = _getindex(devinfo, 1)
unsafe_free!(devinfo)
if info < 0
throw(ArgumentError("The $(info)th parameter is wrong"))
elseif info > 0
throw(LinearAlgebra.SingularException(info))
end

A, devipiv
end
end
Expand Down Expand Up @@ -208,7 +228,9 @@ for (fname,elty) in ((:cusolverDnSgetrs, :Float32),
CuPtr{$elty}, Cint, CuPtr{Cint}, CuPtr{$elty}, Cint,
CuPtr{Cint}), dense_handle(), cutrans, n, nrhs,
A, lda, ipiv, B, ldb, devinfo)

info = _getindex(devinfo, 1)
unsafe_free!(devinfo)
if info < 0
throw(ArgumentError("The $(info)th parameter is wrong"))
end
Expand Down Expand Up @@ -246,6 +268,7 @@ for (bname, fname, elty) in ((:cusolverDnSormqr_bufferSize, :cusolverDnSormqr, :
lda = n
end
k = length(tau)

bufSize = Ref{Cint}(0)
@check ccall(($(string(bname)),libcusolver), cusolverStatus_t,
(cusolverDnHandle_t, cublasSideMode_t,
Expand All @@ -254,6 +277,7 @@ for (bname, fname, elty) in ((:cusolverDnSormqr_bufferSize, :cusolverDnSormqr, :
dense_handle(), cuside,
cutrans, m, n, k, A,
lda, tau, C, ldc, bufSize)

buffer = CuArray{$elty}(undef, bufSize[])
devinfo = CuArray{Cint}(undef, 1)
@check ccall(($(string(fname)),libcusolver), cusolverStatus_t,
Expand All @@ -264,10 +288,14 @@ for (bname, fname, elty) in ((:cusolverDnSormqr_bufferSize, :cusolverDnSormqr, :
dense_handle(), cuside,
cutrans, m, n, k, A, lda, tau, C, ldc, buffer,
bufSize[], devinfo)
unsafe_free!(buffer)

info = _getindex(devinfo, 1)
unsafe_free!(devinfo)
if info < 0
throw(ArgumentError("The $(info)th parameter is wrong"))
end

side == 'L' ? C : C[:, 1:minimum(size(A))]
end
end
Expand All @@ -284,22 +312,28 @@ for (bname, fname, elty) in ((:cusolverDnSorgqr_bufferSize, :cusolverDnSorgqr, :
n = min(m, size(A, 2))
lda = max(1, stride(A, 2))
k = length(tau)

bufSize = Ref{Cint}(0)
@check ccall(($(string(bname)), libcusolver), cusolverStatus_t,
(cusolverDnHandle_t, Cint, Cint, Cint, CuPtr{$elty}, Cint,
CuPtr{$elty}, Ptr{Cint}),
dense_handle(), m, n, k, A, lda, tau, bufSize)

buffer = CuArray{$elty}(undef, bufSize[])
devinfo = CuArray{Cint}(undef, 1)
@check ccall(($(string(fname)), libcusolver), cusolverStatus_t,
(cusolverDnHandle_t, Cint, Cint, Cint, CuPtr{$elty},
Cint, CuPtr{$elty}, CuPtr{$elty}, Cint, CuPtr{Cint}),
dense_handle(), m, n, k, A,
lda, tau, buffer, bufSize[], devinfo)
unsafe_free!(buffer)

info = _getindex(devinfo, 1)
unsafe_free!(devinfo)
if info < 0
throw(ArgumentError("The $(info)th parameter is wrong"))
end

if n < size(A, 2)
A[:, 1:n]
else
Expand All @@ -318,10 +352,12 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgebrd_bufferSize, :cusolverDnSg
function gebrd!(A::CuMatrix{$elty})
m, n = size(A)
lda = max(1, stride(A, 2))

bufSize = Ref{Cint}(0)
@check ccall(($(string(bname)), libcusolver), cusolverStatus_t,
(cusolverDnHandle_t, Cint, Cint, Ptr{Cint}),
dense_handle(), m, n, bufSize)

buffer = CuArray{$elty}(undef, bufSize[])
devinfo = CuArray{Cint}(undef, 1)
k = min(m, n)
Expand All @@ -335,10 +371,14 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgebrd_bufferSize, :cusolverDnSg
CuPtr{$elty}, CuPtr{$elty}, Cint, CuPtr{Cint}),
dense_handle(), m, n, A, lda, D, E, TAUQ,
TAUP, buffer, bufSize[], devinfo)
unsafe_free!(buffer)

info = _getindex(devinfo, 1)
unsafe_free!(devinfo)
if info < 0
throw(ArgumentError("The $(info)th parameter is wrong"))
end

A, D, E, TAUQ, TAUP
end
end
Expand All @@ -357,15 +397,12 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvd_bufferSize, :cusolverDnSg
throw(ArgumentError("CUSOLVER's gesvd currently requires m >= n"))
end
lda = max(1, stride(A, 2))

lwork = Ref{Cint}(0)
@check ccall(($(string(bname)), libcusolver), cusolverStatus_t,
(cusolverDnHandle_t, Cint, Cint, Ptr{Cint}),
dense_handle(), m, n, lwork)

work = CuArray{$elty}(undef, lwork[])
rwork = CuArray{$relty}(undef, min(m, n) - 1)
devinfo = CuArray{Cint}(undef, 1)

if jobu === 'S' && m > n
U = CuArray{$elty}(undef, m, n)
elseif jobu === 'N'
Expand All @@ -386,6 +423,9 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvd_bufferSize, :cusolverDnSg
end
ldvt = max(1, stride(Vt, 2))

work = CuArray{$elty}(undef, lwork[])
rwork = CuArray{$relty}(undef, min(m, n) - 1)
devinfo = CuArray{Cint}(undef, 1)
@check ccall(($(string(fname)), libcusolver), cusolverStatus_t,
(cusolverDnHandle_t, Cuchar, Cuchar, Cint,
Cint, CuPtr{$elty}, Cint, CuPtr{$relty},
Expand All @@ -395,7 +435,11 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvd_bufferSize, :cusolverDnSg
n, A, lda, S,
U, ldu, Vt, ldvt,
work, lwork[], rwork, devinfo)
unsafe_free!(work)
unsafe_free!(rwork)

info = _getindex(devinfo, 1)
unsafe_free!(devinfo)
if info < 0
throw(ArgumentError("The $(info)th parameter is wrong"))
end
Expand Down Expand Up @@ -437,12 +481,12 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdj_bufferSize, :cusolverDnS
end
ldv = max(1, stride(V, 2))

lwork = Ref{Cint}(0)
params = Ref{gesvdjInfo_t}(C_NULL)
cusolverDnCreateGesvdjInfo(params)
cusolverDnXgesvdjSetTolerance(params[], tol)
cusolverDnXgesvdjSetMaxSweeps(params[], max_sweeps)

lwork = Ref{Cint}(0)
@check ccall(($(string(bname)), libcusolver), cusolverStatus_t,
(cusolverDnHandle_t, cusolverEigMode_t, Cint, Cint,
Cint, CuPtr{$elty}, Cint, CuPtr{$relty},
Expand All @@ -455,7 +499,6 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdj_bufferSize, :cusolverDnS

work = CuArray{$elty}(undef, lwork[])
devinfo = CuArray{Cint}(undef, 1)

@check ccall(($(string(fname)), libcusolver), cusolverStatus_t,
(cusolverDnHandle_t, cusolverEigMode_t, Cint, Cint,
Cint, CuPtr{$elty}, Cint, CuPtr{$relty},
Expand All @@ -465,11 +508,16 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdj_bufferSize, :cusolverDnS
n, A, lda, S,
U, ldu, V, ldv,
work, lwork[], devinfo, params[])
unsafe_free!(work)

info = _getindex(devinfo, 1)
unsafe_free!(devinfo)
if info < 0
throw(ArgumentError("The $(info)th parameter is wrong"))
end

cusolverDnDestroyGesvdjInfo(params[])

U, S, V
end
end
Expand All @@ -488,6 +536,7 @@ for (jname, bname, fname, elty, relty) in ((:syevd!, :cusolverDnSsyevd_bufferSiz
n = checksquare(A)
lda = max(1, stride(A, 2))
W = CuArray{$relty}(undef, n)

bufSize = Ref{Cint}(0)
@check ccall(($(string(bname)), libcusolver), cusolverStatus_t,
(cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t,
Expand All @@ -501,10 +550,14 @@ for (jname, bname, fname, elty, relty) in ((:syevd!, :cusolverDnSsyevd_bufferSiz
Cint, CuPtr{$elty}, Cint, CuPtr{$relty}, CuPtr{$elty},
Cint, CuPtr{Cint}), dense_handle(), cujobz, cuuplo,
n, A, lda, W, buffer, bufSize[], devinfo)
unsafe_free!(buffer)

info = _getindex(devinfo, 1)
unsafe_free!(devinfo)
if info < 0
throw(ArgumentError("The $(info)th parameter is wrong"))
end

if jobz == 'N'
return W
elseif jobz == 'V'
Expand Down Expand Up @@ -534,6 +587,7 @@ for (jname, bname, fname, elty, relty) in ((:sygvd!, :cusolverDnSsygvd_bufferSiz
lda = max(1, stride(A, 2))
ldb = max(1, stride(B, 2))
W = CuArray{$relty}(undef, n)

bufSize = Ref{Cint}(0)
cuitype = cusolverEigType_t(itype)
@check ccall(($(string(bname)), libcusolver), cusolverStatus_t,
Expand All @@ -550,10 +604,14 @@ for (jname, bname, fname, elty, relty) in ((:sygvd!, :cusolverDnSsygvd_bufferSiz
CuPtr{$elty}, Cint, CuPtr{Cint}), dense_handle(), cuitype,
cujobz, cuuplo, n, A, lda, B, ldb, W, buffer,
bufSize[], devinfo)
unsafe_free!(buffer)

info = _getindex(devinfo, 1)
unsafe_free!(devinfo)
if info < 0
throw(ArgumentError("The $(info)th parameter is wrong"))
end

if jobz == 'N'
return W
elseif jobz == 'V'
Expand Down Expand Up @@ -585,17 +643,19 @@ for (jname, bname, fname, elty, relty) in ((:sygvj!, :cusolverDnSsygvj_bufferSiz
lda = max(1, stride(A, 2))
ldb = max(1, stride(B, 2))
W = CuArray{$relty}(undef, n)
bufSize = Ref{Cint}(0)
params = Ref{syevjInfo_t}(C_NULL)
cusolverDnCreateSyevjInfo(params)
cusolverDnXsyevjSetTolerance(params[], tol)
cusolverDnXsyevjSetMaxSweeps(params[], max_sweeps)

bufSize = Ref{Cint}(0)
@check ccall(($(string(bname)), libcusolver), cusolverStatus_t,
(cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t,
cublasFillMode_t, Cint, CuPtr{$elty}, Cint, CuPtr{$elty},
Cint, CuPtr{$relty}, Ptr{Cint}, syevjInfo_t),
dense_handle(), Cint(itype), cujobz, cuuplo, n, A, lda, B,
ldb, W, bufSize, params[])

buffer = CuArray{$elty}(undef, bufSize[])
devinfo = CuArray{Cint}(undef, 1)
@check ccall(($(string(fname)), libcusolver), cusolverStatus_t,
Expand All @@ -605,11 +665,16 @@ for (jname, bname, fname, elty, relty) in ((:sygvj!, :cusolverDnSsygvj_bufferSiz
syevjInfo_t), dense_handle(), Cint(itype), cujobz,
cuuplo, n, A, lda, B, ldb, W, buffer, bufSize[],
devinfo, params[])
unsafe_free!(buffer)

info = _getindex(devinfo, 1)
unsafe_free!(devinfo)
if info < 0
throw(ArgumentError("The $(info)th parameter is wrong"))
end

cusolverDnDestroySyevjInfo(params[])

if jobz == 'N'
return W
elseif jobz == 'V'
Expand Down

0 comments on commit b0b4db9

Please sign in to comment.