Skip to content

Commit

Permalink
Make use of fast sparse outer products from Julia (closes #4)
Browse files Browse the repository at this point in the history
The feature was added to Julia in time for v1.2 in JuliaLang/julia#24980,
so get rid of the custom `outer()` method here and rewrite `quadprod()`
in terms of just standard matrix methods. Julia v1.2 is the
minimum-supported version at this point, so no need to worry about
backporting the functionality.

In the future, this function may yet still go away since the
implementation is nearly trivial at this point, but that can be a
follow-up PR.
  • Loading branch information
jmert committed Jul 8, 2020
1 parent 23fc8fa commit 203cf1e
Showing 1 changed file with 4 additions and 87 deletions.
91 changes: 4 additions & 87 deletions src/numerics.jl
Original file line number Diff line number Diff line change
@@ -1,100 +1,17 @@
using SparseArrays

"""
Computes the outer product between a given column of a sparse matrix and a vector.
"""
function outer end

"""
outer(A::SparseMatrixCSC, n::Integer, w::AbstractVector)
Performs the equivalent of ``\\vec a_n \\vec w^\\dagger`` where ``\\vec a_n`` is the
column `A[:,n]`.
"""
function outer(A::SparseMatrixCSC{Tv,Ti}, n::Integer, w::AbstractVector{Tv}) where {Tv,Ti}
colptrn = nzrange(A, n)
rowvalA = rowvals(A)
nzvalsA = nonzeros(A)

nnza = length(colptrn)
nnzw = length(w)
numnz = nnza * nnzw

colptr = Vector{Ti}(undef, nnzw+1)
rowval = Vector{Ti}(undef, numnz)
nzvals = Vector{Tv}(undef, numnz)

idx = 0
@inbounds for jj = 1:nnzw
colptr[jj] = idx + 1

wv = conj(w[jj])
iszero(wv) && continue

for ii = colptrn
idx += 1
rowval[idx] = rowvalA[ii] # copy row index from A
nzvals[idx] = wv * nzvalsA[ii] # outer product values
end
end
@inbounds colptr[nnzw+1] = idx + 1
return SparseMatrixCSC(size(A,1), nnzw, colptr, rowval, nzvals)
end

"""
outer(w::AbstractVector, A::SparseMatrixCSC, n::Integer)
Performs the equivalent of ``\\vec w \\vec{a}_n^\\dagger`` where ``\\vec a_n`` is the
column `A[:,n]`.
"""
function outer(w::AbstractVector{Tv}, A::SparseMatrixCSC{Tv,Ti}, n::Integer) where {Tv,Ti}
colptrn = nzrange(A, n)
rowvalA = rowvals(A)
nzvalsA = nonzeros(A)

nnza = length(colptrn)
nnzw = length(w)
numnz = nnza * nnzw

colptr = zeros(Ti, size(A,1)+1)
rowval = Vector{Ti}(undef, numnz)
nzvals = Vector{Tv}(undef, numnz)

idx = 0
@inbounds colptr[1] = 1 # col 1 always at index 1
@inbounds for jj = colptrn
av = conj(nzvalsA[jj])
rv = rowvalA[jj]

for ii = 1:nnzw
wv = w[ii]
iszero(wv) && continue

idx += 1
colptr[rv+1] += 1 # count num of entries in column
rowval[idx] = ii
nzvals[idx] = w[ii] * av # outer product values
end
end
cumsum!(colptr, colptr) # offsets are sum of all previous

return SparseMatrixCSC(nnzw, size(A,1), colptr, rowval, nzvals)
end

"""
quadprod(A, b, n, dir=:col)
Computes the quadratic product ``ABA^T`` efficiently for the case where ``B`` is all zero
Computes the quadratic product ``ABA^\\top`` efficiently for the case where ``B`` is all zero
except for the `n`th column or row vector `b`, for `dir = :col` or `dir = :row`,
respectively.
"""
function quadprod(A, b, n, dir::Symbol=:col)
@inline function quadprod(A, b, n, dir::Symbol=:col)
if dir == :col
w = A * b
return outer(w, A, n)
return (A * sparse(b)) * view(A, :, n)'
elseif dir == :row
w = A * b
return outer(A, n, w)
return view(A, :, n) * (A * sparse(b))'
else
error("Unrecognized direction `dir = $(repr(dir))`.")
end
Expand Down

0 comments on commit 203cf1e

Please sign in to comment.