Skip to content

Commit

Permalink
add DataAPI.unwrap (#328)
Browse files Browse the repository at this point in the history
  • Loading branch information
bkamins authored Feb 24, 2021
1 parent 876d0db commit 7df18c8
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 68 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "CategoricalArrays"
uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597"
version = "0.9.2"
version = "0.9.3"

[deps]
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
Expand All @@ -13,7 +13,7 @@ StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[compat]
DataAPI = "1.5"
DataAPI = "1.6"
JSON = "0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21"
JSON3 = "1.1.2"
Missings = "0.4.3"
Expand Down
77 changes: 40 additions & 37 deletions docs/src/using.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ julia> using CategoricalArrays
julia> x = CategoricalArray(["Old", "Young", "Middle", "Young"], ordered=true)
4-element CategoricalArray{String,1,UInt32}:
"Old"
"Young"
"Old"
"Young"
"Middle"
"Young"
"Young"
```

Expand All @@ -22,15 +22,15 @@ By default, the levels are lexically sorted, which is clearly not correct in our
julia> levels(x)
3-element Array{String,1}:
"Middle"
"Old"
"Young"
"Old"
"Young"
julia> levels!(x, ["Young", "Middle", "Old"])
4-element CategoricalArray{String,1,UInt32}:
"Old"
"Young"
"Old"
"Young"
"Middle"
"Young"
"Young"
```

Expand Down Expand Up @@ -69,20 +69,20 @@ To get rid of the `"Old"` group, just call the [`droplevels!`](@ref) function:
```jldoctest using
julia> levels(x)
3-element Array{String,1}:
"Young"
"Young"
"Middle"
"Old"
"Old"
julia> droplevels!(x)
4-element CategoricalArray{String,1,UInt32}:
"Young"
"Young"
"Young"
"Young"
"Middle"
"Young"
"Young"
julia> levels(x)
2-element Array{String,1}:
"Young"
"Young"
"Middle"
```
Expand Down Expand Up @@ -115,10 +115,10 @@ Let's adapt the example developed above to support missing values. Since there a
```jldoctest using
julia> y = CategoricalArray{Union{Missing, String}}(["Old", "Young", "Middle", "Young"], ordered=true)
4-element CategoricalArray{Union{Missing, String},1,UInt32}:
"Old"
"Young"
"Old"
"Young"
"Middle"
"Young"
"Young"
```

Expand All @@ -128,15 +128,15 @@ Levels still need to be reordered manually:
julia> levels(y)
3-element Array{String,1}:
"Middle"
"Old"
"Young"
"Old"
"Young"
julia> levels!(y, ["Young", "Middle", "Old"])
4-element CategoricalArray{Union{Missing, String},1,UInt32}:
"Old"
"Young"
"Old"
"Young"
"Middle"
"Young"
"Young"
```

Expand All @@ -156,9 +156,9 @@ missing
julia> y
4-element CategoricalArray{Union{Missing, String},1,UInt32}:
missing
"Young"
"Young"
"Middle"
"Young"
"Young"
julia> y[1]
missing
Expand All @@ -173,17 +173,17 @@ julia> y[1] = "Old"
julia> y
4-element CategoricalArray{Union{Missing, String},1,UInt32}:
"Old"
"Young"
"Old"
"Young"
"Middle"
"Young"
"Young"
julia> levels!(y, ["Young", "Middle"]; allowmissing=true)
4-element CategoricalArray{Union{Missing, String},1,UInt32}:
missing
"Young"
"Young"
"Middle"
"Young"
"Young"
```

Expand All @@ -205,17 +205,17 @@ If we concatenate the two sets, the levels of the resulting categorical vector a
julia> xy = vcat(x, y)
6-element CategoricalArray{String,1,UInt32}:
"Middle"
"Old"
"Old"
"Middle"
"Young"
"Young"
"Middle"
"Middle"
julia> levels(xy)
3-element Array{String,1}:
"Young"
"Young"
"Middle"
"Old"
"Old"
julia> isordered(xy)
true
Expand All @@ -233,7 +233,7 @@ julia> ordered!(x, false);
julia> levels(x)
2-element Array{String,1}:
"Middle"
"Old"
"Old"
julia> x[1] = y[1]
CategoricalValue{String,UInt32} "Young" (1/2)
Expand All @@ -243,9 +243,9 @@ CategoricalValue{String,UInt32} "Young"
julia> levels(x)
3-element Array{String,1}:
"Young"
"Young"
"Middle"
"Old"
"Old"
```

In cases where levels with incompatible orderings are combined, the ordering of the first array wins and the resulting array is marked as unordered:
Expand Down Expand Up @@ -317,4 +317,7 @@ Do note that in some cases the two sets of levels may have compatible orderings,

`recode!(a[, default], pairs...)` - Replace one or more values in `a` in-place

See [API Index](@ref) for more details.
`unwrap(x)` - Return a value contained in categorical value `x`; if `x` is `Missing`
return `missing`;

See [API Index](@ref) for more details.
5 changes: 4 additions & 1 deletion src/CategoricalArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ module CategoricalArrays
isordered, ordered!
export cut, recode, recode!

import DataAPI: unwrap
export unwrap

using JSON
using DataAPI
using Missings
using Printf
import StructTypes

# JuliaLang/julia#36810
# JuliaLang/julia#36810
if VERSION < v"1.5.2"
Base.OrderStyle(::Type{Union{}}) = Base.Ordered()
end
Expand Down
3 changes: 2 additions & 1 deletion src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,8 @@ function merge_pools!(A::CatArrOrSub,
updaterefs::Bool=true,
updatepool::Bool=true)
if isordered(A) && length(pool(A)) > 0 && pool(B) pool(A)
lev = A isa CategoricalValue ? get(B) : first(setdiff(levels(B), levels(A)))
# TODO: extend OrderedLevelsException to take all values in setdiff(levels(B), levels(A))
lev = first(setdiff(levels(B), levels(A)))
throw(OrderedLevelsException(lev, levels(A)))
end
newlevels, ordered = merge_pools(pool(A), pool(B))
Expand Down
4 changes: 4 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ function categorical(A::AbstractArray, compress::Bool; kwargs...)
throw(ErrorException("categorical(A::AbstractArray, compress, kwargs...) is deprecated: " *
"use categorical(A, compress=compress, kwargs...) instead."))
end

import Base: get

@deprecate get(x::CategoricalValue) DataAPI.unwrap(x)
2 changes: 1 addition & 1 deletion src/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ end
end
levels!(pool, newlevs)
end
get!(pool, get(level))
get!(pool, unwrap(level))
end

@inline function Base.push!(pool::CategoricalPool, level)
Expand Down
48 changes: 27 additions & 21 deletions src/value.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@ unwrap_catvaluetype(::Type{Union{}}) = Union{} # prevent incorrect dispatch to T
unwrap_catvaluetype(::Type{Any}) = Any # prevent recursion in T>:Missing method
unwrap_catvaluetype(::Type{T}) where {T <: CategoricalValue} = leveltype(T)

Base.get(x::CategoricalValue) = levels(x)[level(x)]
"""
unwrap(x::CategoricalValue)
unwrap(x::Missing)
Get the value wrapped by categorical value `x`. If `x` is `Missing` return `missing`.
"""
DataAPI.unwrap(x::CategoricalValue) = levels(x)[level(x)]

"""
levelcode(x::CategoricalValue)
Expand Down Expand Up @@ -71,59 +77,59 @@ Base.promote_rule(::Type{C1}, ::Type{C2}) where {C1<:CategoricalValue, C2<:Categ

# General fallbacks
Base.convert(::Type{S}, x::CategoricalValue) where {S <: SupportedTypes} =
convert(S, get(x))
convert(S, unwrap(x))
Base.convert(::Type{Union{S, Missing}}, x::CategoricalValue) where {S <: SupportedTypes} =
convert(Union{S, Missing}, get(x))
convert(Union{S, Missing}, unwrap(x))
Base.convert(::Type{Union{S, Nothing}}, x::CategoricalValue) where {S <: SupportedTypes} =
convert(Union{S, Nothing}, get(x))
convert(Union{S, Nothing}, unwrap(x))

(::Type{T})(x::T) where {T <: CategoricalValue} = x

Base.Broadcast.broadcastable(x::CategoricalValue) = Ref(x)

function Base.show(io::IO, x::CategoricalValue)
if nonmissingtype(get(io, :typeinfo, Any)) === nonmissingtype(typeof(x))
show(io, get(x))
show(io, unwrap(x))
else
print(io, typeof(x))
print(io, ' ')
show(io, get(x))
show(io, unwrap(x))
if isordered(pool(x))
@printf(io, " (%i/%i)", levelcode(x), length(pool(x)))
end
end
end

Base.print(io::IO, x::CategoricalValue) = print(io, get(x))
Base.string(x::CategoricalValue) = string(get(x))
Base.write(io::IO, x::CategoricalValue) = write(io, get(x))
Base.String(x::CategoricalValue{<:AbstractString}) = String(get(x))
Base.print(io::IO, x::CategoricalValue) = print(io, unwrap(x))
Base.string(x::CategoricalValue) = string(unwrap(x))
Base.write(io::IO, x::CategoricalValue) = write(io, unwrap(x))
Base.String(x::CategoricalValue{<:AbstractString}) = String(unwrap(x))

@inline function Base.:(==)(x::CategoricalValue, y::CategoricalValue)
if pool(x) === pool(y)
return level(x) == level(y)
else
return get(x) == get(y)
return unwrap(x) == unwrap(y)
end
end

Base.:(==)(x::CategoricalValue, y::SupportedTypes) = get(x) == y
Base.:(==)(x::SupportedTypes, y::CategoricalValue) = x == get(y)
Base.:(==)(x::CategoricalValue, y::SupportedTypes) = unwrap(x) == y
Base.:(==)(x::SupportedTypes, y::CategoricalValue) = x == unwrap(y)

@inline function Base.isequal(x::CategoricalValue, y::CategoricalValue)
if pool(x) === pool(y)
return level(x) == level(y)
else
return isequal(get(x), get(y))
return isequal(unwrap(x), unwrap(y))
end
end

Base.isequal(x::CategoricalValue, y::SupportedTypes) = isequal(get(x), y)
Base.isequal(x::SupportedTypes, y::CategoricalValue) = isequal(x, get(y))
Base.isequal(x::CategoricalValue, y::SupportedTypes) = isequal(unwrap(x), y)
Base.isequal(x::SupportedTypes, y::CategoricalValue) = isequal(x, unwrap(y))

Base.in(x::CategoricalValue, y::AbstractRange{T}) where {T<:Integer} = get(x) in y
Base.in(x::CategoricalValue, y::AbstractRange{T}) where {T<:Integer} = unwrap(x) in y

Base.hash(x::CategoricalValue, h::UInt) = hash(get(x), h)
Base.hash(x::CategoricalValue, h::UInt) = hash(unwrap(x), h)

# Method defined even on unordered values so that sort() works
function Base.isless(x::CategoricalValue, y::CategoricalValue)
Expand Down Expand Up @@ -164,14 +170,14 @@ function Base.:<(y::SupportedTypes, x::CategoricalValue)
end

# JSON of CategoricalValue is JSON of the value it refers to
JSON.lower(x::CategoricalValue) = JSON.lower(get(x))
JSON.lower(x::CategoricalValue) = JSON.lower(unwrap(x))
DataAPI.defaultarray(::Type{CategoricalValue{T, R}}, N) where {T, R} =
CategoricalArray{T, N, R}
DataAPI.defaultarray(::Type{Union{CategoricalValue{T, R}, Missing}}, N) where {T, R} =
CategoricalArray{Union{T, Missing}, N, R}

# define appropriate handlers for JSON3 interface
StructTypes.StructType(x::CategoricalValue) = StructTypes.StructType(get(x))
StructTypes.StructType(x::CategoricalValue) = StructTypes.StructType(unwrap(x))
StructTypes.StructType(::Type{<:CategoricalValue{T}}) where {T} = StructTypes.StructType(T)
StructTypes.numbertype(::Type{<:CategoricalValue{T}}) where {T <: Number} = T
StructTypes.construct(::Type{T}, x::CategoricalValue{T}) where {T} = T(get(x))
StructTypes.construct(::Type{T}, x::CategoricalValue{T}) where {T} = T(unwrap(x))
6 changes: 3 additions & 3 deletions test/05_convert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ using CategoricalArrays: DefaultRefType, level, reftype, leveltype
@test convert(Union{T, U}, v3) === v3
end

@test get(v1) === 1
@test get(v2) === 2
@test get(v3) === 3
@test unwrap(v1) === get(v1) === 1
@test unwrap(v2) === get(v2) === 2
@test unwrap(v3) === get(v3) === 3

@test promote(1, v1) === (1, 1)
@test promote(1.0, v1) === (1.0, 1.0)
Expand Down
4 changes: 2 additions & 2 deletions test/06_show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ using JSON
CategoricalPool([1]),
CategoricalPool([1.0]))
v = CategoricalValue(1, pool)
@test JSON.lower(v) == JSON.lower(get(v))
@test typeof(JSON.lower(v)) == typeof(JSON.lower(get(v)))
@test JSON.lower(v) == JSON.lower(unwrap(v))
@test typeof(JSON.lower(v)) == typeof(JSON.lower(unwrap(v)))
end

using JSON3
Expand Down
5 changes: 5 additions & 0 deletions test/13_arraycommon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2182,4 +2182,9 @@ end
end
end

@testset "unwrap" begin
x = categorical(["a", missing, "b", missing])
@test unwrap.(x) ["a", missing, "b", missing]
end

end

2 comments on commit 7df18c8

@nalimilan
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/30738

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.9.3 -m "<description of version>" 7df18c8c0d652a347f7dacf4ca3bdd4cd26efe34
git push origin v0.9.3

Please sign in to comment.