Skip to content

Commit

Permalink
Allow more general eltypes in sparse array multiplication (#33205)
Browse files Browse the repository at this point in the history
  • Loading branch information
pablosanjose authored and fredrikekre committed Sep 20, 2019
1 parent c5cc3f2 commit 34d2b87
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 87 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Standard library changes

#### SparseArrays

* Products involving sparse arrays now allow more general sparse `eltype`s, such as `StaticArrays` ([#33205])

#### Dates

Expand Down
90 changes: 33 additions & 57 deletions stdlib/SparseArrays/src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,6 @@
import LinearAlgebra: checksquare
using Random: rand!

## sparse matrix multiplication

*(A::AbstractSparseMatrixCSC{TvA,TiA}, B::AbstractSparseMatrixCSC{TvB,TiB}) where {TvA,TiA,TvB,TiB} =
*(sppromote(A, B)...)
*(A::AbstractSparseMatrixCSC{TvA,TiA}, transB::Transpose{<:Any,<:AbstractSparseMatrixCSC{TvB,TiB}}) where {TvA,TiA,TvB,TiB} =
(B = transB.parent; (pA, pB) = sppromote(A, B); *(pA, transpose(pB)))
*(A::AbstractSparseMatrixCSC{TvA,TiA}, adjB::Adjoint{<:Any,<:AbstractSparseMatrixCSC{TvB,TiB}}) where {TvA,TiA,TvB,TiB} =
(B = adjB.parent; (pA, pB) = sppromote(A, B); *(pA, adjoint(pB)))
*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC{TvA,TiA}}, B::AbstractSparseMatrixCSC{TvB,TiB}) where {TvA,TiA,TvB,TiB} =
(A = transA.parent; (pA, pB) = sppromote(A, B); *(transpose(pA), pB))
*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC{TvA,TiA}}, B::AbstractSparseMatrixCSC{TvB,TiB}) where {TvA,TiA,TvB,TiB} =
(A = adjA.parent; (pA, pB) = sppromote(A, B); *(adjoint(pA), pB))
*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC{TvA,TiA}}, transB::Transpose{<:Any,<:AbstractSparseMatrixCSC{TvB,TiB}}) where {TvA,TiA,TvB,TiB} =
(A = transA.parent; B = transB.parent; (pA, pB) = sppromote(A, B); *(transpose(pA), transpose(pB)))
*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC{TvA,TiA}}, adjB::Adjoint{<:Any,<:AbstractSparseMatrixCSC{TvB,TiB}}) where {TvA,TiA,TvB,TiB} =
(A = adjA.parent; B = adjB.parent; (pA, pB) = sppromote(A, B); *(adjoint(pA), adjoint(pB)))

function sppromote(A::AbstractSparseMatrixCSC{TvA,TiA}, B::AbstractSparseMatrixCSC{TvB,TiB}) where {TvA,TiA,TvB,TiB}
Tv = promote_type(TvA, TvB)
Ti = promote_type(TiA, TiB)
A = convert(SparseMatrixCSC{Tv,Ti}, A)
B = convert(SparseMatrixCSC{Tv,Ti}, B)
A, B
end

# In matrix-vector multiplication, the correct orientation of the vector is assumed.
const AdjOrTransStridedMatrix{T} = Union{StridedMatrix{T},Adjoint{<:Any,<:StridedMatrix{T}},Transpose{<:Any,<:StridedMatrix{T}}}

