Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,45 @@ assert_eq!(
vec![(0f64, 0), (2f64, 1), (8f64, 2)]
);
```

## Periodic Boundary Conditions

Kiddo supports periodic boundary conditions (PBCs) for float `KdTree` and `ImmutableKdTree` queries. Currently, periodic queries have a considerable performance penalty compared to non-periodic queries.

Periodic queries take a `box_size` argument, where each entry is the period of one axis.

```rust
use kiddo::{KdTree, SquaredEuclidean};

let mut tree: KdTree<f64, 2> = KdTree::new();
tree.add(&[0.95, 0.50], 100);
tree.add(&[0.40, 0.50], 101);

let query = [0.05, 0.50];
let box_size = [1.0, 1.0];

let nearest = tree.nearest_one_periodic::<SquaredEuclidean>(&query, &box_size);
assert_eq!(nearest.item, 100);
assert!((nearest.distance - 0.01).abs() < f64::EPSILON);
```

Available periodic query methods:
* `nearest_one_periodic`
* `nearest_n_periodic`
* `within_periodic`
* `within_unsorted_periodic`
* `nearest_n_within_periodic`

The same API is available on `ImmutableKdTree`.

Notes:
* `box_size[i]` must be strictly positive for every axis
* points are expected to be stored in a principal cell such as `[0, box_size[i])`
* queries should also be supplied in that same cell
* the current implementation evaluates wrapped query images, so performance cost grows with dimension as `3^K`

See [examples/periodic-boundaries.rs](./examples/periodic-boundaries.rs) and [examples/immutable-periodic-boundaries.rs](./examples/immutable-periodic-boundaries.rs).

See the [examples documentation](https://github.com/sdd/kiddo/tree/master/examples) for some more detailed examples.

## Optional Features
Expand Down
37 changes: 37 additions & 0 deletions examples/immutable-periodic-boundaries.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use kiddo::immutable::float::kdtree::ImmutableKdTree;
use kiddo::SquaredEuclidean;

fn main() {
let points = vec![[0.95, 0.50], [0.92, 0.55], [0.40, 0.50], [0.10, 0.10]];
let tree: ImmutableKdTree<f64, u32, 2, 8> = ImmutableKdTree::new_from_slice(&points);

let query = [0.05, 0.50];
let box_size = [1.0, 1.0];
let radius = 0.03;

let nearest = tree.nearest_one_periodic::<SquaredEuclidean>(&query, &box_size);
println!("nearest_one_periodic -> {:?}", nearest);

let nearest_n = tree.nearest_n_periodic::<SquaredEuclidean>(
&query,
std::num::NonZero::new(2).unwrap(),
&box_size,
);
println!("nearest_n_periodic -> {:?}", nearest_n);

let within = tree.within_periodic::<SquaredEuclidean>(&query, radius, &box_size);
println!("within_periodic -> {:?}", within);

let within_unsorted =
tree.within_unsorted_periodic::<SquaredEuclidean>(&query, radius, &box_size);
println!("within_unsorted_periodic -> {:?}", within_unsorted);

let nearest_n_within = tree.nearest_n_within_periodic::<SquaredEuclidean>(
&query,
radius,
std::num::NonZero::new(2).unwrap(),
true,
&box_size,
);
println!("nearest_n_within_periodic -> {:?}", nearest_n_within);
}
35 changes: 35 additions & 0 deletions examples/periodic-boundaries.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use kiddo::{KdTree, SquaredEuclidean};

fn main() {
let mut tree: KdTree<f64, 2> = KdTree::new();
tree.add(&[0.95, 0.50], 100);
tree.add(&[0.92, 0.55], 101);
tree.add(&[0.40, 0.50], 102);
tree.add(&[0.10, 0.10], 103);

let query = [0.05, 0.50];
let box_size = [1.0, 1.0];
let radius = 0.03;

let nearest = tree.nearest_one_periodic::<SquaredEuclidean>(&query, &box_size);
println!("nearest_one_periodic -> {:?}", nearest);

let nearest_n = tree.nearest_n_periodic::<SquaredEuclidean>(&query, 2, &box_size);
println!("nearest_n_periodic -> {:?}", nearest_n);

let within = tree.within_periodic::<SquaredEuclidean>(&query, radius, &box_size);
println!("within_periodic -> {:?}", within);

let within_unsorted =
tree.within_unsorted_periodic::<SquaredEuclidean>(&query, radius, &box_size);
println!("within_unsorted_periodic -> {:?}", within_unsorted);

let nearest_n_within = tree.nearest_n_within_periodic::<SquaredEuclidean>(
&query,
radius,
std::num::NonZero::new(2).unwrap(),
true,
&box_size,
);
println!("nearest_n_within_periodic -> {:?}", nearest_n_within);
}
212 changes: 212 additions & 0 deletions src/float/query/nearest_n.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use az::{Az, Cast};
use std::collections::BinaryHeap;
use std::collections::HashMap;
use std::ops::Rem;

use crate::float::kdtree::{Axis, KdTree};
Expand Down Expand Up @@ -45,6 +46,119 @@
tree.add(&[1.0, 2.0, 5.0], 100);
tree.add(&[2.0, 3.0, 6.0], 101);"
);

/// Finds the nearest `qty` elements to `query` with periodic boundary conditions.
///
/// `box_size` gives the periodic box length for each axis. Query points are expected
/// to be wrapped into the same principal cell as the points stored in the tree.
///
/// This first implementation checks all `3^K` wrapped query images, merges duplicate
/// items that arise from different images, and returns the best `qty` unique items.
#[inline]
pub fn nearest_n_periodic<D>(
&self,
query: &[A; K],
qty: usize,
box_size: &[A; K],
) -> Vec<NearestNeighbour<A, T>>
where
D: DistanceMetric<A, K>,
T: std::hash::Hash + Eq,
{
if qty == 0 {
return Vec::new();
}

box_size.iter().for_each(|axis_len| {
assert!(
*axis_len > A::zero(),
"periodic box sizes must be strictly positive"
);
});

let mut wrapped_query = *query;
let mut best_by_item: HashMap<T, A> = HashMap::new();

self.nearest_n_periodic_recurse::<D>(
query,
qty,
box_size,
0,
&mut wrapped_query,
&mut best_by_item,
);

let mut results: Vec<_> = best_by_item
.into_iter()
.map(|(item, distance)| NearestNeighbour { distance, item })
.collect();

results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can get better performance here by only sorting the results as far as qty. We don't care if the items beyond qty are sorted - only that their distance is > the closest quantity items. Also the unstable sort variants are more appropriate here as we don't care about preserving order of items with the same distance from the raw unsorted results.

You can do this by using a variant of select_nth_unstable. This ensures that the item in position n in the result is in it's final sorted position.

Older versions of select_nth_unstable had the side-effect that the "closer" side of the result array was also sorted after select_nth_unstable exited. I'm not 100% sure that this is still the case with the new ipnsort algorithm that Rust's standard library uses, so we may need to sort the "closer" side again after applying select_nth_unstable.

Something like this should be an improvement:

results.select_nth_unstable_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());

results.truncate(qty);

// this step may not be needed if select_nth_unstable_by
// has the side-effect of sorting the closer side
results.sort_unstable_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());

results

results.truncate(qty);
results
}

fn nearest_n_periodic_recurse<D>(
&self,
query: &[A; K],
qty: usize,
box_size: &[A; K],
axis: usize,
wrapped_query: &mut [A; K],
best_by_item: &mut HashMap<T, A>,
) where
D: DistanceMetric<A, K>,
T: std::hash::Hash + Eq,
{
if axis == K {
for candidate in self.nearest_n::<D>(wrapped_query, qty) {
best_by_item

@sdd sdd Apr 20, 2026

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect that this HashMap based approach is causing some performance issues here.

An alternative we could try here would be:

  • perform all 3^K queries but keep them as-is, ie simple lists of NearestNeighbour objects
  • concatenate the lists and then sort unstable by item and then dist, so we get all results from repeated items adjacent to each other, with multiple results for the same item in ascending distance order
  • walk across the vec, copying later items on top of earlier items to eliminate the more distant images for the same item. Something like this (I've not tested this, or even tried to compile it!):
results.sort_unstable_by_key(|a|a.item); // EDIT: this closure needs to use sort_by, not sort_by_key, and to cmp by distance when a.item == b.item

let mut from_idx = 1;
let mut to_idx = 1;
let curr_item = result[0].item;
while(from_idx < result.len()) {
    if result[to_idx].item != curr_item {
        result[from_idx] = result[to_idx];
        curr_item = result[to_idx].item;
        to_idx += 1;
    }
    from_idx += 1;
}
results.truncate(to_idx);

Then try the same as the previous comment:

results.select_nth_unstable_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());

results.truncate(qty);

// this step may not be needed if select_nth_unstable_by
// has the side-effect of sorting the closer side
results.sort_unstable_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());

results

I don't know for certain but my hunch is that this would be faster than the hashmap-based approach. Certainly worth a try.

.entry(candidate.item)
.and_modify(|best_distance| {
if candidate.distance < *best_distance {
*best_distance = candidate.distance;
}
})
.or_insert(candidate.distance);
}
return;
}

let original = query[axis];
let axis_len = box_size[axis];

wrapped_query[axis] = original - axis_len;
self.nearest_n_periodic_recurse::<D>(
query,
qty,
box_size,
axis + 1,
wrapped_query,
best_by_item,
);

wrapped_query[axis] = original;
self.nearest_n_periodic_recurse::<D>(
query,
qty,
box_size,
axis + 1,
wrapped_query,
best_by_item,
);

wrapped_query[axis] = original + axis_len;
self.nearest_n_periodic_recurse::<D>(
query,
qty,
box_size,
axis + 1,
wrapped_query,
best_by_item,
);

wrapped_query[axis] = original;
}
}

#[cfg(feature = "rkyv")]
Expand Down Expand Up @@ -97,6 +211,7 @@
mod tests {
use crate::float::distance::SquaredEuclidean;
use crate::float::kdtree::{Axis, KdTree};
use crate::nearest_neighbour::NearestNeighbour;
use crate::traits::DistanceMetric;
use rand::Rng;

Expand Down Expand Up @@ -202,6 +317,65 @@
}
}

#[test]
fn can_query_nearest_n_item_with_periodic_boundaries() {
let mut tree: KdTree<f64, u32, 2, 8, u32> = KdTree::new();
let content_to_add = [
([0.95f64, 0.50f64], 1),
([0.92f64, 0.55f64], 2),
([0.40f64, 0.50f64], 3),
([0.10f64, 0.10f64], 4),
];

for (point, item) in content_to_add {
tree.add(&point, item);
}

let query_point = [0.05f64, 0.50f64];
let box_size = [1.0f64, 1.0f64];

let result = tree.nearest_n_periodic::<SquaredEuclidean>(&query_point, 2, &box_size);

assert_eq!(result.len(), 2);
assert!((result[0].distance - 0.01f64).abs() < f64::EPSILON);
assert_eq!(result[0].item, 1);
assert!((result[1].distance - 0.0194f64).abs() < f64::EPSILON);
assert_eq!(result[1].item, 2);
}

#[test]
fn can_query_nearest_n_item_with_periodic_boundaries_large_scale() {
const TREE_SIZE: usize = 10_000;
const NUM_QUERIES: usize = 200;
const N: usize = 7;

let content_to_add: Vec<([f32; 3], u32)> = (0..TREE_SIZE)
.map(|_| rand::random::<([f32; 3], u32)>())
.collect();

let mut tree: KdTree<f32, u32, 3, 32, u32> = KdTree::with_capacity(TREE_SIZE);
content_to_add
.iter()
.for_each(|(point, content)| tree.add(point, *content));

let box_size = [1.0f32, 1.0f32, 1.0f32];
let query_points: Vec<[f32; 3]> = (0..NUM_QUERIES)
.map(|_| rand::random::<[f32; 3]>())
.collect();

Check warning on line 364 in src/float/query/nearest_n.rs

View workflow job for this annotation

GitHub Actions / Formatting

Diff in /home/runner/work/kiddo/kiddo/src/float/query/nearest_n.rs

for query_point in query_points {
let expected =
linear_search_periodic(&content_to_add, N, &query_point, &box_size);
let result = tree.nearest_n_periodic::<SquaredEuclidean>(&query_point, N, &box_size);

assert_eq!(result.len(), expected.len());
for (actual, expected) in result.iter().zip(expected.iter()) {
assert!((actual.distance - expected.distance).abs() < 1e-5);
assert_eq!(actual.item, expected.item);
}
}
}

fn linear_search<A: Axis, const K: usize>(
content: &[([A; K], u32)],
qty: usize,
Expand All @@ -222,4 +396,42 @@

results
}

fn linear_search_periodic<A: Axis, const K: usize>(
content: &[([A; K], u32)],
qty: usize,
query_point: &[A; K],
box_size: &[A; K],
) -> Vec<NearestNeighbour<A, u32>> {
let mut results = vec![];

for &(point, item) in content {
let distance = periodic_dist::<A, K>(query_point, &point, box_size);
let candidate = NearestNeighbour { distance, item };

if results.len() < qty {
results.push(candidate);
results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
} else if distance < results[qty - 1].distance {
results[qty - 1] = candidate;
results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
}
}

results
}

fn periodic_dist<A: Axis, const K: usize>(
query: &[A; K],
point: &[A; K],
box_size: &[A; K],
) -> A {
(0..K)
.map(|axis| {
let diff = (query[axis] - point[axis]).abs();
let wrapped_diff = diff.min(box_size[axis] - diff);
wrapped_diff * wrapped_diff
})
.fold(A::zero(), std::ops::Add::add)
}
}
Loading
Loading