Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix setindex! with SubDArray source #74

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 124 additions & 25 deletions src/DistributedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ end

function Base.convert{T,N}(::Type{DArray}, SD::SubArray{T,N})
D = SD.parent
DArray(SD.dims, procs(D)) do I
DArray(size(SD), procs(D)) do I
TR = typeof(SD.indexes[1])
lindices = Array(TR, 0)
for (i,r) in zip(I, SD.indexes)
Expand Down Expand Up @@ -635,34 +635,133 @@ function Base.setindex!(a::Array, d::DArray,
return a
end

function Base.setindex!(a::Array, s::SubDArray,
I::Union{UnitRange{Int},Colon,Vector{Int},StepRange{Int,Int}}...)
n = length(I)
d = s.parent
J = s.indexes
if length(J) < n
a[I...] = convert(Array,s)
return a
# We also want to optimize setindex! with a SubDArray source, but this is hard
# and only works on 0.5.
if VERSION > v"0.5.0-dev+5230"
# Similar to Base.indexin, but just create a logical mask. Note that this
# must return a logical mask in order to support merging multiple masks
# together into one linear index since we need to know how many elements to
# skip at the end. In many cases range intersection would be much faster
# than generating a logical mask, but that loses the endpoint information.
indexin_mask(a, b::Number) = a .== b
indexin_mask(a, r::Range{Int}) = [i in r for i in a]
indexin_mask(a, b::AbstractArray{Int}) = indexin_mask(a, IntSet(b))
indexin_mask(a, b::AbstractArray) = indexin_mask(a, Set(b))
indexin_mask(a, b) = [i in b for i in a]

import Base: tail
# Given a tuple of indices and a tuple of masks, restrict the indices to the
# valid regions. This is, effectively, reversing Base.setindex_shape_check.
# We can't just use indexing into MergedIndices here because getindex is much
# pickier about singleton dimensions than setindex! is.
restrict_indices(::Tuple{}, ::Tuple{}) = ()
function restrict_indices(a::Tuple{Any, Vararg{Any}}, b::Tuple{Any, Vararg{Any}})
if (length(a[1]) == length(b[1]) == 1) || (length(a[1]) > 1 && length(b[1]) > 1)
(vec(a[1])[vec(b[1])], restrict_indices(tail(a), tail(b))...)
elseif length(a[1]) == 1
(a[1], restrict_indices(tail(a), b))
elseif length(b[1]) == 1 && b[1][1]
restrict_indices(a, tail(b))
else
throw(DimensionMismatch("this should be caught by setindex_shape_check; please submit an issue"))
end
end
offs = [isa(J[i],Int) ? J[i]-1 : first(J[i])-1 for i=1:n]
@sync for i = 1:length(d.pids)
K_c = Any[d.indexes[i]...]
K = [ intersect(J[j],K_c[j]) for j=1:n ]
if !any(isempty, K)
idxs = [ I[j][K[j]-offs[j]] for j=1:n ]
if isequal(K, K_c)
# whole chunk
@async a[idxs...] = chunk(d, i)
else
# partial chunk
@async a[idxs...] =
remotecall_fetch(d.pids[i]) do
view(localpart(d), [K[j]-first(K_c[j])+1 for j=1:n]...)
end
# The final indices are funky - they're allowed to accumulate together.
# An easy (albeit very inefficient) fix for too many masks is to use the
# outer product to merge them. But we can do that lazily with a custom type:
function restrict_indices(a::Tuple{Any}, b::Tuple{Any, Any, Vararg{Any}})
(vec(a[1])[vec(ProductIndices(b, map(length, b)))],)
end
# But too many indices is much harder; this requires merging the indices
# in `a` before applying the final mask in `b`.
function restrict_indices(a::Tuple{Any, Any, Vararg{Any}}, b::Tuple{Any})
if length(a[1]) == 1
(a[1], restrict_indices(tail(a), b))
else
# When one mask spans multiple indices, we need to merge the indices
# together. At this point, we can just use indexing to merge them since
# there's no longer special handling of singleton dimensions
(view(MergedIndices(a, map(length, a)), b[1]),)
end
end

immutable ProductIndices{I,N} <: AbstractArray{Bool, N}
indices::I
sz::NTuple{N,Int}
end
Base.size(P::ProductIndices) = P.sz
# This gets passed to map to avoid breaking propagation of inbounds
Base.@propagate_inbounds propagate_getindex(A, I...) = A[I...]
Base.@propagate_inbounds Base.getindex{_,N}(P::ProductIndices{_,N}, I::Vararg{Int, N}) =
Bool((&)(map(propagate_getindex, P.indices, I)...))

immutable MergedIndices{I,N} <: AbstractArray{CartesianIndex{N}, N}
indices::I
sz::NTuple{N,Int}
end
Base.size(M::MergedIndices) = M.sz
Base.@propagate_inbounds Base.getindex{_,N}(M::MergedIndices{_,N}, I::Vararg{Int, N}) =
CartesianIndex(map(propagate_getindex, M.indices, I))
# Additionally, we optimize bounds checking when using MergedIndices as an
# array index since checking, e.g., A[1:500, 1:500] is *way* faster than
# checking an array of 500^2 elements of CartesianIndex{2}. This optimization
# also applies to reshapes of MergedIndices since the outer shape of the
# container doesn't affect the index elements themselves. We can go even
# farther and say that even restricted views of MergedIndices must be valid
# over the entire array. This is overly strict in general, but in this
# use-case all the merged indices must be valid at some point, so it's ok.
typealias ReshapedMergedIndices{T,N,M<:MergedIndices} Base.ReshapedArray{T,N,M}
typealias SubMergedIndices{T,N,M<:Union{MergedIndices, ReshapedMergedIndices}} SubArray{T,N,M}
typealias MergedIndicesOrSub Union{MergedIndices, ReshapedMergedIndices, SubMergedIndices}
import Base: checkbounds_indices
@inline checkbounds_indices(::Type{Bool}, inds::Tuple{}, I::Tuple{MergedIndicesOrSub,Vararg{Any}}) =
checkbounds_indices(Bool, inds, (parent(parent(I[1])).indices..., tail(I)...))
@inline checkbounds_indices(::Type{Bool}, inds::Tuple{Any}, I::Tuple{MergedIndicesOrSub,Vararg{Any}}) =
checkbounds_indices(Bool, inds, (parent(parent(I[1])).indices..., tail(I)...))
@inline checkbounds_indices(::Type{Bool}, inds::Tuple, I::Tuple{MergedIndicesOrSub,Vararg{Any}}) =
checkbounds_indices(Bool, inds, (parent(parent(I[1])).indices..., tail(I)...))

# The tricky thing here is that we want to optimize the accesses into the
# distributed array, but in doing so, we lose track of which indices in I we
# should be using.
#
# I’ve come to the conclusion that the function is utterly insane.
# There are *6* flavors of indices with four different reference points:
# 1. Find the indices of each portion of the DArray.
# 2. Find the valid subset of indices for the SubArray into that portion.
# 3. Find the portion of the `I` indices that should be used when you access the
# `K` indices in the subarray. This guy is nasty. It’s totally backwards
# from all other arrays, wherein we simply iterate over the source array’s
# elements. You need to *both* know which elements in `J` were skipped
# (`indexin_mask`) and which dimensions should match up (`restrict_indices`)
# 4. If `K` doesn’t correspond to an entire chunk, reinterpret `K` in terms of
# the local portion of the source array
function Base.setindex!(a::Array, s::SubDArray,
I::Union{UnitRange{Int},Colon,Vector{Int},StepRange{Int,Int}}...)
Base.setindex_shape_check(s, Base.index_lengths(a, I...)...)
n = length(I)
d = s.parent
J = Base.decolon(d, s.indexes...)
@sync for i = 1:length(d.pids)
K_c = d.indexes[i]
K = map(intersect, J, K_c)
if !any(isempty, K)
K_mask = map(indexin_mask, J, K_c)
idxs = restrict_indices(Base.decolon(a, I...), K_mask)
if isequal(K, K_c)
# whole chunk
@async a[idxs...] = chunk(d, i)
else
# partial chunk
@async a[idxs...] =
remotecall_fetch(d.pids[i]) do
view(localpart(d), [K[j]-first(K_c[j])+1 for j=1:length(J)]...)
end
end
end
end
return a
end
return a
end

Base.fill!(A::DArray, x) = begin
Expand Down
16 changes: 16 additions & 0 deletions test/darray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,22 @@ facts("test DArray / Array conversion") do
@fact fetch(@spawnat MYID localpart(D)[1,1]) --> D[1,1]
@fact fetch(@spawnat OTHERIDS localpart(D)[1,1]) --> D[1,101]
close(D2)

S2 = convert(Vector{Float64}, D[4, 23:176])
@fact A[4, 23:176] --> S2

S3 = convert(Vector{Float64}, D[23:176, 197])
@fact A[23:176, 197] --> S3

S4 = zeros(4)
setindex!(S4, D[3:4, 99:100], :)
@fact S4 --> vec(D[3:4, 99:100])
@fact S4 --> vec(A[3:4, 99:100])

S5 = zeros(2,2)
setindex!(S5, D[1,1:4], :, 1:2)
@fact vec(S5) --> D[1, 1:4]
@fact vec(S5) --> A[1, 1:4]
end
close(D)
end
Expand Down