diff --git a/graph.go b/graph.go index 053e9ac..b350c4a 100644 --- a/graph.go +++ b/graph.go @@ -64,7 +64,7 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu delete(n.neighbors, worst.Key) // Delete backlink from the worst neighbor. delete(worst.neighbors, n.Key) - worst.replenish(m) + worst.replenish(m, dist) } type searchCandidate[K cmp.Ordered] struct { @@ -148,7 +148,7 @@ func (n *layerNode[K]) search( return result.Slice() } -func (n *layerNode[K]) replenish(m int) { +func (n *layerNode[K]) replenish(m int, dist DistanceFunc) { if len(n.neighbors) >= m { return } @@ -165,7 +165,7 @@ func (n *layerNode[K]) replenish(m int) { if candidate == n { continue } - n.addNeighbor(candidate, m, CosineDistance) + n.addNeighbor(candidate, m, dist) if len(n.neighbors) >= m { return } @@ -175,13 +175,13 @@ func (n *layerNode[K]) replenish(m int) { // isolates remove the node from the graph by removing all connections // to neighbors. -func (n *layerNode[K]) isolate(m int) { +func (n *layerNode[K]) isolate(m int, dist DistanceFunc) { for _, neighbor := range n.neighbors { delete(neighbor.neighbors, n.Key) } for _, neighbor := range n.neighbors { - neighbor.replenish(m) + neighbor.replenish(m, dist) } } @@ -501,7 +501,7 @@ func (h *Graph[K]) Delete(key K) bool { if len(layer.nodes) == 0 { deleteLayer[i] = struct{}{} } - node.isolate(h.M) + node.isolate(h.M, h.Distance) deleted = true } diff --git a/graph_test.go b/graph_test.go index df795e6..bcce34f 100644 --- a/graph_test.go +++ b/graph_test.go @@ -113,13 +113,15 @@ func TestGraph_AddSearch(t *testing.T) { ) require.Len(t, nearest, 4) + // The two closest are 64 and 65 (distance 0.5 each). + // The next two are 63 and 66 (distance 1.5 each). require.EqualValues( t, []Node[int]{ {64, Vector{64}}, {65, Vector{65}}, - {62, Vector{62}}, {63, Vector{63}}, + {66, Vector{66}}, }, nearest, ) @@ -259,3 +261,22 @@ func TestGraph_RemoveAllNodes(t *testing.T) { g.Add(MakeNode(1, vec)) } } + +func TestGraph_DeleteReplenishUsesGraphDistance(t *testing.T) { + // replenish() previously hardcoded CosineDistance. After deleting a + // node from a EuclideanDistance graph, replenish must use the correct + // distance function or the topology becomes corrupted. + g := newTestGraph[int]() // uses EuclideanDistance + for i := 0; i < 20; i++ { + g.Add(Node[int]{Key: i, Value: Vector{float32(i)}}) + } + + // Delete a node in the middle to trigger replenish. + g.Delete(10) + + // Search should still find the correct nearest neighbor. + results := g.Search(Vector{9.5}, 1) + require.Len(t, results, 1) + // Must be 9 or 11 (both distance 0.5 from 9.5). + require.Contains(t, []int{9, 11}, results[0].Key) +} diff --git a/heap/heap.go b/heap/heap.go index 7a5052a..c919c3c 100644 --- a/heap/heap.go +++ b/heap/heap.go @@ -70,8 +70,9 @@ func (h *Heap[T]) Pop() T { return heap.Pop(&h.inner).(T) } +// PopLast removes and returns the maximum element from the heap. func (h *Heap[T]) PopLast() T { - return h.Remove(h.Len() - 1) + return h.Remove(h.maxIndex()) } // Remove removes and returns the element at index i from the heap. @@ -85,9 +86,22 @@ func (h *Heap[T]) Min() T { return h.inner.data[0] } +// maxIndex returns the index of the maximum element by scanning leaf nodes. +// In a min-heap the max is always a leaf (indices n/2 .. n-1). +func (h *Heap[T]) maxIndex() int { + n := h.inner.Len() + best := n / 2 + for i := best + 1; i < n; i++ { + if h.inner.data[best].Less(h.inner.data[i]) { + best = i + } + } + return best +} + // Max returns the maximum element in the heap. func (h *Heap[T]) Max() T { - return h.inner.data[h.inner.Len()-1] + return h.inner.data[h.maxIndex()] } func (h *Heap[T]) Slice() []T { diff --git a/heap/heap_test.go b/heap/heap_test.go index 265723e..fd2a9c3 100644 --- a/heap/heap_test.go +++ b/heap/heap_test.go @@ -32,3 +32,20 @@ func TestHeap(t *testing.T) { t.Errorf("Heap did not return sorted elements: %+v", inOrder) } } + +func TestHeap_MaxAndPopLast(t *testing.T) { + h := Heap[Int]{} + values := []Int{5, 1, 9, 3, 7, 2, 8, 4, 6} + for _, v := range values { + h.Push(v) + } + + require.Equal(t, Int(9), h.Max(), "Max should return the largest element") + require.Equal(t, Int(1), h.Min(), "Min should return the smallest element") + + // PopLast should remove and return the maximum. + popped := h.PopLast() + require.Equal(t, Int(9), popped) + require.Equal(t, Int(8), h.Max(), "Max should be 8 after removing 9") + require.Equal(t, 8, h.Len()) +}