Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Implement a few sparse array basics #572

Merged
merged 4 commits into from
Apr 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions src/sparse/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ export CuSparseMatrixCSC, CuSparseMatrixCSR,
CuSparseVector

import Base: length, size, ndims, eltype, similar, pointer, stride,
copy, convert, reinterpret, show, summary, copyto!, get!, fill!, collect
copy, convert, reinterpret, show, summary, copyto!, getindex, get!, fill!, collect

using LinearAlgebra
import LinearAlgebra: BlasFloat, Hermitian, HermOrSym, issymmetric, Transpose, Adjoint,
ishermitian, istriu, istril, Symmetric, UpperTriangular, LowerTriangular

using SparseArrays
import SparseArrays: sparse, SparseMatrixCSC
import SparseArrays: sparse, SparseMatrixCSC, nnz, nonzeros, nonzeroinds,
_spgetindex

abstract type AbstractCuSparseArray{Tv, N} <: AbstractSparseArray{Tv, Cint, N} end
const AbstractCuSparseVector{Tv} = AbstractCuSparseArray{Tv,1}
Expand Down Expand Up @@ -166,6 +167,11 @@ function size(g::CuSparseMatrix, d::Integer)
end
end

nnz(g::AbstractCuSparseArray) = g.nnz
nonzeros(g::AbstractCuSparseArray) = g.nzVal

nonzeroinds(g::AbstractCuSparseVector) = g.iPtr

issymmetric(M::Union{CuSparseMatrixCSC,CuSparseMatrixCSR}) = false
ishermitian(M::Union{CuSparseMatrixCSC,CuSparseMatrixCSR}) = false
issymmetric(M::Symmetric{CuSparseMatrixCSC}) = true
Expand All @@ -177,6 +183,67 @@ istriu(M::LowerTriangular{T,S}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix}
istril(M::LowerTriangular{T,S}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix} = true
eltype(g::CuSparseMatrix{T}) where T = T

# getindex (mostly adapted from stdlib/SparseArrays)

# Translations
getindex(A::AbstractCuSparseVector, ::Colon) = copy(A)
getindex(A::AbstractCuSparseMatrix, ::Colon, ::Colon) = copy(A)
getindex(A::AbstractCuSparseMatrix, i, ::Colon) = getindex(A, i, 1:size(A, 2))
getindex(A::AbstractCuSparseMatrix, ::Colon, i) = getindex(A, 1:size(A, 1), i)
getindex(A::AbstractCuSparseMatrix, I::Tuple{Integer,Integer}) = getindex(A, I[1], I[2])

# Column slices
function getindex(x::CuSparseMatrixCSC, ::Colon, j::Integer)
checkbounds(x, :, j)
r1 = convert(Int, x.colPtr[j])
r2 = convert(Int, x.colPtr[j+1]) - 1
CuSparseVector(x.rowVal[r1:r2], x.nzVal[r1:r2], size(x, 1))
end

function getindex(x::CuSparseMatrixCSR, i::Integer, ::Colon)
checkbounds(x, :, i)
c1 = convert(Int, x.rowPtr[i])
c2 = convert(Int, x.rowPtr[i+1]) - 1
CuSparseVector(x.colVal[c1:c2], x.nzVal[c1:c2], size(x, 2))
end

# Row slices
# TODO optimize
getindex(A::CuSparseMatrixCSC, i::Integer, ::Colon) = CuSparseVector(sparse(A[i, 1:end]))
# TODO optimize
getindex(A::CuSparseMatrixCSR, ::Colon, j::Integer) = CuSparseVector(sparse(A[1:end, j]))

function getindex(A::CuSparseMatrixCSC{T}, i0::Integer, i1::Integer) where T
m, n = size(A)
if !(1 <= i0 <= m && 1 <= i1 <= n)
throw(BoundsError())
end
r1 = Int(A.colPtr[i1])
r2 = Int(A.colPtr[i1+1]-1)
(r1 > r2) && return zero(T)
r1 = searchsortedfirst(A.rowVal, i0, r1, r2, Base.Order.Forward)
((r1 > r2) || (A.rowVal[r1] != i0)) ? zero(T) : A.nzVal[r1]
end

