Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelizing BallTree Construction #132

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.0'
- '1.3'
- '1'
- 'nightly'
os:
Expand Down
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
name = "NearestNeighbors"
uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
version = "0.4.9"
version = "0.5.0"

[deps]
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Distances = "0.9, 0.10"
StaticArrays = "0.9, 0.10, 0.11, 0.12, 1.0"
julia = "1.0"
julia = "1.3"

[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
1 change: 1 addition & 0 deletions src/NearestNeighbors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import Distances: Metric, result_type, eval_reduce, eval_end, eval_op, eval_star

using StaticArrays
import Base.show
using Base.Threads

export NNTree, BruteTree, KDTree, BallTree, DataFreeTree
export knn, nn, inrange # TODOs? , allpairs, distmat, npairs
Expand Down
100 changes: 77 additions & 23 deletions src/ball_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,8 @@ struct BallTree{V <: AbstractVector,N,T,M <: Metric} <: NNTree{V,M}
reordered::Bool # If the data has been reordered
end

# When we create the bounding spheres we need some temporary arrays.
# We create a type to hold them to not allocate these arrays at every
# function call and to reduce the number of parameters in the tree builder.
struct ArrayBuffers{N,T <: AbstractFloat}
center::MVector{N,T}
end

function ArrayBuffers(::Type{Val{N}}, ::Type{T}) where {N, T}
ArrayBuffers(zeros(MVector{N,T}))
end
# minimum number of data points above which parallelization is triggered by default
const DEFAULT_BALLTREE_MIN_PARALLEL_SIZE = 1024

"""
BallTree(data [, metric = Euclidean(), leafsize = 10]) -> balltree
Expand All @@ -33,14 +25,15 @@ function BallTree(data::AbstractVector{V},
leafsize::Int = 10,
reorder::Bool = true,
storedata::Bool = true,
parallel::Bool = true,
parallel_size::Int = DEFAULT_BALLTREE_MIN_PARALLEL_SIZE,
reorderbuffer::Vector{V} = Vector{V}()) where {V <: AbstractArray, M <: Metric}
reorder = !isempty(reorderbuffer) || (storedata ? reorder : false)

tree_data = TreeData(data, leafsize)
n_d = length(V)
n_p = length(data)

array_buffs = ArrayBuffers(Val{length(V)}, get_T(eltype(V)))
indices = collect(1:n_p)

# Bottom up creation of hyper spheres so need spheres even for leafs)
Expand Down Expand Up @@ -70,7 +63,8 @@ function BallTree(data::AbstractVector{V},
if n_p > 0
# Call the recursive BallTree builder
build_BallTree(1, data, data_reordered, hyper_spheres, metric, indices, indices_reordered,
1, length(data), tree_data, array_buffs, reorder)
1, length(data), tree_data, reorder, Val(parallel), parallel_size)

end

if reorder
Expand All @@ -86,6 +80,8 @@ function BallTree(data::AbstractVecOrMat{T},
leafsize::Int = 10,
storedata::Bool = true,
reorder::Bool = true,
parallel::Bool = true,
parallel_size::Int = DEFAULT_BALLTREE_MIN_PARALLEL_SIZE,
reorderbuffer::Matrix{T} = Matrix{T}(undef, 0, 0)) where {T <: AbstractFloat, M <: Metric}
dim = size(data, 1)
npoints = size(data, 2)
Expand All @@ -96,7 +92,7 @@ function BallTree(data::AbstractVecOrMat{T},
reorderbuffer_points = copy_svec(T, reorderbuffer, Val(dim))
end
BallTree(points, metric, leafsize = leafsize, storedata = storedata, reorder = reorder,
reorderbuffer = reorderbuffer_points)
parallel = parallel, parallel_size = parallel_size, reorderbuffer = reorderbuffer_points)
end

# Recursive function to build the tree.
Expand All @@ -110,16 +106,17 @@ function build_BallTree(index::Int,
low::Int,
high::Int,
tree_data::TreeData,
array_buffs::ArrayBuffers{N,T},
reorder::Bool) where {V <: AbstractVector, N, T}
reorder::Bool,
parallel::Val{false},
parallel_size::Int = 0) where {V <: AbstractVector, N, T}

n_points = high - low + 1 # Points left
if n_points <= tree_data.leafsize
if reorder
reorder_data!(data_reordered, data, index, indices, indices_reordered, tree_data)
end
# Create bounding sphere of points in leaf node by brute force
hyper_spheres[index] = create_bsphere(data, metric, indices, low, high, array_buffs)
hyper_spheres[index] = create_bsphere(data, metric, indices, low, high)
return
end

Expand All @@ -132,22 +129,79 @@ function build_BallTree(index::Int,

# Sort the data at the mid_idx boundary using the split_dim
# to compare
select_spec!(indices, mid_idx, low, high, data, split_dim)
select_spec!(indices, mid_idx, low, high, data, split_dim) # culprit? technically, low and high should be disjoint for different threads

build_BallTree(getleft(index), data, data_reordered, hyper_spheres, metric,
indices, indices_reordered, low, mid_idx - 1,
tree_data, array_buffs, reorder)
indices, indices_reordered, low, mid_idx - 1,
tree_data, reorder, parallel)

build_BallTree(getright(index), data, data_reordered, hyper_spheres, metric,
indices, indices_reordered, mid_idx, high,
tree_data, array_buffs, reorder)
indices, indices_reordered, mid_idx, high,
tree_data, reorder, parallel)

# Finally create bounding hyper sphere from the two children's hyper spheres
hyper_spheres[index] = create_bsphere(metric, hyper_spheres[getleft(index)],
hyper_spheres[getright(index)])
return
end

# Parallelized recursive function to build the tree.
function build_BallTree(index::Int,
data::Vector{V},
data_reordered::Vector{V},
hyper_spheres::Vector{HyperSphere{N,T}},
metric::Metric,
indices::Vector{Int},
indices_reordered::Vector{Int},
low::Int,
high::Int,
tree_data::TreeData,
reorder::Bool,
parallel::Val{true},
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a Val and use a separate function like this feels a bit awkward. Couldn't one just look at parallel_size in the original build_BallTree function and then decide whether to call the parallel function or the serial one?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using type dispatch on the parallel variable is important, because the compiler is able to get rid of temporary allocations during sequential execution. I can isolate the recursive component of the function though, and only use the Val(true) dispatch for that. If we only use a regular if statement on a Bool, performance during sequential execution will take a hit compared to the status quo.

parallel_size::Int) where {V <: AbstractVector, N, T}

n_points = high - low + 1 # Points left
if n_points <= tree_data.leafsize
if reorder
reorder_data!(data_reordered, data, index, indices, indices_reordered, tree_data)
end
# Create bounding sphere of points in leaf node by brute force
hyper_spheres[index] = create_bsphere(data, metric, indices, low, high)
return
end

# Find split such that one of the sub trees has 2^p points
# and the left sub tree has more points
mid_idx = find_split(low, tree_data.leafsize, n_points)

# Brute force to find the dimension with the largest spread
split_dim = find_largest_spread(data, indices, low, high)

# Sort the data at the mid_idx boundary using the split_dim
# to compare
select_spec!(indices, mid_idx, low, high, data, split_dim) # culprit? technically, low and high should be disjoint for different threads

@sync begin
left_n_points = mid_idx - low
left_parallel = Val(left_n_points > parallel_size)
@spawn build_BallTree(getleft(index), data, data_reordered, hyper_spheres, metric,
indices, indices_reordered, low, mid_idx - 1,
tree_data, reorder, left_parallel, parallel_size)

right_n_points = high - mid_idx + 1
right_parallel = Val(right_n_points > parallel_size)
@spawn build_BallTree(getright(index), data, data_reordered, hyper_spheres, metric,
indices, indices_reordered, mid_idx, high,
tree_data, reorder, right_parallel, parallel_size)
end

# Finally create bounding hyper sphere from the two children's hyper spheres
hyper_spheres[index] = create_bsphere(metric, hyper_spheres[getleft(index)],
hyper_spheres[getright(index)],
array_buffs)
hyper_spheres[getright(index)])
return
end


function _knn(tree::BallTree,
point::AbstractVector,
best_idxs::AbstractVector{Int},
Expand Down
64 changes: 22 additions & 42 deletions src/hyperspheres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ end

HyperSphere(center::SVector{N,T1}, r::T2) where {N, T1, T2} = HyperSphere(center, convert(T1, r))

Base.:(==)(A::HyperSphere, B::HyperSphere) = A.center == B.center && A.r == B.r

@inline function intersects(m::M,
s1::HyperSphere{N,T},
s2::HyperSphere{N,T}) where {T <: AbstractFloat, N, M <: Metric}
Expand All @@ -19,55 +21,22 @@ end
evaluate(m, s1.center, s2.center) + s1.r <= s2.r
end

@inline function interpolate(::M,
c1::V,
c2::V,
x,
d,
ab) where {V <: AbstractVector, M <: NormMetric}
alpha = x / d
@assert length(c1) == length(c2)
@inbounds for i in eachindex(ab.center)
ab.center[i] = (1 - alpha) .* c1[i] + alpha .* c2[i]
end
return ab.center, true
end

@inline function interpolate(::M,
c1::V,
::V,
::Any,
::Any,
::Any) where {V <: AbstractVector, M <: Metric}
return c1, false
end

function create_bsphere(data::AbstractVector{V}, metric::Metric, indices::Vector{Int}, low, high, ab) where {V}
n_dim = size(data, 1)
n_points = high - low + 1
# First find center of all points
fill!(ab.center, 0.0)
for i in low:high
for j in 1:length(ab.center)
ab.center[j] += data[indices[i]][j]
end
end
ab.center .*= 1 / n_points

# versions with no array buffer - still not allocating in sequential BallTree construction
using Statistics: mean
function create_bsphere(data::AbstractVector{V}, metric::Metric, indices::Vector{Int}, low, high) where {V}
# find center
center = mean(@views(data[indices[low:high]]))
# Then find r
r = zero(get_T(eltype(V)))
for i in low:high
r = max(r, evaluate(metric, data[indices[i]], ab.center))
r = max(r, evaluate(metric, data[indices[i]], center))
end
r += eps(get_T(eltype(V)))
return HyperSphere(SVector{length(V),eltype(V)}(ab.center), r)
return HyperSphere(SVector{length(V),eltype(V)}(center), r)
end

# Creates a bounding sphere from two other spheres
function create_bsphere(m::Metric,
s1::HyperSphere{N,T},
s2::HyperSphere{N,T},
ab) where {N, T <: AbstractFloat}
function create_bsphere(m::Metric, s1::HyperSphere{N,T}, s2::HyperSphere{N,T}) where {N, T <: AbstractFloat}
if encloses(m, s1, s2)
return HyperSphere(s2.center, s2.r)
elseif encloses(m, s2, s1)
Expand All @@ -79,7 +48,7 @@ function create_bsphere(m::Metric,
# neither s1 nor s2 contains the other)
dist = evaluate(m, s1.center, s2.center)
x = 0.5 * (s2.r - s1.r + dist)
center, is_exact_center = interpolate(m, s1.center, s2.center, x, dist, ab)
center, is_exact_center = interpolate(m, s1.center, s2.center, x, dist)
if is_exact_center
rad = 0.5 * (s2.r + s1.r + dist)
else
Expand All @@ -88,3 +57,14 @@ function create_bsphere(m::Metric,

return HyperSphere(SVector{N,T}(center), rad)
end

@inline function interpolate(::M, c1::V, c2::V, x, d) where {V <: AbstractVector, M <: NormMetric}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why move this function?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had two versions locally, the previous one, and this one without the array buffer variable ab. It turns out that in the sequential code, the compiler is able to get rid of the allocations without explicitly pre-allocating an ArrayBuffer variable. In the parallel code, having an array buffer leads to race conditions, which is why I wrote this modification.

I can move it back to where it was in the file.

length(c1) == length(c2) || throw(DimensionMismatch("interpolate arguments have length $(length(c1)) and $(length(c2))"))
alpha = x / d
center = (1 - alpha) * c1 + alpha * c2
return center, true
end

@inline function interpolate(::M, c1::V, ::V, ::Any, ::Any) where {V <: AbstractVector, M <: Metric}
return c1, false
end
2 changes: 1 addition & 1 deletion src/inrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function inrange(tree::NNTree,

idxs = [Vector{Int}() for _ in 1:length(points)]

for i in 1:length(points)
@threads for i in 1:length(points)
inrange_point!(tree, points[i], radius, sortres, idxs[i])
end
return idxs
Expand Down
8 changes: 4 additions & 4 deletions src/knn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ end
Performs a lookup of the `k` nearest neigbours to the `points` from the data
in the `tree`. If `sortres = true` the result is sorted such that the results are
in the order of increasing distance to the point. `skip` is an optional predicate
to determine if a point that would be returned should be skipped based on its
to determine if a point that would be returned should be skipped based on its
index.
"""
function knn(tree::NNTree{V}, points::Vector{T}, k::Int, sortres=false, skip::F=always_false) where {V, T <: AbstractVector, F<:Function}
check_input(tree, points)
check_k(tree, k)
n_points = length(points)
dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points]
idxs = [Vector{Int}(undef, k) for _ in 1:n_points]
for i in 1:n_points
dists = [Vector{get_T(eltype(V))}(undef, k) for _ in 1:n_points]
idxs = [Vector{Int}(undef, k) for _ in 1:n_points]
@threads for i in 1:n_points
knn_point!(tree, points[i], sortres, dists[i], idxs[i], skip)
end
return idxs, dists
Expand Down
5 changes: 2 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@ using LinearAlgebra

