Skip to content

Commit

Permalink
Fix sparse broadcast[!] for some cases where the output eltype is not…
Browse files Browse the repository at this point in the history
… a concrete subtype of Number (JuliaLang#19561, later part).
  • Loading branch information
Sacha0 committed Mar 2, 2017
1 parent 2bc66e1 commit 3dc07d6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
10 changes: 5 additions & 5 deletions base/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ function _broadcast_zeropres!{Tf}(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, B
Bk, stopBk = numcols(B) == 1 ? (colstartind(B, 1), colboundind(B, 1)) : (colstartind(B, j), colboundind(B, j))
Ax = Ak < stopAk ? storedvals(A)[Ak] : zero(eltype(A))
fvAzB = f(Ax, zero(eltype(B)))
if fvAzB == zero(eltype(C))
if _iszero(fvAzB)
# either A's jth column is empty, or A's jth column contains a nonzero value
# Ax but f(Ax, zero(eltype(B))) is nonetheless zero, so we can scan through
# B's jth column without storing every entry in C's jth column
Expand Down Expand Up @@ -623,7 +623,7 @@ function _broadcast_zeropres!{Tf}(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, B
Bk, stopBk = numcols(B) == 1 ? (colstartind(B, 1), colboundind(B, 1)) : (colstartind(B, j), colboundind(B, j))
Bx = Bk < stopBk ? storedvals(B)[Bk] : zero(eltype(B))
fzAvB = f(zero(eltype(A)), Bx)
if fzAvB == zero(eltype(C))
if _iszero(fzAvB)
# either B's jth column is empty, or B's jth column contains a nonzero value
# Bx but f(zero(eltype(A)), Bx) is nonetheless zero, so we can scan through
# A's jth column without storing every entry in C's jth column
Expand Down Expand Up @@ -701,7 +701,7 @@ function _broadcast_notzeropres!{Tf}(f::Tf, fillvalue, C::SparseVecOrMat, A::Spa
Bk, stopBk = numcols(B) == 1 ? (colstartind(B, 1), colboundind(B, 1)) : (colstartind(B, j), colboundind(B, j))
Ax = Ak < stopAk ? storedvals(A)[Ak] : zero(eltype(A))
fvAzB = f(Ax, zero(eltype(B)))
if fvAzB == zero(eltype(C))
if _iszero(fvAzB)
while Bk < stopBk
Cx = f(Ax, storedvals(B)[Bk])
Cx != fillvalue && (storedvals(C)[jo + storedinds(B)[Bk]] = Cx)
Expand All @@ -726,7 +726,7 @@ function _broadcast_notzeropres!{Tf}(f::Tf, fillvalue, C::SparseVecOrMat, A::Spa
Bk, stopBk = numcols(B) == 1 ? (colstartind(B, 1), colboundind(B, 1)) : (colstartind(B, j), colboundind(B, j))
Bx = Bk < stopBk ? storedvals(B)[Bk] : zero(eltype(B))
fzAvB = f(zero(eltype(A)), Bx)
if fzAvB == zero(eltype(C))
if _iszero(fzAvB)
while Ak < stopAk
Cx = f(storedvals(A)[Ak], Bx)
Cx != fillvalue && (storedvals(C)[jo + storedinds(A)[Ak]] = Cx)
Expand Down Expand Up @@ -771,7 +771,7 @@ function _broadcast_zeropres!{Tf,N}(f::Tf, C::SparseVecOrMat, As::Vararg{SparseV
rows = _initrowforcol_all(j, rowsentinel, isemptys, expandsverts, ks, As)
defaultCx = f(defargs...)
activerow = min(rows...)
if defaultCx == zero(eltype(C)) # zero-preserving column scan
if _iszero(defaultCx) # zero-preserving column scan
while activerow < rowsentinel
# activerows = _isactiverow_all(activerow, rows)
# Cx = f(_gatherbcargs(activerows, defargs, ks, As)...)
Expand Down
24 changes: 17 additions & 7 deletions test/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,24 @@ end
end
end


@testset "sparse map/broadcast with result eltype not a concrete subtype of Number (#19561/#19589)" begin
intoneorfloatzero(x) = x != 0.0 ? Int(1) : Float64(x)
stringorfloatzero(x) = x != 0.0 ? "Hello" : Float64(x)
@test map(intoneorfloatzero, speye(4)) == sparse(map(intoneorfloatzero, eye(4)))
@test map(stringorfloatzero, speye(4)) == sparse(map(stringorfloatzero, eye(4)))
@test broadcast(intoneorfloatzero, speye(4)) == sparse(broadcast(intoneorfloatzero, eye(4)))
@test broadcast(stringorfloatzero, speye(4)) == sparse(broadcast(stringorfloatzero, eye(4)))
N = 4
A, fA = speye(N), eye(N)
B, fB = spzeros(1, N), zeros(1, N)
intorfloat_zeropres(xs...) = all(iszero, xs) ? zero(Float64) : Int(1)
stringorfloat_zeropres(xs...) = all(iszero, xs) ? zero(Float64) : "hello"
intorfloat_notzeropres(xs...) = all(iszero, xs) ? Int(1) : zero(Float64)
stringorfloat_notzeropres(xs...) = all(iszero, xs) ? "hello" : zero(Float64)
for fn in (intorfloat_zeropres, intorfloat_notzeropres,
stringorfloat_zeropres, stringorfloat_notzeropres)
@test map(fn, A) == sparse(map(fn, fA))
@test broadcast(fn, A) == sparse(broadcast(fn, fA))
@test broadcast(fn, A, B) == sparse(broadcast(fn, fA, fB))
@test broadcast(fn, B, A) == sparse(broadcast(fn, fB, fA))
end
for fn in (intorfloat_zeropres, stringorfloat_zeropres)
@test broadcast(fn, A, B, A) == sparse(broadcast(fn, fA, fB, fA))
end
end

@testset "broadcast[!] over combinations of scalars and sparse vectors/matrices" begin
Expand Down

0 comments on commit 3dc07d6

Please sign in to comment.