diff --git a/src/evaluation.jl b/src/evaluation.jl index 807adbf..1b908dd 100644 --- a/src/evaluation.jl +++ b/src/evaluation.jl @@ -4,7 +4,6 @@ @inline eval_pow(d::Minkowski, s) = abs(s)^d.p @inline eval_diff(::NonweightedMinowskiMetric, a, b, dim) = a - b -@inline eval_diff(::Chebyshev, ::Any, b, dim) = b @inline eval_diff(m::WeightedMinowskiMetric, a, b, dim) = m.weights[dim] * (a-b) function evaluate_maybe_end(d::Distances.UnionMetrics, a::AbstractVector, diff --git a/src/hyperrectangles.jl b/src/hyperrectangles.jl index 261fac1..2ea3b3c 100644 --- a/src/hyperrectangles.jl +++ b/src/hyperrectangles.jl @@ -28,3 +28,16 @@ get_max_distance_no_end(m, rec, point) = get_min_distance_no_end(m, rec, point) = get_min_max_distance_no_end(distance_function_min, m, rec, point) + +@inline function update_new_min(M::Metric, old_min, hyper_rec, p_dim, split_dim, split_val) + @inbounds begin + lo = hyper_rec.mins[split_dim] + hi = hyper_rec.maxes[split_dim] + end + ddiff = distance_function_min(p_dim, hi, lo) + split_diff = abs(p_dim - split_val) + split_diff_pow = eval_pow(M, split_diff) + ddiff_pow = eval_pow(M, ddiff) + diff_tot = eval_diff(M, split_diff_pow, ddiff_pow, split_dim) + return old_min + diff_tot +end diff --git a/src/kd_tree.jl b/src/kd_tree.jl index a6326ff..6147831 100644 --- a/src/kd_tree.jl +++ b/src/kd_tree.jl @@ -173,8 +173,6 @@ function knn_kernel!(tree::KDTree{V}, split_dim = tree.split_dims[index] p_dim = point[split_dim] split_val = tree.split_vals[index] - lo = hyper_rec.mins[split_dim] - hi = hyper_rec.maxes[split_dim] split_diff = p_dim - split_val M = tree.metric # Point is to the right of the split value @@ -183,21 +181,21 @@ function knn_kernel!(tree::KDTree{V}, far = getleft(index) hyper_rec_far = HyperRectangle(hyper_rec.mins, @inbounds setindex(hyper_rec.maxes, split_val, split_dim)) hyper_rec_close = HyperRectangle(@inbounds(setindex(hyper_rec.mins, split_val, split_dim)), hyper_rec.maxes) - ddiff = max(zero(eltype(V)), p_dim - hi) else close = getleft(index) far = getright(index) hyper_rec_far = HyperRectangle(@inbounds(setindex(hyper_rec.mins, split_val, split_dim)), hyper_rec.maxes) hyper_rec_close = HyperRectangle(hyper_rec.mins, @inbounds setindex(hyper_rec.maxes, split_val, split_dim)) - ddiff = max(zero(eltype(V)), lo - p_dim) end # Always call closer sub tree knn_kernel!(tree, close, point, best_idxs, best_dists, min_dist, hyper_rec_close, skip) - split_diff_pow = eval_pow(M, split_diff) - ddiff_pow = eval_pow(M, ddiff) - diff_tot = eval_diff(M, split_diff_pow, ddiff_pow, split_dim) - new_min = eval_reduce(M, min_dist, diff_tot) + if M isa Chebyshev + new_min = get_min_distance_no_end(M, hyper_rec_far, point) + else + new_min = update_new_min(M, min_dist, hyper_rec, p_dim, split_dim, split_val) + end + if new_min < best_dists[1] knn_kernel!(tree, far, point, best_idxs, best_dists, new_min, hyper_rec_far, skip) end @@ -233,8 +231,6 @@ function inrange_kernel!(tree::KDTree, split_val = tree.split_vals[index] split_dim = tree.split_dims[index] - lo = hyper_rec.mins[split_dim] - hi = hyper_rec.maxes[split_dim] p_dim = point[split_dim] split_diff = p_dim - split_val M = tree.metric @@ -246,13 +242,11 @@ function inrange_kernel!(tree::KDTree, far = getleft(index) hyper_rec_far = HyperRectangle(hyper_rec.mins, @inbounds setindex(hyper_rec.maxes, split_val, split_dim)) hyper_rec_close = HyperRectangle(@inbounds(setindex(hyper_rec.mins, split_val, split_dim)), hyper_rec.maxes) - ddiff = max(zero(p_dim - hi), p_dim - hi) else # Point is to the left of the split value close = getleft(index) far = getright(index) hyper_rec_far = HyperRectangle(@inbounds(setindex(hyper_rec.mins, split_val, split_dim)), hyper_rec.maxes) hyper_rec_close = HyperRectangle(hyper_rec.mins, @inbounds setindex(hyper_rec.maxes, split_val, split_dim)) - ddiff = max(zero(lo - p_dim), lo - p_dim) end # Call closer sub tree count += inrange_kernel!(tree, close, point, r, idx_in_ball, hyper_rec_close, min_dist) @@ -263,10 +257,12 @@ function inrange_kernel!(tree::KDTree, # It would be interesting to benchmark this on some different data sets. # Call further sub tree with the new min distance - split_diff_pow = eval_pow(M, split_diff) - ddiff_pow = eval_pow(M, ddiff) - diff_tot = eval_diff(M, split_diff_pow, ddiff_pow, split_dim) - new_min = eval_reduce(M, min_dist, diff_tot) + if M isa Chebyshev + new_min = get_min_distance_no_end(M, hyper_rec_far, point) + else + new_min = update_new_min(M, min_dist, hyper_rec, p_dim, split_dim, split_val) + end + count += inrange_kernel!(tree, far, point, r, idx_in_ball, hyper_rec_far, new_min) return count end diff --git a/test/runtests.jl b/test/runtests.jl index 584a271..29b2fae 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -80,4 +80,28 @@ using NearestNeighbors: HyperRectangle, get_min_distance_no_end, get_max_distanc @test get_min_distance_no_end(m, hr, point) ≈ NearestNeighbors.eval_pow(m, m(closest_point, point)) @test get_max_distance_no_end(m, hr, point) ≈ NearestNeighbors.eval_pow(m, m(furthest_point, point)) end + + for m in ms + hyper_rec = NearestNeighbors.HyperRectangle{SVector{1, Float32}}(Float32[0.5553872], Float32[0.6169486]) + point = [0.5] + min_dist = NearestNeighbors.get_min_distance_no_end(m, hyper_rec, point) + split_dim = 1 + split_val = 0.5844354f0 + hyper_rec_far = NearestNeighbors.HyperRectangle{SVector{1, Float32}}(Float32[0.5844354], Float32[0.6169486]) + new_min = NearestNeighbors.update_new_min(m, min_dist, hyper_rec, point[split_dim], split_dim, split_val) + new_min_true = NearestNeighbors.get_min_distance_no_end(m, hyper_rec_far, point) + @test new_min ≈ new_min_true + end + + for m in ms + hyper_rec = NearestNeighbors.HyperRectangle{SVector{2, Float64}}([0.07935189250034036, 0.682552911042077], [0.1619776648454222, 0.8046815005307764]) + point = [0.06630748183735935, 0.7541470744398973] + min_dist = NearestNeighbors.get_min_distance_no_end(m, hyper_rec, point) + split_dim = 2 + split_val = 0.7388396209627084 + hyper_rec_far = NearestNeighbors.HyperRectangle{SVector{2, Float64}}([0.07935189250034036, 0.682552911042077], [0.1619776648454222, 0.7388396209627084]) + new_min = NearestNeighbors.update_new_min(m, min_dist, hyper_rec, point[split_dim], split_dim, split_val) + new_min_true = NearestNeighbors.get_min_distance_no_end(m, hyper_rec_far, point) + @test new_min ≈ new_min_true broken = m isa Chebyshev + end end