Skip to content

Commit

Permalink
properly qualify stack
Browse files Browse the repository at this point in the history
  • Loading branch information
aplavin committed Oct 18, 2022
1 parent 200dfc5 commit fbb344d
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ end
@static if VERSION > v"1.9-DEV"
# TODO: only supports dims=: for now
function Base.stack(A::KeyedArray; dims::Colon=:)
data = @invoke Base.stack(A::AbstractArray; dims)
data = @invoke stack(A::AbstractArray; dims)
if !allequal(named_axiskeys(a) for a in A)
throw(DimensionMismatch("stack expects uniform axiskeys for all arrays"))
end
Expand Down
8 changes: 4 additions & 4 deletions src/stack.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

using LazyStack
import LazyStack

# for stack_iter
LazyStack.no_wraps(a::KeyedArray) = LazyStack.no_wraps(NamedDims.unname(parent(a)))
Expand All @@ -23,14 +23,14 @@ stack_keys(xs::Tuple{Vararg{<:KeyedArray}}) =

# array of arrays: first strip off outer containers...
function LazyStack.stack(xs::KeyedArray{<:AbstractArray})
KeyedArray(stack(parent(xs)), stack_keys(xs))
KeyedArray(LazyStack.stack(parent(xs)), stack_keys(xs))
end
function LazyStack.stack(xs::KeyedArray{<:AbstractArray,N,<:NamedDimsArray{L}}) where {L,N}
data = stack(parent(parent(xs)))
data = LazyStack.stack(parent(parent(xs)))
KeyedArray(LazyStack.ensure_named(data, LazyStack.getnames(xs)), stack_keys(xs))
end
function LazyStack.stack(xs::NamedDimsArray{L,<:AbstractArray,N,<:KeyedArray}) where {L,N}
data = stack(parent(parent(xs)))
data = LazyStack.stack(parent(parent(xs)))
LazyStack.ensure_named(KeyedArray(data, stack_keys(xs)), LazyStack.getnames(xs))
end

Expand Down
2 changes: 1 addition & 1 deletion test/_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ A3 = wrapdims(rand(Int8, 3,4,2), r='a':'c', c=2:5, p=[10.0, 20.0])
KeyedArray([5, 6], a=[:x, :y]),
], b=10:12)

sk = Base.stack(arr)::KeyedArray
sk = stack(arr)::KeyedArray
@test sk == [1 3 5; 2 4 6]
@test named_axiskeys(sk) == (a=[:x, :y], b=10:12)
end
Expand Down
20 changes: 10 additions & 10 deletions test/_packages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,25 +129,25 @@ end
end
end
@testset "stack" begin
using LazyStack
using LazyStack: stack as lstack

rin = [wrapdims(1:3, a='a':'c') for i=1:4]

@test axiskeys(stack(rin), :a) == 'a':'c'
@test axiskeys(stack(:b, rin...), :a) == 'a':'c' # tuple
@test axiskeys(stack(z for z in rin), :a) == 'a':'c' # generator
@test axiskeys(lstack(rin), :a) == 'a':'c'
@test axiskeys(lstack(:b, rin...), :a) == 'a':'c' # tuple
@test axiskeys(lstack(z for z in rin), :a) == 'a':'c' # generator

rout = wrapdims([[1,2], [3,4]], b=10:11)
@test axiskeys(stack(rout), :b) == 10:11
@test axiskeys(lstack(rout), :b) == 10:11

rboth = wrapdims(rin, b=10:13)
@test axiskeys(stack(rboth), :a) == 'a':'c'
@test axiskeys(stack(rboth), :b) == 10:13
@test axiskeys(lstack(rboth), :a) == 'a':'c'
@test axiskeys(lstack(rboth), :b) == 10:13

nts = [(i=i, j="j", k=33) for i=1:3]
@test axiskeys(stack(nts), 1) == [:i, :j, :k]
@test axiskeys(stack(:z, nts...), 1) == [:i, :j, :k]
@test axiskeys(stack(n for n in nts), 1) == [:i, :j, :k]
@test axiskeys(lstack(nts), 1) == [:i, :j, :k]
@test axiskeys(lstack(:z, nts...), 1) == [:i, :j, :k]
@test axiskeys(lstack(n for n in nts), 1) == [:i, :j, :k]

end
@testset "dates" begin
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Test, AxisKeys, NamedDims
using Statistics, OffsetArrays, Tables, UniqueVectors, LazyStack
using Statistics, OffsetArrays, Tables, UniqueVectors
using ChainRulesCore: ProjectTo, NoTangent
using ChainRulesTestUtils: test_rrule
using FiniteDifferences
Expand Down

0 comments on commit fbb344d

Please sign in to comment.