Skip to content

Commit

Permalink
Make collect(::CategoricalArray) return a CategoricalArray (#252)
Browse files Browse the repository at this point in the history
This fixes an inconsistency, as currently `collect(::Type{<:CategoricalValue}, ::AbstractArray)` returns
a `CategoricalArray` but not `collect(::CategoricalArray)`. This is due to the fact that we define
`similar` methods for `Array` but not for ranges, which is what `collect` uses internally.
  • Loading branch information
nalimilan authored Apr 8, 2020
1 parent dcc24cf commit 8044adf
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
5 changes: 3 additions & 2 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,8 @@ similar(A::CategoricalArray{S, M, Q}, ::Type{Union{CategoricalValue{T}, Missing}
dims::NTuple{N, Int}) where {T, N, S, M, Q} =
CategoricalArray{Union{T, Missing}, N, Q}(undef, dims)

for A in (:Array, :Vector, :Matrix) # to fix ambiguities
# AbstractRange methods are needed since collect uses 1:1 as dummy array
for A in (:Array, :Vector, :Matrix, :AbstractRange)
@eval begin
similar(A::$A, ::Type{CategoricalValue{T, R}},
dims::NTuple{N, Int}=size(A)) where {T, R, N} =
Expand Down Expand Up @@ -780,7 +781,7 @@ function in(x::CategoricalValue, y::CategoricalArray{T, N, R}) where {T, N, R}
end

Array(A::CategoricalArray{T}) where {T} = Array{T}(A)
collect(A::CategoricalArray{T}) where {T} = Array{T}(A)
collect(A::CategoricalArray) = copy(A)

function float(A::CategoricalArray{T}) where T
if !isconcretetype(T)
Expand Down
24 changes: 17 additions & 7 deletions test/13_arraycommon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,14 @@ end
y = similar(Vector{T}, (3,))
@test isa(y, Vector{T})
@test size(y) == (3,)

y = similar(1:1, Union{CategoricalValue{String, UInt8}, T})
@test isa(y, CategoricalVector{Union{String, T}, UInt8})
@test size(y) == (1,)

y = similar(1:1, Union{CategoricalValue{String, UInt8}, T}, (3,))
@test isa(y, CategoricalVector{Union{String, T}, UInt8})
@test size(y) == (3,)
end

@testset "copy" begin
Expand Down Expand Up @@ -1145,18 +1153,20 @@ end
end
end

@testset "collect of CategoricalArray produces Array" begin
@testset "collect of CategoricalArray produces CategoricalArray" begin
x = [1,1,2,2]
y = categorical(x)
z = collect(y)
@test typeof(x) == typeof(z)
@test z == x
for z in (collect(y), collect(eltype(y), y), collect(Iterators.take(y, 4)))
@test typeof(y) == typeof(z)
@test z == y == x
end

x = [1,1,2,missing]
y = categorical(x)
z = collect(y)
@test typeof(x) == typeof(z)
@test z x
for z in (collect(y), collect(eltype(y), y), collect(Iterators.take(y, 4)))
@test typeof(y) == typeof(z)
@test z y x
end
end

@testset "Array(::CategoricalArray{T}) produces Array{T}" begin
Expand Down

0 comments on commit 8044adf

Please sign in to comment.