using Distances: Distances, Metric, evaluate, PeriodicEuclidean
struct CustomMetric1 <: Metric end
Distances.evaluate(::CustomMetric1, a::AbstractVector, b::AbstractVector) = maximum(abs.(a .- b))
Distances.evaluate(::CustomMetric1, a::AbstractVector, b::AbstractVector) = maximum(abs, (a .- b))
function NearestNeighbors.interpolate(::CustomMetric1,
a::V,
b::V,
x,
d,
ab) where {V <: AbstractVector}
d) where {V <: AbstractVector}
idx = (abs.(b .- a) .>= d - x)
c = copy(Array(a))
c[idx] = (1 - x / d) * a[idx] + (x / d) * b[idx]
Expand Down
4 changes: 2 additions & 2 deletions test/test_inrange.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Does not test leafsize
@testset "inrange" begin
@testset "metric" for metric in [Euclidean()]
@testset "tree type" for TreeType in trees_with_brute
@testset "metric $Metric" for metric in [Euclidean()]
@testset "tree type $TreeType" for TreeType in trees_with_brute
function test(data)
tree = TreeType(data, metric; leafsize=2)
dosort = true
Expand Down
4 changes: 2 additions & 2 deletions test/test_knn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import Distances.evaluate

@testset "knn" begin
@testset "metric" for metric in [metrics; WeightedEuclidean(ones(2))]
@testset "tree type" for TreeType in trees_with_brute
@testset "metric $metric" for metric in [metrics; WeightedEuclidean(ones(2))]
@testset "tree type $TreeType" for TreeType in trees_with_brute
function test(data)
tree = TreeType(data, metric; leafsize=2)

Expand Down
Loading