Skip to content

Commit

Permalink
Rename Broadcast.*_indices to *_axes as appropriate
Browse files Browse the repository at this point in the history
  • Loading branch information
mbauman committed Apr 23, 2018
1 parent a2b9015 commit c8bb374
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 62 deletions.
56 changes: 28 additions & 28 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using .Base.Cartesian
using .Base: Indices, OneTo, linearindices, tail, to_shape, isoperator, promote_typejoin,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
import .Base: broadcast, broadcast!, copy, copyto!
export BroadcastStyle, broadcast_indices, broadcast_similar, broadcastable,
export BroadcastStyle, broadcast_axes, broadcast_similar, broadcastable,
broadcast_getindex, broadcast_setindex!, dotview, @__dot__

### Objects with customized broadcasting behavior should declare a BroadcastStyle
Expand Down Expand Up @@ -188,7 +188,7 @@ Base.show(io::IO, bc::Broadcasted{Style}) where {Style} = print(io, Broadcasted,
broadcast_similar(::BroadcastStyle, ::Type{ElType}, inds, As...)
Allocate an output object for [`broadcast`](@ref), appropriate for the indicated
[`Broadcast.BroadcastStyle`](@ref). `ElType` and `inds` specify the desired element type and indices of the
[`Broadcast.BroadcastStyle`](@ref). `ElType` and `inds` specify the desired element type and axes of the
container. `As...` are the input arguments supplied to `broadcast`.
"""
broadcast_similar(::DefaultArrayStyle{N}, ::Type{ElType}, inds::Indices{N}, bc) where {N,ElType} =
Expand All @@ -201,25 +201,25 @@ broadcast_similar(::ArrayConflict, ::Type{ElType}, inds::Indices, bc) where ElTy
broadcast_similar(::ArrayConflict, ::Type{Bool}, inds::Indices, bc) =
similar(BitArray, inds)

## Computing the result's indices. Most types probably won't need to specialize this.
broadcast_indices() = ()
broadcast_indices(A::Tuple) = (OneTo(length(A)),)
broadcast_indices(A::Ref) = ()
broadcast_indices(A) = axes(A)
## Computing the result's axes. Most types probably won't need to specialize this.
broadcast_axes() = ()
broadcast_axes(A::Tuple) = (OneTo(length(A)),)
broadcast_axes(A::Ref) = ()
broadcast_axes(A) = axes(A)
"""
Base.broadcast_indices(::SrcStyle, A)
Base.broadcast_axes(A)
Compute the indices for objects `A` with [`BroadcastStyle`](@ref) `SrcStyle`.
If needed, you can specialize this method for your styles.
You should only need to provide a custom implementation for non-AbstractArrayStyles.
Compute the axes for `A`.
This should only be specialized for objects that do not define axes but want to participate in broadcasting.
"""
broadcast_indices
broadcast_axes

### End of methods that users will typically have to specialize ###

Base.axes(bc::Broadcasted) = _axes(bc, bc.axes)
_axes(::Broadcasted, axes::Tuple) = axes
_axes(bc::Broadcasted, ::Nothing) = combine_indices(bc.args...)
_axes(bc::Broadcasted, ::Nothing) = combine_axes(bc.args...)
_axes(bc::Broadcasted{Style{Tuple}}, ::Nothing) = (Base.OneTo(length(longest_tuple(nothing, bc.args))),)
_axes(bc::Broadcasted{<:AbstractArrayStyle{0}}, ::Nothing) = ()

Expand Down Expand Up @@ -252,10 +252,10 @@ they must provide their own `Base.axes(::Broadcasted{Style})` and
"""
@inline function instantiate(bc::Broadcasted{Style}) where {Style}
if bc.axes isa Nothing # Not done via dispatch to make it easier to extend instantiate(::Broadcasted{Style})
axes = combine_indices(bc.args...)
axes = combine_axes(bc.args...)
else
axes = bc.axes
check_broadcast_indices(axes, bc.args...)
check_broadcast_axes(axes, bc.args...)
end
return Broadcasted{Style}(bc.f, bc.args, axes)
end
Expand Down Expand Up @@ -411,8 +411,8 @@ One of these should be undefined (and thus return Broadcast.Unknown).""")
end

# Indices utilities
combine_indices(A, B...) = broadcast_shape(broadcast_indices(A), combine_indices(B...))
combine_indices(A) = broadcast_indices(A)
combine_axes(A, B...) = broadcast_shape(broadcast_axes(A), combine_axes(B...))
combine_axes(A) = broadcast_axes(A)

# shape (i.e., tuple-of-indices) inputs
broadcast_shape(shape::Tuple) = shape
Expand Down Expand Up @@ -444,11 +444,11 @@ function check_broadcast_shape(shp, Ashp::Tuple)
_bcsm(shp[1], Ashp[1]) || throw(DimensionMismatch("array could not be broadcast to match destination"))
check_broadcast_shape(tail(shp), tail(Ashp))
end
check_broadcast_indices(shp, A) = check_broadcast_shape(shp, broadcast_indices(A))
check_broadcast_axes(shp, A) = check_broadcast_shape(shp, broadcast_axes(A))
# comparing many inputs
@inline function check_broadcast_indices(shp, A, As...)
check_broadcast_indices(shp, A)
check_broadcast_indices(shp, As...)
@inline function check_broadcast_axes(shp, A, As...)
check_broadcast_axes(shp, A)
check_broadcast_axes(shp, As...)
end

## Indexing manipulations
Expand All @@ -468,8 +468,8 @@ an `Int`.
Any remaining indices in `I` beyond the length of the `keep` tuple are truncated. The `keep` and `default`
tuples may be created by `newindexer(argument)`.
"""
Base.@propagate_inbounds newindex(arg, I::CartesianIndex) = CartesianIndex(_newindex(broadcast_indices(arg), I.I))
Base.@propagate_inbounds newindex(arg, I::Int) = CartesianIndex(_newindex(broadcast_indices(arg), (I,)))
Base.@propagate_inbounds newindex(arg, I::CartesianIndex) = CartesianIndex(_newindex(broadcast_axes(arg), I.I))
Base.@propagate_inbounds newindex(arg, I::Int) = CartesianIndex(_newindex(broadcast_axes(arg), (I,)))
Base.@propagate_inbounds _newindex(ax::Tuple, I::Tuple) = (ifelse(Base.unsafe_length(ax[1])==1, ax[1][1], I[1]), _newindex(tail(ax), tail(I))...)
Base.@propagate_inbounds _newindex(ax::Tuple{}, I::Tuple) = ()
Base.@propagate_inbounds _newindex(ax::Tuple, I::Tuple{}) = (ax[1][1], _newindex(tail(ax), ())...)
Expand All @@ -484,7 +484,7 @@ Base.@propagate_inbounds _newindex(ax::Tuple{}, I::Tuple{}) = ()

# newindexer(A) generates `keep` and `Idefault` (for use by `newindex` above)
# for a particular array `A`; `shapeindexer` does so for its axes.
@inline newindexer(A) = shapeindexer(broadcast_indices(A))
@inline newindexer(A) = shapeindexer(broadcast_axes(A))
@inline shapeindexer(ax) = _newindexer(ax)
@inline _newindexer(indsA::Tuple{}) = (), ()
@inline function _newindexer(indsA::Tuple)
Expand Down Expand Up @@ -525,7 +525,7 @@ struct Extruded{T, K, D}
keeps::K # A tuple of booleans, specifying which indices should be passed normally
defaults::D # A tuple of integers, specifying the index to use when keeps[i] is false (as defaults[i])
end
@inline broadcast_indices(b::Extruded) = broadcast_indices(b.x)
@inline broadcast_axes(b::Extruded) = broadcast_axes(b.x)
Base.@propagate_inbounds _broadcast_getindex(b::Extruded, i) = b.x[newindex(i, b.keeps, b.defaults)]
extrude(x::AbstractArray) = Extruded(x, newindexer(x)...)
extrude(x) = x
Expand Down Expand Up @@ -1034,14 +1034,14 @@ julia> broadcast_getindex(A, [1 2 1; 1 2 2], [1, 2])
```
"""
broadcast_getindex(src::AbstractArray, I::AbstractArray...) =
broadcast_getindex!(Base.similar(Array{eltype(src)}, combine_indices(I...)), src, I...)
broadcast_getindex!(Base.similar(Array{eltype(src)}, combine_axes(I...)), src, I...)

@generated function broadcast_getindex!(dest::AbstractArray, src::AbstractArray, I::AbstractArray...)
N = length(I)
Isplat = Expr[:(I[$d]) for d = 1:N]
quote
@nexprs $N d->(I_d = I[d])
check_broadcast_indices(Base.axes(dest), $(Isplat...)) # unnecessary if this function is never called directly
check_broadcast_axes(Base.axes(dest), $(Isplat...)) # unnecessary if this function is never called directly
checkbounds(src, $(Isplat...))
@nexprs $N d->(@nexprs $N k->(Ibcast_d_k = Base.axes(I_k, d) == OneTo(1)))
@nloops $N i dest d->(@nexprs $N k->(j_d_k = Ibcast_d_k ? 1 : i_d)) begin
Expand Down Expand Up @@ -1072,7 +1072,7 @@ See [`broadcast_getindex`](@ref) for examples of the treatment of `inds`.
quote
@nexprs $N d->(I_d = I[d])
checkbounds(A, $(Isplat...))
shape = combine_indices($(Isplat...))
shape = combine_axes($(Isplat...))
@nextract $N shape d->(length(shape) < d ? OneTo(1) : shape[d])
@nexprs $N d->(@nexprs $N k->(Ibcast_d_k = Base.axes(I_k, d) == 1:1))
if !isa(x, AbstractArray)
Expand Down
4 changes: 4 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,10 @@ end
@deprecate indices(a) axes(a)
@deprecate indices(a, d) axes(a, d)

# And similar _indices names in Broadcast
@eval Broadcast Base.@deprecate_binding broadcast_indices broadcast_axes false
@eval Broadcast Base.@deprecate_binding check_broadcast_indices check_broadcast_axes false

# PR #25046
export reload, workspace
reload(name::AbstractString) = error("`reload($(repr(name)))` is discontinued, consider Revise.jl for an alternative workflow.")
Expand Down
2 changes: 1 addition & 1 deletion doc/src/base/arrays.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ For specializing broadcast on custom types, see
```@docs
Base.BroadcastStyle
Base.broadcast_similar
Base.broadcast_indices
Base.broadcast_axes
Base.Broadcast.AbstractArrayStyle
Base.Broadcast.ArrayStyle
Base.Broadcast.DefaultArrayStyle
Expand Down
2 changes: 1 addition & 1 deletion doc/src/manual/interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ V = view(A, [1,2,4], :) # is not strided, as the spacing between rows is not f
| `Base.broadcast_similar(::DestStyle, ::Type{ElType}, inds, bc)` | Allocation of output container |
| **Optional methods** | | |
| `Base.BroadcastStyle(::Style1, ::Style2) = Style12()` | Precedence rules for mixing styles |
| `Base.broadcast_indices(::StyleA, A)` | Declaration of the indices of `A` for broadcasting purposes (defaults to [`axes(A)`](@ref)) |
| `Base.broadcast_axes(::StyleA, A)` | Declaration of the indices of `A` for broadcasting purposes (defaults to [`axes(A)`](@ref)) |
| `Base.broadcastable(x)` | Convert `x` to an object that has `axes` and supports indexing |
| **Bypassing default machinery** | |
| `Base.copy(bc::Broadcasted{DestStyle})` | Custom implementation of `broadcast` |
Expand Down
4 changes: 2 additions & 2 deletions stdlib/SparseArrays/src/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ function _diffshape_broadcast(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMa
fpreszeros = _iszero(fofzeros)
indextypeC = _promote_indtype(A, Bs...)
entrytypeC = Base.Broadcast.combine_eltypes(f, (A, Bs...))
shapeC = to_shape(Base.Broadcast.combine_indices(A, Bs...))
shapeC = to_shape(Base.Broadcast.combine_axes(A, Bs...))
maxnnzC = fpreszeros ? _checked_maxnnzbcres(shapeC, A, Bs...) : _densennz(shapeC)
C = _allocres(shapeC, indextypeC, entrytypeC, maxnnzC)
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
Expand Down Expand Up @@ -984,7 +984,7 @@ end

@inline function _copyto!(f, dest, As::SparseVecOrMat...)
_aresameshape(dest, As...) && return _noshapecheck_map!(f, dest, As...)
Base.Broadcast.check_broadcast_indices(axes(dest), As...)
Base.Broadcast.check_broadcast_axes(axes(dest), As...)
fofzeros = f(_zeros_eltypes(As...)...)
if _iszero(fofzeros)
return _broadcast_zeropres!(f, dest, As...)
Expand Down
24 changes: 12 additions & 12 deletions stdlib/SparseArrays/test/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ end
@test broadcast!(cos, Z, X) == sparse(broadcast!(cos, fZ, fX))
# --> test shape checks for broadcast! entry point
# TODO strengthen this test, avoiding dependence on checking whether
# check_broadcast_indices throws to determine whether sparse broadcast should throw
# check_broadcast_axes throws to determine whether sparse broadcast should throw
try
Base.Broadcast.check_broadcast_indices(axes(Z), spzeros((shapeX .- 1)...))
Base.Broadcast.check_broadcast_axes(axes(Z), spzeros((shapeX .- 1)...))
catch
@test_throws DimensionMismatch broadcast!(sin, Z, spzeros((shapeX .- 1)...))
end
Expand All @@ -149,9 +149,9 @@ end
@test broadcast!(cos, V, X) == sparse(broadcast!(cos, fV, fX))
# --> test shape checks for broadcast! entry point
# TODO strengthen this test, avoiding dependence on checking whether
# check_broadcast_indices throws to determine whether sparse broadcast should throw
# check_broadcast_axes throws to determine whether sparse broadcast should throw
try
Base.Broadcast.check_broadcast_indices(axes(V), spzeros((shapeX .- 1)...))
Base.Broadcast.check_broadcast_axes(axes(V), spzeros((shapeX .- 1)...))
catch
@test_throws DimensionMismatch broadcast!(sin, V, spzeros((shapeX .- 1)...))
end
Expand Down Expand Up @@ -184,9 +184,9 @@ end
@test broadcast(*, X, Y) == sparse(broadcast(*, fX, fY))
@test broadcast(f, X, Y) == sparse(broadcast(f, fX, fY))
# TODO strengthen this test, avoiding dependence on checking whether
# check_broadcast_indices throws to determine whether sparse broadcast should throw
# check_broadcast_axes throws to determine whether sparse broadcast should throw
try
Base.Broadcast.combine_indices(spzeros((shapeX .- 1)...), Y)
Base.Broadcast.combine_axes(spzeros((shapeX .- 1)...), Y)
catch
@test_throws DimensionMismatch broadcast(+, spzeros((shapeX .- 1)...), Y)
end
Expand All @@ -207,9 +207,9 @@ end
@test broadcast!(f, Z, X, Y) == sparse(broadcast!(f, fZ, fX, fY))
# --> test shape checks for both broadcast and broadcast! entry points
# TODO strengthen this test, avoiding dependence on checking whether
# check_broadcast_indices throws to determine whether sparse broadcast should throw
# check_broadcast_axes throws to determine whether sparse broadcast should throw
try
Base.Broadcast.check_broadcast_indices(axes(Z), spzeros((shapeX .- 1)...), Y)
Base.Broadcast.check_broadcast_axes(axes(Z), spzeros((shapeX .- 1)...), Y)
catch
@test_throws DimensionMismatch broadcast!(f, Z, spzeros((shapeX .- 1)...), Y)
end
Expand Down Expand Up @@ -247,9 +247,9 @@ end
@test broadcast(*, X, Y, Z) == sparse(broadcast(*, fX, fY, fZ))
@test broadcast(f, X, Y, Z) == sparse(broadcast(f, fX, fY, fZ))
# TODO strengthen this test, avoiding dependence on checking whether
# check_broadcast_indices throws to determine whether sparse broadcast should throw
# check_broadcast_axes throws to determine whether sparse broadcast should throw
try
Base.Broadcast.combine_indices(spzeros((shapeX .- 1)...), Y, Z)
Base.Broadcast.combine_axes(spzeros((shapeX .- 1)...), Y, Z)
catch
@test_throws DimensionMismatch broadcast(+, spzeros((shapeX .- 1)...), Y, Z)
end
Expand Down Expand Up @@ -279,9 +279,9 @@ end
@test broadcast!(f, Q, X, Y, Z) == sparse(broadcast!(f, fQ, fX, fY, fZ))
# --> test shape checks for both broadcast and broadcast! entry points
# TODO strengthen this test, avoiding dependence on checking whether
# check_broadcast_indices throws to determine whether sparse broadcast should throw
# check_broadcast_axes throws to determine whether sparse broadcast should throw
try
Base.Broadcast.check_broadcast_indices(axes(Q), spzeros((shapeX .- 1)...), Y, Z)
Base.Broadcast.check_broadcast_axes(axes(Q), spzeros((shapeX .- 1)...), Y, Z)
catch
@test_throws DimensionMismatch broadcast!(f, Q, spzeros((shapeX .- 1)...), Y, Z)
end
Expand Down
36 changes: 18 additions & 18 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

module TestBroadcastInternals

using Base.Broadcast: check_broadcast_indices, check_broadcast_shape, newindex, _bcs
using Base.Broadcast: check_broadcast_axes, check_broadcast_shape, newindex, _bcs
using Base: OneTo
using Test, Random

Expand All @@ -19,22 +19,22 @@ using Test, Random
@test_throws DimensionMismatch _bcs((-1:1, 2:6), (-1:1, 2:5))
@test_throws DimensionMismatch _bcs((-1:1, 2:5), (2, 2:5))

@test @inferred(Broadcast.combine_indices(zeros(3,4), zeros(3,4))) == (OneTo(3),OneTo(4))
@test @inferred(Broadcast.combine_indices(zeros(3,4), zeros(3))) == (OneTo(3),OneTo(4))
@test @inferred(Broadcast.combine_indices(zeros(3), zeros(3,4))) == (OneTo(3),OneTo(4))
@test @inferred(Broadcast.combine_indices(zeros(3), zeros(1,4), zeros(1))) == (OneTo(3),OneTo(4))

check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,5))
check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,1))
check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3))
check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,5), zeros(3))
check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,5), 1)
check_broadcast_indices((OneTo(3),OneTo(5)), 5, 2)
@test_throws DimensionMismatch check_broadcast_indices((OneTo(3),OneTo(5)), zeros(2,5))
@test_throws DimensionMismatch check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,4))
@test_throws DimensionMismatch check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,4,2))
@test_throws DimensionMismatch check_broadcast_indices((OneTo(3),OneTo(5)), zeros(3,5), zeros(2))
check_broadcast_indices((-1:1, 6:9), 1)
@test @inferred(Broadcast.combine_axes(zeros(3,4), zeros(3,4))) == (OneTo(3),OneTo(4))
@test @inferred(Broadcast.combine_axes(zeros(3,4), zeros(3))) == (OneTo(3),OneTo(4))
@test @inferred(Broadcast.combine_axes(zeros(3), zeros(3,4))) == (OneTo(3),OneTo(4))
@test @inferred(Broadcast.combine_axes(zeros(3), zeros(1,4), zeros(1))) == (OneTo(3),OneTo(4))

