diff --git a/base/sparse/sparsematrix.jl b/base/sparse/sparsematrix.jl index 714f92f178d85..d1b38e77e1dce 100644 --- a/base/sparse/sparsematrix.jl +++ b/base/sparse/sparsematrix.jl @@ -1084,18 +1084,67 @@ end # macro (.-)(A::Number, B::SparseMatrixCSC) = A .- full(B) ( -)(A::Array , B::SparseMatrixCSC) = A - full(B) +# multiplication and division by scalars need to be careful about 0, Inf, NaN +# corner cases where we might need to return dense data (as a SparseMatrixCSC +# for type stability) +function densify_with_default(A::SparseMatrixCSC, spvals, defaultvalue) + # return a SparseMatrixCSC C with the same dimensions as A, structural + # nonzero values spvals in the same locations that A has structural + # nonzeros, and nonzero value defaultvalue everywhere else + m, n = size(A) + Arowval = A.rowval + Acolptr = A.colptr + Cnnz = m * n + Cnzval = fill(defaultvalue, Cnnz) + Crowval = similar(Arowval, Cnnz) + Ccolptr = similar(Acolptr) + Ccolptr[1] = 1 + for col = 1:n + Ccolptr[col+1] = 1 + col * m + Crowval[Ccolptr[col] : Ccolptr[col+1]-1] = 1:m + for k in nzrange(A, col) + Cnzval[sub2ind((m, n), Arowval[k], col)] = spvals[k] + end + end + return SparseMatrixCSC(m, n, Ccolptr, Crowval, Cnzval) +end + (.*)(A::AbstractArray, B::AbstractArray) = broadcast_zpreserving(MulFun(), A, B) -(.*)(A::SparseMatrixCSC, B::Number) = SparseMatrixCSC(A.m, A.n, copy(A.colptr), copy(A.rowval), A.nzval .* B) -(.*)(A::Number, B::SparseMatrixCSC) = SparseMatrixCSC(B.m, B.n, copy(B.colptr), copy(B.rowval), A .* B.nzval) +function (.*)(A::SparseMatrixCSC, B::Number) + if isfinite(B) + SparseMatrixCSC(A.m, A.n, copy(A.colptr), copy(A.rowval), A.nzval .* B) + else + densify_with_default(A, A.nzval .* B, zero(eltype(A)) .* B) + end +end +function (.*)(A::Number, B::SparseMatrixCSC) + if isfinite(A) + SparseMatrixCSC(B.m, B.n, copy(B.colptr), copy(B.rowval), A .* B.nzval) + else + densify_with_default(B, A .* B.nzval, A .* zero(eltype(B))) + end +end -(./)(A::SparseMatrixCSC, B::Number) = SparseMatrixCSC(A.m, A.n, copy(A.colptr), copy(A.rowval), A.nzval ./ B) +function (./)(A::SparseMatrixCSC, B::Number) + if B == 0 || isnan(B) + densify_with_default(A, A.nzval ./ B, zero(eltype(A)) ./ B) + else + SparseMatrixCSC(A.m, A.n, copy(A.colptr), copy(A.rowval), A.nzval ./ B) + end +end (./)(A::Number, B::SparseMatrixCSC) = (./)(A, full(B)) (./)(A::SparseMatrixCSC, B::Array) = (./)(full(A), B) (./)(A::Array, B::SparseMatrixCSC) = (./)(A, full(B)) (./)(A::SparseMatrixCSC, B::SparseMatrixCSC) = (./)(full(A), full(B)) (.\)(A::SparseMatrixCSC, B::Number) = (.\)(full(A), B) -(.\)(A::Number, B::SparseMatrixCSC) = SparseMatrixCSC(B.m, B.n, copy(B.colptr), copy(B.rowval), A .\ B.nzval ) +function (.\)(A::Number, B::SparseMatrixCSC) + if A == 0 || isnan(A) + densify_with_default(B, A .\ B.nzval, A .\ zero(eltype(B))) + else + SparseMatrixCSC(B.m, B.n, copy(B.colptr), copy(B.rowval), A .\ B.nzval) + end +end (.\)(A::SparseMatrixCSC, B::Array) = (.\)(full(A), B) (.\)(A::Array, B::SparseMatrixCSC) = (.\)(A, full(B)) (.\)(A::SparseMatrixCSC, B::SparseMatrixCSC) = (.\)(full(A), full(B)) diff --git a/test/sparsedir/sparse.jl b/test/sparsedir/sparse.jl index 71e99e62af391..4bd181d98db2d 100644 --- a/test/sparsedir/sparse.jl +++ b/test/sparsedir/sparse.jl @@ -1200,3 +1200,15 @@ let @test_throws LinAlg.SingularException LowerTriangular(A)\ones(n) @test_throws LinAlg.SingularException UpperTriangular(A)\ones(n) end + +# Inf/NaN corner cases in sparse .* scalar, scalar .* sparse, +# sparse ./ scalar, scalar .\ sparse +for A in (4*speye(5,3), 3*sparse(ones(Int, 4,6)), + SparseMatrixCSC(4, 3, [1,3,5,8], [1,2,2,3,2,3,4], + [0.0, -0.0, -Inf, Inf, NaN, -NaN, 2.0])), + B in (0.0, -0.0, -Inf, Inf, NaN, -NaN, 2.0) + @test_approx_eq_eps full(A .* B) full(A) .* B 0 + @test_approx_eq_eps full(B .* A) B .* full(A) 0 + @test_approx_eq_eps full(A ./ B) full(A) ./ B 0 + @test_approx_eq_eps full(B .\ A) B .\ full(A) 0 +end