Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

format with runic #200

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions benchmark/benchmarkdatafreetree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Mmap
runtimes = []
runtimesreordered = []

function create_tree(n, reorder=false)
function create_tree(n, reorder = false)
filename = tempname()
d = 10
data = Mmap.mmap(filename, Matrix{Float32}, (d, n))
Expand All @@ -22,7 +22,7 @@ end

function knnbench(tree, data, n, N)
ind = rand(1:n, N)
knn(tree, data[:,ind], 3)[2]
knn(tree, data[:, ind], 3)[2]
end

function bench()
Expand All @@ -34,10 +34,10 @@ function bench()
tr, datar, filenamer = create_tree(n, true)

bm = @benchmark knnbench(t, data, n, 1000)
push!(runtimes, mean(bm.samples.elapsed_times) / 1e9)
push!(runtimes, mean(bm.samples.elapsed_times) / 1.0e9)

bmr = @benchmark knnbench(tr, datar, n, 1000)
push!(runtimesreordered, mean(bmr.samples.elapsed_times) / 1e9)
push!(runtimesreordered, mean(bmr.samples.elapsed_times) / 1.0e9)

rm(filename)
rm(filenamer)
Expand Down
8 changes: 5 additions & 3 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ for n_points in (EXTENSIVE_BENCHMARK ? (10^3, 10^5) : 10^5)
data = rand(StableRNG(123), dim, n_points)
for leafsize in (EXTENSIVE_BENCHMARK ? (1, 10) : 10)
for reorder in (true, false)
for (tree_type, SUITE_name) in ((KDTree, "kd tree"),
(BallTree, "ball tree"))
for (tree_type, SUITE_name) in (
(KDTree, "kd tree"),
(BallTree, "ball tree"),
)
tree = tree_type(data; leafsize = leafsize, reorder = reorder)
SUITE["build tree"]["$(tree_type) $dim × $n_points, ls = $leafsize"] = @benchmarkable $(tree_type)($data; leafsize = $leafsize, reorder = $reorder)
for input_size in (1, 1000)
Expand All @@ -27,7 +29,7 @@ for n_points in (EXTENSIVE_BENCHMARK ? (10^3, 10^5) : 10^5)
end
perc = 0.01
V = π^(dim / 2) / gamma(dim / 2 + 1) * (1 / 2)^dim
r = (V * perc * gamma(dim / 2 + 1))^(1/dim)
r = (V * perc * gamma(dim / 2 + 1))^(1 / dim)
r_formatted = @sprintf("%3.2e", r)
SUITE["inrange"]["$(tree_type) $dim × $n_points, ls = $leafsize, input_size = $input_size, r = $r_formatted"] = @benchmarkable inrange($tree, $input_data, $r)
end
Expand Down
41 changes: 25 additions & 16 deletions src/NearestNeighbors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,46 @@ export knn, knn!, nn, inrange, inrange!,inrangecount # TODOs? , allpairs, distma
export injectdata

export Euclidean,
Cityblock,
Minkowski,
Chebyshev,
Hamming,
WeightedEuclidean,
WeightedCityblock,
WeightedMinkowski
Cityblock,
Minkowski,
Chebyshev,
Hamming,
WeightedEuclidean,
WeightedCityblock,
WeightedMinkowski

abstract type NNTree{V <: AbstractVector,P <: PreMetric} end
abstract type NNTree{V <: AbstractVector, P <: PreMetric} end

const NonweightedMinowskiMetric = Union{Euclidean,Chebyshev,Cityblock,Minkowski}
const WeightedMinowskiMetric = Union{WeightedEuclidean,WeightedCityblock,WeightedMinkowski}
const NonweightedMinowskiMetric = Union{Euclidean, Chebyshev, Cityblock, Minkowski}
const WeightedMinowskiMetric = Union{WeightedEuclidean, WeightedCityblock, WeightedMinkowski}
const MinkowskiMetric = Union{NonweightedMinowskiMetric, WeightedMinowskiMetric}
function check_input(::NNTree{V1}, ::AbstractVector{V2}) where {V1, V2 <: AbstractVector}
if length(V1) != length(V2)
throw(ArgumentError(
"dimension of input points:$(length(V2)) and tree data:$(length(V1)) must agree"))
throw(
ArgumentError(
"dimension of input points:$(length(V2)) and tree data:$(length(V1)) must agree",
),
)
end
end

function check_input(::NNTree{V1}, point::AbstractVector{T}) where {V1, T <: Number}
if length(V1) != length(point)
throw(ArgumentError(
"dimension of input points:$(length(point)) and tree data:$(length(V1)) must agree"))
throw(
ArgumentError(
"dimension of input points:$(length(point)) and tree data:$(length(V1)) must agree",
),
)
end
end

