From 9cf27c987d1374cbf5285c8c199b81fd1606bc89 Mon Sep 17 00:00:00 2001 From: Tony Kelman Date: Sun, 7 Feb 2016 02:24:03 -0800 Subject: [PATCH] Fix multiplication, division between sparse and scalar for corner cases when the scalar is zero, Inf, or NaN Returning dense data for sparse input is potentially really bad for performance, but necessary to satisfy full(op(A, B)) == op(full(A), full(B)) --- base/sparse/sparsematrix.jl | 57 ++++++++++++++++++++++++++++++++++--- test/sparsedir/sparse.jl | 12 ++++++++ 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/base/sparse/sparsematrix.jl b/base/sparse/sparsematrix.jl index 714f92f178d85..d1b38e77e1dce 100644 --- a/base/sparse/sparsematrix.jl +++ b/base/sparse/sparsematrix.jl @@ -1084,18 +1084,67 @@ end # macro (.-)(A::Number, B::SparseMatrixCSC) = A .- full(B) ( -)(A::Array , B::SparseMatrixCSC) = A - full(B) +# multiplication and division by scalars need to be careful about 0, Inf, NaN +# corner cases where we might need to return dense data (as a SparseMatrixCSC +# for type stability) +function densify_with_default(A::SparseMatrixCSC, spvals, defaultvalue) + # return a SparseMatrixCSC C with the same dimensions as A, structural + # nonzero values spvals in the same locations that A has structural + # nonzeros, and nonzero value defaultvalue everywhere else + m, n = size(A) + Arowval = A.rowval + Acolptr = A.colptr + Cnnz = m * n + Cnzval = fill(defaultvalue, Cnnz) + Crowval = similar(Arowval, Cnnz) + Ccolptr = similar(Acolptr) + Ccolptr[1] = 1 + for col = 1:n + Ccolptr[col+1] = 1 + col * m + Crowval[Ccolptr[col] : Ccolptr[col+1]-1] = 1:m + for k in nzrange(A, col) + Cnzval[sub2ind((m, n), Arowval[k], col)] = spvals[k] + end + end + return SparseMatrixCSC(m, n, Ccolptr, Crowval, Cnzval) +end + (.*)(A::AbstractArray, B::AbstractArray) = broadcast_zpreserving(MulFun(), A, B) -(.*)(A::SparseMatrixCSC, B::Number) = SparseMatrixCSC(A.m, A.n, copy(A.colptr), copy(A.rowval), A.nzval .* B) -(.*)(A::Number, B::SparseMatrixCSC) = SparseMatrixCSC(B.m, B.n, copy(B.colptr), copy(B.rowval), A .* B.nzval) +function (.*)(A::SparseMatrixCSC, B::Number) + if isfinite(B) + SparseMatrixCSC(A.m, A.n, copy(A.colptr), copy(A.rowval), A.nzval .* B) + else + densify_with_default(A, A.nzval .* B, zero(eltype(A)) .* B) + end +end +function (.*)(A::Number, B::SparseMatrixCSC) + if isfinite(A) + SparseMatrixCSC(B.m, B.n, copy(B.colptr), copy(B.rowval), A .* B.nzval) + else + densify_with_default(B, A .* B.nzval, A .* zero(eltype(B))) + end +end -(./)(A::SparseMatrixCSC, B::Number) = SparseMatrixCSC(A.m, A.n, copy(A.colptr), copy(A.rowval), A.nzval ./ B) +function (./)(A::SparseMatrixCSC, B::Number) + if B == 0 || isnan(B) + densify_with_default(A, A.nzval ./ B, zero(eltype(A)) ./ B) + else + SparseMatrixCSC(A.m, A.n, copy(A.colptr), copy(A.rowval), A.nzval ./ B) + end +end (./)(A::Number, B::SparseMatrixCSC) = (./)(A, full(B)) (./)(A::SparseMatrixCSC, B::Array) = (./)(full(A), B) (./)(A::Array, B::SparseMatrixCSC) = (./)(A, full(B)) (./)(A::SparseMatrixCSC, B::SparseMatrixCSC) = (./)(full(A), full(B)) (.\)(A::SparseMatrixCSC, B::Number) = (.\)(full(A), B) -(.\)(A::Number, B::SparseMatrixCSC) = SparseMatrixCSC(B.m, B.n, copy(B.colptr), copy(B.rowval), A .\ B.nzval ) +function (.\)(A::Number, B::SparseMatrixCSC) + if A == 0 || isnan(A) + densify_with_default(B, A .\ B.nzval, A .\ zero(eltype(B))) + else + SparseMatrixCSC(B.m, B.n, copy(B.colptr), copy(B.rowval), A .\ B.nzval) + end +end (.\)(A::SparseMatrixCSC, B::Array) = (.\)(full(A), B) (.\)(A::Array, B::SparseMatrixCSC) = (.\)(A, full(B)) (.\)(A::SparseMatrixCSC, B::SparseMatrixCSC) = (.\)(full(A), full(B)) diff --git a/test/sparsedir/sparse.jl b/test/sparsedir/sparse.jl index 71e99e62af391..4bd181d98db2d 100644 --- a/test/sparsedir/sparse.jl +++ b/test/sparsedir/sparse.jl @@ -1200,3 +1200,15 @@ let @test_throws LinAlg.SingularException LowerTriangular(A)\ones(n) @test_throws LinAlg.SingularException UpperTriangular(A)\ones(n) end + +# Inf/NaN corner cases in sparse .* scalar, scalar .* sparse, +# sparse ./ scalar, scalar .\ sparse +for A in (4*speye(5,3), 3*sparse(ones(Int, 4,6)), + SparseMatrixCSC(4, 3, [1,3,5,8], [1,2,2,3,2,3,4], + [0.0, -0.0, -Inf, Inf, NaN, -NaN, 2.0])), + B in (0.0, -0.0, -Inf, Inf, NaN, -NaN, 2.0) + @test_approx_eq_eps full(A .* B) full(A) .* B 0 + @test_approx_eq_eps full(B .* A) B .* full(A) 0 + @test_approx_eq_eps full(A ./ B) full(A) ./ B 0 + @test_approx_eq_eps full(B .\ A) B .\ full(A) 0 +end