diff --git a/base/sparse/higherorderfns.jl b/base/sparse/higherorderfns.jl index 858d4527fdb49..e4fa8b7613ea6 100644 --- a/base/sparse/higherorderfns.jl +++ b/base/sparse/higherorderfns.jl @@ -582,7 +582,7 @@ function _broadcast_zeropres!{Tf}(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, B Bk, stopBk = numcols(B) == 1 ? (colstartind(B, 1), colboundind(B, 1)) : (colstartind(B, j), colboundind(B, j)) Ax = Ak < stopAk ? storedvals(A)[Ak] : zero(eltype(A)) fvAzB = f(Ax, zero(eltype(B))) - if fvAzB == zero(eltype(C)) + if _iszero(fvAzB) # either A's jth column is empty, or A's jth column contains a nonzero value # Ax but f(Ax, zero(eltype(B))) is nonetheless zero, so we can scan through # B's jth column without storing every entry in C's jth column @@ -623,7 +623,7 @@ function _broadcast_zeropres!{Tf}(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, B Bk, stopBk = numcols(B) == 1 ? (colstartind(B, 1), colboundind(B, 1)) : (colstartind(B, j), colboundind(B, j)) Bx = Bk < stopBk ? storedvals(B)[Bk] : zero(eltype(B)) fzAvB = f(zero(eltype(A)), Bx) - if fzAvB == zero(eltype(C)) + if _iszero(fzAvB) # either B's jth column is empty, or B's jth column contains a nonzero value # Bx but f(zero(eltype(A)), Bx) is nonetheless zero, so we can scan through # A's jth column without storing every entry in C's jth column @@ -701,7 +701,7 @@ function _broadcast_notzeropres!{Tf}(f::Tf, fillvalue, C::SparseVecOrMat, A::Spa Bk, stopBk = numcols(B) == 1 ? (colstartind(B, 1), colboundind(B, 1)) : (colstartind(B, j), colboundind(B, j)) Ax = Ak < stopAk ? storedvals(A)[Ak] : zero(eltype(A)) fvAzB = f(Ax, zero(eltype(B))) - if fvAzB == zero(eltype(C)) + if _iszero(fvAzB) while Bk < stopBk Cx = f(Ax, storedvals(B)[Bk]) Cx != fillvalue && (storedvals(C)[jo + storedinds(B)[Bk]] = Cx) @@ -726,7 +726,7 @@ function _broadcast_notzeropres!{Tf}(f::Tf, fillvalue, C::SparseVecOrMat, A::Spa Bk, stopBk = numcols(B) == 1 ? (colstartind(B, 1), colboundind(B, 1)) : (colstartind(B, j), colboundind(B, j)) Bx = Bk < stopBk ? storedvals(B)[Bk] : zero(eltype(B)) fzAvB = f(zero(eltype(A)), Bx) - if fzAvB == zero(eltype(C)) + if _iszero(fzAvB) while Ak < stopAk Cx = f(storedvals(A)[Ak], Bx) Cx != fillvalue && (storedvals(C)[jo + storedinds(A)[Ak]] = Cx) @@ -771,7 +771,7 @@ function _broadcast_zeropres!{Tf,N}(f::Tf, C::SparseVecOrMat, As::Vararg{SparseV rows = _initrowforcol_all(j, rowsentinel, isemptys, expandsverts, ks, As) defaultCx = f(defargs...) activerow = min(rows...) - if defaultCx == zero(eltype(C)) # zero-preserving column scan + if _iszero(defaultCx) # zero-preserving column scan while activerow < rowsentinel # activerows = _isactiverow_all(activerow, rows) # Cx = f(_gatherbcargs(activerows, defargs, ks, As)...) diff --git a/test/sparse/higherorderfns.jl b/test/sparse/higherorderfns.jl index a8589f205ade6..591b86ad4bdc3 100644 --- a/test/sparse/higherorderfns.jl +++ b/test/sparse/higherorderfns.jl @@ -265,14 +265,24 @@ end end end - @testset "sparse map/broadcast with result eltype not a concrete subtype of Number (#19561/#19589)" begin - intoneorfloatzero(x) = x != 0.0 ? Int(1) : Float64(x) - stringorfloatzero(x) = x != 0.0 ? "Hello" : Float64(x) - @test map(intoneorfloatzero, speye(4)) == sparse(map(intoneorfloatzero, eye(4))) - @test map(stringorfloatzero, speye(4)) == sparse(map(stringorfloatzero, eye(4))) - @test broadcast(intoneorfloatzero, speye(4)) == sparse(broadcast(intoneorfloatzero, eye(4))) - @test broadcast(stringorfloatzero, speye(4)) == sparse(broadcast(stringorfloatzero, eye(4))) + N = 4 + A, fA = speye(N), eye(N) + B, fB = spzeros(1, N), zeros(1, N) + intorfloat_zeropres(xs...) = all(iszero, xs) ? zero(Float64) : Int(1) + stringorfloat_zeropres(xs...) = all(iszero, xs) ? zero(Float64) : "hello" + intorfloat_notzeropres(xs...) = all(iszero, xs) ? Int(1) : zero(Float64) + stringorfloat_notzeropres(xs...) = all(iszero, xs) ? "hello" : zero(Float64) + for fn in (intorfloat_zeropres, intorfloat_notzeropres, + stringorfloat_zeropres, stringorfloat_notzeropres) + @test map(fn, A) == sparse(map(fn, fA)) + @test broadcast(fn, A) == sparse(broadcast(fn, fA)) + @test broadcast(fn, A, B) == sparse(broadcast(fn, fA, fB)) + @test broadcast(fn, B, A) == sparse(broadcast(fn, fB, fA)) + end + for fn in (intorfloat_zeropres, stringorfloat_zeropres) + @test broadcast(fn, A, B, A) == sparse(broadcast(fn, fA, fB, fA)) + end end @testset "broadcast[!] over combinations of scalars and sparse vectors/matrices" begin