check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,5))
check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,1))
check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3))
check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,5), zeros(3))
check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,5), 1)
check_broadcast_axes((OneTo(3),OneTo(5)), 5, 2)
@test_throws DimensionMismatch check_broadcast_axes((OneTo(3),OneTo(5)), zeros(2,5))
@test_throws DimensionMismatch check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,4))
@test_throws DimensionMismatch check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,4,2))
@test_throws DimensionMismatch check_broadcast_axes((OneTo(3),OneTo(5)), zeros(3,5), zeros(2))
check_broadcast_axes((-1:1, 6:9), 1)

check_broadcast_shape((-1:1, 6:9), (-1:1, 6:9))
check_broadcast_shape((-1:1, 6:9), (-1:1, 1))
Expand Down Expand Up @@ -678,7 +678,7 @@ struct T22053
t
end
Broadcast.BroadcastStyle(::Type{T22053}) = Broadcast.Style{T22053}()
Broadcast.broadcast_indices(::T22053) = ()
Broadcast.broadcast_axes(::T22053) = ()
Broadcast.broadcastable(t::T22053) = t
function Base.copy(bc::Broadcast.Broadcasted{Broadcast.Style{T22053}})
all(x->isa(x, T22053), bc.args) && return 1
Expand Down

0 comments on commit c8bb374

Please sign in to comment.