From cc2dd79e533e11a5e992fdac8ff66b7537beecb3 Mon Sep 17 00:00:00 2001 From: Shane Cheng Date: Wed, 5 Jul 2023 10:27:29 +1000 Subject: [PATCH] Made the result idxs and dists datatypes selectable --- src/ball_tree.jl | 10 +++++----- src/brute_tree.jl | 8 ++++---- src/kd_tree.jl | 8 ++++---- src/knn.jl | 10 +++++----- src/tree_ops.jl | 4 ++-- test/test_knn.jl | 7 ++++--- 6 files changed, 24 insertions(+), 23 deletions(-) diff --git a/src/ball_tree.jl b/src/ball_tree.jl index 5115f9e..8f931f3 100644 --- a/src/ball_tree.jl +++ b/src/ball_tree.jl @@ -150,9 +150,9 @@ end function _knn(tree::BallTree, point::AbstractVector, - best_idxs::AbstractVector{Int}, + best_idxs::AbstractVector{T}, best_dists::AbstractVector, - skip::F) where {F} + skip::F) where {F, T <: Integer} knn_kernel!(tree, 1, point, best_idxs, best_dists, skip) return end @@ -161,9 +161,9 @@ end function knn_kernel!(tree::BallTree{V}, index::Int, point::AbstractArray, - best_idxs::AbstractVector{Int}, + best_idxs::AbstractVector{T}, best_dists::AbstractVector, - skip::F) where {V, F} + skip::F) where {V, F, T <: Integer} if isleaf(tree.tree_data.n_internal_nodes, index) add_points_knn!(best_dists, best_idxs, tree, index, point, true, skip) return @@ -194,7 +194,7 @@ end function _inrange(tree::BallTree{V}, point::AbstractVector, radius::Number, - idx_in_ball::Union{Nothing, Vector{Int}}) where {V} + idx_in_ball::Union{Nothing, Vector{T}}) where {V, T <: Integer} ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball" return inrange_kernel!(tree, 1, point, ball, idx_in_ball) # Call the recursive range finder end diff --git a/src/brute_tree.jl b/src/brute_tree.jl index f37df0c..522e222 100644 --- a/src/brute_tree.jl +++ b/src/brute_tree.jl @@ -31,9 +31,9 @@ end function _knn(tree::BruteTree{V}, point::AbstractVector, - best_idxs::AbstractVector{Int}, + best_idxs::AbstractVector{T}, best_dists::AbstractVector, - skip::F) where {V, F} + skip::F) where {V, F, T <: Integer} knn_kernel!(tree, point, best_idxs, best_dists, skip) return @@ -41,9 +41,9 @@ end function knn_kernel!(tree::BruteTree{V}, point::AbstractVector, - best_idxs::AbstractVector{Int}, + best_idxs::AbstractVector{T}, best_dists::AbstractVector, - skip::F) where {V, F} + skip::F) where {V, F, T <: Integer} for i in 1:length(tree.data) if skip(i) continue diff --git a/src/kd_tree.jl b/src/kd_tree.jl index faa9257..9964866 100644 --- a/src/kd_tree.jl +++ b/src/kd_tree.jl @@ -158,9 +158,9 @@ end function _knn(tree::KDTree, point::AbstractVector, - best_idxs::AbstractVector{Int}, + best_idxs::AbstractVector{T}, best_dists::AbstractVector, - skip::F) where {F} + skip::F) where {F, T <: Integer} init_min = get_min_distance(tree.hyper_rec, point) knn_kernel!(tree, 1, point, best_idxs, best_dists, init_min, skip) @simd for i in eachindex(best_dists) @@ -171,10 +171,10 @@ end function knn_kernel!(tree::KDTree{V}, index::Int, point::AbstractVector, - best_idxs::AbstractVector{Int}, + best_idxs::AbstractVector{T}, best_dists::AbstractVector, min_dist, - skip::F) where {V, F} + skip::F) where {V, F, T <: Integer} # At a leaf node. Go through all points in node and add those in range if isleaf(tree.tree_data.n_internal_nodes, index) add_points_knn!(best_dists, best_idxs, tree, index, point, false, skip) diff --git a/src/knn.jl b/src/knn.jl index b76bf75..021f085 100644 --- a/src/knn.jl +++ b/src/knn.jl @@ -14,12 +14,12 @@ in the order of increasing distance to the point. `skip` is an optional predicat to determine if a point that would be returned should be skipped based on its index. """ -function knn(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::F=always_false) where {V, T <: AbstractVector, F<:Function} +function knn(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::F=always_false; idxs_type::DataType = Int, dists_type = get_T(eltype(V))) where {V, T <: AbstractVector, F<:Function} check_input(tree, points) check_k(tree, k) n_points = length(points) - dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points] - idxs = [Vector{Int}(undef, k) for _ in 1:n_points] + dists = [Vector{dists_type}(undef, k) for _ in 1:n_points] + idxs = [Vector{idxs_type}(undef, k) for _ in 1:n_points] for i in 1:n_points knn_point!(tree, points[i], sortres, dists[i], idxs[i], skip) end @@ -27,11 +27,11 @@ function knn(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::F= end function knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist, idx, skip::F) where {V, T <: Number, F} - fill!(idx, -1) + fill!(idx, 0) fill!(dist, typemax(get_T(eltype(V)))) _knn(tree, point, idx, dist, skip) if skip !== always_false - skipped_idxs = findall(==(-1), idx) + skipped_idxs = findall(==(0), idx) deleteat!(idx, skipped_idxs) deleteat!(dist, skipped_idxs) end diff --git a/src/tree_ops.jl b/src/tree_ops.jl index 8835841..1c3aba7 100644 --- a/src/tree_ops.jl +++ b/src/tree_ops.jl @@ -90,9 +90,9 @@ end # Checks the distance function and add those points that are among the k best. # Uses a heap for fast insertion. -@inline function add_points_knn!(best_dists::AbstractVector, best_idxs::AbstractVector{Int}, +@inline function add_points_knn!(best_dists::AbstractVector, best_idxs::AbstractVector{T}, tree::NNTree, index::Int, point::AbstractVector, - do_end::Bool, skip::F) where {F} + do_end::Bool, skip::F) where {F, T <: Integer} for z in get_leaf_range(tree.tree_data, index) idx = tree.reordered ? z : tree.indices[z] dist_d = evaluate(tree.metric, tree.data[idx], point, do_end) diff --git a/test/test_knn.jl b/test/test_knn.jl index 26fed29..77586b3 100644 --- a/test/test_knn.jl +++ b/test/test_knn.jl @@ -31,9 +31,9 @@ import Distances.evaluate @test idxs[1] == 8 @test idxs[2] == 3 - idxs, dists = knn(tree, [SVector{2, Float64}(0.8,0.8), SVector{2, Float64}(0.1,0.8)], 1, true) - @test idxs[1][1] == 8 - @test idxs[2][1] == 3 + idxs, dists = knn(tree, [SVector{2, Float64}(0.8,0.8), SVector{2, Float64}(0.1,0.8)], 1, true; idxs_type = UInt32, dists_type = Float16) + @test typeof(idxs[1][1]) == UInt32 + @test typeof(dists[2][1]) == Float16 idxs, dists = nn(tree, [SVector{2, Float64}(0.8,0.8), SVector{2, Float64}(0.1,0.8)]) @test idxs[1] == 8 @@ -91,3 +91,4 @@ end @test nearest == [1, 3] @test distance ≈ [0.02239688629947563, 0.13440059522389006] end +