diff --git a/src/HashIndex.jl b/src/HashIndex.jl index ac9bfe3..0492ffd 100644 --- a/src/HashIndex.jl +++ b/src/HashIndex.jl @@ -32,6 +32,26 @@ function Base.count(f::Fix2{typeof(isequal)}, a::AcceleratedArray{<:Any, <:Any, end end +function Base.count(f::Fix2{typeof(==)}, a::AcceleratedArray{<:Any, <:Any, <:Any, <:HashIndex}) + out = 0 + + for x in other_equal(f.x) # Search for other things that might be == but not isequal (like -0.0) + (hasindex, token) = gettoken(a.index.dict, x) + @inbounds if hasindex + out += length(@inbounds gettokenvalue(a.index.dict, token)) + end + end + + if f(f.x) # Exclude things that are not == to themselves (like NaN) + (hasindex, token) = gettoken(a.index.dict, f.x) + if hasindex + return out + length(@inbounds gettokenvalue(a.index.dict, token)) + end + end + + return out +end + function Base.findall(f::Fix2{typeof(isequal)}, a::AcceleratedArray{<:Any, <:Any, <:Any, <:HashIndex}) (hasindex, token) = gettoken(a.index.dict, f.x) if hasindex @@ -41,6 +61,33 @@ function Base.findall(f::Fix2{typeof(isequal)}, a::AcceleratedArray{<:Any, <:Any end end +function Base.findall(f::Fix2{typeof(==)}, a::AcceleratedArray{<:Any, <:Any, <:Any, <:HashIndex}) + out = Vector{keytype(a)}() + n_matches = 0 + + for x in other_equal(f.x) # Search for other things that might be == but not isequal (like -0.0) + (hasindex, token) = gettoken(a.index.dict, x) + @inbounds if hasindex + append!(out, @inbounds gettokenvalue(a.index.dict, token)) + n_matches += 1 + end + end + + if f(f.x) # Exclude things that are not == to themselves (like NaN) + (hasindex, token) = gettoken(a.index.dict, f.x) + if hasindex + append!(out, @inbounds gettokenvalue(a.index.dict, token)) + n_matches += 1 + end + end + + if n_matches > 1 + sort!(out) + end + + return out +end + # TODO: findall for arbitrary predicates by just checking each unique key? (Sometimes faster, sometimes slower?) function Base.findfirst(f::Fix2{typeof(isequal)}, a::AcceleratedArray{<:Any, <:Any, <:Any, <:HashIndex}) diff --git a/src/predicates.jl b/src/predicates.jl index e679898..1cc3cf2 100644 --- a/src/predicates.jl +++ b/src/predicates.jl @@ -58,3 +58,16 @@ if VERSION < v"1.2.0-DEV.257" Base.:(>=)(x) = Fix2(>=, x) Base.:(!=)(x) = Fix2(!=, x) end + +# Search for other things that are == but not isequal (for example -0.0) +other_equal(::Any) = () +other_equal(x::AbstractFloat) = isequal(x, 0.0) ? MaybeVector(convert(typeof(x), -0.0)) : isequal(x, -0.0) ? MaybeVector(convert(typeof(x), 0.0)) : MaybeVector{typeof(x)}() + +# ismissing +Base.count(f::typeof(ismissing), a::AcceleratedArray) = count(isequal(missing), a) +Base.findall(f::typeof(ismissing), a::AcceleratedArray) = findall(isequal(missing), a) +Base.findfirst(f::typeof(ismissing), a::AcceleratedArray) = findfirst(isequal(missing), a) +Base.findlast(f::typeof(ismissing), a::AcceleratedArray) = findlast(isequal(missing), a) +Base.filter(f::typeof(ismissing), a::AcceleratedArray) = filter(isequal(missing), a) + +# TODO isnan, iszero, isone, isfinite (all are slightly strange when considering things like Complex, Matrix and String) \ No newline at end of file