-
Notifications
You must be signed in to change notification settings - Fork 25
feat: add periodic boundary condition queries for float trees #297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
5736b54
31002da
b07a909
17fed87
f064893
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| use kiddo::immutable::float::kdtree::ImmutableKdTree; | ||
| use kiddo::SquaredEuclidean; | ||
|
|
||
| fn main() { | ||
| let points = vec![[0.95, 0.50], [0.92, 0.55], [0.40, 0.50], [0.10, 0.10]]; | ||
| let tree: ImmutableKdTree<f64, u32, 2, 8> = ImmutableKdTree::new_from_slice(&points); | ||
|
|
||
| let query = [0.05, 0.50]; | ||
| let box_size = [1.0, 1.0]; | ||
| let radius = 0.03; | ||
|
|
||
| let nearest = tree.nearest_one_periodic::<SquaredEuclidean>(&query, &box_size); | ||
| println!("nearest_one_periodic -> {:?}", nearest); | ||
|
|
||
| let nearest_n = tree.nearest_n_periodic::<SquaredEuclidean>( | ||
| &query, | ||
| std::num::NonZero::new(2).unwrap(), | ||
| &box_size, | ||
| ); | ||
| println!("nearest_n_periodic -> {:?}", nearest_n); | ||
|
|
||
| let within = tree.within_periodic::<SquaredEuclidean>(&query, radius, &box_size); | ||
| println!("within_periodic -> {:?}", within); | ||
|
|
||
| let within_unsorted = | ||
| tree.within_unsorted_periodic::<SquaredEuclidean>(&query, radius, &box_size); | ||
| println!("within_unsorted_periodic -> {:?}", within_unsorted); | ||
|
|
||
| let nearest_n_within = tree.nearest_n_within_periodic::<SquaredEuclidean>( | ||
| &query, | ||
| radius, | ||
| std::num::NonZero::new(2).unwrap(), | ||
| true, | ||
| &box_size, | ||
| ); | ||
| println!("nearest_n_within_periodic -> {:?}", nearest_n_within); | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| use kiddo::{KdTree, SquaredEuclidean}; | ||
|
|
||
| fn main() { | ||
| let mut tree: KdTree<f64, 2> = KdTree::new(); | ||
| tree.add(&[0.95, 0.50], 100); | ||
| tree.add(&[0.92, 0.55], 101); | ||
| tree.add(&[0.40, 0.50], 102); | ||
| tree.add(&[0.10, 0.10], 103); | ||
|
|
||
| let query = [0.05, 0.50]; | ||
| let box_size = [1.0, 1.0]; | ||
| let radius = 0.03; | ||
|
|
||
| let nearest = tree.nearest_one_periodic::<SquaredEuclidean>(&query, &box_size); | ||
| println!("nearest_one_periodic -> {:?}", nearest); | ||
|
|
||
| let nearest_n = tree.nearest_n_periodic::<SquaredEuclidean>(&query, 2, &box_size); | ||
| println!("nearest_n_periodic -> {:?}", nearest_n); | ||
|
|
||
| let within = tree.within_periodic::<SquaredEuclidean>(&query, radius, &box_size); | ||
| println!("within_periodic -> {:?}", within); | ||
|
|
||
| let within_unsorted = | ||
| tree.within_unsorted_periodic::<SquaredEuclidean>(&query, radius, &box_size); | ||
| println!("within_unsorted_periodic -> {:?}", within_unsorted); | ||
|
|
||
| let nearest_n_within = tree.nearest_n_within_periodic::<SquaredEuclidean>( | ||
| &query, | ||
| radius, | ||
| std::num::NonZero::new(2).unwrap(), | ||
| true, | ||
| &box_size, | ||
| ); | ||
| println!("nearest_n_within_periodic -> {:?}", nearest_n_within); | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| use az::{Az, Cast}; | ||
| use std::collections::BinaryHeap; | ||
| use std::collections::HashMap; | ||
| use std::ops::Rem; | ||
|
|
||
| use crate::float::kdtree::{Axis, KdTree}; | ||
|
|
@@ -45,6 +46,119 @@ | |
| tree.add(&[1.0, 2.0, 5.0], 100); | ||
| tree.add(&[2.0, 3.0, 6.0], 101);" | ||
| ); | ||
|
|
||
| /// Finds the nearest `qty` elements to `query` with periodic boundary conditions. | ||
| /// | ||
| /// `box_size` gives the periodic box length for each axis. Query points are expected | ||
| /// to be wrapped into the same principal cell as the points stored in the tree. | ||
| /// | ||
| /// This first implementation checks all `3^K` wrapped query images, merges duplicate | ||
| /// items that arise from different images, and returns the best `qty` unique items. | ||
| #[inline] | ||
| pub fn nearest_n_periodic<D>( | ||
| &self, | ||
| query: &[A; K], | ||
| qty: usize, | ||
| box_size: &[A; K], | ||
| ) -> Vec<NearestNeighbour<A, T>> | ||
| where | ||
| D: DistanceMetric<A, K>, | ||
| T: std::hash::Hash + Eq, | ||
| { | ||
| if qty == 0 { | ||
| return Vec::new(); | ||
| } | ||
|
|
||
| box_size.iter().for_each(|axis_len| { | ||
| assert!( | ||
| *axis_len > A::zero(), | ||
| "periodic box sizes must be strictly positive" | ||
| ); | ||
| }); | ||
|
|
||
| let mut wrapped_query = *query; | ||
| let mut best_by_item: HashMap<T, A> = HashMap::new(); | ||
|
|
||
| self.nearest_n_periodic_recurse::<D>( | ||
| query, | ||
| qty, | ||
| box_size, | ||
| 0, | ||
| &mut wrapped_query, | ||
| &mut best_by_item, | ||
| ); | ||
|
|
||
| let mut results: Vec<_> = best_by_item | ||
| .into_iter() | ||
| .map(|(item, distance)| NearestNeighbour { distance, item }) | ||
| .collect(); | ||
|
|
||
| results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); | ||
| results.truncate(qty); | ||
| results | ||
| } | ||
|
|
||
| fn nearest_n_periodic_recurse<D>( | ||
| &self, | ||
| query: &[A; K], | ||
| qty: usize, | ||
| box_size: &[A; K], | ||
| axis: usize, | ||
| wrapped_query: &mut [A; K], | ||
| best_by_item: &mut HashMap<T, A>, | ||
| ) where | ||
| D: DistanceMetric<A, K>, | ||
| T: std::hash::Hash + Eq, | ||
| { | ||
| if axis == K { | ||
| for candidate in self.nearest_n::<D>(wrapped_query, qty) { | ||
| best_by_item | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspect that this HashMap based approach is causing some performance issues here. An alternative we could try here would be:
results.sort_unstable_by_key(|a|a.item); // EDIT: this closure needs to use sort_by, not sort_by_key, and to cmp by distance when a.item == b.item
let mut from_idx = 1;
let mut to_idx = 1;
let curr_item = result[0].item;
while(from_idx < result.len()) {
if result[to_idx].item != curr_item {
result[from_idx] = result[to_idx];
curr_item = result[to_idx].item;
to_idx += 1;
}
from_idx += 1;
}
results.truncate(to_idx);Then try the same as the previous comment: results.select_nth_unstable_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
results.truncate(qty);
// this step may not be needed if select_nth_unstable_by
// has the side-effect of sorting the closer side
results.sort_unstable_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
resultsI don't know for certain but my hunch is that this would be faster than the hashmap-based approach. Certainly worth a try. |
||
| .entry(candidate.item) | ||
| .and_modify(|best_distance| { | ||
| if candidate.distance < *best_distance { | ||
| *best_distance = candidate.distance; | ||
| } | ||
| }) | ||
| .or_insert(candidate.distance); | ||
| } | ||
| return; | ||
| } | ||
|
|
||
| let original = query[axis]; | ||
| let axis_len = box_size[axis]; | ||
|
|
||
| wrapped_query[axis] = original - axis_len; | ||
| self.nearest_n_periodic_recurse::<D>( | ||
| query, | ||
| qty, | ||
| box_size, | ||
| axis + 1, | ||
| wrapped_query, | ||
| best_by_item, | ||
| ); | ||
|
|
||
| wrapped_query[axis] = original; | ||
| self.nearest_n_periodic_recurse::<D>( | ||
| query, | ||
| qty, | ||
| box_size, | ||
| axis + 1, | ||
| wrapped_query, | ||
| best_by_item, | ||
| ); | ||
|
|
||
| wrapped_query[axis] = original + axis_len; | ||
| self.nearest_n_periodic_recurse::<D>( | ||
| query, | ||
| qty, | ||
| box_size, | ||
| axis + 1, | ||
| wrapped_query, | ||
| best_by_item, | ||
| ); | ||
|
|
||
| wrapped_query[axis] = original; | ||
| } | ||
| } | ||
|
|
||
| #[cfg(feature = "rkyv")] | ||
|
|
@@ -97,6 +211,7 @@ | |
| mod tests { | ||
| use crate::float::distance::SquaredEuclidean; | ||
| use crate::float::kdtree::{Axis, KdTree}; | ||
| use crate::nearest_neighbour::NearestNeighbour; | ||
| use crate::traits::DistanceMetric; | ||
| use rand::Rng; | ||
|
|
||
|
|
@@ -202,6 +317,65 @@ | |
| } | ||
| } | ||
|
|
||
| #[test] | ||
| fn can_query_nearest_n_item_with_periodic_boundaries() { | ||
| let mut tree: KdTree<f64, u32, 2, 8, u32> = KdTree::new(); | ||
| let content_to_add = [ | ||
| ([0.95f64, 0.50f64], 1), | ||
| ([0.92f64, 0.55f64], 2), | ||
| ([0.40f64, 0.50f64], 3), | ||
| ([0.10f64, 0.10f64], 4), | ||
| ]; | ||
|
|
||
| for (point, item) in content_to_add { | ||
| tree.add(&point, item); | ||
| } | ||
|
|
||
| let query_point = [0.05f64, 0.50f64]; | ||
| let box_size = [1.0f64, 1.0f64]; | ||
|
|
||
| let result = tree.nearest_n_periodic::<SquaredEuclidean>(&query_point, 2, &box_size); | ||
|
|
||
| assert_eq!(result.len(), 2); | ||
| assert!((result[0].distance - 0.01f64).abs() < f64::EPSILON); | ||
| assert_eq!(result[0].item, 1); | ||
| assert!((result[1].distance - 0.0194f64).abs() < f64::EPSILON); | ||
| assert_eq!(result[1].item, 2); | ||
| } | ||
|
|
||
| #[test] | ||
| fn can_query_nearest_n_item_with_periodic_boundaries_large_scale() { | ||
| const TREE_SIZE: usize = 10_000; | ||
| const NUM_QUERIES: usize = 200; | ||
| const N: usize = 7; | ||
|
|
||
| let content_to_add: Vec<([f32; 3], u32)> = (0..TREE_SIZE) | ||
| .map(|_| rand::random::<([f32; 3], u32)>()) | ||
| .collect(); | ||
|
|
||
| let mut tree: KdTree<f32, u32, 3, 32, u32> = KdTree::with_capacity(TREE_SIZE); | ||
| content_to_add | ||
| .iter() | ||
| .for_each(|(point, content)| tree.add(point, *content)); | ||
|
|
||
| let box_size = [1.0f32, 1.0f32, 1.0f32]; | ||
| let query_points: Vec<[f32; 3]> = (0..NUM_QUERIES) | ||
| .map(|_| rand::random::<[f32; 3]>()) | ||
| .collect(); | ||
|
|
||
| for query_point in query_points { | ||
| let expected = | ||
| linear_search_periodic(&content_to_add, N, &query_point, &box_size); | ||
| let result = tree.nearest_n_periodic::<SquaredEuclidean>(&query_point, N, &box_size); | ||
|
|
||
| assert_eq!(result.len(), expected.len()); | ||
| for (actual, expected) in result.iter().zip(expected.iter()) { | ||
| assert!((actual.distance - expected.distance).abs() < 1e-5); | ||
| assert_eq!(actual.item, expected.item); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| fn linear_search<A: Axis, const K: usize>( | ||
| content: &[([A; K], u32)], | ||
| qty: usize, | ||
|
|
@@ -222,4 +396,42 @@ | |
|
|
||
| results | ||
| } | ||
|
|
||
| fn linear_search_periodic<A: Axis, const K: usize>( | ||
| content: &[([A; K], u32)], | ||
| qty: usize, | ||
| query_point: &[A; K], | ||
| box_size: &[A; K], | ||
| ) -> Vec<NearestNeighbour<A, u32>> { | ||
| let mut results = vec![]; | ||
|
|
||
| for &(point, item) in content { | ||
| let distance = periodic_dist::<A, K>(query_point, &point, box_size); | ||
| let candidate = NearestNeighbour { distance, item }; | ||
|
|
||
| if results.len() < qty { | ||
| results.push(candidate); | ||
| results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); | ||
| } else if distance < results[qty - 1].distance { | ||
| results[qty - 1] = candidate; | ||
| results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); | ||
| } | ||
| } | ||
|
|
||
| results | ||
| } | ||
|
|
||
| fn periodic_dist<A: Axis, const K: usize>( | ||
| query: &[A; K], | ||
| point: &[A; K], | ||
| box_size: &[A; K], | ||
| ) -> A { | ||
| (0..K) | ||
| .map(|axis| { | ||
| let diff = (query[axis] - point[axis]).abs(); | ||
| let wrapped_diff = diff.min(box_size[axis] - diff); | ||
| wrapped_diff * wrapped_diff | ||
| }) | ||
| .fold(A::zero(), std::ops::Add::add) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can get better performance here by only sorting the results as far as
qty. We don't care if the items beyondqtyare sorted - only that theirdistanceis > the closestquantityitems. Also theunstablesort variants are more appropriate here as we don't care about preserving order of items with the same distance from the raw unsorted results.You can do this by using a variant of
select_nth_unstable. This ensures that the item in position n in the result is in it's final sorted position.Older versions of
select_nth_unstablehad the side-effect that the "closer" side of the result array was also sorted afterselect_nth_unstableexited. I'm not 100% sure that this is still the case with the new ipnsort algorithm that Rust's standard library uses, so we may need to sort the "closer" side again after applyingselect_nth_unstable.Something like this should be an improvement: