Skip to content

Commit

Permalink
Bugfix in setaxes!
Browse files Browse the repository at this point in the history
  • Loading branch information
fverdugo committed Aug 21, 2020
1 parent c71015f commit bbb97b5
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 2 deletions.
47 changes: 45 additions & 2 deletions src/Arrays/CachedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,57 @@ function similar(::Type{CachedArray{T,N,A}},s::Tuple{Vararg{Int}}) where {T,N,A}
end

function setaxes!(a::CachedArray,ax)
s = map(length,ax)
if s != size(a.array)
if ! _same_axes(axes(a.array),ax)
s = map(length,ax)
if haskey(a.buffer,s)
a.array = a.buffer[s]
if ! _same_axes(axes(a.array),ax)
a.array = similar(a.array,ax)
a.buffer[s] = a.array
end
else
a.array = similar(a.array,ax)
a.buffer[s] = a.array
end
end
nothing
end

function _same_axes(a,b)
a === b || a == b
end

function _same_axes(a::NTuple{N,BlockedUnitRange},b::NTuple{N,BlockedUnitRange}) where N
if a === b
true
else
all(map(_same_axes_1d,a,b))
end
end

_same_axes_1d(a::BlockedUnitRange,b::BlockedUnitRange) = blocklasts(a) == blocklasts(b)

function _same_axes(a::NTuple{N,TwoLevelBlockedUnitRange},b::NTuple{N,TwoLevelBlockedUnitRange}) where N
if a === b
true
else
all(map(_same_axes_1d,a,b))
end
end

function _same_axes_1d(a::TwoLevelBlockedUnitRange,b::TwoLevelBlockedUnitRange)
r = _same_axes_1d(a.global_range,b.global_range)
la = length(a.local_ranges)
lb = length(b.local_ranges)
if la!=lb
return false
else
for i in 1:la
@inbounds ra = a.local_ranges[i]
@inbounds rb = b.local_ranges[i]
r = r && _same_axes_1d(ra,rb)
end
return r
end
end

27 changes: 27 additions & 0 deletions test/ArraysTests/BlockArraysCooTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,25 @@ cc = CachedArray(c)

axs = (blockedrange([2,3,3]), blockedrange([2,3,3]))
setaxes!(cc,axs)
@test map(blocklasts,axes(cc.array)) == map(blocklasts,axs)
@test cc.array === c

axs = (blockedrange([4,5,3]), blockedrange([2,4,3]))
setaxes!(cc,axs)
@test map(blocklasts,axes(cc.array)) == map(blocklasts,axs)
fill!(cc.array,0)

axs = (blockedrange([5,4,3]), blockedrange([4,2,3]))
setaxes!(cc,axs)
@test map(blocklasts,axes(cc.array)) == map(blocklasts,axs)

axs1 = (blockedrange([5,4,3]), blockedrange([4,2,3]))
axs2 = (blockedrange([4,5,3]), blockedrange([2,4,3]))
axs3 = (blockedrange([4,5,3]), blockedrange([2,4,3]))
@test Arrays._same_axes(axs1,axs2) == false
@test Arrays._same_axes(axs2,axs2)
@test Arrays._same_axes(axs2,axs3)

blocks = [ 10*[1,2], 20*[1,2,3] ]
blockids = [(1,),(3,)]
axs = (blockedrange([2,4,3]),)
Expand All @@ -106,10 +119,12 @@ b = BlockArrayCoo(blocks,blockids,axs)
cb = CachedArray(b)

setaxes!(cb,axs)
@test map(blocklasts,axes(cb.array)) == map(blocklasts,axs)
@test cb.array === b

axs = (blockedrange([3,2,3]),)
setaxes!(cb,axs)
@test map(blocklasts,axes(cb.array)) == map(blocklasts,axs)
@test size(cb) == (8,)

c = copy(a)
Expand Down Expand Up @@ -254,4 +269,16 @@ mul!(rS,aS,bS,3,2)
@test isa(rS,BlockArrayCoo)
@test isa(rS[Block(2)],BlockArrayCoo)

axs1 = (blockedrange([5,4,3]), blockedrange([4,2,3]))
axs2 = (blockedrange([4,5,3]), blockedrange([2,4,3]))
axsA = (blockedrange([axs1[1],axs1[1]]),blockedrange([axs2[2],axs2[2]]))
axsB = (blockedrange([axs2[1],axs2[1]]),blockedrange([axs1[2],axs1[2]]))
axsC = (blockedrange([axs2[1],axs2[1]]),blockedrange([axs1[2],axs1[2]]))
@test Arrays._same_axes(axsA,axsB) == false
@test Arrays._same_axes(axsA,axsA)
@test Arrays._same_axes(axsB,axsC)

#using BenchmarkTools
#@btime Arrays._same_axes($axsA,$axsA)

end # module

0 comments on commit bbb97b5

Please sign in to comment.