Expand All @@ -50,10 +25,10 @@ function mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::Union{StridedVe
end
C
end
*(A::AbstractSparseMatrixCSC{TA,S}, x::StridedVector{Tx}) where {TA,S,Tx} =
(T = promote_op(matprod, TA, Tx); mul!(similar(x, T, size(A, 1)), A, x, one(T), zero(T)))
*(A::AbstractSparseMatrixCSC{TA,S}, B::StridedMatrix{Tx}) where {TA,S,Tx} =
(T = promote_op(matprod, TA, Tx); mul!(similar(B, T, (size(A, 1), size(B, 2))), A, B, one(T), zero(T)))
*(A::AbstractSparseMatrixCSC{TA}, x::StridedVector{Tx}) where {TA,Tx} =
(T = promote_op(matprod, TA, Tx); mul!(similar(x, T, size(A, 1)), A, x, true, false))
*(A::AbstractSparseMatrixCSC{TA}, B::StridedMatrix{Tx}) where {TA,Tx} =
(T = promote_op(matprod, TA, Tx); mul!(similar(B, T, (size(A, 1), size(B, 2))), A, B, true, false))

function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedMatrix}, α::Number, β::Number)
A = adjA.parent
Expand All @@ -76,10 +51,10 @@ function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}
end
C
end
*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC{TA,S}}, x::StridedVector{Tx}) where {TA,S,Tx} =
(T = promote_op(matprod, TA, Tx); mul!(similar(x, T, size(adjA, 1)), adjA, x, one(T), zero(T)))
*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC{TA,S}}, B::AdjOrTransStridedMatrix{Tx}) where {TA,S,Tx} =
(T = promote_op(matprod, TA, Tx); mul!(similar(B, T, (size(adjA, 1), size(B, 2))), adjA, B, one(T), zero(T)))
*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, x::StridedVector{Tx}) where {Tx} =
(T = promote_op(matprod, eltype(adjA), Tx); mul!(similar(x, T, size(adjA, 1)), adjA, x, true, false))
*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTransStridedMatrix) =
(T = promote_op(matprod, eltype(adjA), eltype(B)); mul!(similar(B, T, (size(adjA, 1), size(B, 2))), adjA, B, true, false))

function mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedMatrix}, α::Number, β::Number)
A = transA.parent
Expand All @@ -102,19 +77,19 @@ function mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:AbstractSparseMatrix
end
C
end
*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC{TA,S}}, x::StridedVector{Tx}) where {TA,S,Tx} =
(T = promote_op(matprod, TA, Tx); mul!(similar(x, T, size(transA, 1)), transA, x, one(T), zero(T)))
*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC{TA,S}}, B::AdjOrTransStridedMatrix{Tx}) where {TA,S,Tx} =
(T = promote_op(matprod, TA, Tx); mul!(similar(B, T, (size(transA, 1), size(B, 2))), transA, B, one(T), zero(T)))
*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, x::StridedVector{Tx}) where {Tx} =
(T = promote_op(matprod, eltype(transA), Tx); mul!(similar(x, T, size(transA, 1)), transA, x, true, false))
*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTransStridedMatrix) =
(T = promote_op(matprod, eltype(transA), eltype(B)); mul!(similar(B, T, (size(transA, 1), size(B, 2))), transA, B, true, false))

# For compatibility with dense multiplication API. Should be deleted when dense multiplication
# API is updated to follow BLAS API.
mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::Union{StridedVector,AdjOrTransStridedMatrix}) =
mul!(C, A, B, one(eltype(B)), zero(eltype(C)))
mul!(C, A, B, true, false)
mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedMatrix}) =
mul!(C, adjA, B, one(eltype(B)), zero(eltype(C)))
mul!(C, adjA, B, true, false)
mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedMatrix}) =
mul!(C, transA, B, one(eltype(B)), zero(eltype(C)))
mul!(C, transA, B, true, false)

