diff --git a/README.md b/README.md index 32b5dd41..62e91405 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,45 @@ assert_eq!( vec![(0f64, 0), (2f64, 1), (8f64, 2)] ); ``` + +## Periodic Boundary Conditions + +Kiddo supports periodic boundary conditions (PBCs) for float `KdTree` and `ImmutableKdTree` queries. Currently, periodic queries have a considerable performance penalty compared to non-periodic queries. + +Periodic queries take a `box_size` argument, where each entry is the period of one axis. + +```rust +use kiddo::{KdTree, SquaredEuclidean}; + +let mut tree: KdTree = KdTree::new(); +tree.add(&[0.95, 0.50], 100); +tree.add(&[0.40, 0.50], 101); + +let query = [0.05, 0.50]; +let box_size = [1.0, 1.0]; + +let nearest = tree.nearest_one_periodic::(&query, &box_size); +assert_eq!(nearest.item, 100); +assert!((nearest.distance - 0.01).abs() < f64::EPSILON); +``` + +Available periodic query methods: +* `nearest_one_periodic` +* `nearest_n_periodic` +* `within_periodic` +* `within_unsorted_periodic` +* `nearest_n_within_periodic` + +The same API is available on `ImmutableKdTree`. + +Notes: +* `box_size[i]` must be strictly positive for every axis +* points are expected to be stored in a principal cell such as `[0, box_size[i])` +* queries should also be supplied in that same cell +* the current implementation evaluates wrapped query images, so performance cost grows with dimension as `3^K` + +See [examples/periodic-boundaries.rs](./examples/periodic-boundaries.rs) and [examples/immutable-periodic-boundaries.rs](./examples/immutable-periodic-boundaries.rs). + See the [examples documentation](https://github.com/sdd/kiddo/tree/master/examples) for some more detailed examples. ## Optional Features diff --git a/examples/immutable-periodic-boundaries.rs b/examples/immutable-periodic-boundaries.rs new file mode 100644 index 00000000..614965c3 --- /dev/null +++ b/examples/immutable-periodic-boundaries.rs @@ -0,0 +1,37 @@ +use kiddo::immutable::float::kdtree::ImmutableKdTree; +use kiddo::SquaredEuclidean; + +fn main() { + let points = vec![[0.95, 0.50], [0.92, 0.55], [0.40, 0.50], [0.10, 0.10]]; + let tree: ImmutableKdTree = ImmutableKdTree::new_from_slice(&points); + + let query = [0.05, 0.50]; + let box_size = [1.0, 1.0]; + let radius = 0.03; + + let nearest = tree.nearest_one_periodic::(&query, &box_size); + println!("nearest_one_periodic -> {:?}", nearest); + + let nearest_n = tree.nearest_n_periodic::( + &query, + std::num::NonZero::new(2).unwrap(), + &box_size, + ); + println!("nearest_n_periodic -> {:?}", nearest_n); + + let within = tree.within_periodic::(&query, radius, &box_size); + println!("within_periodic -> {:?}", within); + + let within_unsorted = + tree.within_unsorted_periodic::(&query, radius, &box_size); + println!("within_unsorted_periodic -> {:?}", within_unsorted); + + let nearest_n_within = tree.nearest_n_within_periodic::( + &query, + radius, + std::num::NonZero::new(2).unwrap(), + true, + &box_size, + ); + println!("nearest_n_within_periodic -> {:?}", nearest_n_within); +} diff --git a/examples/periodic-boundaries.rs b/examples/periodic-boundaries.rs new file mode 100644 index 00000000..272216e2 --- /dev/null +++ b/examples/periodic-boundaries.rs @@ -0,0 +1,35 @@ +use kiddo::{KdTree, SquaredEuclidean}; + +fn main() { + let mut tree: KdTree = KdTree::new(); + tree.add(&[0.95, 0.50], 100); + tree.add(&[0.92, 0.55], 101); + tree.add(&[0.40, 0.50], 102); + tree.add(&[0.10, 0.10], 103); + + let query = [0.05, 0.50]; + let box_size = [1.0, 1.0]; + let radius = 0.03; + + let nearest = tree.nearest_one_periodic::(&query, &box_size); + println!("nearest_one_periodic -> {:?}", nearest); + + let nearest_n = tree.nearest_n_periodic::(&query, 2, &box_size); + println!("nearest_n_periodic -> {:?}", nearest_n); + + let within = tree.within_periodic::(&query, radius, &box_size); + println!("within_periodic -> {:?}", within); + + let within_unsorted = + tree.within_unsorted_periodic::(&query, radius, &box_size); + println!("within_unsorted_periodic -> {:?}", within_unsorted); + + let nearest_n_within = tree.nearest_n_within_periodic::( + &query, + radius, + std::num::NonZero::new(2).unwrap(), + true, + &box_size, + ); + println!("nearest_n_within_periodic -> {:?}", nearest_n_within); +} diff --git a/src/float/query/nearest_n.rs b/src/float/query/nearest_n.rs index 85981aa1..fa08fa1c 100644 --- a/src/float/query/nearest_n.rs +++ b/src/float/query/nearest_n.rs @@ -1,5 +1,6 @@ use az::{Az, Cast}; use std::collections::BinaryHeap; +use std::collections::HashMap; use std::ops::Rem; use crate::float::kdtree::{Axis, KdTree}; @@ -45,6 +46,119 @@ where tree.add(&[1.0, 2.0, 5.0], 100); tree.add(&[2.0, 3.0, 6.0], 101);" ); + + /// Finds the nearest `qty` elements to `query` with periodic boundary conditions. + /// + /// `box_size` gives the periodic box length for each axis. Query points are expected + /// to be wrapped into the same principal cell as the points stored in the tree. + /// + /// This first implementation checks all `3^K` wrapped query images, merges duplicate + /// items that arise from different images, and returns the best `qty` unique items. + #[inline] + pub fn nearest_n_periodic( + &self, + query: &[A; K], + qty: usize, + box_size: &[A; K], + ) -> Vec> + where + D: DistanceMetric, + T: std::hash::Hash + Eq, + { + if qty == 0 { + return Vec::new(); + } + + box_size.iter().for_each(|axis_len| { + assert!( + *axis_len > A::zero(), + "periodic box sizes must be strictly positive" + ); + }); + + let mut wrapped_query = *query; + let mut best_by_item: HashMap = HashMap::new(); + + self.nearest_n_periodic_recurse::( + query, + qty, + box_size, + 0, + &mut wrapped_query, + &mut best_by_item, + ); + + let mut results: Vec<_> = best_by_item + .into_iter() + .map(|(item, distance)| NearestNeighbour { distance, item }) + .collect(); + + results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); + results.truncate(qty); + results + } + + fn nearest_n_periodic_recurse( + &self, + query: &[A; K], + qty: usize, + box_size: &[A; K], + axis: usize, + wrapped_query: &mut [A; K], + best_by_item: &mut HashMap, + ) where + D: DistanceMetric, + T: std::hash::Hash + Eq, + { + if axis == K { + for candidate in self.nearest_n::(wrapped_query, qty) { + best_by_item + .entry(candidate.item) + .and_modify(|best_distance| { + if candidate.distance < *best_distance { + *best_distance = candidate.distance; + } + }) + .or_insert(candidate.distance); + } + return; + } + + let original = query[axis]; + let axis_len = box_size[axis]; + + wrapped_query[axis] = original - axis_len; + self.nearest_n_periodic_recurse::( + query, + qty, + box_size, + axis + 1, + wrapped_query, + best_by_item, + ); + + wrapped_query[axis] = original; + self.nearest_n_periodic_recurse::( + query, + qty, + box_size, + axis + 1, + wrapped_query, + best_by_item, + ); + + wrapped_query[axis] = original + axis_len; + self.nearest_n_periodic_recurse::( + query, + qty, + box_size, + axis + 1, + wrapped_query, + best_by_item, + ); + + wrapped_query[axis] = original; + } } #[cfg(feature = "rkyv")] @@ -97,6 +211,7 @@ where mod tests { use crate::float::distance::SquaredEuclidean; use crate::float::kdtree::{Axis, KdTree}; + use crate::nearest_neighbour::NearestNeighbour; use crate::traits::DistanceMetric; use rand::Rng; @@ -202,6 +317,65 @@ mod tests { } } + #[test] + fn can_query_nearest_n_item_with_periodic_boundaries() { + let mut tree: KdTree = KdTree::new(); + let content_to_add = [ + ([0.95f64, 0.50f64], 1), + ([0.92f64, 0.55f64], 2), + ([0.40f64, 0.50f64], 3), + ([0.10f64, 0.10f64], 4), + ]; + + for (point, item) in content_to_add { + tree.add(&point, item); + } + + let query_point = [0.05f64, 0.50f64]; + let box_size = [1.0f64, 1.0f64]; + + let result = tree.nearest_n_periodic::(&query_point, 2, &box_size); + + assert_eq!(result.len(), 2); + assert!((result[0].distance - 0.01f64).abs() < f64::EPSILON); + assert_eq!(result[0].item, 1); + assert!((result[1].distance - 0.0194f64).abs() < f64::EPSILON); + assert_eq!(result[1].item, 2); + } + + #[test] + fn can_query_nearest_n_item_with_periodic_boundaries_large_scale() { + const TREE_SIZE: usize = 10_000; + const NUM_QUERIES: usize = 200; + const N: usize = 7; + + let content_to_add: Vec<([f32; 3], u32)> = (0..TREE_SIZE) + .map(|_| rand::random::<([f32; 3], u32)>()) + .collect(); + + let mut tree: KdTree = KdTree::with_capacity(TREE_SIZE); + content_to_add + .iter() + .for_each(|(point, content)| tree.add(point, *content)); + + let box_size = [1.0f32, 1.0f32, 1.0f32]; + let query_points: Vec<[f32; 3]> = (0..NUM_QUERIES) + .map(|_| rand::random::<[f32; 3]>()) + .collect(); + + for query_point in query_points { + let expected = + linear_search_periodic(&content_to_add, N, &query_point, &box_size); + let result = tree.nearest_n_periodic::(&query_point, N, &box_size); + + assert_eq!(result.len(), expected.len()); + for (actual, expected) in result.iter().zip(expected.iter()) { + assert!((actual.distance - expected.distance).abs() < 1e-5); + assert_eq!(actual.item, expected.item); + } + } + } + fn linear_search( content: &[([A; K], u32)], qty: usize, @@ -222,4 +396,42 @@ mod tests { results } + + fn linear_search_periodic( + content: &[([A; K], u32)], + qty: usize, + query_point: &[A; K], + box_size: &[A; K], + ) -> Vec> { + let mut results = vec![]; + + for &(point, item) in content { + let distance = periodic_dist::(query_point, &point, box_size); + let candidate = NearestNeighbour { distance, item }; + + if results.len() < qty { + results.push(candidate); + results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); + } else if distance < results[qty - 1].distance { + results[qty - 1] = candidate; + results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); + } + } + + results + } + + fn periodic_dist( + query: &[A; K], + point: &[A; K], + box_size: &[A; K], + ) -> A { + (0..K) + .map(|axis| { + let diff = (query[axis] - point[axis]).abs(); + let wrapped_diff = diff.min(box_size[axis] - diff); + wrapped_diff * wrapped_diff + }) + .fold(A::zero(), std::ops::Add::add) + } } diff --git a/src/float/query/nearest_n_within.rs b/src/float/query/nearest_n_within.rs index c0d22ea2..40d3131a 100644 --- a/src/float/query/nearest_n_within.rs +++ b/src/float/query/nearest_n_within.rs @@ -1,6 +1,7 @@ use az::{Az, Cast}; use sorted_vec::SortedVec; use std::collections::BinaryHeap; +use std::collections::HashMap; use std::ops::Rem; use crate::float::kdtree::{Axis, KdTree}; @@ -51,6 +52,124 @@ let mut tree: KdTree = KdTree::new(); tree.add(&[1.0, 2.0, 5.0], 100); tree.add(&[2.0, 3.0, 6.0], 101);" ); + + /// Finds up to `max_items` elements within `dist` of `query` with periodic boundary conditions. + /// + /// `box_size` gives the periodic box length for each axis. Query points are expected + /// to be wrapped into the same principal cell as the points stored in the tree. + /// + /// This first implementation checks all `3^K` wrapped query images and merges duplicate + /// items that can arise from multiple images. + #[inline] + pub fn nearest_n_within_periodic( + &self, + query: &[A; K], + dist: A, + max_items: std::num::NonZero, + sorted: bool, + box_size: &[A; K], + ) -> Vec> + where + D: DistanceMetric, + T: std::hash::Hash + Eq, + { + box_size.iter().for_each(|axis_len| { + assert!( + *axis_len > A::zero(), + "periodic box sizes must be strictly positive" + ); + }); + + let mut wrapped_query = *query; + let mut best_by_item: HashMap = HashMap::new(); + + self.nearest_n_within_periodic_recurse::( + query, + dist, + max_items, + box_size, + 0, + &mut wrapped_query, + &mut best_by_item, + ); + + let mut results: Vec<_> = best_by_item + .into_iter() + .map(|(item, distance)| NearestNeighbour { distance, item }) + .collect(); + + if sorted { + results.sort(); + } + results.truncate(max_items.get()); + results + } + + fn nearest_n_within_periodic_recurse( + &self, + query: &[A; K], + dist: A, + max_items: std::num::NonZero, + box_size: &[A; K], + axis: usize, + wrapped_query: &mut [A; K], + best_by_item: &mut HashMap, + ) where + D: DistanceMetric, + T: std::hash::Hash + Eq, + { + if axis == K { + for candidate in self.nearest_n_within::(wrapped_query, dist, max_items, false) { + best_by_item + .entry(candidate.item) + .and_modify(|best_distance| { + if candidate.distance < *best_distance { + *best_distance = candidate.distance; + } + }) + .or_insert(candidate.distance); + } + return; + } + + let original = query[axis]; + let axis_len = box_size[axis]; + + wrapped_query[axis] = original - axis_len; + self.nearest_n_within_periodic_recurse::( + query, + dist, + max_items, + box_size, + axis + 1, + wrapped_query, + best_by_item, + ); + + wrapped_query[axis] = original; + self.nearest_n_within_periodic_recurse::( + query, + dist, + max_items, + box_size, + axis + 1, + wrapped_query, + best_by_item, + ); + + wrapped_query[axis] = original + axis_len; + self.nearest_n_within_periodic_recurse::( + query, + dist, + max_items, + box_size, + axis + 1, + wrapped_query, + best_by_item, + ); + + wrapped_query[axis] = original; + } } #[cfg(feature = "rkyv")] @@ -103,6 +222,7 @@ where mod tests { use crate::float::distance::SquaredEuclidean; use crate::float::kdtree::{Axis, KdTree}; + use crate::nearest_neighbour::NearestNeighbour; use crate::traits::DistanceMetric; use rand::Rng; use std::cmp::Ordering; @@ -344,6 +464,114 @@ mod tests { } } + #[test] + fn can_query_nearest_n_items_within_periodic_boundaries() { + let mut tree: KdTree = KdTree::new(); + let content_to_add = [ + ([0.95f64, 0.50f64], 1), + ([0.92f64, 0.55f64], 2), + ([0.40f64, 0.50f64], 3), + ([0.10f64, 0.10f64], 4), + ]; + + for (point, item) in content_to_add { + tree.add(&point, item); + } + + let query_point = [0.05f64, 0.50f64]; + let box_size = [1.0f64, 1.0f64]; + let radius = 0.03f64; + let max_qty = NonZero::new(2).unwrap(); + + let result = tree.nearest_n_within_periodic::( + &query_point, + radius, + max_qty, + true, + &box_size, + ); + + assert_eq!(result.len(), 2); + assert!((result[0].distance - 0.01f64).abs() < f64::EPSILON); + assert_eq!(result[0].item, 1); + assert!((result[1].distance - 0.0194f64).abs() < f64::EPSILON); + assert_eq!(result[1].item, 2); + } + + #[test] + fn can_query_nearest_n_items_within_periodic_boundaries_unsorted() { + let mut tree: KdTree = KdTree::new(); + let content_to_add = [ + ([0.95f64, 0.50f64], 1), + ([0.92f64, 0.55f64], 2), + ([0.40f64, 0.50f64], 3), + ([0.10f64, 0.10f64], 4), + ]; + + for (point, item) in content_to_add { + tree.add(&point, item); + } + + let query_point = [0.05f64, 0.50f64]; + let box_size = [1.0f64, 1.0f64]; + let radius = 0.03f64; + let max_qty = NonZero::new(2).unwrap(); + + let mut result = tree.nearest_n_within_periodic::( + &query_point, + radius, + max_qty, + false, + &box_size, + ); + stabilize_neighbours(&mut result); + + assert_eq!(result.len(), 2); + assert!((result[0].distance - 0.01f64).abs() < f64::EPSILON); + assert_eq!(result[0].item, 1); + assert!((result[1].distance - 0.0194f64).abs() < f64::EPSILON); + assert_eq!(result[1].item, 2); + } + + #[test] + fn can_query_nearest_n_items_within_periodic_boundaries_large_scale() { + const TREE_SIZE: usize = 10_000; + const NUM_QUERIES: usize = 200; + const RADIUS: f32 = 0.05; + + let max_qty = NonZero::new(5).unwrap(); + let content_to_add: Vec<([f32; 3], u32)> = (0..TREE_SIZE) + .map(|_| rand::random::<([f32; 3], u32)>()) + .collect(); + + let mut tree: KdTree = KdTree::with_capacity(TREE_SIZE); + content_to_add + .iter() + .for_each(|(point, content)| tree.add(point, *content)); + + let box_size = [1.0f32, 1.0f32, 1.0f32]; + let query_points: Vec<[f32; 3]> = (0..NUM_QUERIES) + .map(|_| rand::random::<[f32; 3]>()) + .collect(); + + for query_point in query_points { + let expected = linear_search_periodic(&content_to_add, &query_point, RADIUS, max_qty, &box_size); + let result = tree.nearest_n_within_periodic::( + &query_point, + RADIUS, + max_qty, + true, + &box_size, + ); + + assert_eq!(result.len(), expected.len()); + for (actual, expected) in result.iter().zip(expected.iter()) { + assert!((actual.distance - expected.distance).abs() < 1e-5); + assert_eq!(actual.item, expected.item); + } + } + } + fn linear_search( content: &[([A; K], u32)], query_point: &[A; K], @@ -363,6 +591,30 @@ mod tests { matching_items } + fn linear_search_periodic( + content: &[([A; K], u32)], + query_point: &[A; K], + radius: A, + max_qty: NonZero, + box_size: &[A; K], + ) -> Vec> { + let mut matching_items = vec![]; + + for &(point, item) in content { + let dist = periodic_dist::(query_point, &point, box_size); + if dist < radius { + matching_items.push(NearestNeighbour { + distance: dist, + item, + }); + } + } + + stabilize_neighbours(&mut matching_items); + matching_items.truncate(max_qty.get()); + matching_items + } + fn stabilize_sort(matching_items: &mut [(A, u32)]) { matching_items.sort_unstable_by(|a, b| { let dist_cmp = a.0.partial_cmp(&b.0).unwrap(); @@ -373,4 +625,29 @@ mod tests { } }); } + + fn stabilize_neighbours(matching_items: &mut [NearestNeighbour]) { + matching_items.sort_unstable_by(|a, b| { + let dist_cmp = a.distance.partial_cmp(&b.distance).unwrap(); + if dist_cmp == Ordering::Equal { + a.item.cmp(&b.item) + } else { + dist_cmp + } + }); + } + + fn periodic_dist( + query: &[A; K], + point: &[A; K], + box_size: &[A; K], + ) -> A { + (0..K) + .map(|axis| { + let diff = (query[axis] - point[axis]).abs(); + let wrapped_diff = diff.min(box_size[axis] - diff); + wrapped_diff * wrapped_diff + }) + .fold(A::zero(), std::ops::Add::add) + } } diff --git a/src/float/query/nearest_one.rs b/src/float/query/nearest_one.rs index 3879bcf8..71b7aad1 100644 --- a/src/float/query/nearest_one.rs +++ b/src/float/query/nearest_one.rs @@ -58,6 +58,134 @@ where tree.add(&[1.0, 2.0, 5.0], 100); tree.add(&[2.0, 3.0, 6.0], 101);" ); + + /// Finds the nearest element to `query` with periodic boundary conditions. + /// + /// `box_size` gives the periodic box length for each axis. Query points are expected + /// to be wrapped into the same principal cell as the points stored in the tree. + /// + /// This first implementation checks all `3^K` wrapped query images and reuses the + /// existing nearest-neighbour search for each one. + /// + /// # Examples + /// + /// ```rust + /// use kiddo::KdTree; + /// use kiddo::SquaredEuclidean; + /// + /// let mut tree: KdTree = KdTree::new(); + /// tree.add(&[0.95, 0.5], 1); + /// tree.add(&[0.40, 0.5], 2); + /// + /// let nearest = tree.nearest_one_periodic::(&[0.05, 0.5], &[1.0, 1.0]); + /// + /// assert_eq!(nearest.item, 1); + /// assert!((nearest.distance - 0.01).abs() < f64::EPSILON); + /// ``` + #[inline] + pub fn nearest_one_periodic( + &self, + query: &[A; K], + box_size: &[A; K], + ) -> NearestNeighbour + where + D: DistanceMetric, + { + self.nearest_one_periodic_point::(query, box_size).0 + } + + /// Finds the nearest element to `query` with periodic boundary conditions and also + /// returns the coordinates of the nearest point stored in the tree. + #[inline] + pub fn nearest_one_periodic_point( + &self, + query: &[A; K], + box_size: &[A; K], + ) -> (NearestNeighbour, [A; K]) + where + D: DistanceMetric, + { + box_size.iter().for_each(|axis_len| { + assert!( + *axis_len > A::zero(), + "periodic box sizes must be strictly positive" + ); + }); + + let mut wrapped_query = *query; + let mut best = NearestNeighbour { + distance: A::infinity(), + item: T::default(), + }; + let mut best_point = [A::zero(); K]; + + self.nearest_one_periodic_point_recurse::( + query, + box_size, + 0, + &mut wrapped_query, + &mut best, + &mut best_point, + ); + + (best, best_point) + } + + fn nearest_one_periodic_point_recurse( + &self, + query: &[A; K], + box_size: &[A; K], + axis: usize, + wrapped_query: &mut [A; K], + best: &mut NearestNeighbour, + best_point: &mut [A; K], + ) where + D: DistanceMetric, + { + if axis == K { + let (candidate, candidate_point) = self.nearest_one_point::(wrapped_query); + if candidate.distance < best.distance { + *best = candidate; + best_point.copy_from_slice(&candidate_point); + } + return; + } + + let original = query[axis]; + let axis_len = box_size[axis]; + + wrapped_query[axis] = original - axis_len; + self.nearest_one_periodic_point_recurse::( + query, + box_size, + axis + 1, + wrapped_query, + best, + best_point, + ); + + wrapped_query[axis] = original; + self.nearest_one_periodic_point_recurse::( + query, + box_size, + axis + 1, + wrapped_query, + best, + best_point, + ); + + wrapped_query[axis] = original + axis_len; + self.nearest_one_periodic_point_recurse::( + query, + box_size, + axis + 1, + wrapped_query, + best, + best_point, + ); + + wrapped_query[axis] = original; + } } #[cfg(feature = "rkyv")] @@ -110,7 +238,7 @@ where #[cfg(test)] mod tests { - use crate::float::distance::Manhattan; + use crate::float::distance::{Manhattan, SquaredEuclidean}; use crate::float::kdtree::{Axis, KdTree}; use crate::nearest_neighbour::NearestNeighbour; use crate::traits::DistanceMetric; @@ -202,6 +330,68 @@ mod tests { } } + #[test] + fn can_query_nearest_one_item_with_periodic_boundaries() { + let mut tree: KdTree = KdTree::new(); + let content_to_add = [ + ([0.95f64, 0.50f64], 1), + ([0.40f64, 0.50f64], 2), + ([0.10f64, 0.10f64], 3), + ([0.75f64, 0.90f64], 4), + ]; + + for (point, item) in content_to_add { + tree.add(&point, item); + } + + let query_point = [0.05f64, 0.50f64]; + let box_size = [1.0f64, 1.0f64]; + + let expected = NearestNeighbour { + distance: 0.01f64, + item: 1, + }; + + let result = tree.nearest_one_periodic::(&query_point, &box_size); + assert!((result.distance - expected.distance).abs() < f64::EPSILON); + assert_eq!(result.item, expected.item); + + let (result, result_point) = + tree.nearest_one_periodic_point::(&query_point, &box_size); + assert!((result.distance - expected.distance).abs() < f64::EPSILON); + assert_eq!(result.item, expected.item); + assert_eq!(result_point, [0.95f64, 0.50f64]); + } + + #[test] + fn can_query_nearest_one_item_with_periodic_boundaries_large_scale() { + const TREE_SIZE: usize = 10_000; + const NUM_QUERIES: usize = 200; + + let content_to_add: Vec<([f32; 3], u32)> = (0..TREE_SIZE) + .map(|_| rand::random::<([f32; 3], u32)>()) + .collect(); + + let mut tree: KdTree = KdTree::with_capacity(TREE_SIZE); + content_to_add + .iter() + .for_each(|(point, content)| tree.add(point, *content)); + + let box_size = [1.0f32, 1.0f32, 1.0f32]; + let query_points: Vec<[f32; 3]> = (0..NUM_QUERIES) + .map(|_| rand::random::<[f32; 3]>()) + .collect(); + + for query_point in query_points { + let expected = + linear_search_periodic::(&content_to_add, &query_point, &box_size); + let result = tree.nearest_one_periodic::(&query_point, &box_size); + + assert!((result.distance - expected.distance).abs() < 1e-5); + assert_eq!(result.item, expected.item); + } + } + fn linear_search( content: &[([A; K], u32)], query_point: &[A; K], @@ -222,4 +412,46 @@ mod tests { item: best_item, } } + + fn linear_search_periodic( + content: &[([A; K], u32)], + query_point: &[A; K], + box_size: &[A; K], + ) -> NearestNeighbour + where + D: DistanceMetric, + { + let mut best = NearestNeighbour { + distance: A::infinity(), + item: u32::MAX, + }; + + for &(point, item) in content { + let distance = periodic_dist::(query_point, &point, box_size); + if distance < best.distance { + best = NearestNeighbour { distance, item }; + } + } + + best + } + + fn periodic_dist( + query: &[A; K], + point: &[A; K], + box_size: &[A; K], + ) -> A + where + D: DistanceMetric, + { + let wrapped: [A; K] = std::array::from_fn(|axis| { + let diff = (query[axis] - point[axis]).abs(); + diff.min(box_size[axis] - diff) + }); + + wrapped + .into_iter() + .map(|axis_dist| D::dist1(axis_dist, A::zero())) + .fold(A::zero(), std::ops::Add::add) + } } diff --git a/src/float/query/within.rs b/src/float/query/within.rs index a895c5ab..cd02163a 100644 --- a/src/float/query/within.rs +++ b/src/float/query/within.rs @@ -43,6 +43,28 @@ let mut tree: KdTree = KdTree::new(); tree.add(&[1.0, 2.0, 5.0], 100); tree.add(&[2.0, 3.0, 6.0], 101);" ); + + /// Finds all elements within `dist` of `query` with periodic boundary conditions. + /// + /// `box_size` gives the periodic box length for each axis. Query points are expected + /// to be wrapped into the same principal cell as the points stored in the tree. + /// + /// Results are returned sorted nearest-first. + #[inline] + pub fn within_periodic( + &self, + query: &[A; K], + dist: A, + box_size: &[A; K], + ) -> Vec> + where + D: DistanceMetric, + T: std::hash::Hash + Eq, + { + let mut matching_items = self.within_unsorted_periodic::(query, dist, box_size); + matching_items.sort(); + matching_items + } } #[cfg(feature = "rkyv")] @@ -93,7 +115,7 @@ where #[cfg(test)] mod tests { - use crate::float::distance::Manhattan; + use crate::float::distance::{Manhattan, SquaredEuclidean}; use crate::float::kdtree::{Axis, KdTree}; use crate::nearest_neighbour::NearestNeighbour; use crate::traits::DistanceMetric; @@ -192,6 +214,33 @@ mod tests { } } + #[test] + fn can_query_items_within_periodic_boundaries() { + let mut tree: KdTree = KdTree::new(); + let content_to_add = [ + ([0.95f64, 0.50f64], 1), + ([0.92f64, 0.55f64], 2), + ([0.40f64, 0.50f64], 3), + ([0.10f64, 0.10f64], 4), + ]; + + for (point, item) in content_to_add { + tree.add(&point, item); + } + + let query_point = [0.05f64, 0.50f64]; + let box_size = [1.0f64, 1.0f64]; + let radius = 0.03f64; + + let result = tree.within_periodic::(&query_point, radius, &box_size); + + assert_eq!(result.len(), 2); + assert!((result[0].distance - 0.01f64).abs() < f64::EPSILON); + assert_eq!(result[0].item, 1); + assert!((result[1].distance - 0.0194f64).abs() < f64::EPSILON); + assert_eq!(result[1].item, 2); + } + fn linear_search( content: &[([A; K], u32)], query_point: &[A; K], diff --git a/src/float/query/within_unsorted.rs b/src/float/query/within_unsorted.rs index 5d2d776a..70fc34d9 100644 --- a/src/float/query/within_unsorted.rs +++ b/src/float/query/within_unsorted.rs @@ -1,4 +1,5 @@ use az::{Az, Cast}; +use std::collections::HashMap; use std::ops::Rem; use crate::float::kdtree::{Axis, KdTree}; @@ -45,6 +46,111 @@ let mut tree: KdTree = KdTree::new(); tree.add(&[1.0, 2.0, 5.0], 100); tree.add(&[2.0, 3.0, 6.0], 101);" ); + + /// Finds all elements within `dist` of `query` with periodic boundary conditions. + /// + /// `box_size` gives the periodic box length for each axis. Query points are expected + /// to be wrapped into the same principal cell as the points stored in the tree. + /// + /// This first implementation checks all `3^K` wrapped query images and merges duplicate + /// items that can arise from multiple images. + #[inline] + pub fn within_unsorted_periodic( + &self, + query: &[A; K], + dist: A, + box_size: &[A; K], + ) -> Vec> + where + D: DistanceMetric, + T: std::hash::Hash + Eq, + { + box_size.iter().for_each(|axis_len| { + assert!( + *axis_len > A::zero(), + "periodic box sizes must be strictly positive" + ); + }); + + let mut wrapped_query = *query; + let mut best_by_item: HashMap = HashMap::new(); + + self.within_unsorted_periodic_recurse::( + query, + dist, + box_size, + 0, + &mut wrapped_query, + &mut best_by_item, + ); + + best_by_item + .into_iter() + .map(|(item, distance)| NearestNeighbour { distance, item }) + .collect() + } + + fn within_unsorted_periodic_recurse( + &self, + query: &[A; K], + dist: A, + box_size: &[A; K], + axis: usize, + wrapped_query: &mut [A; K], + best_by_item: &mut HashMap, + ) where + D: DistanceMetric, + T: std::hash::Hash + Eq, + { + if axis == K { + for candidate in self.within_unsorted::(wrapped_query, dist) { + best_by_item + .entry(candidate.item) + .and_modify(|best_distance| { + if candidate.distance < *best_distance { + *best_distance = candidate.distance; + } + }) + .or_insert(candidate.distance); + } + return; + } + + let original = query[axis]; + let axis_len = box_size[axis]; + + wrapped_query[axis] = original - axis_len; + self.within_unsorted_periodic_recurse::( + query, + dist, + box_size, + axis + 1, + wrapped_query, + best_by_item, + ); + + wrapped_query[axis] = original; + self.within_unsorted_periodic_recurse::( + query, + dist, + box_size, + axis + 1, + wrapped_query, + best_by_item, + ); + + wrapped_query[axis] = original + axis_len; + self.within_unsorted_periodic_recurse::( + query, + dist, + box_size, + axis + 1, + wrapped_query, + best_by_item, + ); + + wrapped_query[axis] = original; + } } #[cfg(feature = "rkyv")] @@ -97,6 +203,7 @@ where mod tests { use crate::float::distance::SquaredEuclidean; use crate::float::kdtree::{Axis, KdTree}; + use crate::nearest_neighbour::NearestNeighbour; use crate::traits::DistanceMetric; use rand::Rng; use std::cmp::Ordering; @@ -200,6 +307,67 @@ mod tests { } } + #[test] + fn can_query_items_unsorted_with_periodic_boundaries() { + let mut tree: KdTree = KdTree::new(); + let content_to_add = [ + ([0.95f64, 0.50f64], 1), + ([0.92f64, 0.55f64], 2), + ([0.40f64, 0.50f64], 3), + ([0.10f64, 0.10f64], 4), + ]; + + for (point, item) in content_to_add { + tree.add(&point, item); + } + + let query_point = [0.05f64, 0.50f64]; + let box_size = [1.0f64, 1.0f64]; + let radius = 0.03f64; + + let mut result: Vec<_> = tree + .within_unsorted_periodic::(&query_point, radius, &box_size); + stabilize_neighbours(&mut result); + + assert_eq!(result.len(), 2); + assert!((result[0].distance - 0.01f64).abs() < f64::EPSILON); + assert_eq!(result[0].item, 1); + assert!((result[1].distance - 0.0194f64).abs() < f64::EPSILON); + assert_eq!(result[1].item, 2); + } + + #[test] + fn can_query_items_unsorted_with_periodic_boundaries_large_scale() { + const TREE_SIZE: usize = 10_000; + const NUM_QUERIES: usize = 200; + const RADIUS: f32 = 0.05; + + let content_to_add: Vec<([f32; 3], u32)> = (0..TREE_SIZE) + .map(|_| rand::random::<([f32; 3], u32)>()) + .collect(); + + let mut tree: KdTree = KdTree::with_capacity(TREE_SIZE); + content_to_add + .iter() + .for_each(|(point, content)| tree.add(point, *content)); + + let box_size = [1.0f32, 1.0f32, 1.0f32]; + let query_points: Vec<[f32; 3]> = (0..NUM_QUERIES) + .map(|_| rand::random::<[f32; 3]>()) + .collect(); + + for query_point in query_points { + let expected = + linear_search_periodic(&content_to_add, &query_point, RADIUS, &box_size); + + let mut result: Vec<_> = tree + .within_unsorted_periodic::(&query_point, RADIUS, &box_size); + stabilize_neighbours(&mut result); + + assert_same_neighbour_set(&result, &expected, 1e-5); + } + } + fn linear_search( content: &[([A; K], u32)], query_point: &[A; K], @@ -229,4 +397,68 @@ mod tests { } }); } + + fn stabilize_neighbours(matching_items: &mut [NearestNeighbour]) { + matching_items.sort_unstable_by(|a, b| { + let dist_cmp = a.distance.partial_cmp(&b.distance).unwrap(); + if dist_cmp == Ordering::Equal { + a.item.cmp(&b.item) + } else { + dist_cmp + } + }); + } + + fn linear_search_periodic( + content: &[([A; K], u32)], + query_point: &[A; K], + radius: A, + box_size: &[A; K], + ) -> Vec> { + let mut matching_items = vec![]; + + for &(point, item) in content { + let distance = periodic_dist::(query_point, &point, box_size); + if distance < radius { + matching_items.push(NearestNeighbour { distance, item }); + } + } + + stabilize_neighbours(&mut matching_items); + matching_items + } + + fn periodic_dist( + query: &[A; K], + point: &[A; K], + box_size: &[A; K], + ) -> A { + (0..K) + .map(|axis| { + let diff = (query[axis] - point[axis]).abs(); + let wrapped_diff = diff.min(box_size[axis] - diff); + wrapped_diff * wrapped_diff + }) + .fold(A::zero(), std::ops::Add::add) + } + + fn assert_same_neighbour_set( + actual: &[NearestNeighbour], + expected: &[NearestNeighbour], + tolerance: f32, + ) { + assert_eq!(actual.len(), expected.len()); + + let expected_by_item: std::collections::HashMap<_, _> = expected + .iter() + .map(|entry| (entry.item, entry.distance)) + .collect(); + + for entry in actual { + let expected_distance = expected_by_item + .get(&entry.item) + .expect("missing expected periodic neighbour"); + assert!((entry.distance - *expected_distance).abs() < tolerance); + } + } } diff --git a/src/immutable/float/query/nearest_n.rs b/src/immutable/float/query/nearest_n.rs index 9c34868b..f20cb3bf 100644 --- a/src/immutable/float/query/nearest_n.rs +++ b/src/immutable/float/query/nearest_n.rs @@ -5,6 +5,7 @@ use crate::nearest_neighbour::NearestNeighbour; use crate::traits::Content; use crate::traits::DistanceMetric; use az::Cast; +use std::collections::HashMap; use std::num::NonZero; use crate::generate_immutable_nearest_n; @@ -49,6 +50,112 @@ where let tree: ImmutableKdTree = ImmutableKdTree::new_from_slice(&content);" ); + + /// Finds the nearest `qty` elements to `query` with periodic boundary conditions. + /// + /// `box_size` gives the periodic box length for each axis. Query points are expected + /// to be wrapped into the same principal cell as the points stored in the tree. + #[inline] + pub fn nearest_n_periodic( + &self, + query: &[A; K], + max_qty: NonZero, + box_size: &[A; K], + ) -> Vec> + where + D: DistanceMetric, + T: std::hash::Hash + Eq, + { + box_size.iter().for_each(|axis_len| { + assert!( + *axis_len > A::zero(), + "periodic box sizes must be strictly positive" + ); + }); + + let mut wrapped_query = *query; + let mut best_by_item: HashMap = HashMap::new(); + + self.nearest_n_periodic_recurse::( + query, + max_qty, + box_size, + 0, + &mut wrapped_query, + &mut best_by_item, + ); + + let mut results: Vec<_> = best_by_item + .into_iter() + .map(|(item, distance)| NearestNeighbour { distance, item }) + .collect(); + + results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); + results.truncate(max_qty.get()); + results + } + + fn nearest_n_periodic_recurse( + &self, + query: &[A; K], + max_qty: NonZero, + box_size: &[A; K], + axis: usize, + wrapped_query: &mut [A; K], + best_by_item: &mut HashMap, + ) where + D: DistanceMetric, + T: std::hash::Hash + Eq, + { + if axis == K { + for candidate in self.nearest_n::(wrapped_query, max_qty) { + best_by_item + .entry(candidate.item) + .and_modify(|best_distance| { + if candidate.distance < *best_distance { + *best_distance = candidate.distance; + } + }) + .or_insert(candidate.distance); + } + return; + } + + let original = query[axis]; + let axis_len = box_size[axis]; + + wrapped_query[axis] = original - axis_len; + self.nearest_n_periodic_recurse::( + query, + max_qty, + box_size, + axis + 1, + wrapped_query, + best_by_item, + ); + + wrapped_query[axis] = original; + self.nearest_n_periodic_recurse::( + query, + max_qty, + box_size, + axis + 1, + wrapped_query, + best_by_item, + ); + + wrapped_query[axis] = original + axis_len; + self.nearest_n_periodic_recurse::( + query, + max_qty, + box_size, + axis + 1, + wrapped_query, + best_by_item, + ); + + wrapped_query[axis] = original; + } } #[cfg(feature = "rkyv")] @@ -102,6 +209,7 @@ mod tests { use crate::float::distance::SquaredEuclidean; use crate::float::kdtree::Axis; use crate::immutable::float::kdtree::ImmutableKdTree; + use crate::nearest_neighbour::NearestNeighbour; use crate::traits::DistanceMetric; use az::{Az, Cast}; use rand::Rng; @@ -296,6 +404,51 @@ mod tests { } } + #[test] + fn can_query_nearest_n_item_with_periodic_boundaries_f64() { + let content_to_add = [ + [0.95f64, 0.50f64], + [0.92f64, 0.55f64], + [0.40f64, 0.50f64], + [0.10f64, 0.10f64], + ]; + + let tree: ImmutableKdTree = ImmutableKdTree::new_from_slice(&content_to_add); + let query_point = [0.05f64, 0.50f64]; + let box_size = [1.0f64, 1.0f64]; + let max_qty = NonZero::new(2).unwrap(); + + let result = tree.nearest_n_periodic::(&query_point, max_qty, &box_size); + assert_eq!(result.len(), 2); + assert!((result[0].distance - 0.01f64).abs() < f64::EPSILON); + assert_eq!(result[0].item, 0); + assert!((result[1].distance - 0.0194f64).abs() < f64::EPSILON); + assert_eq!(result[1].item, 1); + } + + #[test] + fn can_query_nearest_n_item_with_periodic_boundaries_large_scale_f32() { + const TREE_SIZE: usize = 10_000; + const NUM_QUERIES: usize = 200; + + let max_qty = NonZero::new(5).unwrap(); + let content_to_add: Vec<[f32; 3]> = (0..TREE_SIZE).map(|_| rand::random::<[f32; 3]>()).collect(); + let tree: ImmutableKdTree = ImmutableKdTree::new_from_slice(&content_to_add); + let box_size = [1.0f32, 1.0f32, 1.0f32]; + let query_points: Vec<[f32; 3]> = (0..NUM_QUERIES).map(|_| rand::random::<[f32; 3]>()).collect(); + + for query_point in query_points.iter() { + let expected = linear_search_periodic(&content_to_add, max_qty.into(), query_point, &box_size); + let result = tree.nearest_n_periodic::(query_point, max_qty, &box_size); + + assert_eq!(result.len(), expected.len()); + for (actual, expected) in result.iter().zip(expected.iter()) { + assert!((actual.distance - expected.distance).abs() < 1e-5); + assert_eq!(actual.item, expected.item); + } + } + } + fn linear_search( content: &[[A; K]], qty: usize, @@ -319,4 +472,45 @@ mod tests { results } + + fn linear_search_periodic( + content: &[[A; K]], + qty: usize, + query_point: &[A; K], + box_size: &[A; K], + ) -> Vec> { + let mut results = vec![]; + + for (idx, point) in content.iter().enumerate() { + let dist = periodic_dist(query_point, point, box_size); + let candidate = NearestNeighbour { + distance: dist, + item: idx as u32, + }; + + if results.len() < qty { + results.push(candidate); + results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); + } else if dist < results[qty - 1].distance { + results[qty - 1] = candidate; + results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); + } + } + + results + } + + fn periodic_dist( + query: &[A; K], + point: &[A; K], + box_size: &[A; K], + ) -> A { + (0..K) + .map(|axis| { + let diff = (query[axis] - point[axis]).abs(); + let wrapped_diff = diff.min(box_size[axis] - diff); + wrapped_diff * wrapped_diff + }) + .fold(A::zero(), std::ops::Add::add) + } } diff --git a/src/immutable/float/query/nearest_n_within.rs b/src/immutable/float/query/nearest_n_within.rs index 7b3cfece..ec3d8a08 100644 --- a/src/immutable/float/query/nearest_n_within.rs +++ b/src/immutable/float/query/nearest_n_within.rs @@ -1,6 +1,7 @@ use az::Cast; use sorted_vec::SortedVec; use std::collections::BinaryHeap; +use std::collections::HashMap; use std::num::NonZero; use std::ops::Rem; @@ -54,6 +55,117 @@ where let tree: ImmutableKdTree = ImmutableKdTree::new_from_slice(&content);" ); + + #[inline] + pub fn nearest_n_within_periodic( + &self, + query: &[A; K], + dist: A, + max_items: NonZero, + sorted: bool, + box_size: &[A; K], + ) -> Vec> + where + D: DistanceMetric, + T: std::hash::Hash + Eq, + { + box_size.iter().for_each(|axis_len| { + assert!( + *axis_len > A::zero(), + "periodic box sizes must be strictly positive" + ); + }); + + let mut wrapped_query = *query; + let mut best_by_item: HashMap = HashMap::new(); + + self.nearest_n_within_periodic_recurse::( + query, + dist, + max_items, + box_size, + 0, + &mut wrapped_query, + &mut best_by_item, + ); + + let mut results: Vec<_> = best_by_item + .into_iter() + .map(|(item, distance)| NearestNeighbour { distance, item }) + .collect(); + + if sorted { + results.sort(); + } + results.truncate(max_items.get()); + results + } + + fn nearest_n_within_periodic_recurse( + &self, + query: &[A; K], + dist: A, + max_items: NonZero, + box_size: &[A; K], + axis: usize, + wrapped_query: &mut [A; K], + best_by_item: &mut HashMap, + ) where + D: DistanceMetric, + T: std::hash::Hash + Eq, + { + if axis == K { + for candidate in self.nearest_n_within::(wrapped_query, dist, max_items, false) { + best_by_item + .entry(candidate.item) + .and_modify(|best_distance| { + if candidate.distance < *best_distance { + *best_distance = candidate.distance; + } + }) + .or_insert(candidate.distance); + } + return; + } + + let original = query[axis]; + let axis_len = box_size[axis]; + + wrapped_query[axis] = original - axis_len; + self.nearest_n_within_periodic_recurse::( + query, + dist, + max_items, + box_size, + axis + 1, + wrapped_query, + best_by_item, + ); + + wrapped_query[axis] = original; + self.nearest_n_within_periodic_recurse::( + query, + dist, + max_items, + box_size, + axis + 1, + wrapped_query, + best_by_item, + ); + + wrapped_query[axis] = original + axis_len; + self.nearest_n_within_periodic_recurse::( + query, + dist, + max_items, + box_size, + axis + 1, + wrapped_query, + best_by_item, + ); + + wrapped_query[axis] = original; + } } #[cfg(feature = "rkyv")] @@ -111,6 +223,7 @@ mod tests { use crate::float::distance::SquaredEuclidean; use crate::float::kdtree::Axis; use crate::immutable::float::kdtree::ImmutableKdTree; + use crate::nearest_neighbour::NearestNeighbour; use crate::traits::DistanceMetric; use rand::Rng; use std::cmp::Ordering; @@ -222,6 +335,65 @@ mod tests { } } + #[test] + fn can_query_items_within_periodic_boundaries() { + let content_to_add = [ + [0.95f64, 0.50f64], + [0.92f64, 0.55f64], + [0.40f64, 0.50f64], + [0.10f64, 0.10f64], + ]; + + let tree: ImmutableKdTree = ImmutableKdTree::new_from_slice(&content_to_add); + let query_point = [0.05f64, 0.50f64]; + let box_size = [1.0f64, 1.0f64]; + let radius = 0.03f64; + let max_qty = NonZero::new(2).unwrap(); + + let result = tree.nearest_n_within_periodic::( + &query_point, + radius, + max_qty, + true, + &box_size, + ); + assert_eq!(result.len(), 2); + assert!((result[0].distance - 0.01f64).abs() < f64::EPSILON); + assert_eq!(result[0].item, 0); + assert!((result[1].distance - 0.0194f64).abs() < f64::EPSILON); + assert_eq!(result[1].item, 1); + } + + #[test] + fn can_query_items_within_periodic_boundaries_large_scale() { + const TREE_SIZE: usize = 10_000; + const NUM_QUERIES: usize = 200; + const RADIUS: f32 = 0.05; + + let max_qty = NonZero::new(5).unwrap(); + let content_to_add: Vec<[f32; 3]> = (0..TREE_SIZE).map(|_| rand::random::<[f32; 3]>()).collect(); + let tree: ImmutableKdTree = ImmutableKdTree::new_from_slice(&content_to_add); + let box_size = [1.0f32, 1.0f32, 1.0f32]; + let query_points: Vec<[f32; 3]> = (0..NUM_QUERIES).map(|_| rand::random::<[f32; 3]>()).collect(); + + for query_point in query_points.iter() { + let expected = linear_search_periodic(&content_to_add, query_point, RADIUS, max_qty, &box_size); + let result = tree.nearest_n_within_periodic::( + query_point, + RADIUS, + max_qty, + true, + &box_size, + ); + + assert_eq!(result.len(), expected.len()); + for (actual, expected) in result.iter().zip(expected.iter()) { + assert!((actual.distance - expected.distance).abs() < 1e-5); + assert_eq!(actual.item, expected.item); + } + } + } + fn linear_search( content: &[[A; K]], query_point: &[A; K], @@ -251,4 +423,42 @@ mod tests { } }); } + + fn linear_search_periodic( + content: &[[A; K]], + query_point: &[A; K], + radius: A, + max_qty: NonZero, + box_size: &[A; K], + ) -> Vec> { + let mut matching_items = vec![]; + + for (idx, point) in content.iter().enumerate() { + let dist = periodic_dist(query_point, point, box_size); + if dist < radius { + matching_items.push(NearestNeighbour { + distance: dist, + item: idx as u32, + }); + } + } + + matching_items.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); + matching_items.truncate(max_qty.get()); + matching_items + } + + fn periodic_dist( + query: &[A; K], + point: &[A; K], + box_size: &[A; K], + ) -> A { + (0..K) + .map(|axis| { + let diff = (query[axis] - point[axis]).abs(); + let wrapped_diff = diff.min(box_size[axis] - diff); + wrapped_diff * wrapped_diff + }) + .fold(A::zero(), std::ops::Add::add) + } } diff --git a/src/immutable/float/query/nearest_one.rs b/src/immutable/float/query/nearest_one.rs index 6e16eb01..803950b1 100644 --- a/src/immutable/float/query/nearest_one.rs +++ b/src/immutable/float/query/nearest_one.rs @@ -49,6 +49,79 @@ where let tree: ImmutableKdTree = ImmutableKdTree::new_from_slice(&content);" ); + + /// Finds the nearest element to `query` with periodic boundary conditions. + /// + /// `box_size` gives the periodic box length for each axis. Query points are expected + /// to be wrapped into the same principal cell as the points stored in the tree. + /// + /// This first implementation checks all `3^K` wrapped query images and reuses the + /// existing nearest-neighbour search for each one. + #[inline] + pub fn nearest_one_periodic( + &self, + query: &[A; K], + box_size: &[A; K], + ) -> NearestNeighbour + where + D: DistanceMetric, + { + box_size.iter().for_each(|axis_len| { + assert!( + *axis_len > A::zero(), + "periodic box sizes must be strictly positive" + ); + }); + + let mut wrapped_query = *query; + let mut best = NearestNeighbour { + distance: A::infinity(), + item: T::default(), + }; + + self.nearest_one_periodic_recurse::( + query, + box_size, + 0, + &mut wrapped_query, + &mut best, + ); + + best + } + + fn nearest_one_periodic_recurse( + &self, + query: &[A; K], + box_size: &[A; K], + axis: usize, + wrapped_query: &mut [A; K], + best: &mut NearestNeighbour, + ) where + D: DistanceMetric, + { + if axis == K { + let candidate = self.nearest_one::(wrapped_query); + if candidate.distance < best.distance { + *best = candidate; + } + return; + } + + let original = query[axis]; + let axis_len = box_size[axis]; + + wrapped_query[axis] = original - axis_len; + self.nearest_one_periodic_recurse::(query, box_size, axis + 1, wrapped_query, best); + + wrapped_query[axis] = original; + self.nearest_one_periodic_recurse::(query, box_size, axis + 1, wrapped_query, best); + + wrapped_query[axis] = original + axis_len; + self.nearest_one_periodic_recurse::(query, box_size, axis + 1, wrapped_query, best); + + wrapped_query[axis] = original; + } } #[cfg(feature = "rkyv")] @@ -245,6 +318,43 @@ mod tests { } } + #[test] + fn can_query_nearest_one_item_with_periodic_boundaries_f64() { + let content_to_add = [ + [0.95f64, 0.50f64], + [0.92f64, 0.55f64], + [0.40f64, 0.50f64], + [0.10f64, 0.10f64], + ]; + + let tree: ImmutableKdTree = ImmutableKdTree::new_from_slice(&content_to_add); + let query_point = [0.05f64, 0.50f64]; + let box_size = [1.0f64, 1.0f64]; + + let result = tree.nearest_one_periodic::(&query_point, &box_size); + assert!((result.distance - 0.01f64).abs() < f64::EPSILON); + assert_eq!(result.item, 0); + } + + #[test] + fn can_query_nearest_one_item_with_periodic_boundaries_large_scale_f32() { + const TREE_SIZE: usize = 10_000; + const NUM_QUERIES: usize = 200; + + let content_to_add: Vec<[f32; 3]> = (0..TREE_SIZE).map(|_| rand::random::<[f32; 3]>()).collect(); + let tree: ImmutableKdTree = ImmutableKdTree::new_from_slice(&content_to_add); + let box_size = [1.0f32, 1.0f32, 1.0f32]; + let query_points: Vec<[f32; 3]> = (0..NUM_QUERIES).map(|_| rand::random::<[f32; 3]>()).collect(); + + for query_point in query_points.iter() { + let expected = linear_search_periodic(&content_to_add, query_point, &box_size); + let result = tree.nearest_one_periodic::(query_point, &box_size); + + assert!((result.distance - expected.distance).abs() < 1e-5); + assert_eq!(result.item, expected.item); + } + } + #[test] fn can_query_nearest_one_item_large_scale_f32() { let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(3); @@ -294,4 +404,41 @@ mod tests { item: best_item, } } + + fn linear_search_periodic( + content: &[[A; K]], + query_point: &[A; K], + box_size: &[A; K], + ) -> NearestNeighbour { + let mut best = NearestNeighbour { + distance: A::infinity(), + item: u32::MAX, + }; + + for (idx, point) in content.iter().enumerate() { + let distance = periodic_dist(query_point, point, box_size); + if distance < best.distance { + best = NearestNeighbour { + distance, + item: idx as u32, + }; + } + } + + best + } + + fn periodic_dist( + query: &[A; K], + point: &[A; K], + box_size: &[A; K], + ) -> A { + (0..K) + .map(|axis| { + let diff = (query[axis] - point[axis]).abs(); + let wrapped_diff = diff.min(box_size[axis] - diff); + wrapped_diff * wrapped_diff + }) + .fold(A::zero(), std::ops::Add::add) + } } diff --git a/src/immutable/float/query/within.rs b/src/immutable/float/query/within.rs index a04c9f7a..44bc53fb 100644 --- a/src/immutable/float/query/within.rs +++ b/src/immutable/float/query/within.rs @@ -46,6 +46,22 @@ where let tree: ImmutableKdTree = ImmutableKdTree::new_from_slice(&content);" ); + + #[inline] + pub fn within_periodic( + &self, + query: &[A; K], + dist: A, + box_size: &[A; K], + ) -> Vec> + where + D: DistanceMetric, + T: std::hash::Hash + Eq, + { + let mut matching_items = self.within_unsorted_periodic::(query, dist, box_size); + matching_items.sort(); + matching_items + } } #[cfg(feature = "rkyv")] @@ -96,7 +112,7 @@ where #[cfg(test)] mod tests { - use crate::float::distance::Manhattan; + use crate::float::distance::{Manhattan, SquaredEuclidean}; use crate::float::kdtree::Axis; use crate::immutable::float::kdtree::ImmutableKdTree; use crate::traits::DistanceMetric; @@ -197,6 +213,29 @@ mod tests { } } + #[test] + fn can_query_items_within_periodic_boundaries() { + let content_to_add = [ + [0.95f64, 0.50f64], + [0.92f64, 0.55f64], + [0.40f64, 0.50f64], + [0.10f64, 0.10f64], + ]; + + let tree: ImmutableKdTree = + ImmutableKdTree::new_from_slice(&content_to_add); + let query_point = [0.05f64, 0.50f64]; + let box_size = [1.0f64, 1.0f64]; + let radius = 0.03f64; + + let result = tree.within_periodic::(&query_point, radius, &box_size); + assert_eq!(result.len(), 2); + assert!((result[0].distance - 0.01f64).abs() < f64::EPSILON); + assert_eq!(result[0].item, 0); + assert!((result[1].distance - 0.0194f64).abs() < f64::EPSILON); + assert_eq!(result[1].item, 1); + } + fn linear_search( content: &[[A; K]], query_point: &[A; K], diff --git a/src/immutable/float/query/within_unsorted.rs b/src/immutable/float/query/within_unsorted.rs index 28310c5a..92e9078a 100644 --- a/src/immutable/float/query/within_unsorted.rs +++ b/src/immutable/float/query/within_unsorted.rs @@ -6,6 +6,7 @@ use crate::nearest_neighbour::NearestNeighbour; use crate::traits::Content; use crate::traits::DistanceMetric; use az::Cast; +use std::collections::HashMap; macro_rules! generate_immutable_float_within_unsorted { ($doctest_build_tree:tt) => { @@ -46,6 +47,105 @@ where let tree: ImmutableKdTree = ImmutableKdTree::new_from_slice(&content);" ); + + /// Finds all elements within `dist` of `query` with periodic boundary conditions. + #[inline] + pub fn within_unsorted_periodic( + &self, + query: &[A; K], + dist: A, + box_size: &[A; K], + ) -> Vec> + where + D: DistanceMetric, + T: std::hash::Hash + Eq, + { + box_size.iter().for_each(|axis_len| { + assert!( + *axis_len > A::zero(), + "periodic box sizes must be strictly positive" + ); + }); + + let mut wrapped_query = *query; + let mut best_by_item: HashMap = HashMap::new(); + + self.within_unsorted_periodic_recurse::( + query, + dist, + box_size, + 0, + &mut wrapped_query, + &mut best_by_item, + ); + + best_by_item + .into_iter() + .map(|(item, distance)| NearestNeighbour { distance, item }) + .collect() + } + + fn within_unsorted_periodic_recurse( + &self, + query: &[A; K], + dist: A, + box_size: &[A; K], + axis: usize, + wrapped_query: &mut [A; K], + best_by_item: &mut HashMap, + ) where + D: DistanceMetric, + T: std::hash::Hash + Eq, + { + if axis == K { + for candidate in self.within_unsorted::(wrapped_query, dist) { + best_by_item + .entry(candidate.item) + .and_modify(|best_distance| { + if candidate.distance < *best_distance { + *best_distance = candidate.distance; + } + }) + .or_insert(candidate.distance); + } + return; + } + + let original = query[axis]; + let axis_len = box_size[axis]; + + wrapped_query[axis] = original - axis_len; + self.within_unsorted_periodic_recurse::( + query, + dist, + box_size, + axis + 1, + wrapped_query, + best_by_item, + ); + + wrapped_query[axis] = original; + self.within_unsorted_periodic_recurse::( + query, + dist, + box_size, + axis + 1, + wrapped_query, + best_by_item, + ); + + wrapped_query[axis] = original + axis_len; + self.within_unsorted_periodic_recurse::( + query, + dist, + box_size, + axis + 1, + wrapped_query, + best_by_item, + ); + + wrapped_query[axis] = original; + } } #[cfg(feature = "rkyv")] @@ -99,6 +199,7 @@ mod tests { use crate::float::distance::SquaredEuclidean; use crate::float::kdtree::Axis; use crate::immutable::float::kdtree::ImmutableKdTree; + use crate::nearest_neighbour::NearestNeighbour; use crate::traits::DistanceMetric; use rand::Rng; use std::cmp::Ordering; @@ -195,6 +296,54 @@ mod tests { } } + #[test] + fn can_query_items_unsorted_with_periodic_boundaries() { + let content_to_add = [ + [0.95f64, 0.50f64], + [0.92f64, 0.55f64], + [0.40f64, 0.50f64], + [0.10f64, 0.10f64], + ]; + + let tree: ImmutableKdTree = ImmutableKdTree::new_from_slice(&content_to_add); + let query_point = [0.05f64, 0.50f64]; + let box_size = [1.0f64, 1.0f64]; + let radius = 0.03f64; + + let mut result = tree.within_unsorted_periodic::(&query_point, radius, &box_size); + stabilize_neighbours(&mut result); + + assert_eq!(result.len(), 2); + assert!((result[0].distance - 0.01f64).abs() < f64::EPSILON); + assert_eq!(result[0].item, 0); + assert!((result[1].distance - 0.0194f64).abs() < f64::EPSILON); + assert_eq!(result[1].item, 1); + } + + #[test] + fn can_query_items_unsorted_with_periodic_boundaries_large_scale() { + const TREE_SIZE: usize = 10_000; + const NUM_QUERIES: usize = 200; + const RADIUS: f32 = 0.05; + + let content_to_add: Vec<[f32; 3]> = (0..TREE_SIZE).map(|_| rand::random::<[f32; 3]>()).collect(); + let tree: ImmutableKdTree = ImmutableKdTree::new_from_slice(&content_to_add); + let box_size = [1.0f32, 1.0f32, 1.0f32]; + let query_points: Vec<[f32; 3]> = (0..NUM_QUERIES).map(|_| rand::random::<[f32; 3]>()).collect(); + + for query_point in query_points.iter() { + let expected = linear_search_periodic(&content_to_add, query_point, RADIUS, &box_size); + let mut result = tree.within_unsorted_periodic::(query_point, RADIUS, &box_size); + stabilize_neighbours(&mut result); + + assert_eq!(result.len(), expected.len()); + for (actual, expected) in result.iter().zip(expected.iter()) { + assert!((actual.distance - expected.distance).abs() < 1e-5); + assert_eq!(actual.item, expected.item); + } + } + } + fn linear_search( content: &[[A; K]], query_point: &[A; K], @@ -224,4 +373,51 @@ mod tests { } }); } + + fn stabilize_neighbours(matching_items: &mut [NearestNeighbour]) { + matching_items.sort_unstable_by(|a, b| { + let dist_cmp = a.distance.partial_cmp(&b.distance).unwrap(); + if dist_cmp == Ordering::Equal { + a.item.cmp(&b.item) + } else { + dist_cmp + } + }); + } + + fn linear_search_periodic( + content: &[[A; K]], + query_point: &[A; K], + radius: A, + box_size: &[A; K], + ) -> Vec> { + let mut matching_items = vec![]; + + for (idx, point) in content.iter().enumerate() { + let dist = periodic_dist(query_point, point, box_size); + if dist < radius { + matching_items.push(NearestNeighbour { + distance: dist, + item: idx as u32, + }); + } + } + + stabilize_neighbours(&mut matching_items); + matching_items + } + + fn periodic_dist( + query: &[A; K], + point: &[A; K], + box_size: &[A; K], + ) -> A { + (0..K) + .map(|axis| { + let diff = (query[axis] - point[axis]).abs(); + let wrapped_diff = diff.min(box_size[axis] - diff); + wrapped_diff * wrapped_diff + }) + .fold(A::zero(), std::ops::Add::add) + } }