From dbe1ae0a7e5a53a3afc92c5edb5876c522506258 Mon Sep 17 00:00:00 2001 From: Matt Bauman Date: Mon, 23 Apr 2018 19:56:46 -0400 Subject: [PATCH] Customizable lazy fused broadcasting in pure Julia MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This patch represents the combined efforts of four individuals, over 60 commits, and an iterated design over (at least) three pull requests that spanned nearly an entire year (closes #22063, #23692, #25377 by superceding them). This introduces a pure Julia data structure that represents a fused broadcast expression. For example, the expression `2 .* (x .+ 1)` lowers to: ```julia julia> Meta.@lower 2 .* (x .+ 1) :($(Expr(:thunk, CodeInfo(:(begin Core.SSAValue(0) = (Base.getproperty)(Base.Broadcast, :materialize) Core.SSAValue(1) = (Base.getproperty)(Base.Broadcast, :make) Core.SSAValue(2) = (Base.getproperty)(Base.Broadcast, :make) Core.SSAValue(3) = (Core.SSAValue(2))(+, x, 1) Core.SSAValue(4) = (Core.SSAValue(1))(*, 2, Core.SSAValue(3)) Core.SSAValue(5) = (Core.SSAValue(0))(Core.SSAValue(4)) return Core.SSAValue(5) end))))) ``` Or, slightly more readably as: ```julia using .Broadcast: materialize, make materialize(make(*, 2, make(+, x, 1))) ``` The `Broadcast.make` function serves two purposes. Its primary purpose is to construct the `Broadcast.Broadcasted` objects that hold onto the function, the tuple of arguments (potentially including nested `Broadcasted` arguments), and sometimes a set of `axes` to include knowledge of the outer shape. The secondary purpose, however, is to allow an "out" for objects that _don't_ want to participate in fusion. For example, if `x` is a range in the above `2 .* (x .+ 1)` expression, it needn't allocate an array and operate elementwise — it can just compute and return a new range. Thus custom structures are able to specialize `Broadcast.make(f, args...)` just as they'd specialize on `f` normally to return an immediate result. `Broadcast.materialize` is identity for everything _except_ `Broadcasted` objects for which it allocates an appropriate result and computes the broadcast. It does two things: it `initialize`s the outermost `Broadcasted` object to compute its axes and then `copy`s it. Similarly, an in-place fused broadcast like `y .= 2 .* (x .+ 1)` uses the exact same expression tree to compute the right-hand side of the expression as above, and then uses `materialize!(y, make(*, 2, make(+, x, 1)))` to `instantiate` the `Broadcasted` expression tree and then `copyto!` it into the given destination. All-together, this forms a complete API for custom types to extend and customize the behavior of broadcast (fixes #22060). It uses the existing `BroadcastStyle`s throughout to simplify dispatch on many arguments: * Custom types can opt-out of broadcast fusion by specializing `Broadcast.make(f, args...)` or `Broadcast.make(::BroadcastStyle, f, args...)`. * The `Broadcasted` object computes and stores the type of the combined `BroadcastStyle` of its arguments as its first type parameter, allowing for easy dispatch and specialization. * Custom Broadcast storage is still allocated via `broadcast_similar`, however instead of passing just a function as a first argument, the entire `Broadcasted` object is passed as a final argument. This potentially allows for much more runtime specialization dependent upon the exact expression given. * Custom broadcast implmentations for a `CustomStyle` are defined by specializing `copy(bc::Broadcasted{CustomStyle})` or `copyto!(dest::AbstractArray, bc::Broadcasted{CustomStyle})`. * Fallback broadcast specializations for a given output object of type `Dest` (for the `DefaultArrayStyle` or another such style that hasn't implemented assignments into such an object) are defined by specializing `copyto(dest::Dest, bc::Broadcasted{Nothing})`. As it fully supports range broadcasting, this now deprecates `(1:5) + 2` to `.+`, just as had been done for all `AbstractArray`s in general. As a first-mover proof of concept, LinearAlgebra uses this new system to improve broadcasting over structured arrays. Before, broadcasting over a structured matrix would result in a sparse array. Now, broadcasting over a structured matrix will _either_ return an appropriately structured matrix _or_ a dense array. This does incur a type instability (in the form of a discriminated union) in some situations, but thanks to type-based introspection of the `Broadcasted` wrapper commonly used functions can be special cased to be type stable. For example: ```julia julia> f(d) = round.(Int, d) f (generic function with 1 method) julia> @inferred f(Diagonal(rand(3))) 3×3 Diagonal{Int64,Array{Int64,1}}: 0 ⋅ ⋅ ⋅ 0 ⋅ ⋅ ⋅ 1 julia> @inferred Diagonal(rand(3)) .* 3 ERROR: return type Diagonal{Float64,Array{Float64,1}} does not match inferred return type Union{Array{Float64,2}, Diagonal{Float64,Array{Float64,1}}} Stacktrace: [1] error(::String) at ./error.jl:33 [2] top-level scope julia> @inferred Diagonal(1:4) .+ Bidiagonal(rand(4), rand(3), 'U') .* Tridiagonal(1:3, 1:4, 1:3) 4×4 Tridiagonal{Float64,Array{Float64,1}}: 1.30771 0.838589 ⋅ ⋅ 0.0 3.89109 0.0459757 ⋅ ⋅ 0.0 4.48033 2.51508 ⋅ ⋅ 0.0 6.23739 ``` In addition to the issues referenced above, it fixes: * Fixes #19313, #22053, #23445, and #24586: Literals are no longer treated specially in a fused broadcast; they're just arguments in a `Broadcasted` object like everything else. * Fixes #21094: Since broadcasting is now represented by a pure Julia datastructure it can be created within `@generated` functions and serialized. * Fixes #26097: The fallback destination-array specialization method of `copyto!` is specifically implemented as `Broadcasted{Nothing}` and will not be confused by `nothing` arguments. * Fixes the broadcast-specific element of #25499: The default base broadcast implementation no longer depends upon `Base._return_type` to allocate its array (except in the empty or concretely-type cases). Note that the sparse implementation (#19595) is still dependent upon inference and is _not_ fixed. * Fixes #25340: Functions are treated like normal values just like arguments and only evaluated once. * Fixes #22255, and is performant with 12+ fused broadcasts. Okay, that one was fixed on master already, but this fixes it now, too. * Fixes #25521. * The performance of this patch has been thoroughly tested through its iterative development process in #25377. There remain [two classes of performance regressions](#25377) that Nanosoldier flagged. * #25691: Propagation of constant literals sill lose their constant-ness upon going through the broadcast machinery. I believe quite a large number of functions would need to be marked as `@pure` to support this -- including functions that are intended to be specialized. (For bookkeeping, this is the squashed version of the [teh-jn/lazydotfuse](https://github.com/JuliaLang/julia/pull/25377) branch as of a1d4e7ec9756ada74fb48f2c514615b9d981cf5c. Squashed and separated out to make it easier to review and commit) Co-authored-by: Tim Holy Co-authored-by: Jameson Nash Co-authored-by: Andrew Keller --- NEWS.md | 13 +- base/bitarray.jl | 40 - base/broadcast.jl | 906 ++++++++++++------ base/compiler/ssair/inlining2.jl | 8 +- base/compiler/ssair/slot2ssa.jl | 4 +- base/deprecated.jl | 4 + base/float.jl | 10 - base/range.jl | 66 -- base/reducedim.jl | 4 +- base/sort.jl | 7 +- base/statistics.jl | 2 +- doc/src/base/arrays.md | 2 +- doc/src/manual/interfaces.md | 173 ++-- src/julia-syntax.scm | 123 +-- stdlib/LinearAlgebra/src/LinearAlgebra.jl | 3 + stdlib/LinearAlgebra/src/bidiag.jl | 11 - stdlib/LinearAlgebra/src/diagonal.jl | 1 - .../LinearAlgebra/src/structuredbroadcast.jl | 180 ++++ stdlib/LinearAlgebra/src/triangular.jl | 3 - stdlib/LinearAlgebra/src/tridiag.jl | 23 +- stdlib/LinearAlgebra/src/uniformscaling.jl | 6 +- .../LinearAlgebra/test/structuredbroadcast.jl | 101 ++ stdlib/SparseArrays/src/higherorderfns.jl | 258 +++-- stdlib/SparseArrays/test/higherorderfns.jl | 80 +- test/bitarray.jl | 35 + test/broadcast.jl | 114 ++- test/core.jl | 4 +- test/numbers.jl | 3 +- test/ranges.jl | 87 +- 29 files changed, 1407 insertions(+), 864 deletions(-) create mode 100644 stdlib/LinearAlgebra/src/structuredbroadcast.jl create mode 100644 stdlib/LinearAlgebra/test/structuredbroadcast.jl diff --git a/NEWS.md b/NEWS.md index 93d446f42874e..0da63c9973120 100644 --- a/NEWS.md +++ b/NEWS.md @@ -388,11 +388,6 @@ This section lists changes that do not have deprecation warnings. Its return value has been removed. Use the `process_running` function to determine if a process has already exited. - * Broadcasting has been redesigned with an extensible public interface. The new API is - documented at https://docs.julialang.org/en/latest/manual/interfaces/#Interfaces-1. - `AbstractArray` types that specialized broadcasting using the old internal API will - need to switch to the new API. ([#20740]) - * The logging system has been redesigned - `info` and `warn` are deprecated and replaced with the logging macros `@info`, `@warn`, `@debug` and `@error`. The `logging` function is also deprecated and replaced with @@ -418,6 +413,14 @@ This section lists changes that do not have deprecation warnings. * `findn(x::AbstractArray)` has been deprecated in favor of `findall(!iszero, x)`, which now returns cartesian indices for multidimensional arrays (see below, [#25532]). + * Broadcasting operations are no longer fused into a single operation by Julia's parser. + Instead, a lazy `Broadcasted` wrapper is created, and the parser will call + `copy(bc::Broadcasted)` or `copyto!(dest, bc::Broadcasted)` + to evaluate the wrapper. Consequently, package authors generally need to specialize + `copy` and `copyto!` methods rather than `broadcast` and `broadcast!`. + See the [Interfaces chapter](https://docs.julialang.org/en/latest/manual/interfaces/#Interfaces-1) + for more information. + * `find` has been renamed to `findall`. `findall`, `findfirst`, `findlast`, `findnext` now take and/or return the same type of indices as `keys`/`pairs` for `AbstractArray`, `AbstractDict`, `AbstractString`, `Tuple` and `NamedTuple` objects ([#24774], [#25545]). diff --git a/base/bitarray.jl b/base/bitarray.jl index bac2ad07d6a79..898980b92d4ac 100644 --- a/base/bitarray.jl +++ b/base/bitarray.jl @@ -1097,19 +1097,6 @@ function (-)(B::BitArray) end broadcast(::typeof(sign), B::BitArray) = copy(B) -function broadcast(::typeof(~), B::BitArray) - C = similar(B) - Bc = B.chunks - if !isempty(Bc) - Cc = C.chunks - for i = 1:length(Bc) - Cc[i] = ~Bc[i] - end - Cc[end] &= _msk_end(B) - end - return C -end - """ flipbits!(B::BitArray{N}) -> BitArray{N} @@ -1166,33 +1153,6 @@ end (/)(B::BitArray, x::Number) = (/)(Array(B), x) (/)(x::Number, B::BitArray) = (/)(x, Array(B)) -# broadcast specializations for &, |, and xor/⊻ -broadcast(::typeof(&), B::BitArray, x::Bool) = x ? copy(B) : falses(size(B)) -broadcast(::typeof(&), x::Bool, B::BitArray) = broadcast(&, B, x) -broadcast(::typeof(|), B::BitArray, x::Bool) = x ? trues(size(B)) : copy(B) -broadcast(::typeof(|), x::Bool, B::BitArray) = broadcast(|, B, x) -broadcast(::typeof(xor), B::BitArray, x::Bool) = x ? .~B : copy(B) -broadcast(::typeof(xor), x::Bool, B::BitArray) = broadcast(xor, B, x) -for f in (:&, :|, :xor) - @eval begin - function broadcast(::typeof($f), A::BitArray, B::BitArray) - F = BitArray(undef, promote_shape(size(A),size(B))...) - Fc = F.chunks - Ac = A.chunks - Bc = B.chunks - (isempty(Ac) || isempty(Bc)) && return F - for i = 1:length(Fc) - Fc[i] = ($f)(Ac[i], Bc[i]) - end - Fc[end] &= _msk_end(F) - return F - end - broadcast(::typeof($f), A::DenseArray{Bool}, B::BitArray) = broadcast($f, BitArray(A), B) - broadcast(::typeof($f), B::BitArray, A::DenseArray{Bool}) = broadcast($f, B, BitArray(A)) - end -end - - ## promotion to complex ## # TODO? diff --git a/base/broadcast.jl b/base/broadcast.jl index 55dd8313172b0..ef81b60f89f4b 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -3,11 +3,10 @@ module Broadcast using .Base.Cartesian -using .Base: Indices, OneTo, linearindices, tail, to_shape, - _msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, - isoperator, promote_typejoin, unalias -import .Base: broadcast, broadcast! -export BroadcastStyle, broadcast_indices, broadcast_similar, broadcastable, +using .Base: Indices, OneTo, linearindices, tail, to_shape, isoperator, promote_typejoin, + _msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias +import .Base: broadcast, broadcast!, copy, copyto! +export BroadcastStyle, broadcast_axes, broadcast_similar, broadcastable, broadcast_getindex, broadcast_setindex!, dotview, @__dot__ ### Objects with customized broadcasting behavior should declare a BroadcastStyle @@ -149,47 +148,223 @@ BroadcastStyle(a::AbstractArrayStyle{N}, ::DefaultArrayStyle{N}) where N = a BroadcastStyle(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} = typeof(a)(_max(Val(M),Val(N))) +### Lazy-wrapper for broadcasting + +# `Broadcasted` wrap the arguments to `broadcast(f, args...)`. A statement like +# y = x .* (x .+ 1) +# will result in code that is essentially +# y = copy(Broadcasted(*, x, Broadcasted(+, x, 1))) +# `broadcast!` results in `copyto!(dest, Broadcasted(...))`. + +# The use of `Nothing` in place of a `BroadcastStyle` has a different +# application, in the fallback method +# copyto!(dest, bc::Broadcasted) = copyto!(dest, convert(Broadcasted{Nothing}, bc)) +# This allows methods +# copyto!(dest::DestType, bc::Broadcasted{Nothing}) +# that specialize on `DestType` to be easily disambiguated from +# methods that instead specialize on `BroadcastStyle`, +# copyto!(dest::AbstractArray, bc::Broadcasted{MyStyle}) + +struct Broadcasted{Style<:Union{Nothing,BroadcastStyle}, Axes, F, Args<:Tuple} + f::F + args::Args + axes::Axes # the axes of the resulting object (may be bigger than implied by `args` if this is nested inside a larger `Broadcasted`) +end + +Broadcasted(f::F, args::Args, axes=nothing) where {F, Args<:Tuple} = + Broadcasted{typeof(combine_styles(args...))}(f, args, axes) +function Broadcasted{Style}(f::F, args::Args, axes=nothing) where {Style, F, Args<:Tuple} + # using Core.Typeof rather than F preserves inferrability when f is a type + Broadcasted{Style, typeof(axes), Core.Typeof(f), Args}(f, args, axes) +end + +Base.convert(::Type{Broadcasted{NewStyle}}, bc::Broadcasted{Style,Axes,F,Args}) where {NewStyle,Style,Axes,F,Args} = + Broadcasted{NewStyle,Axes,F,Args}(bc.f, bc.args, bc.axes) + +Base.show(io::IO, bc::Broadcasted{Style}) where {Style} = print(io, Broadcasted, '{', Style, "}(", bc.f, ", ", bc.args, ')') + ## Allocating the output container """ - broadcast_similar(f, ::BroadcastStyle, ::Type{ElType}, inds, As...) + broadcast_similar(::BroadcastStyle, ::Type{ElType}, inds, bc) Allocate an output object for [`broadcast`](@ref), appropriate for the indicated -[`Broadcast.BroadcastStyle`](@ref). `ElType` and `inds` specify the desired element type and indices of the -container. -`f` is the broadcast operation, and `As...` are the arguments supplied to `broadcast`. +[`Broadcast.BroadcastStyle`](@ref). `ElType` and `inds` specify the desired element type and axes of the +container. The final `bc` argument is the `Broadcasted` object representing the fused broadcast operation +and its arguments. """ -broadcast_similar(f, ::DefaultArrayStyle{N}, ::Type{ElType}, inds::Indices{N}, As...) where {N,ElType} = +broadcast_similar(::DefaultArrayStyle{N}, ::Type{ElType}, inds::Indices{N}, bc) where {N,ElType} = similar(Array{ElType}, inds) -broadcast_similar(f, ::DefaultArrayStyle{N}, ::Type{Bool}, inds::Indices{N}, As...) where N = +broadcast_similar(::DefaultArrayStyle{N}, ::Type{Bool}, inds::Indices{N}, bc) where N = similar(BitArray, inds) # In cases of conflict we fall back on Array -broadcast_similar(f, ::ArrayConflict, ::Type{ElType}, inds::Indices, As...) where ElType = +broadcast_similar(::ArrayConflict, ::Type{ElType}, inds::Indices, bc) where ElType = similar(Array{ElType}, inds) -broadcast_similar(f, ::ArrayConflict, ::Type{Bool}, inds::Indices, As...) = +broadcast_similar(::ArrayConflict, ::Type{Bool}, inds::Indices, bc) = similar(BitArray, inds) -## Computing the result's indices. Most types probably won't need to specialize this. -broadcast_indices() = () -broadcast_indices(::Type{T}) where T = () -broadcast_indices(A) = broadcast_indices(combine_styles(A), A) -broadcast_indices(::Style{Tuple}, A) = (OneTo(length(A)),) -broadcast_indices(::DefaultArrayStyle{0}, A::Ref) = () -broadcast_indices(::BroadcastStyle, A) = Base.axes(A) +## Computing the result's axes. Most types probably won't need to specialize this. +broadcast_axes() = () +broadcast_axes(A::Tuple) = (OneTo(length(A)),) +broadcast_axes(A::Ref) = () +@inline broadcast_axes(A) = axes(A) """ - Base.broadcast_indices(::SrcStyle, A) + Base.broadcast_axes(A) -Compute the indices for objects `A` with [`BroadcastStyle`](@ref) `SrcStyle`. -If needed, you can specialize this method for your styles. -You should only need to provide a custom implementation for non-AbstractArrayStyles. +Compute the axes for `A`. + +This should only be specialized for objects that do not define axes but want to participate in broadcasting. """ -broadcast_indices +broadcast_axes ### End of methods that users will typically have to specialize ### +@inline Base.axes(bc::Broadcasted) = _axes(bc, bc.axes) +_axes(::Broadcasted, axes::Tuple) = axes +@inline _axes(bc::Broadcasted, ::Nothing) = combine_axes(bc.args...) +_axes(bc::Broadcasted{Style{Tuple}}, ::Nothing) = (Base.OneTo(length(longest_tuple(nothing, bc.args))),) +_axes(bc::Broadcasted{<:AbstractArrayStyle{0}}, ::Nothing) = () + +BroadcastStyle(::Type{<:Broadcasted{Style}}) where {Style} = Style() +BroadcastStyle(::Type{<:Broadcasted{S}}) where {S<:Union{Nothing,Unknown}} = + throw(ArgumentError("Broadcasted{Unknown} wrappers do not have a style assigned")) + +argtype(::Type{Broadcasted{Style,Axes,F,Args}}) where {Style,Axes,F,Args} = Args +argtype(bc::Broadcasted) = argtype(typeof(bc)) + +const NestedTuple = Tuple{<:Broadcasted,Vararg{Any}} +not_nested(bc::Broadcasted) = _not_nested(bc.args) +_not_nested(t::Tuple) = _not_nested(tail(t)) +_not_nested(::NestedTuple) = false +_not_nested(::Tuple{}) = true + +## Instantiation fills in the "missing" fields in Broadcasted. +instantiate(x) = x + +""" + Broadcast.instantiate(bc::Broadcasted) + +Construct the axes and indexing helpers for the lazy Broadcasted object `bc`. + +Custom `BroadcastStyle`s may override this default in cases where it is fast and easy +to compute the resulting `axes` and indexing helpers on-demand, leaving those fields +of the `Broadcasted` object empty (populated with `nothing`). If they do so, however, +they must provide their own `Base.axes(::Broadcasted{Style})` and +`Base.getindex(::Broadcasted{Style}, I::Union{Int,CartesianIndex})` methods as appropriate. +""" +@inline function instantiate(bc::Broadcasted{Style}) where {Style} + if bc.axes isa Nothing # Not done via dispatch to make it easier to extend instantiate(::Broadcasted{Style}) + axes = combine_axes(bc.args...) + else + axes = bc.axes + check_broadcast_axes(axes, bc.args...) + end + return Broadcasted{Style}(bc.f, bc.args, axes) +end +instantiate(bc::Broadcasted{<:Union{AbstractArrayStyle{0}, Style{Tuple}}}) = bc + +## Flattening + +""" + bcf = flatten(bc) + +Create a "flat" representation of a lazy-broadcast operation. +From + f.(a, g.(b, c), d) +we produce the equivalent of + h.(a, b, c, d) +where + h(w, x, y, z) = f(w, g(x, y), z) +In terms of its internal representation, + Broadcasted(f, a, Broadcasted(g, b, c), d) +becomes + Broadcasted(h, a, b, c, d) + +This is an optional operation that may make custom implementation of broadcasting easier in +some cases. +""" +function flatten(bc::Broadcasted{Style}) where {Style} + isflat(bc.args) && return bc + # concatenate the nested arguments into {a, b, c, d} + args = cat_nested(x->x.args, bc) + # build a function `makeargs` that takes a "flat" argument list and + # and creates the appropriate input arguments for `f`, e.g., + # makeargs = (w, x, y, z) -> (w, g(x, y), z) + # + # `makeargs` is built recursively and looks a bit like this: + # makeargs(w, x, y, z) = (w, makeargs1(x, y, z)...) + # = (w, g(x, y), makeargs2(z)...) + # = (w, g(x, y), z) + let makeargs = make_makeargs(bc) + newf = @inline function(args::Vararg{Any,N}) where N + bc.f(makeargs(args...)...) + end + return Broadcasted{Style}(newf, args, bc.axes) + end +end + +isflat(args::NestedTuple) = false +isflat(args::Tuple) = isflat(tail(args)) +isflat(args::Tuple{}) = true + +cat_nested(fieldextractor, bc::Broadcasted) = cat_nested(fieldextractor, fieldextractor(bc), ()) + +cat_nested(fieldextractor, t::Tuple, rest) = + (t[1], cat_nested(fieldextractor, tail(t), rest)...) +cat_nested(fieldextractor, t::Tuple{<:Broadcasted,Vararg{Any}}, rest) = + cat_nested(fieldextractor, cat_nested(fieldextractor, fieldextractor(t[1]), tail(t)), rest) +cat_nested(fieldextractor, t::Tuple{}, tail) = cat_nested(fieldextractor, tail, ()) +cat_nested(fieldextractor, t::Tuple{}, tail::Tuple{}) = () + +make_makeargs(bc::Broadcasted) = make_makeargs(()->(), bc.args) +@inline function make_makeargs(makeargs, t::Tuple) + let makeargs = make_makeargs(makeargs, tail(t)) + return @inline function(head, tail::Vararg{Any,N}) where N + (head, makeargs(tail...)...) + end + end +end +@inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}}) + bc = t[1] + let makeargs = make_makeargs(makeargs, tail(t)) + let makeargs = make_makeargs(makeargs, bc.args) + headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args) + return @inline function(args::Vararg{Any,N}) where N + args1 = makeargs(args...) + a, b = headargs(args1...), tailargs(args1...) + (bc.f(a...), b...) + end + end + end +end +make_makeargs(makeargs, ::Tuple{}) = makeargs + +@inline function make_headargs(t::Tuple) + let headargs = make_headargs(tail(t)) + return @inline function(head, tail::Vararg{Any,N}) where N + (head, headargs(tail...)...) + end + end +end +@inline function make_headargs(::Tuple{}) + return @inline function(tail::Vararg{Any,N}) where N + () + end +end + +@inline function make_tailargs(t::Tuple) + let tailargs = make_tailargs(tail(t)) + return @inline function(head, tail::Vararg{Any,N}) where N + tailargs(tail...) + end + end +end +@inline function make_tailargs(::Tuple{}) + return @inline function(tail::Vararg{Any,N}) where N + tail + end +end + ## Broadcasting utilities ## -# special cases defined for performance -broadcast(f, x::Number...) = f(x...) -@inline broadcast(f, t::NTuple{N,Any}, ts::Vararg{NTuple{N,Any}}) where {N} = map(f, t, ts...) ## logic for deciding the BroadcastStyle # Dimensionality: computing max(M,N) in the type domain so we preserve inferrability @@ -204,6 +379,7 @@ longest(t1::Tuple, ::Tuple{}) = (true, longest(Base.tail(t1), ())...) longest(::Tuple{}, ::Tuple{}) = () # combine_styles operates on values (arbitrarily many) +combine_styles() = DefaultArrayStyle{0}() combine_styles(c) = result_style(BroadcastStyle(typeof(c))) combine_styles(c1, c2) = result_style(combine_styles(c1), combine_styles(c2)) @inline combine_styles(c1, c2, cs...) = result_style(combine_styles(c1), combine_styles(c2, cs...)) @@ -236,8 +412,8 @@ One of these should be undefined (and thus return Broadcast.Unknown).""") end # Indices utilities -combine_indices(A, B...) = broadcast_shape(broadcast_indices(A), combine_indices(B...)) -combine_indices(A) = broadcast_indices(A) +@inline combine_axes(A, B...) = broadcast_shape(broadcast_axes(A), combine_axes(B...)) +combine_axes(A) = broadcast_axes(A) # shape (i.e., tuple-of-indices) inputs broadcast_shape(shape::Tuple) = shape @@ -269,119 +445,124 @@ function check_broadcast_shape(shp, Ashp::Tuple) _bcsm(shp[1], Ashp[1]) || throw(DimensionMismatch("array could not be broadcast to match destination")) check_broadcast_shape(tail(shp), tail(Ashp)) end -check_broadcast_indices(shp, A) = check_broadcast_shape(shp, broadcast_indices(A)) +check_broadcast_axes(shp, A) = check_broadcast_shape(shp, broadcast_axes(A)) # comparing many inputs -@inline function check_broadcast_indices(shp, A, As...) - check_broadcast_indices(shp, A) - check_broadcast_indices(shp, As...) +@inline function check_broadcast_axes(shp, A, As...) + check_broadcast_axes(shp, A) + check_broadcast_axes(shp, As...) end ## Indexing manipulations +""" + newindex(argument, I) + newindex(I, keep, default) + +Recompute index `I` such that it appropriately constrains broadcasted dimensions to the source. + +Two methods are supported, both allowing for `I` to be specified as either a `CartesianIndex` or +an `Int`. + +* `newindex(argument, I)` dynamically constrains `I` based upon the axes of `argument`. +* `newindex(I, keep, default)` constrains `I` using the pre-computed tuples `keeps` and `defaults`. + * `keep` is a tuple of `Bool`s, where `keep[d] == true` means that dimension `d` in `I` should be preserved as is + * `default` is a tuple of Integers, specifying what index to use in dimension `d` when `keep[d] == false`. + Any remaining indices in `I` beyond the length of the `keep` tuple are truncated. The `keep` and `default` + tuples may be created by `newindexer(argument)`. +""" +Base.@propagate_inbounds newindex(arg, I::CartesianIndex) = CartesianIndex(_newindex(broadcast_axes(arg), I.I)) +Base.@propagate_inbounds newindex(arg, I::Int) = CartesianIndex(_newindex(broadcast_axes(arg), (I,))) +Base.@propagate_inbounds _newindex(ax::Tuple, I::Tuple) = (ifelse(Base.unsafe_length(ax[1])==1, ax[1][1], I[1]), _newindex(tail(ax), tail(I))...) +Base.@propagate_inbounds _newindex(ax::Tuple{}, I::Tuple) = () +Base.@propagate_inbounds _newindex(ax::Tuple, I::Tuple{}) = (ax[1][1], _newindex(tail(ax), ())...) +Base.@propagate_inbounds _newindex(ax::Tuple{}, I::Tuple{}) = () -# newindex(I, keep, Idefault) replaces a CartesianIndex `I` with something that -# is appropriate for a particular broadcast array/scalar. `keep` is a -# NTuple{N,Bool}, where keep[d] == true means that one should preserve -# I[d]; if false, replace it with Idefault[d]. # If dot-broadcasting were already defined, this would be `ifelse.(keep, I, Idefault)`. @inline newindex(I::CartesianIndex, keep, Idefault) = CartesianIndex(_newindex(I.I, keep, Idefault)) +@inline newindex(i::Int, keep::Tuple{Bool}, idefault) = ifelse(keep[1], i, idefault[1]) @inline _newindex(I, keep, Idefault) = (ifelse(keep[1], I[1], Idefault[1]), _newindex(tail(I), tail(keep), tail(Idefault))...) @inline _newindex(I, keep::Tuple{}, Idefault) = () # truncate if keep is shorter than I -# newindexer(shape, A) generates `keep` and `Idefault` (for use by -# `newindex` above) for a particular array `A`, given the -# broadcast indices `shape` -# `keep` is equivalent to map(==, axes(A), shape) (but see #17126) -@inline newindexer(shape, A) = shapeindexer(shape, broadcast_indices(A)) -@inline shapeindexer(shape, indsA::Tuple{}) = (), () -@inline function shapeindexer(shape, indsA::Tuple) +# newindexer(A) generates `keep` and `Idefault` (for use by `newindex` above) +# for a particular array `A`; `shapeindexer` does so for its axes. +@inline newindexer(A) = shapeindexer(broadcast_axes(A)) +@inline shapeindexer(ax) = _newindexer(ax) +@inline _newindexer(indsA::Tuple{}) = (), () +@inline function _newindexer(indsA::Tuple) ind1 = indsA[1] - keep, Idefault = shapeindexer(tail(shape), tail(indsA)) - (shape[1] == ind1, keep...), (first(ind1), Idefault...) + keep, Idefault = _newindexer(tail(indsA)) + (length(ind1)!=1, keep...), (first(ind1), Idefault...) end -# Equivalent to map(x->newindexer(shape, x), As) (but see #17126) -map_newindexer(shape, ::Tuple{}) = (), () -@inline function map_newindexer(shape, As) - A1 = As[1] - keeps, Idefaults = map_newindexer(shape, tail(As)) - keep, Idefault = newindexer(shape, A1) - (keep, keeps...), (Idefault, Idefaults...) -end -@inline function map_newindexer(shape, A, Bs) - keeps, Idefaults = map_newindexer(shape, Bs) - keep, Idefault = newindexer(shape, A) - (keep, keeps...), (Idefault, Idefaults...) +@inline function Base.getindex(bc::Broadcasted, I) + @boundscheck checkbounds(bc, I) + @inbounds _broadcast_getindex(bc, I) end -Base.@propagate_inbounds _broadcast_getindex(::Type{T}, I) where T = T -Base.@propagate_inbounds _broadcast_getindex(A, I) = _broadcast_getindex(combine_styles(A), A, I) -Base.@propagate_inbounds _broadcast_getindex(::DefaultArrayStyle{0}, A, I) = A[] -Base.@propagate_inbounds _broadcast_getindex(::Any, A, I) = A[I] -Base.@propagate_inbounds _broadcast_getindex(::Style{Tuple}, A::Tuple{Any}, I) = A[1] +@inline Base.checkbounds(bc::Broadcasted, I) = + Base.checkbounds_indices(Bool, axes(bc), (I,)) || Base.throw_boundserror(bc, (I,)) -## Broadcasting core -# nargs encodes the number of As arguments (which matches the number -# of keeps). The first two type parameters are to ensure specialization. -@generated function _broadcast!(f, B::AbstractArray, keeps::K, Idefaults::ID, A::AT, Bs::BT, ::Val{N}, iter) where {K,ID,AT,BT,N} - nargs = N + 1 - quote - $(Expr(:meta, :inline)) - # destructure the keeps and As tuples - A_1 = A - @nexprs $N i->(A_{i+1} = Bs[i]) - @nexprs $nargs i->(keep_i = keeps[i]) - @nexprs $nargs i->(Idefault_i = Idefaults[i]) - @simd for I in iter - # reverse-broadcast the indices - @nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i)) - # extract array values - @nexprs $nargs i->(@inbounds val_i = _broadcast_getindex(A_i, I_i)) - # call the function and store the result - result = @ncall $nargs f val - @inbounds B[I] = result - end - return B - end -end -# For BitArray outputs, we cache the result in a "small" Vector{Bool}, -# and then copy in chunks into the output -@generated function _broadcast!(f, B::BitArray, keeps::K, Idefaults::ID, A::AT, Bs::BT, ::Val{N}, iter) where {K,ID,AT,BT,N} - nargs = N + 1 - quote - $(Expr(:meta, :inline)) - # destructure the keeps and As tuples - A_1 = A - @nexprs $N i->(A_{i+1} = Bs[i]) - @nexprs $nargs i->(keep_i = keeps[i]) - @nexprs $nargs i->(Idefault_i = Idefaults[i]) - C = Vector{Bool}(undef, bitcache_size) - Bc = B.chunks - ind = 1 - cind = 1 - @simd for I in iter - # reverse-broadcast the indices - @nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i)) - # extract array values - @nexprs $nargs i->(@inbounds val_i = _broadcast_getindex(A_i, I_i)) - # call the function and store the result - @inbounds C[ind] = @ncall $nargs f val - ind += 1 - if ind > bitcache_size - dumpbitcache(Bc, cind, C) - cind += bitcache_chunks - ind = 1 - end - end - if ind > 1 - @inbounds C[ind:bitcache_size] = false - dumpbitcache(Bc, cind, C) - end - return B - end +""" + _broadcast_getindex(A, I) + +Index into `A` with `I`, collapsing broadcasted indices to their singleton indices as appropriate +""" +Base.@propagate_inbounds _broadcast_getindex(A::Union{Ref,AbstractArray{<:Any,0},Number}, I) = A[] # Scalar-likes can just ignore all indices +Base.@propagate_inbounds _broadcast_getindex(::Ref{Type{T}}, I) where {T} = T +# Tuples are statically known to be singleton or vector-like +Base.@propagate_inbounds _broadcast_getindex(A::Tuple{Any}, I) = A[1] +Base.@propagate_inbounds _broadcast_getindex(A::Tuple, I) = A[I[1]] +# Everything else falls back to dynamically dropping broadcasted indices based upon its axes +Base.@propagate_inbounds _broadcast_getindex(A, I) = A[newindex(A, I)] + +# In some cases, it's more efficient to sort out which dimensions should be dropped +# ahead of time (often when the size checks aren't able to be lifted out of the loop). +# The Extruded struct computes that information ahead of time and stores it as a pair +# of tuples to optimize indexing later. This is most commonly needed for `Array` and +# other `AbstractArray` subtypes that wrap `Array` and dynamically ask it for its size. +struct Extruded{T, K, D} + x::T + keeps::K # A tuple of booleans, specifying which indices should be passed normally + defaults::D # A tuple of integers, specifying the index to use when keeps[i] is false (as defaults[i]) +end +@inline broadcast_axes(b::Extruded) = broadcast_axes(b.x) +Base.@propagate_inbounds _broadcast_getindex(b::Extruded, i) = b.x[newindex(i, b.keeps, b.defaults)] +extrude(x::AbstractArray) = Extruded(x, newindexer(x)...) +extrude(x) = x + +# For Broadcasted +Base.@propagate_inbounds function _broadcast_getindex(bc::Broadcasted{<:Any,<:Any,<:Any,<:Any}, I) + args = _getindex(bc.args, I) + return _broadcast_getindex_evalf(bc.f, args...) +end +# Hack around losing Type{T} information in the final args tuple. Julia actually +# knows (in `code_typed`) the _value_ of these types, statically displaying them, +# but inference is currently skipping inferring the type of the types as they are +# transiently placed in a tuple as the argument list is lispily constructed. These +# additional methods recover type stability when a `Type` appears in one of the +# first two arguments of a function. +Base.@propagate_inbounds function _broadcast_getindex(bc::Broadcasted{<:Any,<:Any,<:Any,<:Tuple{Ref{Type{T}},Vararg{Any}}}, I) where {T} + args = _getindex(tail(bc.args), I) + return _broadcast_getindex_evalf(bc.f, T, args...) +end +Base.@propagate_inbounds function _broadcast_getindex(bc::Broadcasted{<:Any,<:Any,<:Any,<:Tuple{Any,Ref{Type{T}},Vararg{Any}}}, I) where {T} + arg1 = _broadcast_getindex(bc.args[1], I) + args = _getindex(tail(tail(bc.args)), I) + return _broadcast_getindex_evalf(bc.f, arg1, T, args...) +end +Base.@propagate_inbounds function _broadcast_getindex(bc::Broadcasted{<:Any,<:Any,<:Any,<:Tuple{Ref{Type{T}},Ref{Type{S}},Vararg{Any}}}, I) where {T,S} + args = _getindex(tail(tail(bc.args)), I) + return _broadcast_getindex_evalf(bc.f, T, S, args...) end +# Utilities for _broadcast_getindex +Base.@propagate_inbounds _getindex(args::Tuple, I) = (_broadcast_getindex(args[1], I), _getindex(tail(args), I)...) +Base.@propagate_inbounds _getindex(args::Tuple{Any}, I) = (_broadcast_getindex(args[1], I),) +Base.@propagate_inbounds _getindex(args::Tuple{}, I) = () + +@inline _broadcast_getindex_evalf(f::Tf, args::Vararg{Any,N}) where {Tf,N} = f(args...) # not propagate_inbounds + """ broadcastable(x) @@ -410,129 +591,27 @@ julia> broadcastable("hello") # Strings break convention of matching iteration a Base.RefValue{String}("hello") ``` """ -broadcastable(x::Union{Symbol,AbstractString,Function,UndefInitializer,Nothing,RoundingMode,Missing}) = Ref(x) +broadcastable(x::Union{Symbol,AbstractString,Function,UndefInitializer,Nothing,RoundingMode,Missing,Val}) = Ref(x) broadcastable(x::Ptr) = Ref{Ptr}(x) # Cannot use Ref(::Ptr) until ambiguous deprecation goes through broadcastable(::Type{T}) where {T} = Ref{Type{T}}(T) -broadcastable(x::Union{AbstractArray,Number,Ref,Tuple}) = x +broadcastable(x::Union{AbstractArray,Number,Ref,Tuple,Broadcasted}) = x # In the future, default to collecting arguments. TODO: uncomment once deprecations are removed # broadcastable(x) = collect(x) # broadcastable(::Union{AbstractDict, NamedTuple}) = error("intentionally unimplemented to allow development in 1.x") -""" - broadcast!(f, dest, As...) +## Computation of inferred result type, for empty and concretely inferred cases only +_broadcast_getindex_eltype(bc::Broadcasted) = Base._return_type(bc.f, eltypes(bc.args)) +_broadcast_getindex_eltype(A) = eltype(A) # Tuple, Array, etc. -Like [`broadcast`](@ref), but store the result of -`broadcast(f, As...)` in the `dest` array. -Note that `dest` is only used to store the result, and does not supply -arguments to `f` unless it is also listed in the `As`, -as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`. -""" -@inline function broadcast!(f::Tf, dest, As::Vararg{Any,N}) where {Tf,N} - As′ = map(broadcastable, As) - broadcast!(f, dest, combine_styles(As′...), As′...) -end -@inline broadcast!(f::Tf, dest, ::BroadcastStyle, As::Vararg{Any,N}) where {Tf,N} = broadcast!(f, dest, nothing, As...) +eltypes(::Tuple{}) = Tuple{} +eltypes(t::Tuple{Any}) = Tuple{_broadcast_getindex_eltype(t[1])} +eltypes(t::Tuple{Any,Any}) = Tuple{_broadcast_getindex_eltype(t[1]), _broadcast_getindex_eltype(t[2])} +eltypes(t::Tuple) = Tuple{_broadcast_getindex_eltype(t[1]), eltypes(tail(t)).types...} -# Default behavior (separated out so that it can be called by users who want to extend broadcast!). -@inline function broadcast!(f, dest, ::Nothing, As::Vararg{Any, N}) where N - if f isa typeof(identity) && N == 1 - A = As[1] - if A isa AbstractArray && Base.axes(dest) == Base.axes(A) - return copyto!(dest, A) - end - end - _broadcast!(f, dest, As...) - return dest -end +# Inferred eltype of result of broadcast(f, args...) +combine_eltypes(f, args::Tuple) = Base._return_type(f, eltypes(args)) -# Optimization for the case where all arguments are 0-dimensional -@inline function broadcast!(f, dest, ::AbstractArrayStyle{0}, As::Vararg{Any, N}) where N - if dest isa AbstractArray - if f isa typeof(identity) && N == 1 - return fill!(dest, As[1][]) - else - @inbounds for I in eachindex(dest) - dest[I] = f(map(getindex, As)...) - end - return dest - end - end - _broadcast!(f, dest, As...) - return dest -end - -# For broadcasted assignments like `broadcast!(f, A, ..., A, ...)`, where `A` -# appears on both the LHS and the RHS of the `.=`, then we know we're only -# going to make one pass through the array, and even though `A` is aliasing -# against itself, the mutations won't affect the result as the indices on the -# LHS and RHS will always match. This is not true in general, but with the `.op=` -# syntax it's fairly common for an argument to be `===` a source. -broadcast_unalias(dest, src) = dest === src ? src : unalias(dest, src) - -# This indirection allows size-dependent implementations. -@inline function _broadcast!(f, C, A, Bs::Vararg{Any,N}) where N - shape = broadcast_indices(C) - @boundscheck check_broadcast_indices(shape, A, Bs...) - A′ = broadcast_unalias(C, A) - Bs′ = map(B->broadcast_unalias(C, B), Bs) - keeps, Idefaults = map_newindexer(shape, A′, Bs′) - iter = CartesianIndices(shape) - _broadcast!(f, C, keeps, Idefaults, A′, Bs′, Val(N), iter) - return C -end - -# broadcast with element type adjusted on-the-fly. This widens the element type of -# B as needed (allocating a new container and copying previously-computed values) to -# accommodate any incompatible new elements. -@generated function _broadcast!(f, B::AbstractArray, keeps::K, Idefaults::ID, As::AT, ::Val{nargs}, iter, st, count) where {K,ID,AT,nargs} - quote - $(Expr(:meta, :noinline)) - # destructure the keeps and As tuples - @nexprs $nargs i->(A_i = As[i]) - @nexprs $nargs i->(keep_i = keeps[i]) - @nexprs $nargs i->(Idefault_i = Idefaults[i]) - while !done(iter, st) - I, st = next(iter, st) - # reverse-broadcast the indices - @nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i)) - # extract array values - @nexprs $nargs i->(@inbounds val_i = _broadcast_getindex(A_i, I_i)) - # call the function - V = @ncall $nargs f val - # store the result - if V isa eltype(B) - @inbounds B[I] = V - else - # This element type doesn't fit in B. Allocate a new B with wider eltype, - # copy over old values, and continue - newB = Base.similar(B, promote_typejoin(eltype(B), typeof(V))) - for II in Iterators.take(iter, count) - newB[II] = B[II] - end - newB[I] = V - return _broadcast!(f, newB, keeps, Idefaults, As, Val(nargs), iter, st, count+1) - end - count += 1 - end - return B - end -end - -maptoTuple(f) = Tuple{} -maptoTuple(f, a, b...) = Tuple{f(a), maptoTuple(f, b...).types...} - -# An element type satisfying for all A: -# broadcast_getindex( -# combine_styles(A), -# A, broadcast_indices(A) -# )::_broadcast_getindex_eltype(A) -_broadcast_getindex_eltype(A) = _broadcast_getindex_eltype(combine_styles(A), A) -_broadcast_getindex_eltype(::BroadcastStyle, A) = eltype(A) # Tuple, Array, etc. -_broadcast_getindex_eltype(::DefaultArrayStyle{0}, ::Ref{T}) where {T} = T - -# Inferred eltype of result of broadcast(f, xs...) -combine_eltypes(f, A, As...) = - Base._return_type(f, maptoTuple(_broadcast_getindex_eltype, A, As...)) +## Broadcasting core """ broadcast(f, As...) @@ -610,77 +689,294 @@ julia> string.(("one","two","three","four"), ": ", 1:4) ``` """ -@inline function broadcast(f, A, Bs...) - A′ = broadcastable(A) - Bs′ = map(broadcastable, Bs) - broadcast(f, combine_styles(A′, Bs′...), nothing, nothing, A′, Bs′...) +broadcast(f::Tf, As...) where {Tf} = copy(instantiate(make(f, As...))) + +# special cases defined for performance +@inline broadcast(f, x::Number...) = f(x...) +@inline broadcast(f, t::NTuple{N,Any}, ts::Vararg{NTuple{N,Any}}) where {N} = map(f, t, ts...) + +""" + broadcast!(f, dest, As...) + +Like [`broadcast`](@ref), but store the result of +`broadcast(f, As...)` in the `dest` array. +Note that `dest` is only used to store the result, and does not supply +arguments to `f` unless it is also listed in the `As`, +as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`. +""" +broadcast!(f::Tf, dest, As::Vararg{Any,N}) where {Tf,N} = (materialize!(dest, make(f, As...)); dest) + +""" + Broadcast.materialize(bc) + +Take a lazy `Broadcasted` object and compute the result +""" +@inline materialize(bc::Broadcasted) = copy(instantiate(bc)) +materialize(x) = x +@inline function materialize!(dest, bc::Broadcasted{Style}) where {Style} + return copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest)))) +end +@inline function materialize!(dest, x) + return copyto!(dest, instantiate(Broadcasted(identity, (x,), axes(dest)))) end -# In the scalar case we unwrap the arguments and just call `f` -@inline broadcast(f, ::AbstractArrayStyle{0}, ::Nothing, ::Nothing, A, Bs...) = f(A[], map(getindex, Bs)...) +## general `copy` methods +@inline copy(bc::Broadcasted{<:AbstractArrayStyle{0}}) = bc[CartesianIndex()] +copy(bc::Broadcasted{<:Union{Nothing,Unknown}}) = + throw(ArgumentError("broadcasting requires an assigned BroadcastStyle")) -@inline broadcast(f, s::BroadcastStyle, ::Nothing, ::Nothing, A, Bs...) = - broadcast(f, s, combine_eltypes(f, A, Bs...), combine_indices(A, Bs...), A, Bs...) +const NonleafHandlingStyles = Union{DefaultArrayStyle,ArrayConflict} -const NonleafHandlingTypes = Union{DefaultArrayStyle,ArrayConflict} +@inline function copy(bc::Broadcasted{Style}) where {Style} + ElType = combine_eltypes(bc.f, bc.args) + if Base.isconcretetype(ElType) + # We can trust it and defer to the simpler `copyto!` + return copyto!(broadcast_similar(Style(), ElType, axes(bc), bc), bc) + end + # When ElType is not concrete, use narrowing. Use the first output + # value to determine the starting output eltype; copyto_nonleaf! + # will widen `dest` as needed to accommodate later values. + bc′ = preprocess(nothing, bc) + iter = CartesianIndices(axes(bc′)) + state = start(iter) + if done(iter, state) + # if empty, take the ElType at face value + return broadcast_similar(Style(), ElType, axes(bc′), bc′) + end + # Initialize using the first value + I, state = next(iter, state) + @inbounds val = bc′[I] + dest = broadcast_similar(Style(), typeof(val), axes(bc′), bc′) + @inbounds dest[I] = val + # Now handle the remaining values + return copyto_nonleaf!(dest, bc′, iter, state, 1) +end -@inline function broadcast(f, s::NonleafHandlingTypes, ::Type{ElType}, inds::Indices, As...) where ElType - if !Base.isconcretetype(ElType) - return broadcast_nonleaf(f, s, ElType, inds, As...) +## general `copyto!` methods +# The most general method falls back to a method that replaces Style->Nothing +# This permits specialization on typeof(dest) without introducing ambiguities +@inline copyto!(dest::AbstractArray, bc::Broadcasted) = copyto!(dest, convert(Broadcasted{Nothing}, bc)) + +# Performance optimization for the Scalar case +@inline function copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractArrayStyle{0}}) + if not_nested(bc) + if bc.f === identity && bc.args isa Tuple{Any} # only a single input argument to broadcast! + # broadcast!(identity, dest, val) is equivalent to fill!(dest, val) + return fill!(dest, bc.args[1][]) + else + args = bc.args + @inbounds for I in eachindex(dest) + dest[I] = bc.f(map(getindex, args)...) + end + return dest + end end - dest = broadcast_similar(f, s, ElType, inds, As...) - broadcast!(f, dest, As...) + # Fall back to the default implementation + return copyto!(dest, instantiate(bc)) end -@inline function broadcast(f, s::BroadcastStyle, ::Type{ElType}, inds::Indices, As...) where ElType - dest = broadcast_similar(f, s, ElType, inds, As...) - broadcast!(f, dest, As...) +# For broadcasted assignments like `broadcast!(f, A, ..., A, ...)`, where `A` +# appears on both the LHS and the RHS of the `.=`, then we know we're only +# going to make one pass through the array, and even though `A` is aliasing +# against itself, the mutations won't affect the result as the indices on the +# LHS and RHS will always match. This is not true in general, but with the `.op=` +# syntax it's fairly common for an argument to be `===` a source. +broadcast_unalias(dest, src) = dest === src ? src : unalias(dest, src) +broadcast_unalias(::Nothing, src) = src + +# Preprocessing a `Broadcasted` does two things: +# * unaliases any arguments from `dest` +# * "extrudes" the arguments where it is advantageous to pre-compute the broadcasted indices +@inline preprocess(dest, bc::Broadcasted{Style}) where {Style} = Broadcasted{Style}(bc.f, preprocess_args(dest, bc.args), bc.axes) +preprocess(dest, x) = extrude(broadcast_unalias(dest, x)) + +@inline preprocess_args(dest, args::Tuple) = (preprocess(dest, args[1]), preprocess_args(dest, tail(args))...) +preprocess_args(dest, args::Tuple{Any}) = (preprocess(dest, args[1]),) +preprocess_args(dest, args::Tuple{}) = () + +# Specialize this method if all you want to do is specialize on typeof(dest) +@inline function copyto!(dest::AbstractArray, bc::Broadcasted{Nothing}) + axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc)) + # Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match + if bc.f === identity && bc.args isa Tuple{<:AbstractArray} # only a single input argument to broadcast! + A = bc.args[1] + if axes(dest) == axes(A) + return copyto!(dest, A) + end + end + bc′ = preprocess(dest, bc) + @simd for I in CartesianIndices(axes(bc′)) + @inbounds dest[I] = bc′[I] + end + return dest end -# When ElType is not concrete, use narrowing. Use the first element of each input to determine -# the starting output eltype; the _broadcast! method will widen `dest` as needed to -# accommodate later values. -function broadcast_nonleaf(f, s::NonleafHandlingTypes, ::Type{ElType}, shape::Indices, As...) where ElType - nargs = length(As) - iter = CartesianIndices(shape) - if isempty(iter) - return Base.similar(Array{ElType}, shape) +# Performance optimization: for BitArray outputs, we cache the result +# in a "small" Vector{Bool}, and then copy in chunks into the output +function copyto!(dest::BitArray, bc::Broadcasted{Nothing}) + axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc)) + ischunkedbroadcast(dest, bc) && return chunkedcopyto!(dest, bc) + tmp = Vector{Bool}(undef, bitcache_size) + destc = dest.chunks + ind = cind = 1 + bc′ = preprocess(dest, bc) + @simd for I in CartesianIndices(axes(bc′)) + @inbounds tmp[ind] = bc′[I] + ind += 1 + if ind > bitcache_size + dumpbitcache(destc, cind, tmp) + cind += bitcache_chunks + ind = 1 + end end - keeps, Idefaults = map_newindexer(shape, As) - st = start(iter) - I, st = next(iter, st) - val = f([ _broadcast_getindex(As[i], newindex(I, keeps[i], Idefaults[i])) for i=1:nargs ]...) - if val isa Bool - dest = Base.similar(BitArray, shape) - else - dest = Base.similar(Array{typeof(val)}, shape) + if ind > 1 + @inbounds tmp[ind:bitcache_size] = false + dumpbitcache(destc, cind, tmp) end - dest[I] = val - _broadcast!(f, dest, keeps, Idefaults, As, Val(nargs), iter, st, 1) -end - -@inline broadcast(f, ::Style{Tuple}, ::Nothing, ::Nothing, A, Bs...) = - tuplebroadcast(f, longest_tuple(A, Bs...), A, Bs...) -@inline tuplebroadcast(f, ::NTuple{N,Any}, As...) where {N} = - ntuple(k -> f(tuplebroadcast_getargs(As, k)...), Val(N)) -@inline tuplebroadcast(f, ::NTuple{N,Any}, ::Ref{Type{T}}, As...) where {N,T} = - ntuple(k -> f(T, tuplebroadcast_getargs(As, k)...), Val(N)) -longest_tuple(A::Tuple, B::Tuple, Bs...) = longest_tuple(_longest_tuple(A, B), Bs...) -longest_tuple(A, B::Tuple, Bs...) = longest_tuple(B, Bs...) -longest_tuple(A::Tuple, B, Bs...) = longest_tuple(A, Bs...) -longest_tuple(A, B, Bs...) = longest_tuple(Bs...) -longest_tuple(A::Tuple) = A + return dest +end + +# For some BitArray operations, we can work at the level of chunks. The trivial +# implementation just walks over the UInt64 chunks in a linear fashion. +# This requires three things: +# 1. The function must be known to work at the level of chunks +# 2. The only arrays involved must be BitArrays or scalars +# 3. There must not be any broadcasting beyond scalar — all array sizes must match +# We could eventually allow for all broadcasting and other array types, but that +# requires very careful consideration of all the edge effects. +const ChunkableOp = Union{typeof(&), typeof(|), typeof(xor), typeof(~)} +const BroadcastedChunkableOp{Style<:Union{Nothing,BroadcastStyle}, Axes, F<:ChunkableOp, Args<:Tuple} = Broadcasted{Style,Axes,F,Args} +ischunkedbroadcast(R, bc::BroadcastedChunkableOp) = ischunkedbroadcast(R, bc.args) +ischunkedbroadcast(R, args) = false +ischunkedbroadcast(R, args::Tuple{<:BitArray,Vararg{Any}}) = size(R) == size(args[1]) && ischunkedbroadcast(R, tail(args)) +ischunkedbroadcast(R, args::Tuple{<:Bool,Vararg{Any}}) = ischunkedbroadcast(R, tail(args)) +ischunkedbroadcast(R, args::Tuple{<:BroadcastedChunkableOp,Vararg{Any}}) = ischunkedbroadcast(R, args[1]) && ischunkedbroadcast(R, tail(args)) +ischunkedbroadcast(R, args::Tuple{}) = true + +liftchunks(::Tuple{}) = () +liftchunks(args::Tuple{<:BitArray,Vararg{Any}}) = (args[1].chunks, liftchunks(tail(args))...) +# Transform scalars to repeated scalars the size of a chunk +liftchunks(args::Tuple{<:Bool,Vararg{Any}}) = (ifelse(args[1], typemax(UInt64), UInt64(0)), liftchunks(tail(args))...) +ithchunk(i) = () +Base.@propagate_inbounds ithchunk(i, c::Vector{UInt64}, args...) = (c[i], ithchunk(i, args...)...) +Base.@propagate_inbounds ithchunk(i, b::UInt64, args...) = (b, ithchunk(i, args...)...) +function chunkedcopyto!(dest::BitArray, bc::Broadcasted) + isempty(dest) && return dest + f = flatten(bc) + args = liftchunks(f.args) + dc = dest.chunks + @simd for i in eachindex(dc) + @inbounds dc[i] = f.f(ithchunk(i, args...)...) + end + @inbounds dc[end] &= Base._msk_end(dest) + return dest +end + + +@noinline throwdm(axdest, axsrc) = + throw(DimensionMismatch("destination axes $axdest are not compatible with source axes $axsrc")) + +function copyto_nonleaf!(dest, bc::Broadcasted, iter, state, count) + T = eltype(dest) + while !done(iter, state) + I, state = next(iter, state) + @inbounds val = bc[I] + S = typeof(val) + if S <: T + @inbounds dest[I] = val + else + # This element type doesn't fit in dest. Allocate a new dest with wider eltype, + # copy over old values, and continue + newdest = Base.similar(dest, promote_typejoin(T, S)) + for II in Iterators.take(iter, count) + newdest[II] = dest[II] + end + newdest[I] = val + return copyto_nonleaf!(newdest, bc, iter, state, count+1) + end + count += 1 + end + return dest +end + +## Tuple methods + +@inline copy(bc::Broadcasted{Style{Tuple}}) = + tuplebroadcast(longest_tuple(nothing, bc.args), bc) +@inline tuplebroadcast(::NTuple{N,Any}, bc) where {N} = ntuple(k -> @inbounds(_broadcast_getindex(bc, k)), Val(N)) +# This is a little tricky: find the longest tuple (first arg) within the list of arguments (second arg) +# Start with nothing as a placeholder and go until we find the first tuple in the argument list +longest_tuple(::Nothing, t::Tuple{Tuple,Vararg{Any}}) = longest_tuple(t[1], tail(t)) +# Or recurse through nested broadcast expressions +longest_tuple(::Nothing, t::Tuple{Broadcasted,Vararg{Any}}) = longest_tuple(longest_tuple(nothing, t[1].args), tail(t)) +longest_tuple(::Nothing, t::Tuple) = longest_tuple(nothing, tail(t)) +# And then compare it against all other tuples we find in the argument list or nested broadcasts +longest_tuple(l::Tuple, t::Tuple{Tuple,Vararg{Any}}) = longest_tuple(_longest_tuple(l, t[1]), tail(t)) +longest_tuple(l::Tuple, t::Tuple) = longest_tuple(l, tail(t)) +longest_tuple(l::Tuple, ::Tuple{}) = l +longest_tuple(l::Tuple, t::Tuple{Broadcasted}) = longest_tuple(l, t[1].args) +longest_tuple(l::Tuple, t::Tuple{Broadcasted,Vararg{Any}}) = longest_tuple(longest_tuple(l, t[1].args), tail(t)) # Support only 1-tuples and N-tuples where there are no conflicts in N _longest_tuple(A::Tuple{Any}, B::Tuple{Any}) = A -_longest_tuple(A::NTuple{N,Any}, B::NTuple{N,Any}) where N = A -_longest_tuple(A::NTuple{N,Any}, B::Tuple{Any}) where N = A _longest_tuple(A::Tuple{Any}, B::NTuple{N,Any}) where N = B +_longest_tuple(A::NTuple{N,Any}, B::Tuple{Any}) where N = A +_longest_tuple(A::NTuple{N,Any}, B::NTuple{N,Any}) where N = A @noinline _longest_tuple(A, B) = throw(DimensionMismatch("tuples $A and $B could not be broadcast to a common size")) -tuplebroadcast_getargs(::Tuple{}, k) = () -@inline tuplebroadcast_getargs(As, k) = - (_broadcast_getindex(first(As), k), tuplebroadcast_getargs(tail(As), k)...) +## scalar-range broadcast operations ## +# DefaultArrayStyle and \ are not available at the time of range.jl +make(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange) = range(-first(r), step=-step(r), length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(-), r::StepRangeLen) = StepRangeLen(-r.ref, -r.step, length(r), r.offset) +make(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange) = LinRange(-r.start, -r.stop, length(r)) + +make(::DefaultArrayStyle{1}, ::typeof(+), x::Real, r::AbstractUnitRange) = range(x + first(r), length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractUnitRange, x::Real) = range(first(r) + x, length=length(r)) +# For #18336 we need to prevent promotion of the step type: +make(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractRange, x::Number) = range(first(r) + x, step=step(r), length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::AbstractRange) = range(x + first(r), step=step(r), length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(+), r::StepRangeLen{T}, x::Number) where T = + StepRangeLen{typeof(T(r.ref)+x)}(r.ref + x, r.step, length(r), r.offset) +make(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::StepRangeLen{T}) where T = + StepRangeLen{typeof(x+T(r.ref))}(x + r.ref, r.step, length(r), r.offset) +make(::DefaultArrayStyle{1}, ::typeof(+), r::LinRange, x::Number) = LinRange(r.start + x, r.stop + x, length(r)) +make(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::LinRange) = LinRange(x + r.start, x + r.stop, length(r)) +make(::DefaultArrayStyle{1}, ::typeof(+), r1::AbstractRange, r2::AbstractRange) = r1 + r2 + +make(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractUnitRange, x::Number) = range(first(r)-x, length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractRange, x::Number) = range(first(r)-x, step=step(r), length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::AbstractRange) = range(x-first(r), step=-step(r), length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(-), r::StepRangeLen{T}, x::Number) where T = + StepRangeLen{typeof(T(r.ref)-x)}(r.ref - x, r.step, length(r), r.offset) +make(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::StepRangeLen{T}) where T = + StepRangeLen{typeof(x-T(r.ref))}(x - r.ref, -r.step, length(r), r.offset) +make(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange, x::Number) = LinRange(r.start - x, r.stop - x, length(r)) +make(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::LinRange) = LinRange(x - r.start, x - r.stop, length(r)) +make(::DefaultArrayStyle{1}, ::typeof(-), r1::AbstractRange, r2::AbstractRange) = r1 - r2 + +make(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::AbstractRange) = range(x*first(r), step=x*step(r), length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::StepRangeLen{T}) where {T} = + StepRangeLen{typeof(x*T(r.ref))}(x*r.ref, x*r.step, length(r), r.offset) +make(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::LinRange) = LinRange(x * r.start, x * r.stop, r.len) +# separate in case of noncommutative multiplication +make(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Number) = range(first(r)*x, step=step(r)*x, length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(*), r::StepRangeLen{T}, x::Number) where {T} = + StepRangeLen{typeof(T(r.ref)*x)}(r.ref*x, r.step*x, length(r), r.offset) +make(::DefaultArrayStyle{1}, ::typeof(*), r::LinRange, x::Number) = LinRange(r.start * x, r.stop * x, r.len) + +make(::DefaultArrayStyle{1}, ::typeof(/), r::AbstractRange, x::Number) = range(first(r)/x, step=step(r)/x, length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(/), r::StepRangeLen{T}, x::Number) where {T} = + StepRangeLen{typeof(T(r.ref)/x)}(r.ref/x, r.step/x, length(r), r.offset) +make(::DefaultArrayStyle{1}, ::typeof(/), r::LinRange, x::Number) = LinRange(r.start / x, r.stop / x, r.len) + +make(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::AbstractRange) = range(x\first(r), step=x\step(r), length=length(r)) +make(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::StepRangeLen) = StepRangeLen(x\r.ref, x\r.step, length(r), r.offset) +make(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::LinRange) = LinRange(x \ r.start, x \ r.stop, r.len) + +make(::DefaultArrayStyle{1}, ::typeof(big), r::UnitRange) = big(r.start):big(last(r)) +make(::DefaultArrayStyle{1}, ::typeof(big), r::StepRange) = big(r.start):big(r.step):big(last(r)) +make(::DefaultArrayStyle{1}, ::typeof(big), r::StepRangeLen) = StepRangeLen(big(r.ref), big(r.step), length(r), r.offset) +make(::DefaultArrayStyle{1}, ::typeof(big), r::LinRange) = LinRange(big(r.start), big(r.stop), length(r)) """ @@ -739,16 +1035,14 @@ julia> broadcast_getindex(A, [1 2 1; 1 2 2], [1, 2]) ``` """ broadcast_getindex(src::AbstractArray, I::AbstractArray...) = - broadcast_getindex!(Base.similar(Array{eltype(src)}, combine_indices(I...)), - src, - I...) + broadcast_getindex!(Base.similar(Array{eltype(src)}, combine_axes(I...)), src, I...) @generated function broadcast_getindex!(dest::AbstractArray, src::AbstractArray, I::AbstractArray...) N = length(I) Isplat = Expr[:(I[$d]) for d = 1:N] quote @nexprs $N d->(I_d = I[d]) - check_broadcast_indices(Base.axes(dest), $(Isplat...)) # unnecessary if this function is never called directly + check_broadcast_axes(Base.axes(dest), $(Isplat...)) # unnecessary if this function is never called directly checkbounds(src, $(Isplat...)) @nexprs $N d->(@nexprs $N k->(Ibcast_d_k = Base.axes(I_k, d) == OneTo(1))) @nloops $N i dest d->(@nexprs $N k->(j_d_k = Ibcast_d_k ? 1 : i_d)) begin @@ -779,7 +1073,7 @@ See [`broadcast_getindex`](@ref) for examples of the treatment of `inds`. quote @nexprs $N d->(I_d = I[d]) checkbounds(A, $(Isplat...)) - shape = combine_indices($(Isplat...)) + shape = combine_axes($(Isplat...)) @nextract $N shape d->(length(shape) < d ? OneTo(1) : shape[d]) @nexprs $N d->(@nexprs $N k->(Ibcast_d_k = Base.axes(I_k, d) == 1:1)) if !isa(x, AbstractArray) @@ -892,4 +1186,26 @@ macro __dot__(x) esc(__dot__(x)) end +@inline make_kwsyntax(f, args...; kwargs...) = make((args...)->f(args...; kwargs...), args...) +@inline function make(f, args...) + args′ = map(broadcastable, args) + make(combine_styles(args′...), f, args′...) +end +# Due to the current Type{T}/DataType specialization heuristics within Tuples, +# the totally generic varargs make(f, args...) method above loses Type{T}s in +# mapping broadcastable across the args. These additional methods with explicit +# arguments ensure we preserve Type{T}s in the first or second argument position. +@inline function make(f, arg1, args...) + arg1′ = broadcastable(arg1) + args′ = map(broadcastable, args) + make(combine_styles(arg1′, args′...), f, arg1′, args′...) +end +@inline function make(f, arg1, arg2, args...) + arg1′ = broadcastable(arg1) + arg2′ = broadcastable(arg2) + args′ = map(broadcastable, args) + make(combine_styles(arg1′, arg2′, args′...), f, arg1′, arg2′, args′...) +end +@inline make(::S, f, args...) where S<:BroadcastStyle = Broadcasted{S}(f, args) + end # module diff --git a/base/compiler/ssair/inlining2.jl b/base/compiler/ssair/inlining2.jl index 02c2fb435048a..085cb2733fb83 100644 --- a/base/compiler/ssair/inlining2.jl +++ b/base/compiler/ssair/inlining2.jl @@ -56,7 +56,7 @@ function batch_inline!(todo::Vector{InliningTodo}, ir::IRCode, linetable::Vector if first_bb != block new_range = first_bb+1:block - bb_rename[new_range] = (1:length(new_range)) .+ length(new_cfg_blocks) + bb_rename[new_range] = (1+length(new_cfg_blocks)):(length(new_range)+length(new_cfg_blocks)) append!(new_cfg_blocks, map(copy, ir.cfg.blocks[new_range])) push!(merged_orig_blocks, last(new_range)) end @@ -79,12 +79,12 @@ function batch_inline!(todo::Vector{InliningTodo}, ir::IRCode, linetable::Vector orig_succs = copy(new_cfg_blocks[end].succs) empty!(new_cfg_blocks[end].succs) if need_split_before - bb_rename_range = (1:length(inlinee_cfg.blocks)) .+ length(new_cfg_blocks) + bb_rename_range = (1+length(new_cfg_blocks)):(length(inlinee_cfg.blocks)+length(new_cfg_blocks)) push!(new_cfg_blocks[end].succs, length(new_cfg_blocks)+1) append!(new_cfg_blocks, inlinee_cfg.blocks) else # Merge the last block that was already there with the first block we're adding - bb_rename_range = (1:length(inlinee_cfg.blocks)) .+ (length(new_cfg_blocks) - 1) + bb_rename_range = length(new_cfg_blocks):(length(inlinee_cfg.blocks)+length(new_cfg_blocks)-1) append!(new_cfg_blocks[end].succs, inlinee_cfg.blocks[1].succs) append!(new_cfg_blocks, inlinee_cfg.blocks[2:end]) end @@ -130,7 +130,7 @@ function batch_inline!(todo::Vector{InliningTodo}, ir::IRCode, linetable::Vector end end new_range = (first_bb + 1):length(ir.cfg.blocks) - bb_rename[new_range] = (1:length(new_range)) .+ length(new_cfg_blocks) + bb_rename[new_range] = (1+length(new_cfg_blocks)):(length(new_range)+length(new_cfg_blocks)) append!(new_cfg_blocks, ir.cfg.blocks[new_range]) # Rename edges original bbs diff --git a/base/compiler/ssair/slot2ssa.jl b/base/compiler/ssair/slot2ssa.jl index e18faa5ed9a2d..72cb54b13647e 100644 --- a/base/compiler/ssair/slot2ssa.jl +++ b/base/compiler/ssair/slot2ssa.jl @@ -371,12 +371,12 @@ function domsort_ssa!(ir::IRCode, domtree::DomTree) crit_edge_breaks_fixup = Tuple{Int, Int}[] for (new_bb, bb) in pairs(result_order) if bb == 0 - new_bbs[new_bb] = BasicBlock((1:1) .+ bb_start_off, [new_bb-1], [result_stmts[bb_start_off].dest]) + new_bbs[new_bb] = BasicBlock((bb_start_off+1):(bb_start_off+1), [new_bb-1], [result_stmts[bb_start_off].dest]) bb_start_off += 1 continue end old_inst_range = ir.cfg.blocks[bb].stmts - inst_range = (1:length(old_inst_range)) .+ bb_start_off + inst_range = (bb_start_off+1):(bb_start_off+length(old_inst_range)) inst_rename[old_inst_range] = Any[SSAValue(x) for x in inst_range] for (nidx, idx) in zip(inst_range, old_inst_range) stmt = ir.stmts[idx] diff --git a/base/deprecated.jl b/base/deprecated.jl index ec73aa06b4950..8efe89cbe319c 100644 --- a/base/deprecated.jl +++ b/base/deprecated.jl @@ -1115,6 +1115,10 @@ end @deprecate indices(a) axes(a) @deprecate indices(a, d) axes(a, d) +# And similar _indices names in Broadcast +@eval Broadcast Base.@deprecate_binding broadcast_indices broadcast_axes true +@eval Broadcast Base.@deprecate_binding check_broadcast_indices check_broadcast_axes false + # PR #25046 export reload, workspace reload(name::AbstractString) = error("`reload($(repr(name)))` is discontinued, consider Revise.jl for an alternative workflow.") diff --git a/base/float.jl b/base/float.jl index 6b468f238279d..db86677c6a8a0 100644 --- a/base/float.jl +++ b/base/float.jl @@ -875,13 +875,3 @@ float(r::StepRangeLen{T}) where {T} = function float(r::LinRange) LinRange(float(r.start), float(r.stop), length(r)) end - -# big, broadcast over arrays -# TODO: do the definitions below primarily pertaining to integers belong in float.jl? -function big end # no prior definitions of big in sysimg.jl, necessitating this -broadcast(::typeof(big), r::UnitRange) = big(r.start):big(last(r)) -broadcast(::typeof(big), r::StepRange) = big(r.start):big(r.step):big(last(r)) -broadcast(::typeof(big), r::StepRangeLen) = StepRangeLen(big(r.ref), big(r.step), length(r), r.offset) -function broadcast(::typeof(big), r::LinRange) - LinRange(big(r.start), big(r.stop), length(r)) -end diff --git a/base/range.jl b/base/range.jl index 950e44e300a55..91a8060ced324 100644 --- a/base/range.jl +++ b/base/range.jl @@ -719,67 +719,6 @@ end StepRangeLen{T,R,S}(-r.ref, -r.step, length(r), r.offset) -(r::LinRange) = LinRange(-r.start, -r.stop, length(r)) -*(x::Number, r::AbstractRange) = range(x*first(r), step=x*step(r), length=length(r)) -*(x::Number, r::StepRangeLen{T}) where {T} = - StepRangeLen{typeof(x*T(r.ref))}(x*r.ref, x*r.step, length(r), r.offset) -*(x::Number, r::LinRange) = LinRange(x * r.start, x * r.stop, r.len) -# separate in case of noncommutative multiplication -*(r::AbstractRange, x::Number) = range(first(r)*x, step=step(r)*x, length=length(r)) -*(r::StepRangeLen{T}, x::Number) where {T} = - StepRangeLen{typeof(T(r.ref)*x)}(r.ref*x, r.step*x, length(r), r.offset) -*(r::LinRange, x::Number) = LinRange(r.start * x, r.stop * x, r.len) - -/(r::AbstractRange, x::Number) = range(first(r)/x, step=step(r)/x, length=length(r)) -/(r::StepRangeLen{T}, x::Number) where {T} = - StepRangeLen{typeof(T(r.ref)/x)}(r.ref/x, r.step/x, length(r), r.offset) -/(r::LinRange, x::Number) = LinRange(r.start / x, r.stop / x, r.len) -# also, separate in case of noncommutative multiplication (division) -\(x::Number, r::AbstractRange) = range(x\first(r), step=x\step(r), length=x\length(r)) -\(x::Number, r::StepRangeLen) = StepRangeLen(x\r.ref, x\r.step, length(r), r.offset) -\(x::Number, r::LinRange) = LinRange(x \ r.start, x \ r.stop, r.len) - -## scalar-range broadcast operations ## - -broadcast(::typeof(-), r::OrdinalRange) = range(-first(r), step=-step(r), length=length(r)) -broadcast(::typeof(-), r::StepRangeLen) = StepRangeLen(-r.ref, -r.step, length(r), r.offset) -broadcast(::typeof(-), r::LinRange) = LinRange(-r.start, -r.stop, length(r)) - -broadcast(::typeof(+), x::Real, r::AbstractUnitRange) = range(x + first(r), length=length(r)) -# For #18336 we need to prevent promotion of the step type: -broadcast(::typeof(+), x::Number, r::AbstractUnitRange) = range(x + first(r), step=step(r), length=length(r)) -broadcast(::typeof(+), x::Number, r::AbstractRange) = (x+first(r)):step(r):(x+last(r)) -function broadcast(::typeof(+), x::Number, r::StepRangeLen{T}) where T - newref = x + r.ref - StepRangeLen{typeof(T(r.ref) + x)}(newref, r.step, length(r), r.offset) -end -function broadcast(::typeof(+), x::Number, r::LinRange) - LinRange(x + r.start, x + r.stop, r.len) -end -broadcast(::typeof(+), r::AbstractRange, x::Number) = broadcast(+, x, r) # assumes addition is commutative - -broadcast(::typeof(-), x::Number, r::AbstractRange) = (x-first(r)):-step(r):(x-last(r)) -broadcast(::typeof(-), x::Number, r::StepRangeLen) = broadcast(+, x, -r) -function broadcast(::typeof(-), x::Number, r::LinRange) - LinRange(x - r.start, x - r.stop, r.len) -end - -broadcast(::typeof(-), r::AbstractRange, x::Number) = broadcast(+, -x, r) # assumes addition is commutative - -broadcast(::typeof(*), x::Number, r::AbstractRange) = range(x*first(r), step=x*step(r), length=length(r)) -broadcast(::typeof(*), x::Number, r::StepRangeLen) = StepRangeLen(x*r.ref, x*r.step, length(r), r.offset) -broadcast(::typeof(*), x::Number, r::LinRange) = LinRange(x * r.start, x * r.stop, r.len) -# separate in case of noncommutative multiplication -broadcast(::typeof(*), r::AbstractRange, x::Number) = range(first(r)*x, step=step(r)*x, length=length(r)) -broadcast(::typeof(*), r::StepRangeLen, x::Number) = StepRangeLen(r.ref*x, r.step*x, length(r), r.offset) -broadcast(::typeof(*), r::LinRange, x::Number) = LinRange(r.start * x, r.stop * x, r.len) - -broadcast(::typeof(/), r::AbstractRange, x::Number) = range(first(r)/x, step=step(r)/x, length=length(r)) -broadcast(::typeof(/), r::StepRangeLen, x::Number) = StepRangeLen(r.ref/x, r.step/x, length(r), r.offset) -broadcast(::typeof(/), r::LinRange, x::Number) = LinRange(r.start / x, r.stop / x, r.len) -# also, separate in case of noncommutative multiplication (division) -broadcast(::typeof(\), x::Number, r::AbstractRange) = range(x\first(r), step=x\step(r), length=x\length(r)) -broadcast(::typeof(\), x::Number, r::StepRangeLen) = StepRangeLen(x\r.ref, x\r.step, length(r), r.offset) -broadcast(::typeof(\), x::Number, r::LinRange) = LinRange(x \ r.start, x \ r.stop, r.len) # promote eltype if at least one container wouldn't change, otherwise join container types. el_same(::Type{T}, a::Type{<:AbstractArray{T,n}}, b::Type{<:AbstractArray{T,n}}) where {T,n} = a @@ -851,8 +790,6 @@ promote_rule(a::Type{LinRange{T}}, ::Type{OR}) where {T,OR<:OrdinalRange} = promote_rule(::Type{LinRange{L}}, b::Type{StepRangeLen{T,R,S}}) where {L,T,R,S} = promote_rule(StepRangeLen{L,L,L}, b) -# +/- of ranges is defined in operators.jl (to be able to use @eval etc.) - ## concatenation ## function vcat(rs::AbstractRange{T}...) where T @@ -960,6 +897,3 @@ function +(r1::StepRangeLen{T,S}, r2::StepRangeLen{T,S}) where {T,S} end -(r1::StepRangeLen, r2::StepRangeLen) = +(r1, -r2) - -broadcast(::typeof(+), r1::AbstractRange, r2::AbstractRange) = r1 + r2 -broadcast(::typeof(-), r1::AbstractRange, r2::AbstractRange) = r1 - r2 diff --git a/base/reducedim.jl b/base/reducedim.jl index a556fbe3667aa..2b0dc06321cad 100644 --- a/base/reducedim.jl +++ b/base/reducedim.jl @@ -218,7 +218,7 @@ function _mapreducedim!(f, op, R::AbstractArray, A::AbstractArray) return R end indsAt, indsRt = safe_tail(axes(A)), safe_tail(axes(R)) # handle d=1 manually - keep, Idefault = Broadcast.shapeindexer(indsAt, indsRt) + keep, Idefault = Broadcast.shapeindexer(indsRt) if reducedim1(R, A) # keep the accumulator as a local variable when reducing along the first dimension i1 = first(indices1(R)) @@ -667,7 +667,7 @@ function findminmax!(f, Rval, Rind, A::AbstractArray{T,N}) where {T,N} # If we're reducing along dimension 1, for efficiency we can make use of a temporary. # Otherwise, keep the result in Rval/Rind so that we traverse A in storage order. indsAt, indsRt = safe_tail(axes(A)), safe_tail(axes(Rval)) - keep, Idefault = Broadcast.shapeindexer(indsAt, indsRt) + keep, Idefault = Broadcast.shapeindexer(indsRt) ks = keys(A) k, kss = next(ks, start(ks)) zi = zero(eltype(ks)) diff --git a/base/sort.jl b/base/sort.jl index 586ec4dedbf53..adf90a5162b78 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -95,9 +95,12 @@ issorted(itr; function partialsort!(v::AbstractVector, k::Union{Int,OrdinalRange}, o::Ordering) inds = axes(v, 1) sort!(v, first(inds), last(inds), PartialQuickSort(k), o) - @views v[k] + maybeview(v, k) end +maybeview(v, k) = view(v, k) +maybeview(v, k::Integer) = v[k] + """ partialsort!(v, k; by=, lt=, rev=false) @@ -716,7 +719,7 @@ function partialsortperm!(ix::AbstractVector{<:Integer}, v::AbstractVector, # do partial quicksort sort!(ix, PartialQuickSort(k), Perm(ord(lt, by, rev, order), v)) - @views ix[k] + maybeview(ix, k) end ## sortperm: the permutation to sort an array ## diff --git a/base/statistics.jl b/base/statistics.jl index 3b0bbb5b9f9ac..350e64639a034 100644 --- a/base/statistics.jl +++ b/base/statistics.jl @@ -145,7 +145,7 @@ function centralize_sumabs2!(R::AbstractArray{S}, A::AbstractArray, means::Abstr return R end indsAt, indsRt = safe_tail(axes(A)), safe_tail(axes(R)) # handle d=1 manually - keep, Idefault = Broadcast.shapeindexer(indsAt, indsRt) + keep, Idefault = Broadcast.shapeindexer(indsRt) if reducedim1(R, A) i1 = first(indices1(R)) @inbounds for IA in CartesianIndices(indsAt) diff --git a/doc/src/base/arrays.md b/doc/src/base/arrays.md index 3357292838d57..2b1a8d7236e56 100644 --- a/doc/src/base/arrays.md +++ b/doc/src/base/arrays.md @@ -69,7 +69,7 @@ For specializing broadcast on custom types, see ```@docs Base.BroadcastStyle Base.broadcast_similar -Base.broadcast_indices +Base.broadcast_axes Base.Broadcast.AbstractArrayStyle Base.Broadcast.ArrayStyle Base.Broadcast.DefaultArrayStyle diff --git a/doc/src/manual/interfaces.md b/doc/src/manual/interfaces.md index d13bd756da854..e237818334ed7 100644 --- a/doc/src/manual/interfaces.md +++ b/doc/src/manual/interfaces.md @@ -435,22 +435,22 @@ V = view(A, [1,2,4], :) # is not strided, as the spacing between rows is not f -## [Broadcasting](@id man-interfaces-broadcasting) +## [Customizing broadcasting](@id man-interfaces-broadcasting) | Methods to implement | Brief description | |:-------------------- |:----------------- | | `Base.BroadcastStyle(::Type{SrcType}) = SrcStyle()` | Broadcasting behavior of `SrcType` | -| `Base.broadcast_similar(f, ::DestStyle, ::Type{ElType}, inds, As...)` | Allocation of output container | +| `Base.broadcast_similar(::DestStyle, ::Type{ElType}, inds, bc)` | Allocation of output container | | **Optional methods** | | | | `Base.BroadcastStyle(::Style1, ::Style2) = Style12()` | Precedence rules for mixing styles | -| `Base.broadcast_indices(::StyleA, A)` | Declaration of the indices of `A` for broadcasting purposes (defaults to [`axes(A)`](@ref)) | +| `Base.broadcast_axes(::StyleA, A)` | Declaration of the indices of `A` for broadcasting purposes (defaults to [`axes(A)`](@ref)) | | `Base.broadcastable(x)` | Convert `x` to an object that has `axes` and supports indexing | | **Bypassing default machinery** | | -| `broadcast(f, As...)` | Complete bypass of broadcasting machinery | -| `broadcast(f, ::DestStyle, ::Nothing, ::Nothing, As...)` | Bypass after container type is computed | -| `broadcast(f, ::DestStyle, ::Type{ElType}, inds::Tuple, As...)` | Bypass after container type, eltype, and indices are computed | -| `broadcast!(f, dest::DestType, ::Nothing, As...)` | Bypass in-place broadcast, specialization on destination type | -| `broadcast!(f, dest, ::BroadcastStyle, As...)` | Bypass in-place broadcast, specialization on `BroadcastStyle` | +| `Base.copy(bc::Broadcasted{DestStyle})` | Custom implementation of `broadcast` | +| `Base.copyto!(dest, bc::Broadcasted{DestStyle})` | Custom implementation of `broadcast!`, specializing on `DestStyle` | +| `Base.copyto!(dest::DestType, bc::Broadcasted{Nothing})` | Custom implementation of `broadcast!`, specializing on `DestType` | +| `Base.Broadcast.make(f, args...)` | Override the default lazy behavior within a fused expression | +| `Base.Broadcast.instantiate(bc::Broadcasted{DestStyle})` | Override the computation of the wrapper's axes and indexers | [Broadcasting](@ref) is triggered by an explicit call to `broadcast` or `broadcast!`, or implicitly by "dot" operations like `A .+ b` or `f.(x, y)`. Any object that has [`axes`](@ref) and supports @@ -463,16 +463,16 @@ in an `Array`. This basic framework is extensible in three major ways: Not all types support `axes` and indexing, but many are convenient to allow in broadcast. The [`Base.broadcastable`](@ref) function is called on each argument to broadcast, allowing -it to return something different that supports `axes` and indexing if it does not. By +it to return something different that supports `axes` and indexing. By default, this is the identity function for all `AbstractArray`s and `Number`s — they already support `axes` and indexing. For a handful of other types (including but not limited to types themselves, functions, special singletons like `missing` and `nothing`, and dates), `Base.broadcastable` returns the argument wrapped in a `Ref` to act as a 0-dimensional "scalar" for the purposes of broadcasting. Custom types can similarly specialize `Base.broadcastable` to define their shape, but they should follow the convention that -`collect(Base.broadcastable(x)) == collect(x)`. A notable exception are `AbstractString`s; -they are special-cased to behave as scalars for the purposes of broadcast even though they -are iterable collections of their characters. +`collect(Base.broadcastable(x)) == collect(x)`. A notable exception is `AbstractString`; +strings are special-cased to behave as scalars for the purposes of broadcast even though +they are iterable collections of their characters. The next two steps (selecting the output array and implementation) are dependent upon determining a single answer for a given set of arguments. Broadcast must take all the varied @@ -483,12 +483,11 @@ styles into a single answer — the "destination style". ### Broadcast Styles -`Base.BroadcastStyle` is the abstract type from which all styles are -derived. When used as a function it has two possible forms, -unary (single-argument) and binary. -The unary variant states that you intend to -implement specific broadcasting behavior and/or output type, -and do not wish to rely on the default fallback ([`Broadcast.DefaultArrayStyle`](@ref)). +`Base.BroadcastStyle` is the abstract type from which all broadcast styles are derived. When used as a +function it has two possible forms, unary (single-argument) and binary. The unary variant states +that you intend to implement specific broadcasting behavior and/or output type, and do not wish to +rely on the default fallback [`Broadcast.DefaultArrayStyle`](@ref). + To override these defaults, you can define a custom `BroadcastStyle` for your object: ```julia @@ -507,27 +506,30 @@ leverage one of the general broadcast wrappers: When your broadcast operation involves several arguments, individual argument styles get combined to determine a single `DestStyle` that controls the type of the output container. -For more detail, see [below](@ref writing-binary-broadcasting-rules). +For more details, see [below](@ref writing-binary-broadcasting-rules). ### Selecting an appropriate output array -The actual allocation of the result array is handled by `Base.broadcast_similar`: +The broadcast style is computed for every broadcasting operation to allow for +dispatch and specialization. The actual allocation of the result array is +handled by `Base.broadcast_similar`, using this style as its first argument. ```julia -Base.broadcast_similar(f, ::DestStyle, ::Type{ElType}, inds, As...) +Base.broadcast_similar(::DestStyle, ::Type{ElType}, inds, bc) ``` -`f` is the operation being performed and `DestStyle` signals the final result from -combining the input styles. -`As...` is the list of input objects. You may not need to use `f` or `As...` -unless they help you build the appropriate object; the fallback definition is +The fallback definition is ```julia -broadcast_similar(f, ::DefaultArrayStyle{N}, ::Type{ElType}, inds::Indices{N}, As...) where {N,ElType} = +broadcast_similar(::DefaultArrayStyle{N}, ::Type{ElType}, inds::Indices{N}, bc) where {N,ElType} = similar(Array{ElType}, inds) ``` -However, if needed you can specialize on any or all of these arguments. +However, if needed you can specialize on any or all of these arguments. The final argument +`bc` is a lazy representation of a (potentially fused) broadcast operation, a `Broadcasted` +object. For these purposes, the most important fields of the wrapper are +`f` and `args`, describing the function and argument list, respectively. Note that the argument +list can — and often does — include other nested `Broadcasted` wrappers. For a complete example, let's say you have created a type, `ArrayAndChar`, that stores an array and a single character: @@ -553,20 +555,21 @@ Base.BroadcastStyle(::Type{<:ArrayAndChar}) = Broadcast.ArrayStyle{ArrayAndChar} ``` -This forces us to also define a `broadcast_similar` method: -```jldoctest ArrayAndChar; filter = r"(^find_aac \(generic function with 2 methods\)$|^$)" -function Base.broadcast_similar(f, ::Broadcast.ArrayStyle{ArrayAndChar}, ::Type{ElType}, inds, As...) where ElType +This means we must also define a corresponding `broadcast_similar` method: +```jldoctest +function Base.broadcast_similar(::Broadcast.ArrayStyle{ArrayAndChar}, ::Type{ElType}, inds, bc) where ElType # Scan the inputs for the ArrayAndChar: - A = find_aac(As...) + A = find_aac(bc) # Use the char field of A to create the output ArrayAndChar(similar(Array{ElType}, inds), A.char) end -"`A = find_aac(As...)` returns the first ArrayAndChar among the arguments." -find_aac(A::ArrayAndChar, B...) = A -find_aac(A, B...) = find_aac(B...); -# output - +"`A = find_aac(As)` returns the first ArrayAndChar among the arguments." +find_aac(bc::Base.Broadcast.Broadcasted) = find_aac(bc.args) +find_aac(args::Tuple) = find_aac(find_aac(args[1]), Base.tail(args)) +find_aac(x) = x +find_aac(a::ArrayAndChar, rest) = a +find_aac(::Any, rest) = find_aac(rest) ``` From these definitions, one obtains the following behavior: @@ -589,58 +592,86 @@ julia> a .+ [5,10] ### [Extending broadcast with custom implementations](@id extending-in-place-broadcast) -Finally, it's worth noting that sometimes it's easier simply to bypass the machinery for -computing result types and container sizes, and just do everything manually. For example, -you can convert a `UnitRange{Int}` `r` to a `UnitRange{BigInt}` with `big.(r)`; the definition -of this method is approximately +In general, a broadcast operation is represented by a lazy `Broadcasted` container that holds onto +the function to be applied alongside its arguments. Those arguments may themselves be more nested +`Broadcasted` containers, forming a large expression tree to be evaluated. A nested tree of +`Broadcasted` containers is directly constructed by the implicit dot syntax; `5 .+ 2.*x` is +transiently represented by `Broadcasted(+, 5, Broadcasted(*, 2, x))`, for example. This is +invisible to users as it is immediately realized through a call to `copy`, but it is this container +that provides the basis for broadcast's extensibility for authors of custom types. The built-in +broadcast machinery will then determine the result type and size based upon the arguments, allocate +it, and then finally copy the realization of the `Broadcasted` object into it with a default +`copyto!(::AbstractArray, ::Broadcasted)` method. The built-in fallback `broadcast` and +`broadcast!` methods similarly construct a transient `Broadcasted` representation of the operation +so they can follow the same codepath. This allows custom array implementations to +provide their own `copyto!` specialization to customize and +optimize broadcasting. This is again determined by the computed broadcast style. This is such +an important part of the operation that it is stored as the first type parameter of the +`Broadcasted` type, allowing for dispatch and specialization. + +For some types, the machinery to "fuse" operations across nested levels of broadcasting +is not available or could be done more efficiently incrementally. In such cases, you may +need or want to evaluate `x .* (x .+ 1)` as if it had been +written `broadcast(*, x, broadcast(+, x, 1))`, where the inner operation is evaluated before +tackling the outer operation. This sort of eager operation is directly supported by a bit +of indirection; instead of directly constructing `Broadcasted` objects, Julia lowers the +fused expression `x .* (x .+ 1)` to `Broadcast.make(*, x, Broadcast.make(+, x, 1))`. Now, +by default, `make` just calls the `Broadcasted` constructor to create the lazy representation +of the fused expression tree, but you can choose to override it for a particular combination +of function and arguments. + +As an example, the builtin `AbstractRange` objects use this machinery to optimize pieces +of broadcasted expressions that can be eagerly evaluated purely in terms of the start, +step, and length (or stop) instead of computing every single element. Just like all the +other machinery, `make` also computes and exposes the combined broadcast style of its +arguments, so instead of specializing on `make(f, args...)`, you can specialize on +`make(::DestStyle, f, args...)` for any combination of style, function, and arguments. + +For example, the following definition supports the negation of ranges: ```julia -Broadcast.broadcast(::typeof(big), r::UnitRange) = big(first(r)):big(last(r)) +make(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange) = range(-first(r), step=-step(r), length=length(r)) ``` -This exploits Julia's ability to dispatch on a particular function type. (This kind of -explicit definition can indeed be necessary if the output container does not support `setindex!`.) -You can optionally choose to implement the actual broadcasting yourself, but allow -the internal machinery to compute the container type, element type, and indices by specializing - -```julia -Broadcast.broadcast(::typeof(somefunction), ::MyStyle, ::Type{ElType}, inds, As...) -``` +### [Extending in-place broadcasting](@id extending-in-place-broadcast) -Extending `broadcast!` (in-place broadcast) should be done with care, as it is easy to introduce -ambiguities between packages. To avoid these ambiguities, we adhere to the following conventions. - -First, if you want to specialize on the destination type, say `DestType`, then you should -define a method with the following signature: +In-place broadcasting can be supported by defining the appropriate `copyto!(dest, bc::Broadcasted)` +method. Because you might want to specialize either on `dest` or the specific subtype of `bc`, +to avoid ambiguities between packages we recommend the following convention. +If you wish to specialize on a particular style `DestStyle`, define a method for ```julia -broadcast!(f, dest::DestType, ::Nothing, As...) +copyto!(dest, bc::Broadcasted{DestStyle}) ``` +Optionally, with this form you can also specialize on the type of `dest`. -Note that no bounds should be placed on the types of `f` and `As...`. - -Second, if specialized `broadcast!` behavior is desired depending on the input types, -you should write [binary broadcasting rules](@ref writing-binary-broadcasting-rules) to -determine a custom `BroadcastStyle` given the input types, say `MyBroadcastStyle`, and you should define a method with the following -signature: +If instead you want to specialize on the destination type `DestType` without specializing +on `DestStyle`, then you should define a method with the following signature: ```julia -broadcast!(f, dest, ::MyBroadcastStyle, As...) +copyto!(dest::DestType, bc::Broadcasted{Nothing}) ``` -Note the lack of bounds on `f`, `dest`, and `As...`. +This leverages a fallback implementation of `copyto!` that converts the wrapper into a +`Broadcasted{Nothing}`. Consequently, specializing on `DestType` has lower precedence than +methods that specialize on `DestStyle`. -Third, simultaneously specializing on both the type of `dest` and the `BroadcastStyle` is fine. In this case, -it is also allowed to specialize on the types of the source arguments (`As...`). For example, these method signatures are OK: +Similarly, you can completely override out-of-place broadcasting with a `copy(::Broadcasted)` +method. -```julia -broadcast!(f, dest::DestType, ::MyBroadcastStyle, As...) -broadcast!(f, dest::DestType, ::MyBroadcastStyle, As::AbstractArray...) -broadcast!(f, dest::DestType, ::Broadcast.DefaultArrayStyle{0}, As::Number...) -``` +#### Working with `Broadcasted` objects + +In order to implement such a `copy` or `copyto!`, method, of course, you must +work with the `Broadcasted` wrapper to compute each element. There are two main +ways of doing so: +* `Broadcast.flatten` recomputes the potentially nested operation into a single + function and flat list of arguments. You are responsible for implementing the + broadcasting shape rules yourself, but this may be helpful in limited situations. +* Iterating over the `CartesianIndices` of the `axes(::Broadcasted)` and using + indexing with the resulting `CartesianIndex` object to compute the result. -#### [Writing binary broadcasting rules](@id writing-binary-broadcasting-rules) +### [Writing binary broadcasting rules](@id writing-binary-broadcasting-rules) The precedence rules are defined by binary `BroadcastStyle` calls: diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index d911737cb088c..ea7956bb1f369 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -1671,53 +1671,11 @@ `(block ,@stmts ,nuref)) expr)) -; fuse nested calls to expr == f.(args...) into a single broadcast call, +; lazily fuse nested calls to expr == f.(args...) into a single broadcast call, ; or a broadcast! call if lhs is non-null. (define (expand-fuse-broadcast lhs rhs) (define (fuse? e) (and (pair? e) (eq? (car e) 'fuse))) - (define (anyfuse? exprs) - (if (null? exprs) #f (if (fuse? (car exprs)) #t (anyfuse? (cdr exprs))))) - (define (to-lambda f args kwargs) ; convert f to anonymous function with hygienic tuple args - (define (genarg arg) (if (vararg? arg) (list '... (gensy)) (gensy))) - ; (To do: optimize the case where f is already an anonymous function, in which - ; case we only need to hygienicize the arguments? But it is quite tricky - ; to fully handle splatted args, typed args, keywords, etcetera. And probably - ; the extra function call is harmless because it will get inlined anyway.) - (let ((genargs (map genarg args))) ; hygienic formal parameters - (if (null? kwargs) - `(-> ,(cons 'tuple genargs) (call ,f ,@genargs)) ; no keyword args - `(-> ,(cons 'tuple genargs) (call ,f (parameters ,@kwargs) ,@genargs))))) - (define (from-lambda f) ; convert (-> (tuple args...) (call func args...)) back to func - (if (and (pair? f) (eq? (car f) '->) (pair? (cadr f)) (eq? (caadr f) 'tuple) - (pair? (caddr f)) (eq? (caaddr f) 'call) (equal? (cdadr f) (cdr (cdaddr f)))) - (car (cdaddr f)) - f)) - (define (fuse-args oldargs) ; replace (fuse f args) with args in oldargs list - (define (fargs newargs oldargs) - (if (null? oldargs) - newargs - (fargs (if (fuse? (car oldargs)) - (append (reverse (caddar oldargs)) newargs) - (cons (car oldargs) newargs)) - (cdr oldargs)))) - (reverse (fargs '() oldargs))) - (define (fuse-funcs f args) ; for (fuse g a) in args, merge/inline g into f - ; any argument A of f that is (fuse g a) gets replaced by let A=(body of g): - (define (fuse-lets fargs args lets) - (if (null? args) - lets - (if (fuse? (car args)) - (fuse-lets (cdr fargs) (cdr args) (cons (list '= (car fargs) (caddr (cadar args))) lets)) - (fuse-lets (cdr fargs) (cdr args) lets)))) - (let ((fargs (cdadr f)) - (fbody (caddr f))) - `(-> - (tuple ,@(fuse-args (map (lambda (oldarg arg) (if (fuse? arg) - `(fuse _ ,(cdadr (cadr arg))) - oldarg)) - fargs args))) - (let (block ,@(reverse (fuse-lets fargs args '()))) ,fbody)))) - (define (dot-to-fuse e) ; convert e == (. f (tuple args)) to (fuse f args) + (define (dot-to-fuse e (top #f)) ; convert e == (. f (tuple args)) to (fuse f args) (define (make-fuse f args) ; check for nested (fuse f args) exprs and combine (define (split-kwargs args) ; return (cons keyword-args positional-args) extracted from args (define (sk args kwargs pargs) @@ -1729,78 +1687,43 @@ (if (has-parameters? args) (sk (reverse (cdr args)) (cdar args) '()) (sk (reverse args) '() '()))) - (let* ((kws.args (split-kwargs args)) - (kws (car kws.args)) - (args (cdr kws.args)) ; fusing occurs on positional args only - (args_ (map dot-to-fuse args))) - (if (anyfuse? args_) - `(fuse ,(fuse-funcs (to-lambda f args kws) args_) ,(fuse-args args_)) - `(fuse ,(to-lambda f args kws) ,args_)))) + (let* ((kws+args (split-kwargs args)) ; fusing occurs on positional args only + (kws (car kws+args)) + (kws (if (null? kws) kws (list (cons 'parameters kws)))) + (args (map dot-to-fuse (cdr kws+args))) + (make `(call (|.| (top Broadcast) ,(if (null? kws) ''make ''make_kwsyntax)) ,@kws ,f ,@args))) + (if top (cons 'fuse make) make))) (if (and (pair? e) (eq? (car e) '|.|)) (let ((f (cadr e)) (x (caddr e))) (cond ((or (atom? x) (eq? (car x) 'quote) (eq? (car x) 'inert) (eq? (car x) '$)) `(call (top getproperty) ,f ,x)) ((eq? (car x) 'tuple) - (make-fuse f (cdr x))) + (if (and (eq? f '^) (length= x 3) (integer? (caddr x))) + (make-fuse (expand-forms '(top literal_pow)) + (list '^ (cadr x) (expand-forms `(call (call (core apply_type) (top Val) ,(caddr x)))))) + (make-fuse f (cdr x)))) (else (error (string "invalid syntax \"" (deparse e) "\""))))) (if (and (pair? e) (eq? (car e) 'call) (dotop? (cadr e))) - (make-fuse (undotop (cadr e)) (cddr e)) + (let ((f (undotop (cadr e))) (x (cddr e))) + (if (and (eq? f '^) (length= x 2) (integer? (cadr x))) + (make-fuse (expand-forms '(top literal_pow)) + (list '^ (car x) (expand-forms `(call (call (core apply_type) (top Val) ,(cadr x)))))) + (make-fuse f x))) e))) - ; given e == (fuse lambda args), compress the argument list by removing (pure) - ; duplicates in args, inlining literals, and moving any varargs to the end: - (define (compress-fuse e) - (define (findfarg arg args fargs) ; for arg in args, return corresponding farg - (if (eq? arg (car args)) - (car fargs) - (findfarg arg (cdr args) (cdr fargs)))) - (if (fuse? e) - (let ((f (cadr e)) - (args (caddr e))) - (define (cf old-fargs old-args new-fargs new-args renames varfarg vararg) - (if (null? old-args) - (let ((nfargs (if (null? varfarg) new-fargs (cons varfarg new-fargs))) - (nargs (if (null? vararg) new-args (cons vararg new-args)))) - `(fuse (-> (tuple ,@(reverse nfargs)) ,(replace-vars (caddr f) renames)) - ,(reverse nargs))) - (let ((farg (car old-fargs)) (arg (car old-args))) - (cond - ((and (vararg? farg) (vararg? arg)) ; arg... must be the last argument - (if (null? varfarg) - (cf (cdr old-fargs) (cdr old-args) - new-fargs new-args renames farg arg) - (if (eq? (cadr vararg) (cadr arg)) - (cf (cdr old-fargs) (cdr old-args) - new-fargs new-args (cons (cons (cadr farg) (cadr varfarg)) renames) - varfarg vararg) - (error "multiple splatted args cannot be fused into a single broadcast")))) - ((julia-scalar? arg) ; inline numeric literals etc. - (cf (cdr old-fargs) (cdr old-args) - new-fargs new-args - (cons (cons farg arg) renames) - varfarg vararg)) - ((and (symbol? arg) (memq arg new-args)) ; combine duplicate args - ; (note: calling memq for every arg is O(length(args)^2) ... - ; ... would be better to replace with a hash table if args is long) - (cf (cdr old-fargs) (cdr old-args) - new-fargs new-args - (cons (cons farg (findfarg arg new-args new-fargs)) renames) - varfarg vararg)) - (else - (cf (cdr old-fargs) (cdr old-args) - (cons farg new-fargs) (cons arg new-args) renames varfarg vararg)))))) - (cf (cdadr f) args '() '() '() '() '())) - e)) ; (not (fuse? e)) - (let ((e (compress-fuse (dot-to-fuse rhs))) ; an expression '(fuse func args) if expr is a dot call + (let ((e (dot-to-fuse rhs #t)) ; an expression '(fuse func args) if expr is a dot call (lhs-view (ref-to-view lhs))) ; x[...] expressions on lhs turn in to view(x, ...) to update x in-place (if (fuse? e) + ; expanded to a fuse op call (if (null? lhs) - (expand-forms `(call (top broadcast) ,(from-lambda (cadr e)) ,@(caddr e))) - (expand-forms `(call (top broadcast!) ,(from-lambda (cadr e)) ,lhs-view ,@(caddr e)))) + (expand-forms `(call (|.| (top Broadcast) 'materialize) ,(cdr e))) + (expand-forms `(call (|.| (top Broadcast) 'materialize!) ,lhs-view ,(cdr e)))) + ; expanded to something else (like a getfield) (if (null? lhs) (expand-forms e) (expand-forms `(call (top broadcast!) (top identity) ,lhs-view ,e)))))) + (define (expand-where body var) (let* ((bounds (analyze-typevar var)) (v (car bounds))) diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index 5e8fd1ac1c517..2e6f801dc992c 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -19,6 +19,8 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as StridedReshapedArray, strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec using Base: hvcat_fill, iszero, IndexLinear, _length, promote_op, promote_typeof, @propagate_inbounds, @pure, reduce, typed_vcat +using Base.Broadcast: Broadcasted + # We use `_length` because of non-1 indices; releases after julia 0.5 # can go back to `length`. `_length(A)` is equivalent to `length(linearindices(A))`. @@ -327,6 +329,7 @@ include("special.jl") include("bitarray.jl") include("ldlt.jl") include("schur.jl") +include("structuredbroadcast.jl") include("deprecated.jl") const ⋅ = dot diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index d0a5ed25523de..9d2d617b78023 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -174,8 +174,6 @@ AbstractMatrix{T}(A::Bidiagonal) where {T} = convert(Bidiagonal{T}, A) convert(T::Type{<:Bidiagonal}, m::AbstractMatrix) = m isa T ? m : T(m) -broadcast(::typeof(big), B::Bidiagonal) = Bidiagonal(big.(B.dv), big.(B.ev), B.uplo) - # For B<:Bidiagonal, similar(B[, neweltype]) should yield a Bidiagonal matrix. # On the other hand, similar(B, [neweltype,] shape...) should yield a sparse matrix. # The first method below effects the former, and the second the latter. @@ -237,18 +235,9 @@ function size(M::Bidiagonal, d::Integer) end #Elementary operations -broadcast(::typeof(abs), M::Bidiagonal) = Bidiagonal(abs.(M.dv), abs.(M.ev), M.uplo) -broadcast(::typeof(round), M::Bidiagonal) = Bidiagonal(round.(M.dv), round.(M.ev), M.uplo) -broadcast(::typeof(trunc), M::Bidiagonal) = Bidiagonal(trunc.(M.dv), trunc.(M.ev), M.uplo) -broadcast(::typeof(floor), M::Bidiagonal) = Bidiagonal(floor.(M.dv), floor.(M.ev), M.uplo) -broadcast(::typeof(ceil), M::Bidiagonal) = Bidiagonal(ceil.(M.dv), ceil.(M.ev), M.uplo) for func in (:conj, :copy, :real, :imag) @eval ($func)(M::Bidiagonal) = Bidiagonal(($func)(M.dv), ($func)(M.ev), M.uplo) end -broadcast(::typeof(round), ::Type{T}, M::Bidiagonal) where {T<:Integer} = Bidiagonal(round.(T, M.dv), round.(T, M.ev), M.uplo) -broadcast(::typeof(trunc), ::Type{T}, M::Bidiagonal) where {T<:Integer} = Bidiagonal(trunc.(T, M.dv), trunc.(T, M.ev), M.uplo) -broadcast(::typeof(floor), ::Type{T}, M::Bidiagonal) where {T<:Integer} = Bidiagonal(floor.(T, M.dv), floor.(T, M.ev), M.uplo) -broadcast(::typeof(ceil), ::Type{T}, M::Bidiagonal) where {T<:Integer} = Bidiagonal(ceil.(T, M.dv), ceil.(T, M.ev), M.uplo) adjoint(B::Bidiagonal) = Adjoint(B) transpose(B::Bidiagonal) = Transpose(B) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 1470fec31e406..ed01cadd9f92f 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -112,7 +112,6 @@ isposdef(D::Diagonal) = all(x -> x > 0, D.diag) factorize(D::Diagonal) = D -broadcast(::typeof(abs), D::Diagonal) = Diagonal(abs.(D.diag)) real(D::Diagonal) = Diagonal(real(D.diag)) imag(D::Diagonal) = Diagonal(imag(D.diag)) diff --git a/stdlib/LinearAlgebra/src/structuredbroadcast.jl b/stdlib/LinearAlgebra/src/structuredbroadcast.jl new file mode 100644 index 0000000000000..24c410b2a299b --- /dev/null +++ b/stdlib/LinearAlgebra/src/structuredbroadcast.jl @@ -0,0 +1,180 @@ +## Broadcast styles +import Base.Broadcast +using Base.Broadcast: DefaultArrayStyle, broadcast_similar, tail + +struct StructuredMatrixStyle{T} <: Broadcast.AbstractArrayStyle{2} end +StructuredMatrixStyle{T}(::Val{2}) where {T} = StructuredMatrixStyle{T}() +StructuredMatrixStyle{T}(::Val{N}) where {T,N} = Broadcast.DefaultArrayStyle{N}() + +const StructuredMatrix = Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular} +Broadcast.BroadcastStyle(::Type{T}) where {T<:StructuredMatrix} = StructuredMatrixStyle{T}() + +# Promotion of broadcasts between structured matrices. This is slightly unusual +# as we define them symmetrically. This allows us to have a fallback to DefaultArrayStyle{2}(). +# Diagonal can cavort with all the other structured matrix types. +# Bidiagonal doesn't know if it's upper or lower, so it becomes Tridiagonal +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Diagonal}, ::StructuredMatrixStyle{<:Diagonal}) = + StructuredMatrixStyle{Diagonal}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Diagonal}, ::StructuredMatrixStyle{<:Union{Bidiagonal,SymTridiagonal,Tridiagonal}}) = + StructuredMatrixStyle{Tridiagonal}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Diagonal}, ::StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}}) = + StructuredMatrixStyle{LowerTriangular}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Diagonal}, ::StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}}) = + StructuredMatrixStyle{UpperTriangular}() + +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Bidiagonal}, ::StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}}) = + StructuredMatrixStyle{Tridiagonal}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:SymTridiagonal}, ::StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}}) = + StructuredMatrixStyle{Tridiagonal}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:Tridiagonal}, ::StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}}) = + StructuredMatrixStyle{Tridiagonal}() + +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:LowerTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,LowerTriangular,UnitLowerTriangular}}) = + StructuredMatrixStyle{LowerTriangular}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:UpperTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}}) = + StructuredMatrixStyle{UpperTriangular}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:UnitLowerTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,LowerTriangular,UnitLowerTriangular}}) = + StructuredMatrixStyle{LowerTriangular}() +Broadcast.BroadcastStyle(::StructuredMatrixStyle{<:UnitUpperTriangular}, ::StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}}) = + StructuredMatrixStyle{UpperTriangular}() + +# All other combinations fall back to the default style +Broadcast.BroadcastStyle(::StructuredMatrixStyle, ::StructuredMatrixStyle) = DefaultArrayStyle{2}() + +# And a definition akin to similar using the structured type: +structured_broadcast_alloc(bc, ::Type{<:Diagonal}, ::Type{ElType}, n) where {ElType} = + Diagonal(Array{ElType}(undef, n)) +# Bidiagonal is tricky as we need to know if it's upper or lower. The promotion +# system will return Tridiagonal when there's more than one Bidiagonal, but when +# there's only one, we need to make figure out upper or lower +find_bidiagonal() = throw(ArgumentError("could not find Bidiagonal within broadcast expression")) +find_bidiagonal(a::Bidiagonal, rest...) = a +find_bidiagonal(bc::Broadcast.Broadcasted, rest...) = find_bidiagonal(find_bidiagonal(bc.args...), rest...) +find_bidiagonal(x, rest...) = find_bidiagonal(rest...) +function structured_broadcast_alloc(bc, ::Type{<:Bidiagonal}, ::Type{ElType}, n) where {ElType} + ex = find_bidiagonal(bc) + return Bidiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n-1), ex.uplo) +end +structured_broadcast_alloc(bc, ::Type{<:SymTridiagonal}, ::Type{ElType}, n) where {ElType} = + SymTridiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n-1)) +structured_broadcast_alloc(bc, ::Type{<:Tridiagonal}, ::Type{ElType}, n) where {ElType} = + Tridiagonal(Array{ElType}(undef, n-1),Array{ElType}(undef, n),Array{ElType}(undef, n-1)) +structured_broadcast_alloc(bc, ::Type{<:LowerTriangular}, ::Type{ElType}, n) where {ElType} = + LowerTriangular(Array{ElType}(undef, n, n)) +structured_broadcast_alloc(bc, ::Type{<:UpperTriangular}, ::Type{ElType}, n) where {ElType} = + UpperTriangular(Array{ElType}(undef, n, n)) +structured_broadcast_alloc(bc, ::Type{<:UnitLowerTriangular}, ::Type{ElType}, n) where {ElType} = + UnitLowerTriangular(Array{ElType}(undef, n, n)) +structured_broadcast_alloc(bc, ::Type{<:UnitUpperTriangular}, ::Type{ElType}, n) where {ElType} = + UnitUpperTriangular(Array{ElType}(undef, n, n)) + +# A _very_ limited list of structure-preserving functions known at compile-time. This list is +# derived from the formerly-implemented `broadcast` methods in 0.6. Note that this must +# preserve both zeros and ones (for Unit***erTriangular) and symmetry (for SymTridiagonal) +const TypeFuncs = Union{typeof(round),typeof(trunc),typeof(floor),typeof(ceil)} +isstructurepreserving(bc::Broadcasted) = isstructurepreserving(bc.f, bc.args...) +isstructurepreserving(::Union{typeof(abs),typeof(big)}, ::StructuredMatrix) = true +isstructurepreserving(::TypeFuncs, ::StructuredMatrix) = true +isstructurepreserving(::TypeFuncs, ::Ref{<:Type}, ::StructuredMatrix) = true +isstructurepreserving(f, args...) = false + +_iszero(n::Number) = iszero(n) +_iszero(x) = x == 0 +fzeropreserving(bc) = (v = fzero(bc); !ismissing(v) && _iszero(v)) +# Very conservatively only allow Numbers and Types in this speculative zero-test pass +fzero(x::Number) = x +fzero(::Type{T}) where T = T +fzero(S::StructuredMatrix) = zero(eltype(S)) +fzero(x) = missing +function fzero(bc::Broadcast.Broadcasted) + args = map(fzero, bc.args) + return any(ismissing, args) ? missing : bc.f(args...) +end + +function Broadcast.broadcast_similar(::StructuredMatrixStyle{T}, ::Type{ElType}, inds, bc) where {T,ElType} + if isstructurepreserving(bc) || (fzeropreserving(bc) && !(T <: Union{SymTridiagonal,UnitLowerTriangular,UnitUpperTriangular})) + return structured_broadcast_alloc(bc, T, ElType, length(inds[1])) + end + return broadcast_similar(DefaultArrayStyle{2}(), ElType, inds, bc) +end + +function copyto!(dest::Diagonal, bc::Broadcasted{<:StructuredMatrixStyle}) + axs = axes(dest) + axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + for i in axs[1] + dest.diag[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i)) + end + return dest +end + +function copyto!(dest::Bidiagonal, bc::Broadcasted{<:StructuredMatrixStyle}) + axs = axes(dest) + axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + for i in axs[1] + dest.dv[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i)) + end + if dest.uplo == 'U' + for i = 1:size(dest, 1)-1 + dest.ev[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1)) + end + else + for i = 1:size(dest, 1)-1 + dest.ev[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i+1, i)) + end + end + return dest +end + +function copyto!(dest::SymTridiagonal, bc::Broadcasted{<:StructuredMatrixStyle}) + axs = axes(dest) + axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + for i in axs[1] + dest.dv[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i)) + end + for i = 1:size(dest, 1)-1 + dest.ev[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1)) + end + return dest +end + +function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle}) + axs = axes(dest) + axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + for i in axs[1] + dest.d[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i)) + end + for i = 1:size(dest, 1)-1 + dest.du[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1)) + dest.dl[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i+1, i)) + end + return dest +end + +function copyto!(dest::LowerTriangular, bc::Broadcasted{<:StructuredMatrixStyle}) + axs = axes(dest) + axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + for j in axs[2] + for i in j:axs[1][end] + dest.data[i,j] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, j)) + end + end + return dest +end + +function copyto!(dest::UpperTriangular, bc::Broadcasted{<:StructuredMatrixStyle}) + axs = axes(dest) + axes(bc) == axs || Broadcast.throwdm(axes(bc), axs) + for j in axs[2] + for i in 1:j + dest.data[i,j] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, j)) + end + end + return dest +end + +# We can also implement `map` and its promotion in terms of broadcast with a stricter dimension check +function map(f, A::StructuredMatrix, Bs::StructuredMatrix...) + sz = size(A) + all(map(B->size(B)==sz, Bs)) || throw(DimensionMismatch("dimensions must match")) + return f.(A, Bs...) +end diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index b6bfcb81b13ad..c6ef222cacc7b 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -37,11 +37,8 @@ for t in (:LowerTriangular, :UnitLowerTriangular, :UpperTriangular, copy(A::$t) = $t(copy(A.data)) - broadcast(::typeof(big), A::$t) = $t(big.(A.data)) - real(A::$t{<:Real}) = A real(A::$t{<:Complex}) = (B = real(A.data); $t(B)) - broadcast(::typeof(abs), A::$t) = $t(abs.(A.data)) end end diff --git a/stdlib/LinearAlgebra/src/tridiag.jl b/stdlib/LinearAlgebra/src/tridiag.jl index 1baf99202c5e1..e29266452a372 100644 --- a/stdlib/LinearAlgebra/src/tridiag.jl +++ b/stdlib/LinearAlgebra/src/tridiag.jl @@ -115,18 +115,9 @@ similar(S::SymTridiagonal, ::Type{T}) where {T} = SymTridiagonal(similar(S.dv, T # similar(S::SymTridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = spzeros(T, dims...) #Elementary operations -broadcast(::typeof(abs), M::SymTridiagonal) = SymTridiagonal(abs.(M.dv), abs.(M.ev)) -broadcast(::typeof(round), M::SymTridiagonal) = SymTridiagonal(round.(M.dv), round.(M.ev)) -broadcast(::typeof(trunc), M::SymTridiagonal) = SymTridiagonal(trunc.(M.dv), trunc.(M.ev)) -broadcast(::typeof(floor), M::SymTridiagonal) = SymTridiagonal(floor.(M.dv), floor.(M.ev)) -broadcast(::typeof(ceil), M::SymTridiagonal) = SymTridiagonal(ceil.(M.dv), ceil.(M.ev)) for func in (:conj, :copy, :real, :imag) @eval ($func)(M::SymTridiagonal) = SymTridiagonal(($func)(M.dv), ($func)(M.ev)) end -broadcast(::typeof(round), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = SymTridiagonal(round.(T, M.dv), round.(T, M.ev)) -broadcast(::typeof(trunc), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = SymTridiagonal(trunc.(T, M.dv), trunc.(T, M.ev)) -broadcast(::typeof(floor), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = SymTridiagonal(floor.(T, M.dv), floor.(T, M.ev)) -broadcast(::typeof(ceil), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = SymTridiagonal(ceil.(T, M.dv), ceil.(T, M.ev)) transpose(S::SymTridiagonal) = S adjoint(S::SymTridiagonal{<:Real}) = S @@ -497,24 +488,11 @@ similar(M::Tridiagonal, ::Type{T}) where {T} = Tridiagonal(similar(M.dl, T), sim copyto!(dest::Tridiagonal, src::Tridiagonal) = (copyto!(dest.dl, src.dl); copyto!(dest.d, src.d); copyto!(dest.du, src.du); dest) #Elementary operations -broadcast(::typeof(abs), M::Tridiagonal) = Tridiagonal(abs.(M.dl), abs.(M.d), abs.(M.du)) -broadcast(::typeof(round), M::Tridiagonal) = Tridiagonal(round.(M.dl), round.(M.d), round.(M.du)) -broadcast(::typeof(trunc), M::Tridiagonal) = Tridiagonal(trunc.(M.dl), trunc.(M.d), trunc.(M.du)) -broadcast(::typeof(floor), M::Tridiagonal) = Tridiagonal(floor.(M.dl), floor.(M.d), floor.(M.du)) -broadcast(::typeof(ceil), M::Tridiagonal) = Tridiagonal(ceil.(M.dl), ceil.(M.d), ceil.(M.du)) for func in (:conj, :copy, :real, :imag) @eval function ($func)(M::Tridiagonal) Tridiagonal(($func)(M.dl), ($func)(M.d), ($func)(M.du)) end end -broadcast(::typeof(round), ::Type{T}, M::Tridiagonal) where {T<:Integer} = - Tridiagonal(round.(T, M.dl), round.(T, M.d), round.(T, M.du)) -broadcast(::typeof(trunc), ::Type{T}, M::Tridiagonal) where {T<:Integer} = - Tridiagonal(trunc.(T, M.dl), trunc.(T, M.d), trunc.(T, M.du)) -broadcast(::typeof(floor), ::Type{T}, M::Tridiagonal) where {T<:Integer} = - Tridiagonal(floor.(T, M.dl), floor.(T, M.d), floor.(T, M.du)) -broadcast(::typeof(ceil), ::Type{T}, M::Tridiagonal) where {T<:Integer} = - Tridiagonal(ceil.(T, M.dl), ceil.(T, M.d), ceil.(T, M.du)) adjoint(S::Tridiagonal) = Adjoint(S) transpose(S::Tridiagonal) = Transpose(S) @@ -577,6 +555,7 @@ function Base.replace_in_print_matrix(A::Tridiagonal,i::Integer,j::Integer,s::Ab i==j-1||i==j||i==j+1 ? s : Base.replace_with_centered_mark(s) end + #tril and triu istriu(M::Tridiagonal) = iszero(M.dl) diff --git a/stdlib/LinearAlgebra/src/uniformscaling.jl b/stdlib/LinearAlgebra/src/uniformscaling.jl index a1644e951100c..5c2ba8720111f 100644 --- a/stdlib/LinearAlgebra/src/uniformscaling.jl +++ b/stdlib/LinearAlgebra/src/uniformscaling.jl @@ -208,10 +208,10 @@ end \(x::Number, J::UniformScaling) = UniformScaling(x\J.λ) -broadcast(::typeof(*), x::Number,J::UniformScaling) = UniformScaling(x*J.λ) -broadcast(::typeof(*), J::UniformScaling,x::Number) = UniformScaling(J.λ*x) +Broadcast.make(::typeof(*), x::Number,J::UniformScaling) = UniformScaling(x*J.λ) +Broadcast.make(::typeof(*), J::UniformScaling,x::Number) = UniformScaling(J.λ*x) -broadcast(::typeof(/), J::UniformScaling,x::Number) = UniformScaling(J.λ/x) +Broadcast.make(::typeof(/), J::UniformScaling,x::Number) = UniformScaling(J.λ/x) ==(J1::UniformScaling,J2::UniformScaling) = (J1.λ == J2.λ) diff --git a/stdlib/LinearAlgebra/test/structuredbroadcast.jl b/stdlib/LinearAlgebra/test/structuredbroadcast.jl new file mode 100644 index 0000000000000..c8bef049fd01a --- /dev/null +++ b/stdlib/LinearAlgebra/test/structuredbroadcast.jl @@ -0,0 +1,101 @@ +module TestStructuredBroadcast +using Test, LinearAlgebra + +@testset "broadcast[!] over combinations of scalars, structured matrices, and dense vectors/matrices" begin + N = 10 + s = rand() + fV = rand(N) + fA = rand(N, N) + Z = copy(fA) + D = Diagonal(rand(N)) + B = Bidiagonal(rand(N), rand(N - 1), :U) + T = Tridiagonal(rand(N - 1), rand(N), rand(N - 1)) + U = UpperTriangular(rand(N,N)) + L = LowerTriangular(rand(N,N)) + structuredarrays = (D, B, T, U, L) + fstructuredarrays = map(Array, structuredarrays) + for (X, fX) in zip(structuredarrays, fstructuredarrays) + @test (Q = broadcast(sin, X); typeof(Q) == typeof(X) && Q == broadcast(sin, fX)) + @test broadcast!(sin, Z, X) == broadcast(sin, fX) + @test (Q = broadcast(cos, X); Q isa Matrix && Q == broadcast(cos, fX)) + @test broadcast!(cos, Z, X) == broadcast(cos, fX) + @test (Q = broadcast(*, s, X); typeof(Q) == typeof(X) && Q == broadcast(*, s, fX)) + @test broadcast!(*, Z, s, X) == broadcast(*, s, fX) + @test (Q = broadcast(+, fV, fA, X); Q isa Matrix && Q == broadcast(+, fV, fA, fX)) + @test broadcast!(+, Z, fV, fA, X) == broadcast(+, fV, fA, fX) + @test (Q = broadcast(*, s, fV, fA, X); Q isa Matrix && Q == broadcast(*, s, fV, fA, fX)) + @test broadcast!(*, Z, s, fV, fA, X) == broadcast(*, s, fV, fA, fX) + for (Y, fY) in zip(structuredarrays, fstructuredarrays) + @test broadcast(+, X, Y) == broadcast(+, fX, fY) + @test broadcast!(+, Z, X, Y) == broadcast(+, fX, fY) + @test broadcast(*, X, Y) == broadcast(*, fX, fY) + @test broadcast!(*, Z, X, Y) == broadcast(*, fX, fY) + end + end + diagonals = (D, B, T) + fdiagonals = map(Array, diagonals) + for (X, fX) in zip(diagonals, fdiagonals) + for (Y, fY) in zip(diagonals, fdiagonals) + @test broadcast(+, X, Y)::Union{Diagonal,Bidiagonal,Tridiagonal} == broadcast(+, fX, fY) + @test broadcast!(+, Z, X, Y) == broadcast(+, fX, fY) + @test broadcast(*, X, Y)::Union{Diagonal,Bidiagonal,Tridiagonal} == broadcast(*, fX, fY) + @test broadcast!(*, Z, X, Y) == broadcast(*, fX, fY) + end + end +end + +@testset "broadcast! where the destination is a structured matrix" begin + N = 5 + A = rand(N, N) + sA = A + copy(A') + D = Diagonal(rand(N)) + B = Bidiagonal(rand(N), rand(N - 1), :U) + T = Tridiagonal(rand(N - 1), rand(N), rand(N - 1)) + @test broadcast!(sin, copy(D), D) == Diagonal(sin.(D)) + @test broadcast!(sin, copy(B), B) == Bidiagonal(sin.(B), :U) + @test broadcast!(sin, copy(T), T) == Tridiagonal(sin.(T)) + @test broadcast!(*, copy(D), D, A) == Diagonal(broadcast(*, D, A)) + @test broadcast!(*, copy(B), B, A) == Bidiagonal(broadcast(*, B, A), :U) + @test broadcast!(*, copy(T), T, A) == Tridiagonal(broadcast(*, T, A)) +end + +@testset "map[!] over combinations of structured matrices" begin + N = 10 + fA = rand(N, N) + Z = copy(fA) + D = Diagonal(rand(N)) + B = Bidiagonal(rand(N), rand(N - 1), :U) + T = Tridiagonal(rand(N - 1), rand(N), rand(N - 1)) + U = UpperTriangular(rand(N,N)) + L = LowerTriangular(rand(N,N)) + structuredarrays = (D, B, T, U, L) + fstructuredarrays = map(Array, structuredarrays) + for (X, fX) in zip(structuredarrays, fstructuredarrays) + @test (Q = map(sin, X); typeof(Q) == typeof(X) && Q == map(sin, fX)) + @test map!(sin, Z, X) == map(sin, fX) + @test (Q = map(cos, X); Q isa Matrix && Q == map(cos, fX)) + @test map!(cos, Z, X) == map(cos, fX) + @test (Q = map(+, fA, X); Q isa Matrix && Q == map(+, fA, fX)) + @test map!(+, Z, fA, X) == map(+, fA, fX) + for (Y, fY) in zip(structuredarrays, fstructuredarrays) + @test map(+, X, Y) == map(+, fX, fY) + @test map!(+, Z, X, Y) == map(+, fX, fY) + @test map(*, X, Y) == map(*, fX, fY) + @test map!(*, Z, X, Y) == map(*, fX, fY) + @test map(+, X, fA, Y) == map(+, fX, fA, fY) + @test map!(+, Z, X, fA, Y) == map(+, fX, fA, fY) + end + end + diagonals = (D, B, T) + fdiagonals = map(Array, diagonals) + for (X, fX) in zip(diagonals, fdiagonals) + for (Y, fY) in zip(diagonals, fdiagonals) + @test map(+, X, Y)::Union{Diagonal,Bidiagonal,Tridiagonal} == broadcast(+, fX, fY) + @test map!(+, Z, X, Y) == broadcast(+, fX, fY) + @test map(*, X, Y)::Union{Diagonal,Bidiagonal,Tridiagonal} == broadcast(*, fX, fY) + @test map!(*, Z, X, Y) == broadcast(*, fX, fY) + end + end +end + +end diff --git a/stdlib/SparseArrays/src/higherorderfns.jl b/stdlib/SparseArrays/src/higherorderfns.jl index bca1585985ad8..3ee447c8a2c54 100644 --- a/stdlib/SparseArrays/src/higherorderfns.jl +++ b/stdlib/SparseArrays/src/higherorderfns.jl @@ -4,15 +4,16 @@ module HigherOrderFns # This module provides higher order functions specialized for sparse arrays, # particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present. -import Base: map, map!, broadcast, broadcast! +import Base: map, map!, broadcast, copy, copyto! using Base: front, tail, to_shape using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector, AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange -using Base.Broadcast: BroadcastStyle +using Base.Broadcast: BroadcastStyle, Broadcasted, flatten using LinearAlgebra # This module is organized as follows: +# (0) Define BroadcastStyle rules and convenience types for dispatch # (1) Define a common interface to SparseVectors and SparseMatrixCSCs sufficient for # map[!]/broadcast[!]'s purposes. The methods below are written against this interface. # (2) Define entry points for map[!] (short children of _map_[not]zeropres!). @@ -29,11 +30,79 @@ using LinearAlgebra # (12) Define map[!] methods handling combinations of sparse and structured matrices. +# (0) BroadcastStyle rules and convenience types for dispatch + +SparseVecOrMat = Union{SparseVector,SparseMatrixCSC} + +# broadcast container type promotion for combinations of sparse arrays and other types +struct SparseVecStyle <: Broadcast.AbstractArrayStyle{1} end +struct SparseMatStyle <: Broadcast.AbstractArrayStyle{2} end +Broadcast.BroadcastStyle(::Type{<:SparseVector}) = SparseVecStyle() +Broadcast.BroadcastStyle(::Type{<:SparseMatrixCSC}) = SparseMatStyle() +const SPVM = Union{SparseVecStyle,SparseMatStyle} + +# SparseVecStyle handles 0-1 dimensions, SparseMatStyle 0-2 dimensions. +# SparseVecStyle promotes to SparseMatStyle for 2 dimensions. +# Fall back to DefaultArrayStyle for higher dimensionality. +SparseVecStyle(::Val{0}) = SparseVecStyle() +SparseVecStyle(::Val{1}) = SparseVecStyle() +SparseVecStyle(::Val{2}) = SparseMatStyle() +SparseVecStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() +SparseMatStyle(::Val{0}) = SparseMatStyle() +SparseMatStyle(::Val{1}) = SparseMatStyle() +SparseMatStyle(::Val{2}) = SparseMatStyle() +SparseMatStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() + +Broadcast.BroadcastStyle(::SparseMatStyle, ::SparseVecStyle) = SparseMatStyle() + +# Tuples promote to dense +Broadcast.BroadcastStyle(::SparseVecStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{1}() +Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}() + +struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end +PromoteToSparse(::Val{0}) = PromoteToSparse() +PromoteToSparse(::Val{1}) = PromoteToSparse() +PromoteToSparse(::Val{2}) = PromoteToSparse() +PromoteToSparse(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() + +const StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal} +Broadcast.BroadcastStyle(::Type{<:Adjoint{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse() +Broadcast.BroadcastStyle(::Type{<:Transpose{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse() + +Broadcast.BroadcastStyle(s::SPVM, ::Broadcast.AbstractArrayStyle{0}) = s +Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse() +Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse() + +Broadcast.BroadcastStyle(::SPVM, ::LinearAlgebra.StructuredMatrixStyle{<:StructuredMatrix}) = PromoteToSparse() +Broadcast.BroadcastStyle(::PromoteToSparse, ::LinearAlgebra.StructuredMatrixStyle{<:StructuredMatrix}) = PromoteToSparse() + +Broadcast.BroadcastStyle(::PromoteToSparse, ::SPVM) = PromoteToSparse() +Broadcast.BroadcastStyle(::PromoteToSparse, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}() + +# FIXME: currently sparse broadcasts are only well-tested on known array types, while any AbstractArray +# could report itself as a DefaultArrayStyle(). +# See https://github.com/JuliaLang/julia/pull/23939#pullrequestreview-72075382 for more details +is_supported_sparse_broadcast() = true +is_supported_sparse_broadcast(::AbstractArray, rest...) = false +is_supported_sparse_broadcast(::AbstractSparseArray, rest...) = is_supported_sparse_broadcast(rest...) +is_supported_sparse_broadcast(::StructuredMatrix, rest...) = is_supported_sparse_broadcast(rest...) +is_supported_sparse_broadcast(::Array, rest...) = is_supported_sparse_broadcast(rest...) +is_supported_sparse_broadcast(t::Union{Transpose, Adjoint}, rest...) = is_supported_sparse_broadcast(t.parent, rest...) +is_supported_sparse_broadcast(x, rest...) = axes(x) === () && is_supported_sparse_broadcast(rest...) +is_supported_sparse_broadcast(x::Ref, rest...) = is_supported_sparse_broadcast(rest...) + +# Dispatch on broadcast operations by number of arguments +const Broadcasted0{Style<:Union{Nothing,BroadcastStyle},Axes,F} = + Broadcasted{Style,Axes,F,Tuple{}} +const SpBroadcasted1{Style<:SPVM,Axes,F,Args<:Tuple{SparseVecOrMat}} = + Broadcasted{Style,Axes,F,Args} +const SpBroadcasted2{Style<:SPVM,Axes,F,Args<:Tuple{SparseVecOrMat,SparseVecOrMat}} = + Broadcasted{Style,Axes,F,Args} + # (1) The definitions below provide a common interface to sparse vectors and matrices # sufficient for the purposes of map[!]/broadcast[!]. This interface treats sparse vectors # as n-by-one sparse matrices which, though technically incorrect, is how broacast[!] views # sparse vectors in practice. -SparseVecOrMat = Union{SparseVector,SparseMatrixCSC} @inline numrows(A::SparseVector) = A.n @inline numrows(A::SparseMatrixCSC) = A.m @inline numcols(A::SparseVector) = 1 @@ -85,18 +154,18 @@ function _noshapecheck_map(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N fofzeros = f(_zeros_eltypes(A, Bs...)...) fpreszeros = _iszero(fofzeros) maxnnzC = fpreszeros ? min(length(A), _sumnnzs(A, Bs...)) : length(A) - entrytypeC = Base.Broadcast.combine_eltypes(f, A, Bs...) + entrytypeC = Base.Broadcast.combine_eltypes(f, (A, Bs...)) indextypeC = _promote_indtype(A, Bs...) C = _allocres(size(A), indextypeC, entrytypeC, maxnnzC) return fpreszeros ? _map_zeropres!(f, C, A, Bs...) : _map_notzeropres!(f, fofzeros, C, A, Bs...) end # (3) broadcast[!] entry points -broadcast(f::Tf, A::SparseVector) where {Tf} = _noshapecheck_map(f, A) -broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A) +copy(bc::SpBroadcasted1) = _noshapecheck_map(bc.f, bc.args[1]) -@inline function broadcast!(f::Tf, C::SparseVecOrMat, ::Nothing) where Tf +@inline function copyto!(C::SparseVecOrMat, bc::Broadcasted0{Nothing}) isempty(C) && return _finishempty!(C) + f = bc.f fofnoargs = f() if _iszero(fofnoargs) # f() is zero, so empty C trimstorage!(C, 0) @@ -109,19 +178,12 @@ broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A) return C end -# the following three similar defs are necessary for type stability in the mixed vector/matrix case -broadcast(f::Tf, A::SparseVector, Bs::Vararg{SparseVector,N}) where {Tf,N} = - _aresameshape(A, Bs...) ? _noshapecheck_map(f, A, Bs...) : _diffshape_broadcast(f, A, Bs...) -broadcast(f::Tf, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N}) where {Tf,N} = - _aresameshape(A, Bs...) ? _noshapecheck_map(f, A, Bs...) : _diffshape_broadcast(f, A, Bs...) -broadcast(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N} = - _diffshape_broadcast(f, A, Bs...) function _diffshape_broadcast(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N} fofzeros = f(_zeros_eltypes(A, Bs...)...) fpreszeros = _iszero(fofzeros) indextypeC = _promote_indtype(A, Bs...) - entrytypeC = Base.Broadcast.combine_eltypes(f, A, Bs...) - shapeC = to_shape(Base.Broadcast.combine_indices(A, Bs...)) + entrytypeC = Base.Broadcast.combine_eltypes(f, (A, Bs...)) + shapeC = to_shape(Base.Broadcast.combine_axes(A, Bs...)) maxnnzC = fpreszeros ? _checked_maxnnzbcres(shapeC, A, Bs...) : _densennz(shapeC) C = _allocres(shapeC, indextypeC, entrytypeC, maxnnzC) return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) : @@ -141,6 +203,10 @@ end @inline _aresameshape(A, B) = size(A) == size(B) @inline _aresameshape(A, B, Cs...) = _aresameshape(A, B) ? _aresameshape(B, Cs...) : false @inline _checksameshape(As...) = _aresameshape(As...) || throw(DimensionMismatch("argument shapes must match")) +@inline _all_args_isa(t::Tuple{Any}, ::Type{T}) where T = isa(t[1], T) +@inline _all_args_isa(t::Tuple{Any,Vararg{Any}}, ::Type{T}) where T = isa(t[1], T) & _all_args_isa(tail(t), T) +@inline _all_args_isa(t::Tuple{Broadcasted}, ::Type{T}) where T = _all_args_isa(t[1].args, T) +@inline _all_args_isa(t::Tuple{Broadcasted,Vararg{Any}}, ::Type{T}) where T = _all_args_isa(t[1].args, T) & _all_args_isa(tail(t), T) @inline _densennz(shape::NTuple{1}) = shape[1] @inline _densennz(shape::NTuple{2}) = shape[1] * shape[2] _maxnnzfrom(shape::NTuple{1}, A) = nnz(A) * div(shape[1], A.n) @@ -887,37 +953,56 @@ end # (10) broadcast over combinations of broadcast scalars and sparse vectors/matrices -# broadcast container type promotion for combinations of sparse arrays and other types -struct SparseVecStyle <: Broadcast.AbstractArrayStyle{1} end -struct SparseMatStyle <: Broadcast.AbstractArrayStyle{2} end -Broadcast.BroadcastStyle(::Type{<:SparseVector}) = SparseVecStyle() -Broadcast.BroadcastStyle(::Type{<:SparseMatrixCSC}) = SparseMatStyle() -const SPVM = Union{SparseVecStyle,SparseMatStyle} +# broadcast entry points for combinations of sparse arrays and other (scalar) types +@inline function copy(bc::Broadcasted{<:SPVM}) + bcf = flatten(bc) + return _copy(bcf.f, bcf.args...) +end -# SparseVecStyle handles 0-1 dimensions, SparseMatStyle 0-2 dimensions. -# SparseVecStyle promotes to SparseMatStyle for 2 dimensions. -# Fall back to DefaultArrayStyle for higher dimensionality. -SparseVecStyle(::Val{0}) = SparseVecStyle() -SparseVecStyle(::Val{1}) = SparseVecStyle() -SparseVecStyle(::Val{2}) = SparseMatStyle() -SparseVecStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() -SparseMatStyle(::Val{0}) = SparseMatStyle() -SparseMatStyle(::Val{1}) = SparseMatStyle() -SparseMatStyle(::Val{2}) = SparseMatStyle() -SparseMatStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() +_copy(f, args::SparseVector...) = _shapecheckbc(f, args...) +_copy(f, args::SparseMatrixCSC...) = _shapecheckbc(f, args...) +_copy(f, args::SparseVecOrMat...) = _diffshape_broadcast(f, args...) +# Otherwise, we incorporate scalars into the function and re-dispatch +function _copy(f, args...) + parevalf, passedargstup = capturescalars(f, args) + return _copy(parevalf, passedargstup...) +end -Broadcast.BroadcastStyle(::SparseMatStyle, ::SparseVecStyle) = SparseMatStyle() +function _shapecheckbc(f, args...) + _aresameshape(args...) ? _noshapecheck_map(f, args...) : _diffshape_broadcast(f, args...) +end -# Tuples promote to dense -Broadcast.BroadcastStyle(::SparseVecStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{1}() -Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}() -# broadcast entry points for combinations of sparse arrays and other (scalar) types -function broadcast(f, ::SPVM, ::Nothing, ::Nothing, mixedargs::Vararg{Any,N}) where N - parevalf, passedargstup = capturescalars(f, mixedargs) - return broadcast(parevalf, passedargstup...) +@inline function copyto!(dest::SparseVecOrMat, bc::Broadcasted{<:SPVM}) + if bc.f === identity && bc isa SpBroadcasted1 && Base.axes(dest) == (A = bc.args[1]; Base.axes(A)) + return copyto!(dest, A) + end + bcf = flatten(bc) + As = map(arg->Base.unalias(dest, arg), bcf.args) + return _copyto!(bcf.f, dest, As...) +end + +@inline function _copyto!(f, dest, As::SparseVecOrMat...) + _aresameshape(dest, As...) && return _noshapecheck_map!(f, dest, As...) + Base.Broadcast.check_broadcast_axes(axes(dest), As...) + fofzeros = f(_zeros_eltypes(As...)...) + if _iszero(fofzeros) + return _broadcast_zeropres!(f, dest, As...) + else + return _broadcast_notzeropres!(f, fofzeros, dest, As...) + end +end + +@inline function _copyto!(f, dest, args...) + # args contains nothing but SparseVecOrMat and scalars + # See below for capturescalars + parevalf, passedsrcargstup = capturescalars(f, args) + _copyto!(parevalf, dest, passedsrcargstup...) +end + +struct CapturedScalars{F, Args, Order} + args::Args end -# for broadcast! see (11) # capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and # broadcast scalar arguments (mixedargs), and returns a function (parevalf, i.e. partially @@ -930,6 +1015,13 @@ end return (parevalf, passedsrcargstup) end end +# Work around losing Type{T}s as DataTypes within the tuple that makeargs creates +@inline capturescalars(f, mixedargs::Tuple{Ref{Type{T}}, Vararg{Any}}) where {T} = + capturescalars((args...)->f(T, args...), Base.tail(mixedargs)) +@inline capturescalars(f, mixedargs::Tuple{SparseVecOrMat, Ref{Type{T}}, Vararg{Any}}) where {T} = + capturescalars((a1, args...)->f(a1, T, args...), (mixedargs[1], Base.tail(Base.tail(mixedargs))...)) +@inline capturescalars(f, mixedargs::Tuple{Union{Ref,AbstractArray{0}}, Ref{Type{T}}, Vararg{Any}}) where {T} = + capturescalars((args...)->f(mixedargs[1], T, args...), Base.tail(Base.tail(mixedargs))) nonscalararg(::SparseVecOrMat) = true nonscalararg(::Any) = false @@ -942,11 +1034,17 @@ end @inline function _capturescalars(arg, mixedargs...) let (rest, f) = _capturescalars(mixedargs...) if nonscalararg(arg) - return (arg, rest...), (head, tail...) -> (head, f(tail...)...) # pass-through to broadcast + return (arg, rest...), @inline function(head, tail...) + (head, f(tail...)...) + end # pass-through to broadcast elseif scalarwrappedarg(arg) - return rest, (tail...) -> (arg[], f(tail...)...) # unwrap and add back scalararg after (in makeargs) + return rest, @inline function(tail...) + (arg[], f(tail...)...) # TODO: This can put a Type{T} in a tuple + end # unwrap and add back scalararg after (in makeargs) else - return rest, (tail...) -> (arg, f(tail...)...) # add back scalararg after (in makeargs) + return rest, @inline function(tail...) + (arg, f(tail...)...) + end # add back scalararg after (in makeargs) end end end @@ -972,69 +1070,18 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f( # vectors/matrices, promote all structured matrices and dense vectors/matrices to sparse # and rebroadcast. otherwise, divert to generic AbstractArray broadcast code. -struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end -PromoteToSparse(::Val{0}) = PromoteToSparse() -PromoteToSparse(::Val{1}) = PromoteToSparse() -PromoteToSparse(::Val{2}) = PromoteToSparse() -PromoteToSparse(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}() - -const StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal} -Broadcast.BroadcastStyle(::Type{<:StructuredMatrix}) = PromoteToSparse() -Broadcast.BroadcastStyle(::Type{<:Adjoint{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse() -Broadcast.BroadcastStyle(::Type{<:Transpose{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse() - -Broadcast.BroadcastStyle(s::SPVM, ::Broadcast.AbstractArrayStyle{0}) = s -Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse() -Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse() - -Broadcast.BroadcastStyle(::PromoteToSparse, ::SPVM) = PromoteToSparse() -Broadcast.BroadcastStyle(::PromoteToSparse, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}() - -# FIXME: currently sparse broadcasts are only well-tested on known array types, while any AbstractArray -# could report itself as a DefaultArrayStyle(). -# See https://github.com/JuliaLang/julia/pull/23939#pullrequestreview-72075382 for more details -is_supported_sparse_broadcast() = true -is_supported_sparse_broadcast(::AbstractArray, rest...) = false -is_supported_sparse_broadcast(::AbstractSparseArray, rest...) = is_supported_sparse_broadcast(rest...) -is_supported_sparse_broadcast(::StructuredMatrix, rest...) = is_supported_sparse_broadcast(rest...) -is_supported_sparse_broadcast(::Array, rest...) = is_supported_sparse_broadcast(rest...) -is_supported_sparse_broadcast(t::Union{Transpose, Adjoint}, rest...) = is_supported_sparse_broadcast(t.parent, rest...) -is_supported_sparse_broadcast(x, rest...) = axes(x) === () && is_supported_sparse_broadcast(rest...) -is_supported_sparse_broadcast(x::Ref, rest...) = is_supported_sparse_broadcast(rest...) -function broadcast(f, s::PromoteToSparse, ::Nothing, ::Nothing, As::Vararg{Any,N}) where {N} - if is_supported_sparse_broadcast(As...) - return broadcast(f, map(_sparsifystructured, As)...) +function copy(bc::Broadcasted{PromoteToSparse}) + bcf = flatten(bc) + if is_supported_sparse_broadcast(bcf.args...) + broadcast(bcf.f, map(_sparsifystructured, bcf.args)...) else - return broadcast(f, Broadcast.ArrayConflict(), nothing, nothing, As...) + return copy(convert(Broadcasted{Broadcast.DefaultArrayStyle{2}}, bc)) end end -# For broadcast! with ::Any inputs, we need a layer of indirection to determine whether -# the inputs can be promoted to SparseVecOrMat. If it's just SparseVecOrMat and scalars, -# we can handle it here, otherwise see below for the promotion machinery. -function broadcast!(f::Tf, dest::SparseVecOrMat, ::SPVM, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N} - if f isa typeof(identity) && N == 0 && Base.axes(dest) == Base.axes(A) - return copyto!(dest, A) - end - A′ = Base.unalias(dest, A) - Bs′ = map(B->Base.unalias(dest, B), Bs) - _aresameshape(dest, A′, Bs′...) && return _noshapecheck_map!(f, dest, A′, Bs′...) - Base.Broadcast.check_broadcast_indices(axes(dest), A′, Bs′...) - fofzeros = f(_zeros_eltypes(A′, Bs′...)...) - fpreszeros = _iszero(fofzeros) - fpreszeros ? _broadcast_zeropres!(f, dest, A′, Bs′...) : - _broadcast_notzeropres!(f, fofzeros, dest, A′, Bs′...) - return dest -end -function broadcast!(f::Tf, dest::SparseVecOrMat, ::SPVM, mixedsrcargs::Vararg{Any,N}) where {Tf,N} - # mixedsrcargs contains nothing but SparseVecOrMat and scalars - parevalf, passedsrcargstup = capturescalars(f, mixedsrcargs) - broadcast!(parevalf, dest, passedsrcargstup...) - return dest -end -function broadcast!(f::Tf, dest::SparseVecOrMat, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where {Tf,N} - broadcast!(f, dest, map(_sparsifystructured, mixedsrcargs)...) - return dest +@inline function copyto!(dest::SparseVecOrMat, bc::Broadcasted{PromoteToSparse}) + bcf = flatten(bc) + broadcast!(bcf.f, dest, map(_sparsifystructured, bcf.args)...) end _sparsifystructured(M::AbstractMatrix) = SparseMatrixCSC(M) @@ -1047,8 +1094,7 @@ _sparsifystructured(x) = x # (12) map[!] over combinations of sparse and structured matrices -SparseOrStructuredMatrix = Union{SparseMatrixCSC,StructuredMatrix} -map(f::Tf, A::StructuredMatrix) where {Tf} = _noshapecheck_map(f, _sparsifystructured(A)) +SparseOrStructuredMatrix = Union{SparseMatrixCSC,LinearAlgebra.StructuredMatrix} map(f::Tf, A::SparseOrStructuredMatrix, Bs::Vararg{SparseOrStructuredMatrix,N}) where {Tf,N} = (_checksameshape(A, Bs...); _noshapecheck_map(f, _sparsifystructured(A), map(_sparsifystructured, Bs)...)) map!(f::Tf, C::SparseMatrixCSC, A::SparseOrStructuredMatrix, Bs::Vararg{SparseOrStructuredMatrix,N}) where {Tf,N} = diff --git a/stdlib/SparseArrays/test/higherorderfns.jl b/stdlib/SparseArrays/test/higherorderfns.jl index b8e2c26d33349..8744f80a39dbe 100644 --- a/stdlib/SparseArrays/test/higherorderfns.jl +++ b/stdlib/SparseArrays/test/higherorderfns.jl @@ -125,9 +125,9 @@ end @test broadcast!(cos, Z, X) == sparse(broadcast!(cos, fZ, fX)) # --> test shape checks for broadcast! entry point # TODO strengthen this test, avoiding dependence on checking whether - # check_broadcast_indices throws to determine whether sparse broadcast should throw + # check_broadcast_axes throws to determine whether sparse broadcast should throw try - Base.Broadcast.check_broadcast_indices(axes(Z), spzeros((shapeX .- 1)...)) + Base.Broadcast.check_broadcast_axes(axes(Z), spzeros((shapeX .- 1)...)) catch @test_throws DimensionMismatch broadcast!(sin, Z, spzeros((shapeX .- 1)...)) end @@ -149,9 +149,9 @@ end @test broadcast!(cos, V, X) == sparse(broadcast!(cos, fV, fX)) # --> test shape checks for broadcast! entry point # TODO strengthen this test, avoiding dependence on checking whether - # check_broadcast_indices throws to determine whether sparse broadcast should throw + # check_broadcast_axes throws to determine whether sparse broadcast should throw try - Base.Broadcast.check_broadcast_indices(axes(V), spzeros((shapeX .- 1)...)) + Base.Broadcast.check_broadcast_axes(axes(V), spzeros((shapeX .- 1)...)) catch @test_throws DimensionMismatch broadcast!(sin, V, spzeros((shapeX .- 1)...)) end @@ -184,9 +184,9 @@ end @test broadcast(*, X, Y) == sparse(broadcast(*, fX, fY)) @test broadcast(f, X, Y) == sparse(broadcast(f, fX, fY)) # TODO strengthen this test, avoiding dependence on checking whether - # check_broadcast_indices throws to determine whether sparse broadcast should throw + # check_broadcast_axes throws to determine whether sparse broadcast should throw try - Base.Broadcast.combine_indices(spzeros((shapeX .- 1)...), Y) + Base.Broadcast.combine_axes(spzeros((shapeX .- 1)...), Y) catch @test_throws DimensionMismatch broadcast(+, spzeros((shapeX .- 1)...), Y) end @@ -207,9 +207,9 @@ end @test broadcast!(f, Z, X, Y) == sparse(broadcast!(f, fZ, fX, fY)) # --> test shape checks for both broadcast and broadcast! entry points # TODO strengthen this test, avoiding dependence on checking whether - # check_broadcast_indices throws to determine whether sparse broadcast should throw + # check_broadcast_axes throws to determine whether sparse broadcast should throw try - Base.Broadcast.check_broadcast_indices(axes(Z), spzeros((shapeX .- 1)...), Y) + Base.Broadcast.check_broadcast_axes(axes(Z), spzeros((shapeX .- 1)...), Y) catch @test_throws DimensionMismatch broadcast!(f, Z, spzeros((shapeX .- 1)...), Y) end @@ -247,9 +247,9 @@ end @test broadcast(*, X, Y, Z) == sparse(broadcast(*, fX, fY, fZ)) @test broadcast(f, X, Y, Z) == sparse(broadcast(f, fX, fY, fZ)) # TODO strengthen this test, avoiding dependence on checking whether - # check_broadcast_indices throws to determine whether sparse broadcast should throw + # check_broadcast_axes throws to determine whether sparse broadcast should throw try - Base.Broadcast.combine_indices(spzeros((shapeX .- 1)...), Y, Z) + Base.Broadcast.combine_axes(spzeros((shapeX .- 1)...), Y, Z) catch @test_throws DimensionMismatch broadcast(+, spzeros((shapeX .- 1)...), Y, Z) end @@ -267,6 +267,8 @@ end fQ = broadcast(f, fX, fY, fZ); Q = sparse(fQ) broadcast!(f, Q, X, Y, Z); Q = sparse(fQ) # warmup for @allocated @test_broken (@allocated broadcast!(f, Q, X, Y, Z)) == 0 + broadcast!(f, Q, X, Y, Z); Q = sparse(fQ) # warmup for @allocated + @test (@allocated broadcast!(f, Q, X, Y, Z)) <= 16 # the preceding test allocates 16 bytes in the entry point for broadcast!, but # none of the earlier tests of the same code path allocate. no allocation shows # up with --track-allocation=user. allocation shows up on the first line of the @@ -277,9 +279,9 @@ end @test broadcast!(f, Q, X, Y, Z) == sparse(broadcast!(f, fQ, fX, fY, fZ)) # --> test shape checks for both broadcast and broadcast! entry points # TODO strengthen this test, avoiding dependence on checking whether - # check_broadcast_indices throws to determine whether sparse broadcast should throw + # check_broadcast_axes throws to determine whether sparse broadcast should throw try - Base.Broadcast.check_broadcast_indices(axes(Q), spzeros((shapeX .- 1)...), Y, Z) + Base.Broadcast.check_broadcast_axes(axes(Q), spzeros((shapeX .- 1)...), Y, Z) catch @test_throws DimensionMismatch broadcast!(f, Q, spzeros((shapeX .- 1)...), Y, Z) end @@ -350,21 +352,11 @@ end @test broadcast!(*, X, sparseargs...) == sparse(broadcast!(*, fX, denseargs...)) @test isa(@inferred(broadcast!(*, X, sparseargs...)), SparseMatrixCSC{elT}) X = sparse(fX) # reset / warmup for @allocated test + # It'd be nice for this to be zero, but there's currently some constant overhead @test_broken (@allocated broadcast!(*, X, sparseargs...)) == 0 - # This test (and the analog below) fails for three reasons: - # (1) In all cases, generating the closures that capture the scalar arguments - # results in allocation, not sure why. - # (2) In some cases, though _broadcast_eltype (which wraps _return_type) - # consistently provides the correct result eltype when passed the closure - # that incorporates the scalar arguments to broadcast (and, with #19667, - # is inferable, so the overall return type from broadcast is inferred), - # in some cases inference seems unable to determine the return type of - # direct calls to that closure. This issue causes variables in both the - # broadcast[!] entry points (fofzeros = f(_zeros_eltypes(args...)...)) and - # the driver routines (Cx in _map_zeropres! and _broadcast_zeropres!) to have - # inferred type Any, resulting in allocation and lackluster performance. - # (3) The sparseargs... splat in the call above allocates a bit, but of course - # that issue is negligible and perhaps could be accounted for in the test. + X = sparse(fX) # reset / warmup for @allocated test + # And broadcasting over Transposes currently requires making a CSC copy, so we must account for that in the bounds + @test (@allocated broadcast!(*, X, sparseargs...)) <= (sum(x->isa(x, Transpose) ? Base.summarysize(x)*2+128 : 0, sparseargs) + 128) end end # test combinations at the limit of inference (eight arguments net) @@ -385,7 +377,8 @@ end @test isa(@inferred(broadcast!(*, X, sparseargs...)), SparseMatrixCSC{elT}) X = sparse(fX) # reset / warmup for @allocated test @test_broken (@allocated broadcast!(*, X, sparseargs...)) == 0 - # please see the note a few lines above re. this @test_broken + X = sparse(fX) # reset / warmup for @allocated test + @test (@allocated broadcast!(*, X, sparseargs...)) <= 128 end end @@ -404,20 +397,12 @@ end structuredarrays = (D, B, T, S) fstructuredarrays = map(Array, structuredarrays) for (X, fX) in zip(structuredarrays, fstructuredarrays) - @test (Q = broadcast(sin, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(sin, fX))) - @test broadcast!(sin, Z, X) == sparse(broadcast(sin, fX)) - @test (Q = broadcast(cos, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(cos, fX))) - @test broadcast!(cos, Z, X) == sparse(broadcast(cos, fX)) - @test (Q = broadcast(*, s, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(*, s, fX))) - @test broadcast!(*, Z, s, X) == sparse(broadcast(*, s, fX)) @test (Q = broadcast(+, V, A, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(+, fV, fA, fX))) @test broadcast!(+, Z, V, A, X) == sparse(broadcast(+, fV, fA, fX)) @test (Q = broadcast(*, s, V, A, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(*, s, fV, fA, fX))) @test broadcast!(*, Z, s, V, A, X) == sparse(broadcast(*, s, fV, fA, fX)) for (Y, fY) in zip(structuredarrays, fstructuredarrays) - @test (Q = broadcast(+, X, Y); Q isa SparseMatrixCSC && Q == sparse(broadcast(+, fX, fY))) @test broadcast!(+, Z, X, Y) == sparse(broadcast(+, fX, fY)) - @test (Q = broadcast(*, X, Y); Q isa SparseMatrixCSC && Q == sparse(broadcast(*, fX, fY))) @test broadcast!(*, Z, X, Y) == sparse(broadcast(*, fX, fY)) end end @@ -426,9 +411,7 @@ end densearrays = (C, M) fD, fB = Array(D), Array(B) for X in densearrays - @test broadcast(+, D, X)::SparseMatrixCSC == sparse(broadcast(+, fD, X)) @test broadcast!(+, Z, D, X) == sparse(broadcast(+, fD, X)) - @test broadcast(*, s, B, X)::SparseMatrixCSC == sparse(broadcast(*, s, fB, X)) @test broadcast!(*, Z, s, B, X) == sparse(broadcast(*, s, fB, X)) @test broadcast(+, V, B, X)::SparseMatrixCSC == sparse(broadcast(+, fV, fB, X)) @test broadcast!(+, Z, V, B, X) == sparse(broadcast(+, fV, fB, X)) @@ -446,25 +429,6 @@ end @test A .+ ntuple(identity, N) isa Matrix end -@testset "broadcast! where the destination is a structured matrix" begin - # Where broadcast!'s destination is a structured matrix, broadcast! should fall back - # to the generic AbstractArray broadcast! code (at least for now). - N, p = 5, 0.4 - A = sprand(N, N, p) - sA = A + copy(A') - D = Diagonal(rand(N)) - B = Bidiagonal(rand(N), rand(N - 1), :U) - T = Tridiagonal(rand(N - 1), rand(N), rand(N - 1)) - @test broadcast!(sin, copy(D), D) == Diagonal(sin.(D)) - @test broadcast!(sin, copy(B), B) == Bidiagonal(sin.(B), :U) - @test broadcast!(sin, copy(T), T) == Tridiagonal(sin.(T)) - @test broadcast!(*, copy(D), D, A) == Diagonal(broadcast(*, D, A)) - @test broadcast!(*, copy(B), B, A) == Bidiagonal(broadcast(*, B, A), :U) - @test broadcast!(*, copy(T), T, A) == Tridiagonal(broadcast(*, T, A)) - # SymTridiagonal (and similar symmetric matrix types) do not support setindex! - # off the diagonal, and so cannot serve as a destination for broadcast! -end - @testset "map[!] over combinations of sparse and structured matrices" begin N, p = 10, 0.4 A = sprand(N, N, p) @@ -476,16 +440,12 @@ end structuredarrays = (D, B, T, S) fstructuredarrays = map(Array, structuredarrays) for (X, fX) in zip(structuredarrays, fstructuredarrays) - @test (Q = map(sin, X); Q isa SparseMatrixCSC && Q == sparse(map(sin, fX))) @test map!(sin, Z, X) == sparse(map(sin, fX)) - @test (Q = map(cos, X); Q isa SparseMatrixCSC && Q == sparse(map(cos, fX))) @test map!(cos, Z, X) == sparse(map(cos, fX)) @test (Q = map(+, A, X); Q isa SparseMatrixCSC && Q == sparse(map(+, fA, fX))) @test map!(+, Z, A, X) == sparse(map(+, fA, fX)) for (Y, fY) in zip(structuredarrays, fstructuredarrays) - @test (Q = map(+, X, Y); Q isa SparseMatrixCSC && Q == sparse(map(+, fX, fY))) @test map!(+, Z, X, Y) == sparse(map(+, fX, fY)) - @test (Q = map(*, X, Y); Q isa SparseMatrixCSC && Q == sparse(map(*, fX, fY))) @test map!(*, Z, X, Y) == sparse(map(*, fX, fY)) @test (Q = map(+, X, A, Y); Q isa SparseMatrixCSC && Q == sparse(map(+, fX, fA, fY))) @test map!(+, Z, X, A, Y) == sparse(map(+, fX, fA, fY)) diff --git a/test/bitarray.jl b/test/bitarray.jl index 572918a2dda77..806ecc6b4aa0f 100644 --- a/test/bitarray.jl +++ b/test/bitarray.jl @@ -1014,6 +1014,41 @@ timesofar("unary arithmetic") @check_bit_operation broadcast(^, b1, 1im) Matrix{ComplexF64} @check_bit_operation broadcast(^, b1, 0x1*im) Matrix{ComplexF64} end + + @testset "Matrix/Vector" begin + b1 = bitrand(n1, n2) + b2 = bitrand(n1) + b3 = bitrand(n2) + + @check_bit_operation broadcast(&, b1, b2) BitMatrix + @check_bit_operation broadcast(&, b1, transpose(b3)) BitMatrix + @check_bit_operation broadcast(&, b2, b1) BitMatrix + @check_bit_operation broadcast(&, transpose(b3), b1) BitMatrix + @check_bit_operation broadcast(|, b1, b2) BitMatrix + @check_bit_operation broadcast(|, b1, transpose(b3)) BitMatrix + @check_bit_operation broadcast(|, b2, b1) BitMatrix + @check_bit_operation broadcast(|, transpose(b3), b1) BitMatrix + @check_bit_operation broadcast(xor, b1, b2) BitMatrix + @check_bit_operation broadcast(xor, b1, transpose(b3)) BitMatrix + @check_bit_operation broadcast(xor, b2, b1) BitMatrix + @check_bit_operation broadcast(xor, transpose(b3), b1) BitMatrix + @check_bit_operation broadcast(+, b1, b2) Matrix{Int} + @check_bit_operation broadcast(+, b1, transpose(b3)) Matrix{Int} + @check_bit_operation broadcast(+, b2, b1) Matrix{Int} + @check_bit_operation broadcast(+, transpose(b3), b1) Matrix{Int} + @check_bit_operation broadcast(-, b1, b2) Matrix{Int} + @check_bit_operation broadcast(-, b1, transpose(b3)) Matrix{Int} + @check_bit_operation broadcast(-, b2, b1) Matrix{Int} + @check_bit_operation broadcast(-, transpose(b3), b1) Matrix{Int} + @check_bit_operation broadcast(*, b1, b2) BitMatrix + @check_bit_operation broadcast(*, b1, transpose(b3)) BitMatrix + @check_bit_operation broadcast(*, b2, b1) BitMatrix + @check_bit_operation broadcast(*, transpose(b3), b1) BitMatrix + @check_bit_operation broadcast(/, b1, b2) Matrix{Float64} + @check_bit_operation broadcast(/, b1, transpose(b3)) Matrix{Float64} + @check_bit_operation broadcast(/, b2, b1) Matrix{Float64} + @check_bit_operation broadcast(/, transpose(b3), b1) Matrix{Float64} + end end timesofar("binary arithmetic") diff --git a/test/broadcast.jl b/test/broadcast.jl index 966fc1a5a9e22..35f127658ad6f 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -2,7 +2,7 @@ module TestBroadcastInternals -using Base.Broadcast: check_broadcast_indices, check_broadcast_shape, newindex, _bcs +using Base.Broadcast: check_broadcast_axes, check_broadcast_shape, newindex, _bcs using Base: OneTo using Test, Random @@ -19,22 +19,22 @@ using Test, Random @test_throws DimensionMismatch _bcs((-1:1, 2:6), (-1:1, 2:5)) @test_throws DimensionMismatch _bcs((-1:1, 2:5), (2, 2:5)) -@test @inferred(Broadcast.combine_indices(zeros(3,4), zeros(3,4))) == (OneTo(3),OneTo(4)) -@test @inferred(Broadcast.combine_indices(zeros(3,4), zeros(3))) == (OneTo(3),OneTo(4)) -@test @inferred(Broadcast.combine_indices(zeros(3), zeros(3,4))) == (OneTo(3),OneTo(4)) -@test @inferred(Broadcast.combine_indices(zeros(3), zeros(1,4), zeros(1))) == (OneTo(3),OneTo(4)) - -check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,5)) -check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,1)) -check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3)) -check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,5), zeros(3)) -check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,5), 1) -check_broadcast_indices((OneTo(3),OneTo(5)), 5, 2) -@test_throws DimensionMismatch check_broadcast_indices((OneTo(3),OneTo(5)), zeros(2,5)) -@test_throws DimensionMismatch check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,4)) -@test_throws DimensionMismatch check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,4,2)) -@test_throws DimensionMismatch check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,5), zeros(2)) -check_broadcast_indices((-1:1, 6:9), 1) +@test @inferred(Broadcast.combine_axes(zeros(3,4), zeros(3,4))) == (OneTo(3),OneTo(4)) +@test @inferred(Broadcast.combine_axes(zeros(3,4), zeros(3))) == (OneTo(3),OneTo(4)) +@test @inferred(Broadcast.combine_axes(zeros(3), zeros(3,4))) == (OneTo(3),OneTo(4)) +@test @inferred(Broadcast.combine_axes(zeros(3), zeros(1,4), zeros(1))) == (OneTo(3),OneTo(4)) + +check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,5)) +check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,1)) +check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3)) +check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,5), zeros(3)) +check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,5), 1) +check_broadcast_axes((OneTo(3),OneTo(5)), 5, 2) +@test_throws DimensionMismatch check_broadcast_axes((OneTo(3),OneTo(5)), zeros(2,5)) +@test_throws DimensionMismatch check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,4)) +@test_throws DimensionMismatch check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,4,2)) +@test_throws DimensionMismatch check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,5), zeros(2)) +check_broadcast_axes((-1:1, 6:9), 1) check_broadcast_shape((-1:1, 6:9), (-1:1, 6:9)) check_broadcast_shape((-1:1, 6:9), (-1:1, 1)) @@ -167,7 +167,9 @@ rt = Base.return_types(broadcast!, Tuple{Function, Array{Float64, 3}, Array{Floa @test length(rt) == 1 && rt[1] == Array{Float64, 3} # f.(args...) syntax (#15032) -let x = [1,3.2,4.7], y = [3.5, pi, 1e-4], α = 0.2342 +let x = [1, 3.2, 4.7], + y = [3.5, pi, 1e-4], + α = 0.2342 @test sin.(x) == broadcast(sin, x) @test sin.(α) == broadcast(sin, α) @test sin.(3.2) == broadcast(sin, 3.2) == sin(3.2) @@ -237,12 +239,12 @@ let x = sin.(1:10), a = [x] @test atan2.(x, cos.(x)) == atan2.(a..., cos.(x)) == broadcast(atan2, x, cos.(a...)) == broadcast(atan2, a..., cos.(a...)) @test ((args...)->cos(args[1])).(x) == cos.(x) == ((y,args...)->cos(y)).(x) end -@test atan2.(3,4) == atan2(3,4) == (() -> atan2(3,4)).() +@test atan2.(3, 4) == atan2(3, 4) == (() -> atan2(3, 4)).() # fusion with keyword args: let x = [1:4;] f17300kw(x; y=0) = x + y @test f17300kw.(x) == x - @test f17300kw.(x, y=1) == f17300kw.(x; y=1) == f17300kw.(x; [(:y,1)]...) == x .+ 1 + @test f17300kw.(x, y=1) == f17300kw.(x; y=1) == f17300kw.(x; [(:y,1)]...) == x .+ 1 == [2, 3, 4, 5] @test f17300kw.(sin.(x), y=1) == f17300kw.(sin.(x); y=1) == sin.(x) .+ 1 @test sin.(f17300kw.(x, y=1)) == sin.(f17300kw.(x; y=1)) == sin.(x .+ 1) end @@ -408,7 +410,7 @@ StrangeType18623(x,y) = (x,y) let f(A, n) = broadcast(x -> +(x, n), A) @test @inferred(f([1.0], 1)) == [2.0] - g() = (a = 1; Broadcast.combine_eltypes(x -> x + a, 1.0)) + g() = (a = 1; Broadcast.combine_eltypes(x -> x + a, (1.0,))) @test @inferred(g()) === Float64 end @@ -428,7 +430,7 @@ abstract type ArrayData{T,N} <: AbstractArray{T,N} end Base.getindex(A::ArrayData, i::Integer...) = A.data[i...] Base.setindex!(A::ArrayData, v::Any, i::Integer...) = setindex!(A.data, v, i...) Base.size(A::ArrayData) = size(A.data) -Base.broadcast_similar(f, ::Broadcast.ArrayStyle{A}, ::Type{T}, inds::Tuple, As...) where {A,T} = +Base.broadcast_similar(::Broadcast.ArrayStyle{A}, ::Type{T}, inds::Tuple, bc) where {A,T} = A(Array{T}(undef, length.(inds))) struct Array19745{T,N} <: ArrayData{T,N} @@ -488,14 +490,21 @@ Base.BroadcastStyle(a2::Broadcast.ArrayStyle{AD2C}, a1::Broadcast.ArrayStyle{AD1 @testset "broadcasting for custom AbstractArray" begin a = randn(10) aa = Array19745(a) - @test a .+ 1 == @inferred(aa .+ 1) - @test a .* a' == @inferred(aa .* aa') + fadd(aa) = aa .+ 1 + fadd2(aa) = aa .+ 1 .* 2 + fprod(aa) = aa .* aa' + @test a .+ 1 == @inferred(fadd(aa)) + @test a .+ 1 .* 2 == @inferred(fadd2(aa)) + @test a .* a' == @inferred(fprod(aa)) @test isa(aa .+ 1, Array19745) + @test isa(aa .+ 1 .* 2, Array19745) @test isa(aa .* aa', Array19745) a1 = AD1(rand(2,3)) a2 = AD2(rand(2)) @test a1 .+ 1 isa AD1 @test a2 .+ 1 isa AD2 + @test a1 .+ 1 .* 2 isa AD1 + @test a2 .+ 1 .* 2 isa AD2 @test a1 .+ a2 isa Array @test a2 .+ a1 isa Array @test a1 .+ a2 .+ a1 isa Array @@ -504,6 +513,8 @@ Base.BroadcastStyle(a2::Broadcast.ArrayStyle{AD2C}, a1::Broadcast.ArrayStyle{AD1 a2 = AD2P(rand(2)) @test a1 .+ 1 isa AD1P @test a2 .+ 1 isa AD2P + @test a1 .+ 1 .* 2 isa AD1P + @test a2 .+ 1 .* 2 isa AD2P @test a1 .+ a2 isa AD1P @test a2 .+ a1 isa AD1P @test a1 .+ a2 .+ a1 isa AD1P @@ -512,6 +523,8 @@ Base.BroadcastStyle(a2::Broadcast.ArrayStyle{AD2C}, a1::Broadcast.ArrayStyle{AD1 a2 = AD2B(rand(2)) @test a1 .+ 1 isa AD1B @test a2 .+ 1 isa AD2B + @test a1 .+ 1 .* 2 isa AD1B + @test a2 .+ 1 .* 2 isa AD2B @test a1 .+ a2 isa AD1B @test a2 .+ a1 isa AD1B @test a1 .+ a2 .+ a1 isa AD1B @@ -520,6 +533,8 @@ Base.BroadcastStyle(a2::Broadcast.ArrayStyle{AD2C}, a1::Broadcast.ArrayStyle{AD1 a2 = AD2C(rand(2)) @test a1 .+ 1 isa AD1C @test a2 .+ 1 isa AD2C + @test a1 .+ 1 .* 2 isa AD1C + @test a2 .+ 1 .* 2 isa AD2C @test_throws ErrorException a1 .+ a2 end @@ -532,7 +547,7 @@ end # Test that broadcast's promotion mechanism handles closures accepting more than one argument. # (See issue #19641 and referenced issues and pull requests.) -let f() = (a = 1; Broadcast.combine_eltypes((x, y) -> x + y + a, 1.0, 1.0)) +let f() = (a = 1; Broadcast.combine_eltypes((x, y) -> x + y + a, (1.0, 1.0))) @test @inferred(f()) == Float64 end @@ -637,3 +652,52 @@ let n = 1 @test ceil.(Int, n ./ (1,)) == (1,) @test ceil.(Int, 1 ./ (1,)) == (1,) end + + +# lots of splatting! +let x = [[1, 4], [2, 5], [3, 6]] + y = .+(x..., .*(x..., x...)..., x[1]..., x[2]..., x[3]...) + @test y == [14463, 14472] + + z = zeros(2) + z .= .+(x..., .*(x..., x...)..., x[1]..., x[2]..., x[3]...) + @test z == Float64[14463, 14472] +end + +# Issue #21094 +@generated function foo21094(out, x) + quote + out .= x .+ x + out + end +end +@test foo21094([0.0], [1.0]) == [2.0] + +# Issue #22053 +struct T22053 + t +end +Broadcast.BroadcastStyle(::Type{T22053}) = Broadcast.Style{T22053}() +Broadcast.broadcast_axes(::T22053) = () +Broadcast.broadcastable(t::T22053) = t +function Base.copy(bc::Broadcast.Broadcasted{Broadcast.Style{T22053}}) + all(x->isa(x, T22053), bc.args) && return 1 + return 0 +end +Base.:*(::T22053, ::T22053) = 2 +let x = T22053(1) + @test x*x == 2 + @test x.*x == 1 +end + +# Issue https://github.com/JuliaLang/julia/pull/25377#discussion_r159956996 +let X = Any[1,2] + X .= nothing + @test X[1] == X[2] == nothing +end + +# Ensure that broadcast styles with custom indexing work +let X = zeros(2, 3) + X .= (1, 2) + @test X == [1 1 1; 2 2 2] +end diff --git a/test/core.jl b/test/core.jl index ba9f4b9d3a61a..668221ed7bf0b 100644 --- a/test/core.jl +++ b/test/core.jl @@ -1941,11 +1941,11 @@ test5884() # issue #5924 let - function Test() + function test5924() func = function () end func end - @test Test()() === nothing + @test test5924()() === nothing end # issue #6031 diff --git a/test/numbers.jl b/test/numbers.jl index b324aa7592728..b4a8910d365d5 100644 --- a/test/numbers.jl +++ b/test/numbers.jl @@ -2415,7 +2415,7 @@ Base.literal_pow(::typeof(^), ::PR20530, ::Val{p}) where {p} = 2 p = 2 @test x^p == 1 @test x^2 == 2 - @test [x,x,x].^2 == [2,2,2] + @test [x, x, x].^2 == [2, 2, 2] for T in (Float16, Float32, Float64, BigFloat, Int8, Int, BigInt, Complex{Int}, Complex{Float64}) for p in -4:4 v = eval(:($T(2)^$p)) @@ -2430,6 +2430,7 @@ Base.literal_pow(::typeof(^), ::PR20530, ::Val{p}) where {p} = 2 end @test PR20889(2)^3 == 5 @test [2,4,8].^-2 == [0.25, 0.0625, 0.015625] + @test [2, 4, 8].^-2 .* 4 == [1.0, 0.25, 0.0625] # nested literal_pow @test ℯ^-2 == exp(-2) ≈ inv(ℯ^2) ≈ (ℯ^-1)^2 ≈ sqrt(ℯ^-4) end module M20889 # do we get the expected behavior without importing Base.^? diff --git a/test/ranges.jl b/test/ranges.jl index c0ebdeee04c6e..9b32428dea766 100644 --- a/test/ranges.jl +++ b/test/ranges.jl @@ -477,15 +477,15 @@ end @test sum(0:0.1:10) == 505. end @testset "broadcasted operations with scalars" begin - @test broadcast(-, 1:3, 2) == -1:1 - @test broadcast(-, 1:3, 0.25) == 1-0.25:3-0.25 - @test broadcast(+, 1:3, 2) == 3:5 - @test broadcast(+, 1:3, 0.25) == 1+0.25:3+0.25 - @test broadcast(+, 1:2:6, 1) == 2:2:6 - @test broadcast(+, 1:2:6, 0.3) == 1+0.3:2:5+0.3 - @test broadcast(-, 1:2:6, 1) == 0:2:4 - @test broadcast(-, 1:2:6, 0.3) == 1-0.3:2:5-0.3 - @test broadcast(-, 2, 1:3) == 1:-1:-1 + @test broadcast(-, 1:3, 2) === -1:1 + @test broadcast(-, 1:3, 0.25) === 1-0.25:3-0.25 + @test broadcast(+, 1:3, 2) === 3:5 + @test broadcast(+, 1:3, 0.25) === 1+0.25:3+0.25 + @test broadcast(+, 1:2:6, 1) === 2:2:6 + @test broadcast(+, 1:2:6, 0.3) === 1+0.3:2:5+0.3 + @test broadcast(-, 1:2:6, 1) === 0:2:4 + @test broadcast(-, 1:2:6, 0.3) === 1-0.3:2:5-0.3 + @test broadcast(-, 2, 1:3) === 1:-1:-1 end @testset "operations between ranges and arrays" begin @test all(([1:5;] + (5:-1:1)) .== 6) @@ -551,27 +551,33 @@ end @test [0.0:prevfloat(0.1):0.3;] == [0.0, prevfloat(0.1), prevfloat(0.2), 0.3] @test [0.0:nextfloat(0.1):0.3;] == [0.0, nextfloat(0.1), nextfloat(0.2)] end -@testset "issue #7420 for type $T" for T = (Float32, Float64,), # BigFloat), - a = -5:25, - s = [-5:-1; 1:25; ], - d = 1:25, - n = -1:15 - - denom = convert(T, d) - strt = convert(T, a)/denom - Δ = convert(T, s)/denom - stop = convert(T, (a + (n - 1) * s)) / denom - vals = T[a:s:(a + (n - 1) * s); ] ./ denom - r = strt:Δ:stop - @test [r;] == vals - @test [range(strt, stop=stop, length=length(r));] == vals - n = length(r) - @test [r[1:n];] == [r;] - @test [r[2:n];] == [r;][2:end] - @test [r[1:3:n];] == [r;][1:3:n] - @test [r[2:2:n];] == [r;][2:2:n] - @test [r[n:-1:2];] == [r;][n:-1:2] - @test [r[n:-2:1];] == [r;][n:-2:1] + +function loop_range_values(::Type{T}) where T + for a = -5:25, + s = [-5:-1; 1:25; ], + d = 1:25, + n = -1:15 + + denom = convert(T, d) + strt = convert(T, a)/denom + Δ = convert(T, s)/denom + stop = convert(T, (a + (n - 1) * s)) / denom + vals = T[a:s:(a + (n - 1) * s); ] ./ denom + r = strt:Δ:stop + @test [r;] == vals + @test [range(strt, stop=stop, length=length(r));] == vals + n = length(r) + @test [r[1:n];] == [r;] + @test [r[2:n];] == [r;][2:end] + @test [r[1:3:n];] == [r;][1:3:n] + @test [r[2:2:n];] == [r;][2:2:n] + @test [r[n:-1:2];] == [r;][n:-1:2] + @test [r[n:-2:1];] == [r;][n:-2:1] + end +end + +@testset "issue #7420 for type $T" for T = (Float32, Float64,) # BigFloat), + loop_range_values(T) end @testset "issue #20373 (unliftable ranges with exact end points)" begin @@ -990,7 +996,10 @@ end for _r in (1:2:100, 1:100, 1f0:2f0:100f0, 1.0:2.0:100.0, range(1, stop=100, length=10), range(1f0, stop=100f0, length=10)) float_r = float(_r) - big_r = big.(_r) + big_r = broadcast(big, _r) + big_rdot = big.(_r) + @test big_rdot == big_r + @test typeof(big_r) == typeof(big_rdot) @test typeof(big_r).name === typeof(_r).name if eltype(_r) <: AbstractFloat @test isa(float_r, typeof(_r)) @@ -1217,6 +1226,22 @@ end @test map(BigFloat, x) === x end +@testset "broadcasting returns ranges" begin + x, r = 2, 1:5 + @test @inferred(x .+ r) === 3:7 + @test @inferred(r .+ x) === 3:7 + @test @inferred(r .- x) === -1:3 + @test @inferred(x .- r) === 1:-1:-3 + @test @inferred(x .* r) === 2:2:10 + @test @inferred(r .* x) === 2:2:10 + @test @inferred(r ./ x) === 0.5:0.5:2.5 + @test @inferred(x ./ r) == 2 ./ [r;] && isa(x ./ r, Vector{Float64}) + @test @inferred(r .\ x) == 2 ./ [r;] && isa(x ./ r, Vector{Float64}) + @test @inferred(x .\ r) === 0.5:0.5:2.5 + + @test @inferred(2 .* (r .+ 1) .+ 2) === 6:2:14 +end + @testset "Bad range calls" begin @test_throws ArgumentError range(1) @test_throws ArgumentError range(nothing)