Skip to content

Commit

Permalink
faster circshift! for SparseMatrixCSC (#30317)
Browse files Browse the repository at this point in the history
* implement circshift! for SparseMatrixCSC

* factor helper function shifter!, implement efficient circshift! for SparseVector

* add some @inbounds for improved performance

* remove allocations completely, giving a large improvement for small matrices

* some renaming to avoid polluting the module namespace

* remove useless reallocation and fix bug with different in/out types, better tests

* avoid action if iszero(r) and/or iszero(c), move sparse vector shifting helpers to sparsevector.jl

* Make shift amounts deterministic in tests, move sparse vector tests into sparsevector.jl

* comment fix

* for some reason, copy!(a::SparseVector, b::SparseVector) does not work
  • Loading branch information
abraunst authored and stevengj committed Dec 25, 2018
1 parent 8eca27d commit 94993e9
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 1 deletion.
2 changes: 1 addition & 1 deletion stdlib/SparseArrays/src/SparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
vcat, hcat, hvcat, cat, imag, argmax, kron, length, log, log1p, max, min,
maximum, minimum, one, promote_eltype, real, reshape, rot180,
rotl90, rotr90, round, setindex!, similar, size, transpose,
vec, permute!, map, map!, Array, diff
vec, permute!, map, map!, Array, diff, circshift!, circshift

using Random: GLOBAL_RNG, AbstractRNG, randsubseq, randsubseq!

Expand Down
43 changes: 43 additions & 0 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3511,3 +3511,46 @@ end
(+)(A::SparseMatrixCSC, J::UniformScaling) = A + sparse(J, size(A)...)
(-)(A::SparseMatrixCSC, J::UniformScaling) = A - sparse(J, size(A)...)
(-)(J::UniformScaling, A::SparseMatrixCSC) = sparse(J, size(A)...) - A

## circular shift

function circshift!(O::SparseMatrixCSC, X::SparseMatrixCSC, (r,c)::Base.DimsInteger{2})
nnz = length(X.nzval)

iszero(nnz) && return copy!(O, X)

##### column shift
c = mod(c, X.n)
if iszero(c)
copy!(O, X)
else
##### readjust output
resize!(O.colptr, X.n + 1)
resize!(O.rowval, nnz)
resize!(O.nzval, nnz)
O.colptr[X.n + 1] = nnz + 1

# exchange left and right blocks
nleft = X.colptr[X.n - c + 1] - 1
nright = nnz - nleft
@inbounds for i=c+1:X.n
O.colptr[i] = X.colptr[i-c] + nright
end
@inbounds for i=1:c
O.colptr[i] = X.colptr[X.n - c + i] - nleft
end
# rotate rowval and nzval by the right number of elements
circshift!(O.rowval, X.rowval, (nright,))
circshift!(O.nzval, X.nzval, (nright,))
end
##### row shift
r = mod(r, X.m)
iszero(r) && return O
@inbounds for i=1:O.n
subvector_shifter!(O.rowval, O.nzval, O.colptr[i], O.colptr[i+1]-1, O.m, r)
end
return O
end

circshift!(O::SparseMatrixCSC, X::SparseMatrixCSC, (r,)::Base.DimsInteger{1}) = circshift!(O, X, (r,0))
circshift!(O::SparseMatrixCSC, X::SparseMatrixCSC, r::Real) = circshift!(O, X, (Integer(r),0))
39 changes: 39 additions & 0 deletions stdlib/SparseArrays/src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1975,3 +1975,42 @@ function fill!(A::Union{SparseVector, SparseMatrixCSC}, x)
end
return A
end



# in-place swaps (dense) blocks start:split and split+1:fin in col
function _swap!(col::AbstractVector, start::Integer, fin::Integer, split::Integer)
split == fin && return
reverse!(col, start, split)
reverse!(col, split + 1, fin)
reverse!(col, start, fin)
return
end


# in-place shifts a sparse subvector by r. Used also by sparsematrix.jl
function subvector_shifter!(R::AbstractVector, V::AbstractVector, start::Integer, fin::Integer, m::Integer, r::Integer)
split = fin
@inbounds for j = start:fin
# shift positions ...
R[j] += r
if R[j] <= m
split = j
else
R[j] -= m
end
end
# ...but rowval should be sorted within columns
_swap!(R, start, fin, split)
_swap!(V, start, fin, split)
end


function circshift!(O::SparseVector, X::SparseVector, (r,)::Base.DimsInteger{1})
O .= X
subvector_shifter!(O.nzind, O.nzval, 1, length(O.nzind), O.n, mod(r, X.n))
return O
end


circshift!(O::SparseVector, X::SparseVector, r::Real,) = circshift!(O, X, (Integer(r),))
27 changes: 27 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2410,4 +2410,31 @@ end
@test one(A) isa SparseMatrixCSC{Int}
end

@testset "circshift" begin
m,n = 17,15
A = sprand(m, n, 0.5)
for rshift in (-1, 0, 1, 10), cshift in (-1, 0, 1, 10)
shifts = (rshift, cshift)
# using dense circshift to compare
B = circshift(Matrix(A), shifts)
# sparse circshift
C = circshift(A, shifts)
@test C == B
# sparse circshift should not add structural zeros
@test nnz(C) == nnz(A)
# test circshift!
D = similar(A)
circshift!(D, A, shifts)
@test D == B
@test nnz(D) == nnz(A)
# test different in/out types
A2 = floor.(100A)
E1 = spzeros(Int64, m, n)
E2 = spzeros(Int64, m, n)
circshift!(E1, A2, shifts)
circshift!(E2, Matrix(A2), shifts)
@test E1 == E2
end
end

end # module
22 changes: 22 additions & 0 deletions stdlib/SparseArrays/test/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1265,4 +1265,26 @@ end
end
end

@testset "SparseVector circshift" begin
n = 100
v = sprand(n, 0.5)
for shift in (0,-1,1,5,-7,n+10)
x = circshift(Vector(v), shift)
w = circshift(v, shift)
@test nnz(v) == nnz(w)
@test w == x
# test circshift!
v1 = similar(v)
circshift!(v1, v, shift)
@test v1 == x
# test different in/out types
y1 = spzeros(Int64, n)
y2 = spzeros(Int64, n)
v2 = floor.(100v)
circshift!(y1, v2, shift)
circshift!(y2, Vector(v2), shift)
@test y1 == y2
end
end

end # module

0 comments on commit 94993e9

Please sign in to comment.