function mul!(C::StridedVecOrMat, X::AdjOrTransStridedMatrix, A::AbstractSparseMatrixCSC, α::Number, β::Number)
mX, nX = size(X)
Expand All @@ -131,8 +106,8 @@ function mul!(C::StridedVecOrMat, X::AdjOrTransStridedMatrix, A::AbstractSparseM
end
C
end
*(X::AdjOrTransStridedMatrix{TX}, A::AbstractSparseMatrixCSC{TvA,TiA}) where {TX,TvA,TiA} =
(T = promote_op(matprod, TX, TvA); mul!(similar(X, T, (size(X, 1), size(A, 2))), X, A, one(T), zero(T)))
*(X::AdjOrTransStridedMatrix, A::AbstractSparseMatrixCSC{TvA}) where {TvA} =
(T = promote_op(matprod, eltype(X), TvA); mul!(similar(X, T, (size(X, 1), size(A, 2))), X, A, true, false))

function mul!(C::StridedVecOrMat, X::AdjOrTransStridedMatrix, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, α::Number, β::Number)
A = adjA.parent
Expand All @@ -150,8 +125,8 @@ function mul!(C::StridedVecOrMat, X::AdjOrTransStridedMatrix, adjA::Adjoint{<:An
end
C
end
*(X::AdjOrTransStridedMatrix{TX}, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC{TvA,TiA}}) where {TX,TvA,TiA} =
(T = promote_op(matprod, TX, TvA); mul!(similar(X, T, (size(X, 1), size(adjA, 2))), X, adjA, one(T), zero(T)))
*(X::AdjOrTransStridedMatrix, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}) =
(T = promote_op(matprod, eltype(X), eltype(adjA)); mul!(similar(X, T, (size(X, 1), size(adjA, 2))), X, adjA, true, false))

function mul!(C::StridedVecOrMat, X::AdjOrTransStridedMatrix, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, α::Number, β::Number)
A = transA.parent
Expand All @@ -169,8 +144,8 @@ function mul!(C::StridedVecOrMat, X::AdjOrTransStridedMatrix, transA::Transpose{
end
C
end
*(X::AdjOrTransStridedMatrix{TX}, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC{TvA,TiA}}) where {TX,TvA,TiA} =
(T = promote_op(matprod, TX, TvA); mul!(similar(X, T, (size(X, 1), size(transA, 2))), X, transA, one(T), zero(T)))
*(X::AdjOrTransStridedMatrix, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}) =
(T = promote_op(matprod, eltype(X), eltype(transA)); mul!(similar(X, T, (size(X, 1), size(transA, 2))), X, transA, true, false))

function (*)(D::Diagonal, A::AbstractSparseMatrixCSC)
T = Base.promote_op(*, eltype(D), eltype(A))
Expand All @@ -184,13 +159,13 @@ end
# Sparse matrix multiplication as described in [Gustavson, 1978]:
# http://dl.acm.org/citation.cfm?id=355796

*(A::AbstractSparseMatrixCSC{Tv,Ti}, B::AbstractSparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = spmatmul(A,B)
*(A::AbstractSparseMatrixCSC{Tv,Ti}, B::Adjoint{<:Any,<:AbstractSparseMatrixCSC{Tv,Ti}}) where {Tv,Ti} = spmatmul(A, copy(B))
*(A::AbstractSparseMatrixCSC{Tv,Ti}, B::Transpose{<:Any,<:AbstractSparseMatrixCSC{Tv,Ti}}) where {Tv,Ti} = spmatmul(A, copy(B))
*(A::Transpose{<:Any,<:AbstractSparseMatrixCSC{Tv,Ti}}, B::AbstractSparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = spmatmul(copy(A), B)
*(A::Adjoint{<:Any,<:AbstractSparseMatrixCSC{Tv,Ti}}, B::AbstractSparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = spmatmul(copy(A), B)
*(A::Adjoint{<:Any,<:AbstractSparseMatrixCSC{Tv,Ti}}, B::Adjoint{<:Any,<:AbstractSparseMatrixCSC{Tv,Ti}}) where {Tv,Ti} = spmatmul(copy(A), copy(B))
*(A::Transpose{<:Any,<:AbstractSparseMatrixCSC{Tv,Ti}}, B::Transpose{<:Any,<:AbstractSparseMatrixCSC{Tv,Ti}}) where {Tv,Ti} = spmatmul(copy(A), copy(B))
*(A::AbstractSparseMatrixCSC, B::AbstractSparseMatrixCSC) = spmatmul(A,B)
*(A::AbstractSparseMatrixCSC, B::Adjoint{<:Any,<:AbstractSparseMatrixCSC}) = spmatmul(A, copy(B))
*(A::AbstractSparseMatrixCSC, B::Transpose{<:Any,<:AbstractSparseMatrixCSC}) = spmatmul(A, copy(B))
*(A::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::AbstractSparseMatrixCSC) = spmatmul(copy(A), B)
*(A::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::AbstractSparseMatrixCSC) = spmatmul(copy(A), B)
*(A::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::Adjoint{<:Any,<:AbstractSparseMatrixCSC}) = spmatmul(copy(A), copy(B))
*(A::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::Transpose{<:Any,<:AbstractSparseMatrixCSC}) = spmatmul(copy(A), copy(B))

# Gustavsen's matrix multiplication algorithm revisited.
# The result rowval vector is already sorted by construction.
Expand All @@ -200,7 +175,9 @@ end
# done by a quicksort of the row indices or by a full scan of the dense result vector.
# The last is faster, if more than ≈ 1/32 of the result column is nonzero.
# TODO: extend to SparseMatrixCSCUnion to allow for SubArrays (view(X, :, r)).
function spmatmul(A::AbstractSparseMatrixCSC{Tv,Ti}, B::AbstractSparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
function spmatmul(A::AbstractSparseMatrixCSC{TvA,TiA}, B::AbstractSparseMatrixCSC{TvB,TiB}) where {TvA,TiA,TvB,TiB}
Tv = promote_op(matprod, TvA, TvB)
Ti = promote_type(TiA, TiB)
mA, nA = size(A)
nB = size(B, 2)
nA == size(B, 1) || throw(DimensionMismatch())
Expand Down Expand Up @@ -371,8 +348,8 @@ end

## triangular sparse handling

possible_adjoint(adj::Bool, a::Real ) = a
possible_adjoint(adj::Bool, a ) = adj ? adjoint(a) : a
possible_adjoint(adj::Bool, a::Real) = a
possible_adjoint(adj::Bool, a) = adj ? adjoint(a) : a

const UnitDiagonalTriangular = Union{UnitUpperTriangular,UnitLowerTriangular}

Expand Down Expand Up @@ -776,7 +753,6 @@ mul!(y::StridedVecOrMat, A::SparseMatrixCSCSymmHerm, x::StridedVecOrMat) = mul!(
# C .= α * A * B + β * C
function mul!(C::StridedVecOrMat{T}, sA::SparseMatrixCSCSymmHerm, B::StridedVecOrMat,
α::Number, β::Number) where T

fuplo = sA.uplo == 'U' ? nzrangeup : nzrangelo
_mul!(fuplo, C, sA, B, T(α), T(β))
end
Expand Down
30 changes: 17 additions & 13 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -856,16 +856,17 @@ sparse(I,J,v::Number,m,n,combine::Function) = sparse(I, J, fill(v,length(I)), In
## Transposition and permutation methods

"""
halfperm!(X::AbstractSparseMatrixCSC{Tv,Ti}, A::AbstractSparseMatrixCSC{Tv,Ti},
q::AbstractVector{<:Integer}, f::Function = identity) where {Tv,Ti}
halfperm!(X::AbstractSparseMatrixCSC{Tv,Ti}, A::AbstractSparseMatrixCSC{TvA,Ti},
q::AbstractVector{<:Integer}, f::Function = identity) where {Tv,TvA,Ti}
Column-permute and transpose `A`, simultaneously applying `f` to each entry of `A`, storing
the result `(f(A)Q)^T` (`map(f, transpose(A[:,q]))`) in `X`.
`X`'s dimensions must match those of `transpose(A)` (`size(X, 1) == size(A, 2)` and `size(X, 2) == size(A, 1)`), and `X`
must have enough storage to accommodate all allocated entries in `A` (`length(rowvals(X)) >= nnz(A)`
and `length(nonzeros(X)) >= nnz(A)`). Column-permutation `q`'s length must match `A`'s column
count (`length(q) == size(A, 2)`).
Element type `Tv` of `X` must match `f(::TvA)`, where `TvA` is the element type of `A`.
`X`'s dimensions must match those of `transpose(A)` (`size(X, 1) == size(A, 2)` and
`size(X, 2) == size(A, 1)`), and `X` must have enough storage to accommodate all allocated
entries in `A` (`length(rowvals(X)) >= nnz(A)` and `length(nonzeros(X)) >= nnz(A)`).
Column-permutation `q`'s length must match `A`'s column count (`length(q) == size(A, 2)`).
This method is the parent of several methods performing transposition and permutation
operations on [`SparseMatrixCSC`](@ref)s. As this method performs no argument checking,
Expand All @@ -876,8 +877,8 @@ algorithms for sparse matrices: multiplication and permuted transposition," ACM
250-269 (1978). The algorithm runs in `O(size(A, 1), size(A, 2), nnz(A))` time and requires no space
beyond that passed in.
"""
function halfperm!(X::AbstractSparseMatrixCSC{Tv,Ti}, A::AbstractSparseMatrixCSC{Tv,Ti},
q::AbstractVector{<:Integer}, f::Function = identity) where {Tv,Ti}
function halfperm!(X::AbstractSparseMatrixCSC{Tv,Ti}, A::AbstractSparseMatrixCSC{TvA,Ti},
q::AbstractVector{<:Integer}, f::Function = identity) where {Tv,TvA,Ti}
_computecolptrs_halfperm!(X, A)
_distributevals_halfperm!(X, A, q, f)
return X
Expand All @@ -886,7 +887,7 @@ end
Helper method for `halfperm!`. Computes `transpose(A[:,q])`'s column pointers, storing them
shifted one position forward in `getcolptr(X)`; `_distributevals_halfperm!` fixes this shift.
"""
function _computecolptrs_halfperm!(X::AbstractSparseMatrixCSC{Tv,Ti}, A::AbstractSparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
function _computecolptrs_halfperm!(X::AbstractSparseMatrixCSC{Tv,Ti}, A::AbstractSparseMatrixCSC{TvA,Ti}) where {Tv,TvA,Ti}
# Compute `transpose(A[:,q])`'s column counts. Store shifted forward one position in getcolptr(X).
fill!(getcolptr(X), 0)
@inbounds for k in 1:nnz(A)
Expand All @@ -908,7 +909,7 @@ distributing `rowvals(A)` and `f`-transformed `nonzeros(A)` into `rowvals(X)` an
respectively. Simultaneously fixes the one-position-forward shift in `getcolptr(X)`.
"""
@noinline function _distributevals_halfperm!(X::AbstractSparseMatrixCSC{Tv,Ti},
A::AbstractSparseMatrixCSC{Tv,Ti}, q::AbstractVector{<:Integer}, f::Function) where {Tv,Ti}
A::AbstractSparseMatrixCSC{TvA,Ti}, q::AbstractVector{<:Integer}, f::Function) where {Tv,TvA,Ti}
@inbounds for Xi in 1:size(A, 2)
Aj = q[Xi]
for Ak in nzrange(A, Aj)
Expand Down Expand Up @@ -944,7 +945,8 @@ end
transpose!(X::AbstractSparseMatrixCSC{Tv,Ti}, A::AbstractSparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = ftranspose!(X, A, identity)
adjoint!(X::AbstractSparseMatrixCSC{Tv,Ti}, A::AbstractSparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = ftranspose!(X, A, conj)

function ftranspose(A::AbstractSparseMatrixCSC{Tv,Ti}, f::Function) where {Tv,Ti}
# manually specifying eltype allows to avoid calling return_type of f on TvA
function ftranspose(A::AbstractSparseMatrixCSC{TvA,Ti}, f::Function, eltype::Type{Tv} = TvA) where {Tv,TvA,Ti}
X = SparseMatrixCSC(size(A, 2), size(A, 1),
ones(Ti, size(A, 1)+1),
Vector{Ti}(undef, nnz(A)),
Expand All @@ -953,8 +955,10 @@ function ftranspose(A::AbstractSparseMatrixCSC{Tv,Ti}, f::Function) where {Tv,Ti
end
adjoint(A::AbstractSparseMatrixCSC) = Adjoint(A)
transpose(A::AbstractSparseMatrixCSC) = Transpose(A)
Base.copy(A::Adjoint{<:Any,<:AbstractSparseMatrixCSC}) = ftranspose(A.parent, x -> copy(adjoint(x)))
Base.copy(A::Transpose{<:Any,<:AbstractSparseMatrixCSC}) = ftranspose(A.parent, x -> copy(transpose(x)))
Base.copy(A::Adjoint{<:Any,<:AbstractSparseMatrixCSC}) =
ftranspose(A.parent, x -> adjoint(copy(x)), eltype(A))
Base.copy(A::Transpose{<:Any,<:AbstractSparseMatrixCSC}) =
ftranspose(A.parent, x -> transpose(copy(x)), eltype(A))
function Base.permutedims(A::AbstractSparseMatrixCSC, (a,b))
(a, b) == (2, 1) && return ftranspose(A, identity)
(a, b) == (1, 2) && return copy(A)
Expand Down
Loading

0 comments on commit 34d2b87

Please sign in to comment.