Skip to content

Commit

Permalink
Try #185:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored May 29, 2020
2 parents 81e74cc + 91cd76f commit 1005058
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 52 deletions.
53 changes: 19 additions & 34 deletions lib/cusparse/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1830,6 +1830,7 @@ for (fname,elty) in ((:cusparseScsrgemm, :Float32),
(:cusparseCcsrgemm, :ComplexF32),
(:cusparseZcsrgemm, :ComplexF64))
@eval begin
# CSR GEMM
function gemm(transa::SparseChar,
transb::SparseChar,
A::CuSparseMatrixCSR{$elty},
Expand All @@ -1841,7 +1842,7 @@ for (fname,elty) in ((:cusparseScsrgemm, :Float32),
cutransb = cusparseop(transb)
cuinda = cusparseindex(indexA)
cuindb = cusparseindex(indexB)
cuindc = cusparseindex(indexB)
cuindc = cusparseindex(indexC)
cudesca = cusparseMatDescr(CUSPARSE_MATRIX_TYPE_GENERAL, CUSPARSE_FILL_MODE_LOWER, CUSPARSE_DIAG_TYPE_NON_UNIT, cuinda)
cudescb = cusparseMatDescr(CUSPARSE_MATRIX_TYPE_GENERAL, CUSPARSE_FILL_MODE_LOWER, CUSPARSE_DIAG_TYPE_NON_UNIT, cuindb)
cudescc = cusparseMatDescr(CUSPARSE_MATRIX_TYPE_GENERAL, CUSPARSE_FILL_MODE_LOWER, CUSPARSE_DIAG_TYPE_NON_UNIT, cuindc)
Expand All @@ -1865,57 +1866,41 @@ for (fname,elty) in ((:cusparseScsrgemm, :Float32),
C.rowPtr, C.colVal)
C
end
end
end

#CSC GEMM
for (fname,elty) in ((:cusparseScsrgemm, :Float32),
(:cusparseDcsrgemm, :Float64),
(:cusparseCcsrgemm, :ComplexF32),
(:cusparseZcsrgemm, :ComplexF64))
@eval begin
# CSC GEMM, these methods are implemented via the CUSPARSE CSR methods
function gemm(transa::SparseChar,
transb::SparseChar,
A::CuSparseMatrixCSC{$elty},
B::CuSparseMatrixCSC{$elty},
indexA::SparseChar,
indexB::SparseChar,
indexC::SparseChar)
ctransa = 'N'
if transa == 'N'
ctransa = 'T'
end
cutransa = cusparseop(ctransa)
ctransb = 'N'
if transb == 'N'
ctransb = 'T'
end
cutransb = cusparseop(ctransb)
cutransa = cusparseop(transa == 'N' ? 'T' : 'N')
cutransb = cusparseop(transb == 'N' ? 'T' : 'N')
cuinda = cusparseindex(indexA)
cuindb = cusparseindex(indexB)
cuindc = cusparseindex(indexB)
cuindc = cusparseindex(indexC)
cudesca = cusparseMatDescr(CUSPARSE_MATRIX_TYPE_GENERAL, CUSPARSE_FILL_MODE_LOWER, CUSPARSE_DIAG_TYPE_NON_UNIT, cuinda)
cudescb = cusparseMatDescr(CUSPARSE_MATRIX_TYPE_GENERAL, CUSPARSE_FILL_MODE_LOWER, CUSPARSE_DIAG_TYPE_NON_UNIT, cuindb)
cudescc = cusparseMatDescr(CUSPARSE_MATRIX_TYPE_GENERAL, CUSPARSE_FILL_MODE_LOWER, CUSPARSE_DIAG_TYPE_NON_UNIT, cuindc)
m,k = ctransa != 'N' ? A.dims : (A.dims[2],A.dims[1])
kB,n = ctransb != 'N' ? B.dims : (B.dims[2],B.dims[1])
m,k = transa == 'N' ? A.dims : (A.dims[2],A.dims[1])
kB,n = transb == 'N' ? B.dims : (B.dims[2],B.dims[1])
if k != kB
throw(DimensionMismatch("Interior dimension of A, $k, and B, $kB, must match"))
end
nnzC = Ref{Cint}(1)
colPtrC = CUDA.zeros(Cint,n + 1)
cusparseXcsrgemmNnz(handle(), cutransa, cutransb,
m, n, k, Ref(cudesca), A.nnz, A.colPtr, A.rowVal,
Ref(cudescb), B.nnz, B.colPtr, B.rowVal, Ref(cudescc),
colPtrC, nnzC)
rowPtrC = CUDA.zeros(Cint,m + 1)
cusparseXcsrgemmNnz(handle(), cutransa, cutransb, m, n, k,
Ref(cudesca), A.nnz, A.colPtr, A.rowVal,
Ref(cudescb), B.nnz, B.colPtr, B.rowVal,
Ref(cudescc), rowPtrC, nnzC)
nnz = nnzC[]
C = CuSparseMatrixCSC(colPtrC, CUDA.zeros(Cint,nnz), CUDA.zeros($elty,nnz), nnz, (m,n))
$fname(handle(), cutransa,
cutransb, m, n, k, Ref(cudesca), A.nnz, A.nzVal,
A.colPtr, A.rowVal, Ref(cudescb), B.nnz, B.nzVal,
B.colPtr, B.rowVal, Ref(cudescc), C.nzVal,
C.colPtr, C.rowVal)
C
C = CuSparseMatrixCSR(rowPtrC, CUDA.zeros(Cint,nnz), CUDA.zeros($elty,nnz), nnz, (m,n))
$fname(handle(), cutransa, cutransb, m, n, k,
Ref(cudesca), A.nnz, A.nzVal, A.colPtr, A.rowVal,
Ref(cudescb), B.nnz, B.nzVal, B.colPtr, B.rowVal,
Ref(cudescc), C.nzVal, C.rowPtr, C.colVal)
return C
end
end
end
Expand Down
40 changes: 22 additions & 18 deletions test/cusparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -923,33 +923,37 @@ end

@testset "gemm" begin
@testset for elty in [Float32,Float64,ComplexF32,ComplexF64]
A = sparse(rand(elty,m,k))
B = sparse(rand(elty,k,n))
C = A * B
d_A = CuSparseMatrixCSR(A)
d_B = CuSparseMatrixCSR(B)
# CSR
A = sparse(rand(elty,m,k)); d_A = CuSparseMatrixCSR(A)
B = sparse(rand(elty,k,n)); d_B = CuSparseMatrixCSR(B)
d_C = CUSPARSE.gemm('N','N',d_A,d_B,'O','O','O')
r_r = collect(d_C.rowPtr)
r_c = collect(d_C.colVal)
r_v = collect(d_C.nzVal)
h_C = collect(d_C)
@test C h_C
@test A * B collect(d_C)
#
@test_throws DimensionMismatch CUSPARSE.gemm('N','T',d_A,d_B,'O','O','O')
@test_throws DimensionMismatch CUSPARSE.gemm('T','T',d_A,d_B,'O','O','O')
@test_throws DimensionMismatch CUSPARSE.gemm('T','N',d_A,d_B,'O','O','O')
@test_throws DimensionMismatch CUSPARSE.gemm('N','N',d_B,d_A,'O','O','O')
#=A = sparse(rand(elty,m,k))
B = sparse(rand(elty,k,n))
d_A = CuSparseMatrixCSC(A)
d_B = CuSparseMatrixCSC(B)
C = A * B
#
A = sparse(rand(elty,m,k)); d_A = CuSparseMatrixCSR(A)
B = sparse(rand(elty,n,k)); d_B = CuSparseMatrixCSR(B)
d_C = CUSPARSE.gemm('N','T',d_A,d_B,'O','O','O')
@test A * transpose(B) collect(d_C)

# CSC
A = sparse(rand(elty,m,k)); d_A = CuSparseMatrixCSC(A)
B = sparse(rand(elty,k,n)); d_B = CuSparseMatrixCSC(B)
d_C = CUSPARSE.gemm('N','N',d_A,d_B,'O','O','O')
h_C = collect(d_C)
@test_approx_eq(C,h_C)
@test A * B collect(d_C)
#
@test_throws(DimensionMismatch,CUSPARSE.gemm('N','T',d_A,d_B,'O','O','O'))
@test_throws(DimensionMismatch,CUSPARSE.gemm('T','T',d_A,d_B,'O','O','O'))
@test_throws(DimensionMismatch,CUSPARSE.gemm('T','N',d_A,d_B,'O','O','O'))
@test_throws(DimensionMismatch,CUSPARSE.gemm('N','N',d_B,d_A,'O','O','O'))=#
@test_throws(DimensionMismatch,CUSPARSE.gemm('N','N',d_B,d_A,'O','O','O'))
#
A = sparse(rand(elty,m,k)); d_A = CuSparseMatrixCSC(A)
B = sparse(rand(elty,n,k)); d_B = CuSparseMatrixCSC(B)
d_C = CUSPARSE.gemm('N','T',d_A,d_B,'O','O','O')
@test A * transpose(B) collect(d_C)
end
end

Expand Down

0 comments on commit 1005058

Please sign in to comment.