Skip to content

Commit

Permalink
Merge pull request #19724 from Sacha0/mixedbc
Browse files Browse the repository at this point in the history
broadcast[!] over combinations of scalars and sparse vectors/matrices
  • Loading branch information
Sacha0 authored Jan 9, 2017
2 parents 7c34d69 + ce545a6 commit 1494f43
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 12 deletions.
4 changes: 3 additions & 1 deletion base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ Note that `dest` is only used to store the result, and does not supply
arguments to `f` unless it is also listed in the `As`,
as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`.
"""
@inline function broadcast!{N}(f, C::AbstractArray, A, Bs::Vararg{Any,N})
@inline broadcast!{N}(f, C::AbstractArray, A, Bs::Vararg{Any,N}) =
broadcast_c!(f, containertype(C, A, Bs...), C, A, Bs...)
@inline function broadcast_c!{N}(f, ::Type, C::AbstractArray, A, Bs::Vararg{Any,N})
shape = indices(C)
@boundscheck check_broadcast_indices(shape, A, Bs...)
keeps, Idefaults = map_newindexer(shape, A, Bs)
Expand Down
61 changes: 50 additions & 11 deletions base/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ module HigherOrderFns
# This module provides higher order functions specialized for sparse arrays,
# particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present.
import Base: map, map!, broadcast, broadcast!
import Base.Broadcast: containertype, promote_containertype,
broadcast_indices, broadcast_c, broadcast_c!

using Base: tail, to_shape
using ..SparseArrays: SparseVector, SparseMatrixCSC, indtype
using Base: front, tail, to_shape
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseArray, indtype

# This module is organized as follows:
# (1) Define a common interface to SparseVectors and SparseMatrixCSCs sufficient for
Expand Down Expand Up @@ -837,15 +839,52 @@ end


# (9) broadcast[!] over combinations of broadcast scalars and sparse vectors/matrices
#
# TODO: The minimal snippet below is not satisfying: A better solution would achieve
# the same for (1) all broadcast scalar types (Base.Broadcast.containertype(x) == Any?) and
# (2) any combination (number, order, type mixture) of broadcast scalars.
#
broadcast{Tf}(f::Tf, x::Union{Number,Bool}, A::SparseMatrixCSC) = broadcast(y -> f(x, y), A)
broadcast{Tf}(f::Tf, A::SparseMatrixCSC, y::Union{Number,Bool}) = broadcast(x -> f(x, y), A)
# NOTE: The following two method definitions work around #19096. These definitions should
# be folded into the two preceding definitions on resolution of #19096.

# broadcast shape promotion for combinations of sparse arrays and other types
broadcast_indices(::Type{AbstractSparseArray}, A) = indices(A)
# broadcast container type promotion for combinations of sparse arrays and other types
containertype{T<:SparseVecOrMat}(::Type{T}) = AbstractSparseArray
# combinations of sparse arrays with broadcast scalars should yield sparse arrays
promote_containertype(::Type{Any}, ::Type{AbstractSparseArray}) = AbstractSparseArray
promote_containertype(::Type{AbstractSparseArray}, ::Type{Any}) = AbstractSparseArray
# combinations of sparse arrays with anything else should fall back to generic dense broadcast
promote_containertype(::Type{Array}, ::Type{AbstractSparseArray}) = Array
promote_containertype(::Type{Tuple}, ::Type{AbstractSparseArray}) = Array
promote_containertype(::Type{AbstractSparseArray}, ::Type{Array}) = Array
promote_containertype(::Type{AbstractSparseArray}, ::Type{Tuple}) = Array

# broadcast[!] entry points for combinations of sparse arrays and other types
@inline function broadcast_c{N}(f, ::Type{AbstractSparseArray}, mixedargs::Vararg{Any,N})
parevalf, passedargstup = capturescalars(f, mixedargs)
return broadcast(parevalf, passedargstup...)
end
@inline function broadcast_c!{N}(f, ::Type{AbstractSparseArray}, dest::SparseVecOrMat, mixedsrcargs::Vararg{Any,N})
parevalf, passedsrcargstup = capturescalars(f, mixedsrcargs)
return broadcast!(parevalf, dest, passedsrcargstup...)
end
# capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and
# broadcast scalar arguments (mixedargs), and returns a function (parevalf, i.e. partially
# evaluated f) and a reduced argument tuple (passedargstup) containing only the sparse
# vectors/matrices in mixedargs in their orginal order, and such that the result of
# broadcast(parevalf, passedargstup...) is broadcast(f, mixedargs...)
@inline capturescalars(f, mixedargs) =
capturescalars((passed, tofill) -> f(tofill...), (), mixedargs...)
# Recursion cases for capturescalars
@inline capturescalars(f, passedargstup, scalararg, mixedargs...) =
capturescalars(capturescalar(f, scalararg), passedargstup, mixedargs...)
@inline capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat, mixedargs...) =
capturescalars(passnonscalar(f), (passedargstup..., nonscalararg), mixedargs...)
@inline passnonscalar(f) = (passed, tofill) -> f(Base.front(passed), (last(passed), tofill...))
@inline capturescalar(f, scalararg) = (passed, tofill) -> f(passed, (scalararg, tofill...))
# Base cases for capturescalars
@inline capturescalars(f, passedargstup, scalararg) =
(capturelastscalar(f, scalararg), passedargstup)
@inline capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat) =
(passlastnonscalar(f), (passedargstup..., nonscalararg))
@inline passlastnonscalar(f) = (passed...) -> f(Base.front(passed), (last(passed),))
@inline capturelastscalar(f, scalararg) = (passed...) -> f(passed, (scalararg,))

# NOTE: The following two method definitions work around #19096.
broadcast{Tf,T}(f::Tf, ::Type{T}, A::SparseMatrixCSC) = broadcast(y -> f(T, y), A)
broadcast{Tf,T}(f::Tf, A::SparseMatrixCSC, ::Type{T}) = broadcast(x -> f(x, T), A)

Expand Down
81 changes: 81 additions & 0 deletions test/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ end
end
end


@testset "sparse map/broadcast with result eltype not a concrete subtype of Number (#19561/#19589)" begin
intoneorfloatzero(x) = x != 0.0 ? Int(1) : Float64(x)
stringorfloatzero(x) = x != 0.0 ? "Hello" : Float64(x)
Expand All @@ -202,6 +203,86 @@ end
@test broadcast(stringorfloatzero, speye(4)) == sparse(broadcast(stringorfloatzero, eye(4)))
end

@testset "broadcast[!] over combinations of scalars and sparse vectors/matrices" begin
N, M, p = 10, 12, 0.5
elT = Float64
s = Float32(2.0)
V = sprand(elT, N, p)
A = sprand(elT, N, M, p)
fV, fA = Array(V), Array(A)
# test combinations involving one to three scalars and one to five sparse vectors/matrices
spargseq, dargseq = Iterators.cycle((A, V)), Iterators.cycle((fA, fV))
for nargs in 1:5 # number of tensor arguments
nargsl = cld(nargs, 2) # number in "left half" of tensor arguments
nargsr = fld(nargs, 2) # number in "right half" of tensor arguments
spargsl = tuple(Iterators.take(spargseq, nargsl)...) # "left half" of tensor args
spargsr = tuple(Iterators.take(spargseq, nargsr)...) # "right half" of tensor args
dargsl = tuple(Iterators.take(dargseq, nargsl)...) # "left half" of tensor args, densified
dargsr = tuple(Iterators.take(dargseq, nargsr)...) # "right half" of tensor args, densified
for (sparseargs, denseargs) in ( # argument combinations including scalars
# a few combinations involving one scalar
((s, spargsl..., spargsr...), (s, dargsl..., dargsr...)),
((spargsl..., s, spargsr...), (dargsl..., s, dargsr...)),
((spargsl..., spargsr..., s), (dargsl..., dargsr..., s)),
# a few combinations involving two scalars
((s, spargsl..., s, spargsr...), (s, dargsl..., s, dargsr...)),
((s, spargsl..., spargsr..., s), (s, dargsl..., dargsr..., s)),
((spargsl..., s, spargsr..., s), (dargsl..., s, dargsr..., s)),
((s, s, spargsl..., spargsr...), (s, s, dargsl..., dargsr...)),
((spargsl..., s, s, spargsr...), (dargsl..., s, s, dargsr...)),
((spargsl..., spargsr..., s, s), (dargsl..., dargsr..., s, s)),
# a few combinations involving three scalars
((s, spargsl..., s, spargsr..., s), (s, dargsl..., s, dargsr..., s)),
((s, spargsl..., s, s, spargsr...), (s, dargsl..., s, s, dargsr...)),
((spargsl..., s, s, spargsr..., s), (dargsl..., s, s, dargsr..., s)),
((spargsl..., s, s, s, spargsr...), (dargsl..., s, s, s, dargsr...)), )
# test broadcast entry point
@test broadcast(*, sparseargs...) == sparse(broadcast(*, denseargs...))
@test isa(@inferred(broadcast(*, sparseargs...)), SparseMatrixCSC{elT})
# test broadcast! entry point
fX = broadcast(*, sparseargs...); X = sparse(fX)
@test broadcast!(*, X, sparseargs...) == sparse(broadcast!(*, fX, denseargs...))
@test isa(@inferred(broadcast!(*, X, sparseargs...)), SparseMatrixCSC{elT})
X = sparse(fX) # reset / warmup for @allocated test
@test_broken (@allocated broadcast!(*, X, sparseargs...)) == 0
# This test (and the analog below) fails for three reasons:
# (1) In all cases, generating the closures that capture the scalar arguments
# results in allocation, not sure why.
# (2) In some cases, though _broadcast_eltype (which wraps _return_type)
# consistently provides the correct result eltype when passed the closure
# that incorporates the scalar arguments to broadcast (and, with #19667,
# is inferable, so the overall return type from broadcast is inferred),
# in some cases inference seems unable to determine the return type of
# direct calls to that closure. This issue causes variables in both the
# broadcast[!] entry points (fofzeros = f(_zeros_eltypes(args...)...)) and
# the driver routines (Cx in _map_zeropres! and _broadcast_zeropres!) to have
# inferred type Any, resulting in allocation and lackluster performance.
# (3) The sparseargs... splat in the call above allocates a bit, but of course
# that issue is negligible and perhaps could be accounted for in the test.
end
end
# test combinations at the limit of inference (eight arguments net)
for (sparseargs, denseargs) in (
((s, s, s, A, s, s, s, s), (s, s, s, fA, s, s, s, s)), # seven scalars, one sparse matrix
((s, s, V, s, s, A, s, s), (s, s, fV, s, s, fA, s, s)), # six scalars, two sparse vectors/matrices
((s, s, V, s, A, s, V, s), (s, s, fV, s, fA, s, fV, s)), # five scalars, three sparse vectors/matrices
((s, V, s, A, s, V, s, A), (s, fV, s, fA, s, fV, s, fA)), # four scalars, four sparse vectors/matrices
((s, V, A, s, V, A, s, A), (s, fV, fA, s, fV, fA, s, fA)), # three scalars, five sparse vectors/matrices
((V, A, V, s, A, V, A, s), (fV, fA, fV, s, fA, fV, fA, s)), # two scalars, six sparse vectors/matrices
((V, A, V, A, s, V, A, V), (fV, fA, fV, fA, s, fV, fA, fV)) ) # one scalar, seven sparse vectors/matrices
# test broadcast entry point
@test broadcast(*, sparseargs...) == sparse(broadcast(*, denseargs...))
@test isa(@inferred(broadcast(*, sparseargs...)), SparseMatrixCSC{elT})
# test broadcast! entry point
fX = broadcast(*, sparseargs...); X = sparse(fX)
@test broadcast!(*, X, sparseargs...) == sparse(broadcast!(*, fX, denseargs...))
@test isa(@inferred(broadcast!(*, X, sparseargs...)), SparseMatrixCSC{elT})
X = sparse(fX) # reset / warmup for @allocated test
@test_broken (@allocated broadcast!(*, X, sparseargs...)) == 0
# please see the note a few lines above re. this @test_broken
end
end

# Older tests of sparse broadcast, now largely covered by the tests above
@testset "assorted tests of sparse broadcast over two input arguments" begin
N, p = 10, 0.3
Expand Down

0 comments on commit 1494f43

Please sign in to comment.