function getindex(A::CuSparseMatrixCSR{T}, i0::Integer, i1::Integer) where T
m, n = size(A)
if !(1 <= i0 <= m && 1 <= i1 <= n)
throw(BoundsError())
end
c1 = Int(A.rowPtr[i0])
c2 = Int(A.rowPtr[i0+1]-1)
(c1 > c2) && return zero(T)
c1 = searchsortedfirst(A.colVal, i1, c1, c2, Base.Order.Forward)
((c1 > c2) || (A.colVal[c1] != i1)) ? zero(T) : A.nzVal[c1]
end

# Called for indexing into `CuSparseVector`s
function _spgetindex(m::Integer, nzind::CuVector{Ti}, nzval::CuVector{Tv},
i::Integer) where {Tv,Ti}
ii = searchsortedfirst(nzind, convert(Ti, i))
(ii <= m && nzind[ii] == i) ? nzval[ii] : zero(Tv)
end

function collect(Vec::CuSparseVector)
SparseVector(Vec.dims[1], collect(Vec.iPtr), collect(Vec.nzVal))
end
Expand Down
36 changes: 36 additions & 0 deletions test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@ blockdim = 5
@test size(d_x,1) == m
@test size(d_x,2) == 1
@test ndims(d_x) == 1
CuArrays.@allowscalar begin
@test Array(d_x[:]) == x[:]
@test d_x[firstindex(d_x)] == x[firstindex(x)]
@test d_x[div(end, 2)] == x[div(end, 2)]
@test d_x[end] == x[end]
@test Array(d_x[firstindex(d_x):end]) == x[firstindex(x):end]
end
@test_throws BoundsError d_x[firstindex(d_x) - 1]
@test_throws BoundsError d_x[end + 1]
@test nnz(d_x) == nnz(x)
@test Array(nonzeros(d_x)) == nonzeros(x)
@test Array(SparseArrays.nonzeroinds(d_x)) == SparseArrays.nonzeroinds(x)
@test nnz(d_x) == length(nonzeros(d_x))
x = sprand(m,n,0.2)
d_x = CuSparseMatrixCSC(x)
@test length(d_x) == m*n
Expand All @@ -26,6 +39,29 @@ blockdim = 5
@test size(d_x,2) == n
@test size(d_x,3) == 1
@test ndims(d_x) == 2
CuArrays.@allowscalar begin
@test Array(d_x[:]) == x[:]
@test d_x[firstindex(d_x)] == x[firstindex(x)]
@test d_x[div(end, 2)] == x[div(end, 2)]
@test d_x[end] == x[end]
@test d_x[firstindex(d_x), firstindex(d_x)] == x[firstindex(x), firstindex(x)]
@test d_x[div(end, 2), div(end, 2)] == x[div(end, 2), div(end, 2)]
@test d_x[end, end] == x[end, end]
@test Array(d_x[firstindex(d_x):end, firstindex(d_x):end]) == x[:, :]
end
@test_throws BoundsError d_x[firstindex(d_x) - 1]
@test_throws BoundsError d_x[end + 1]
@test_throws BoundsError d_x[firstindex(d_x) - 1, firstindex(d_x) - 1]
@test_throws BoundsError d_x[end + 1, end + 1]
@test_throws BoundsError d_x[firstindex(d_x) - 1:end + 1, :]
@test_throws BoundsError d_x[firstindex(d_x) - 1, :]
@test_throws BoundsError d_x[end + 1, :]
@test_throws BoundsError d_x[:, firstindex(d_x) - 1:end + 1]
@test_throws BoundsError d_x[:, firstindex(d_x) - 1]
@test_throws BoundsError d_x[:, end + 1]
@test nnz(d_x) == nnz(x)
@test Array(nonzeros(d_x)) == nonzeros(x)
@test nnz(d_x) == length(nonzeros(d_x))
@test !issymmetric(d_x)
@test !ishermitian(d_x)
@test_throws ArgumentError size(d_x,0)
Expand Down