function check_input(::NNTree{V1}, m::AbstractMatrix) where {V1}
if length(V1) != size(m, 1)
throw(ArgumentError(
"dimension of input points:$(size(m, 1)) and tree data:$(length(V1)) must agree"))
throw(
ArgumentError(
"dimension of input points:$(size(m, 1)) and tree data:$(length(V1)) must agree",
),
)
end
end

Expand Down
153 changes: 90 additions & 63 deletions src/ball_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# which radius are determined from the given metric.
# The tree uses the triangle inequality to prune the search space
# when finding the neighbors to a point,
struct BallTree{V <: AbstractVector,N,T,M <: Metric} <: NNTree{V,M}
struct BallTree{V <: AbstractVector, N, T, M <: Metric} <: NNTree{V, M}
data::Vector{V}
hyper_spheres::Vector{HyperSphere{N,T}} # Each hyper sphere bounds its children
hyper_spheres::Vector{HyperSphere{N, T}} # Each hyper sphere bounds its children
indices::Vector{Int} # Translates from tree index -> point index
metric::M # Metric used for tree
tree_data::TreeData # Some constants needed
Expand All @@ -18,12 +18,14 @@ end

Creates a `BallTree` from the data using the given `metric` and `leafsize`.
"""
function BallTree(data::AbstractVector{V},
metric::Metric = Euclidean();
leafsize::Int = 25,
reorder::Bool = true,
storedata::Bool = true,
reorderbuffer::Vector{V} = Vector{V}()) where {V <: AbstractArray}
function BallTree(
data::AbstractVector{V},
metric::Metric = Euclidean();
leafsize::Int = 25,
reorder::Bool = true,
storedata::Bool = true,
reorderbuffer::Vector{V} = Vector{V}(),
) where {V <: AbstractArray}
reorder = !isempty(reorderbuffer) || (storedata ? reorder : false)

tree_data = TreeData(data, leafsize)
Expand All @@ -32,7 +34,7 @@ function BallTree(data::AbstractVector{V},
indices = collect(1:n_p)

# Bottom up creation of hyper spheres so need spheres even for leafs)
hyper_spheres = Vector{HyperSphere{length(V),eltype(V)}}(undef, tree_data.n_internal_nodes + tree_data.n_leafs)
hyper_spheres = Vector{HyperSphere{length(V), eltype(V)}}(undef, tree_data.n_internal_nodes + tree_data.n_leafs)

indices_reordered = Vector{Int}()
data_reordered = Vector{V}()
Expand All @@ -49,53 +51,64 @@ function BallTree(data::AbstractVector{V},
if metric isa Distances.UnionMetrics
p = parameters(metric)
if p !== nothing && length(p) != length(V)
throw(ArgumentError(
"dimension of input points:$(length(V)) and metric parameter:$(length(p)) must agree"))
throw(
ArgumentError(
"dimension of input points:$(length(V)) and metric parameter:$(length(p)) must agree",
),
)
end
end

if n_p > 0
# Call the recursive BallTree builder
build_BallTree(1, data, data_reordered, hyper_spheres, metric, indices, indices_reordered,
1:length(data), tree_data, reorder)
build_BallTree(
1, data, data_reordered, hyper_spheres, metric, indices, indices_reordered,
1:length(data), tree_data, reorder,
)
end

if reorder
data = data_reordered
indices = indices_reordered
data = data_reordered
indices = indices_reordered
end

BallTree(storedata ? data : similar(data, 0), hyper_spheres, indices, metric, tree_data, reorder)
end

function BallTree(data::AbstractVecOrMat{T},
metric::Metric = Euclidean();
leafsize::Int = 25,
storedata::Bool = true,
reorder::Bool = true,
reorderbuffer::Matrix{T} = Matrix{T}(undef, 0, 0)) where {T <: AbstractFloat}
function BallTree(
data::AbstractVecOrMat{T},
metric::Metric = Euclidean();
leafsize::Int = 25,
storedata::Bool = true,
reorder::Bool = true,
reorderbuffer::Matrix{T} = Matrix{T}(undef, 0, 0),
) where {T <: AbstractFloat}
dim = size(data, 1)
points = copy_svec(T, data, Val(dim))
if isempty(reorderbuffer)
reorderbuffer_points = Vector{SVector{dim,T}}()
reorderbuffer_points = Vector{SVector{dim, T}}()
else
reorderbuffer_points = copy_svec(T, reorderbuffer, Val(dim))
end
BallTree(points, metric; leafsize, storedata, reorder,
reorderbuffer = reorderbuffer_points)
BallTree(
points, metric; leafsize, storedata, reorder,
reorderbuffer = reorderbuffer_points,
)
end

# Recursive function to build the tree.
function build_BallTree(index::Int,
data::AbstractVector{V},
data_reordered::Vector{V},
hyper_spheres::Vector{HyperSphere{N,T}},
metric::Metric,
indices::Vector{Int},
indices_reordered::Vector{Int},
range::UnitRange{Int},
tree_data::TreeData,
reorder::Bool) where {V <: AbstractVector, N, T}
function build_BallTree(
index::Int,
data::AbstractVector{V},
data_reordered::Vector{V},
hyper_spheres::Vector{HyperSphere{N, T}},
metric::Metric,
indices::Vector{Int},
indices_reordered::Vector{Int},
range::UnitRange{Int},
tree_data::TreeData,
reorder::Bool,
) where {V <: AbstractVector, N, T}

n_points = length(range) # Points left
if n_points <= tree_data.leafsize
Expand All @@ -118,35 +131,45 @@ function build_BallTree(index::Int,
# to compare
select_spec!(indices, mid_idx, first(range), last(range), data, split_dim)

build_BallTree(getleft(index), data, data_reordered, hyper_spheres, metric,
indices, indices_reordered, first(range):mid_idx - 1,
tree_data, reorder)
build_BallTree(
getleft(index), data, data_reordered, hyper_spheres, metric,
indices, indices_reordered, first(range):(mid_idx - 1),
tree_data, reorder,
)

build_BallTree(getright(index), data, data_reordered, hyper_spheres, metric,
indices, indices_reordered, mid_idx:last(range),
tree_data, reorder)
build_BallTree(
getright(index), data, data_reordered, hyper_spheres, metric,
indices, indices_reordered, mid_idx:last(range),
tree_data, reorder,
)

# Finally create bounding hyper sphere from the two children's hyper spheres
hyper_spheres[index] = create_bsphere(metric, hyper_spheres[getleft(index)],
hyper_spheres[getright(index)])
hyper_spheres[index] = create_bsphere(
metric, hyper_spheres[getleft(index)],
hyper_spheres[getright(index)],
)
end

function _knn(tree::BallTree,
point::AbstractVector,
best_idxs::AbstractVector{<:Integer},
best_dists::AbstractVector,
skip::F) where {F}
function _knn(
tree::BallTree,
point::AbstractVector,
best_idxs::AbstractVector{<:Integer},
best_dists::AbstractVector,
skip::F,
) where {F}
knn_kernel!(tree, 1, point, best_idxs, best_dists, skip)
return
end


function knn_kernel!(tree::BallTree{V},
index::Int,
point::AbstractArray,
best_idxs::AbstractVector{<:Integer},
best_dists::AbstractVector,
skip::F) where {V, F}
function knn_kernel!(
tree::BallTree{V},
index::Int,
point::AbstractArray,
best_idxs::AbstractVector{<:Integer},
best_dists::AbstractVector,
skip::F,
) where {V, F}
if isleaf(tree.tree_data.n_internal_nodes, index)
add_points_knn!(best_dists, best_idxs, tree, index, point, true, skip)
return
Expand Down Expand Up @@ -174,19 +197,23 @@ function knn_kernel!(tree::BallTree{V},
return
end

function _inrange(tree::BallTree{V},
point::AbstractVector,
radius::Number,
idx_in_ball::Union{Nothing, Vector{<:Integer}}) where {V}
function _inrange(
tree::BallTree{V},
point::AbstractVector,
radius::Number,
idx_in_ball::Union{Nothing, Vector{<:Integer}},
) where {V}
ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball"
return inrange_kernel!(tree, 1, point, ball, idx_in_ball) # Call the recursive range finder
end

function inrange_kernel!(tree::BallTree,
index::Int,
point::AbstractVector,
query_ball::HyperSphere,
idx_in_ball::Union{Nothing, Vector{<:Integer}})
function inrange_kernel!(
tree::BallTree,
index::Int,
point::AbstractVector,
query_ball::HyperSphere,
idx_in_ball::Union{Nothing, Vector{<:Integer}},
)

if index > length(tree.hyper_spheres)
return 0
Expand Down Expand Up @@ -215,7 +242,7 @@ function inrange_kernel!(tree::BallTree,
count += addall(tree, index, idx_in_ball)
else
# Recursively call the left and right sub tree.
count += inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball)
count += inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball)
count += inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball)
end
return count
Expand Down
Loading
Loading