Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor API so that keys are generic #6

Merged
merged 7 commits into from
Jun 14, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Update variable name ID to key
ammario committed Jun 14, 2024
commit ac9deaf635ca5d466b6a0c75a22afb2b4e5bc28a
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -130,18 +130,18 @@ $$

where:
* $n$ is the number of vectors in the graph
* $\text{size(id)}$ is the average size of the ID in bytes
* $\text{size(key)}$ is the average size of the key in bytes
* $M$ is the maximum number of neighbors each node can have
* $d$ is the dimensionality of the vectors
* $mem_{graph}$ is the memory used by the graph structure across all layers
* $mem_{base}$ is the memory used by the vectors themselves in the base or 0th layer

You can infer that:
* Connectivity ($M$) is very expensive if IDs are large
* If $d \cdot 4$ is far larger than $M \cdot \text{size(id)}$, you should expect linear memory usage spent on representing vector data
* If $d \cdot 4$ is far smaller than $M \cdot \text{size(id)}$, you should expect $n \cdot \log(n)$ memory usage spent on representing graph structure
* Connectivity ($M$) is very expensive if keys are large
* If $d \cdot 4$ is far larger than $M \cdot \text{size(key)}$, you should expect linear memory usage spent on representing vector data
* If $d \cdot 4$ is far smaller than $M \cdot \text{size(key)}$, you should expect $n \cdot \log(n)$ memory usage spent on representing graph structure

In the example of a graph with 256 dimensions, and $M = 16$, with 8 byte IDs, you would see that each vector takes:
In the example of a graph with 256 dimensions, and $M = 16$, with 8 byte keys, you would see that each vector takes:

* $256 \cdot 4 = 1024$ data bytes
* $16 \cdot 8 = 128$ metadata bytes
16 changes: 8 additions & 8 deletions encode.go
Original file line number Diff line number Diff line change
@@ -156,7 +156,7 @@ func (h *Graph[K]) Export(w io.Writer) error {
return fmt.Errorf("encode number of nodes: %w", err)
}
for _, node := range layer.nodes {
_, err = multiBinaryWrite(w, node.ID, node.Vec, len(node.neighbors))
_, err = multiBinaryWrite(w, node.Key, node.Value, len(node.neighbors))
if err != nil {
return fmt.Errorf("encode node data: %w", err)
}
@@ -218,10 +218,10 @@ func (h *Graph[K]) Import(r io.Reader) error {

nodes := make(map[K]*layerNode[K], nNodes)
for j := 0; j < nNodes; j++ {
var id K
var key K
var vec Vector
var nNeighbors int
_, err = multiBinaryRead(r, &id, &vec, &nNeighbors)
_, err = multiBinaryRead(r, &key, &vec, &nNeighbors)
if err != nil {
return fmt.Errorf("decoding node %d: %w", j, err)
}
@@ -238,21 +238,21 @@ func (h *Graph[K]) Import(r io.Reader) error {

node := &layerNode[K]{
Node: Node[K]{
ID: id,
Vec: vec,
Key: key,
Value: vec,
},
neighbors: make(map[K]*layerNode[K]),
}

nodes[id] = node
nodes[key] = node
for _, neighbor := range neighbors {
node.neighbors[neighbor] = nil
}
}
// Fill in neighbor pointers
for _, node := range nodes {
for id := range node.neighbors {
node.neighbors[id] = nodes[id]
for key := range node.neighbors {
node.neighbors[key] = nodes[key]
}
}
h.layers[i] = &layer[K]{nodes: nodes}
10 changes: 5 additions & 5 deletions encode_test.go
Original file line number Diff line number Diff line change
@@ -53,17 +53,17 @@ func verifyGraphNodes[K cmp.Ordered](t *testing.T, g *Graph[K]) {
for _, layer := range g.layers {
for _, node := range layer.nodes {
for neighborKey, neighbor := range node.neighbors {
_, ok := layer.nodes[neighbor.ID]
_, ok := layer.nodes[neighbor.Key]
if !ok {
t.Errorf(
"node %v has neighbor %v, but neighbor does not exist",
node.ID, neighbor.ID,
node.Key, neighbor.Key,
)
}

if neighborKey != neighbor.ID {
t.Errorf("node %v has neighbor %v, but neighbor key is %v", node.ID,
neighbor.ID,
if neighborKey != neighbor.Key {
t.Errorf("node %v has neighbor %v, but neighbor key is %v", node.Key,
neighbor.Key,
neighborKey,
)
}
2 changes: 1 addition & 1 deletion example/readme/main.go
Original file line number Diff line number Diff line change
@@ -18,5 +18,5 @@ func main() {
[]float32{0.5, 0.5, 0.5},
1,
)
fmt.Printf("best friend: %v\n", neighbors[0].Vec)
fmt.Printf("best friend: %v\n", neighbors[0].Value)
}
75 changes: 38 additions & 37 deletions graph.go
Original file line number Diff line number Diff line change
@@ -16,19 +16,19 @@ type Vector = []float32

// Node is a node in the graph.
type Node[K cmp.Ordered] struct {
ID K
Vec Vector
Key K
Value Vector
}

func MakeNode[K cmp.Ordered](id K, vec Vector) Node[K] {
return Node[K]{ID: id, Vec: vec}
func MakeNode[K cmp.Ordered](key K, vec Vector) Node[K] {
return Node[K]{Key: key, Value: vec}
}

// layerNode is a node in a layer of the graph.
type layerNode[K cmp.Ordered] struct {
Node[K]

// neighbors is map of neighbor IDs to neighbor nodes.
// neighbors is map of neighbor keys to neighbor nodes.
// It is a map and not a slice to allow for efficient deletes, esp.
// when M is high.
neighbors map[K]*layerNode[K]
@@ -41,7 +41,7 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu
n.neighbors = make(map[K]*layerNode[K], m)
}

n.neighbors[newNode.ID] = newNode
n.neighbors[newNode.Key] = newNode
if len(n.neighbors) <= m {
return
}
@@ -52,7 +52,7 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu
worst *layerNode[K]
)
for _, neighbor := range n.neighbors {
d := dist(neighbor.Vec, n.Vec)
d := dist(neighbor.Value, n.Value)
// d > worstDist may always be false if the distance function
// returns NaN, e.g., when the embeddings are zero.
if d > worstDist || worst == nil {
@@ -61,9 +61,9 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu
}
}

delete(n.neighbors, worst.ID)
delete(n.neighbors, worst.Key)
// Delete backlink from the worst neighbor.
delete(worst.neighbors, n.ID)
delete(worst.neighbors, n.Key)
worst.replenish(m)
}

@@ -92,7 +92,7 @@ func (n *layerNode[K]) search(
candidates.Push(
searchCandidate[K]{
node: n,
dist: distance(n.Vec, target),
dist: distance(n.Value, target),
},
)
var (
@@ -103,7 +103,7 @@ func (n *layerNode[K]) search(

// Begin with the entry node in the result set.
result.Push(candidates.Min())
visited[n.ID] = true
visited[n.Key] = true

for candidates.Len() > 0 {
var (
@@ -113,16 +113,16 @@ func (n *layerNode[K]) search(

// We iterate the map in a sorted, deterministic fashion for
// tests.
neighborIDs := maps.Keys(current.neighbors)
slices.Sort(neighborIDs)
for _, neighborID := range neighborIDs {
neighborKeys := maps.Keys(current.neighbors)
slices.Sort(neighborKeys)
for _, neighborID := range neighborKeys {
neighbor := current.neighbors[neighborID]
if visited[neighborID] {
continue
}
visited[neighborID] = true

dist := distance(neighbor.Vec, target)
dist := distance(neighbor.Value, target)
improved = improved || dist < result.Min().dist
if result.Len() < k {
result.Push(searchCandidate[K]{node: neighbor, dist: dist})
@@ -157,8 +157,8 @@ func (n *layerNode[K]) replenish(m int) {
// This is a naive implementation that could be improved by
// using a priority queue to find the best candidates.
for _, neighbor := range n.neighbors {
for id, candidate := range neighbor.neighbors {
if _, ok := n.neighbors[id]; ok {
for key, candidate := range neighbor.neighbors {
if _, ok := n.neighbors[key]; ok {
// do not add duplicates
continue
}
@@ -177,7 +177,7 @@ func (n *layerNode[K]) replenish(m int) {
// to neighbors.
func (n *layerNode[K]) isolate(m int) {
for _, neighbor := range n.neighbors {
delete(neighbor.neighbors, n.ID)
delete(neighbor.neighbors, n.Key)
neighbor.replenish(m)
}
}
@@ -214,6 +214,7 @@ func (l *layer[K]) size() int {

// Graph is a Hierarchical Navigable Small World graph.
// All public parameters must be set before adding nodes to the graph.
// K is cmp.Ordered instead of of comparable so that they can be sorted.
type Graph[K cmp.Ordered] struct {
// Distance is the distance function used to compare embeddings.
Distance DistanceFunc
@@ -316,7 +317,7 @@ func (g *Graph[T]) Dims() int {
if len(g.layers) == 0 {
return 0
}
return len(g.layers[0].entry().Vec)
return len(g.layers[0].entry().Value)
}

func ptr[T any](v T) *T {
@@ -327,8 +328,8 @@ func ptr[T any](v T) *T {
// If another node with the same ID exists, it is replaced.
func (g *Graph[K]) Add(nodes ...Node[K]) {
for _, node := range nodes {
id := node.ID
vec := node.Vec
key := node.Key
vec := node.Value

g.assertDims(vec)
insertLevel := g.randomLevel()
@@ -350,14 +351,14 @@ func (g *Graph[K]) Add(nodes ...Node[K]) {
layer := g.layers[i]
newNode := &layerNode[K]{
Node: Node[K]{
ID: id,
Vec: vec,
Key: key,
Value: vec,
},
}

// Insert the new node into the layer.
if layer.entry() == nil {
layer.nodes = map[K]*layerNode[K]{id: newNode}
layer.nodes = map[K]*layerNode[K]{key: newNode}
continue
}

@@ -383,14 +384,14 @@ func (g *Graph[K]) Add(nodes ...Node[K]) {
}

// Re-set the elevator node for the next layer.
elevator = ptr(neighborhood[0].node.ID)
elevator = ptr(neighborhood[0].node.Key)

if insertLevel >= i {
if _, ok := layer.nodes[id]; ok {
g.Delete(id)
if _, ok := layer.nodes[key]; ok {
g.Delete(key)
}
// Insert the new node into the layer.
layer.nodes[id] = newNode
layer.nodes[key] = newNode
for _, node := range neighborhood {
// Create a bi-directional edge between the new node and the best node.
node.node.addNeighbor(newNode, g.M, g.Distance)
@@ -428,7 +429,7 @@ func (h *Graph[K]) Search(near Vector, k int) []Node[K] {
// Descending hierarchies
if layer > 0 {
nodes := searchPoint.search(1, efSearch, near, h.Distance)
elevator = ptr(nodes[0].node.ID)
elevator = ptr(nodes[0].node.Key)
continue
}

@@ -453,37 +454,37 @@ func (h *Graph[T]) Len() int {
return h.layers[0].size()
}

// Delete removes a node from the graph by ID.
// Delete removes a node from the graph by key.
// It tries to preserve the clustering properties of the graph by
// replenishing connectivity in the affected neighborhoods.
func (h *Graph[K]) Delete(id K) bool {
func (h *Graph[K]) Delete(key K) bool {
if len(h.layers) == 0 {
return false
}

var deleted bool
for _, layer := range h.layers {
node, ok := layer.nodes[id]
node, ok := layer.nodes[key]
if !ok {
continue
}
delete(layer.nodes, id)
delete(layer.nodes, key)
node.isolate(h.M)
deleted = true
}

return deleted
}

// Lookup returns the vector with the given ID.
func (h *Graph[K]) Lookup(id K) (Vector, bool) {
// Lookup returns the vector with the given key.
func (h *Graph[K]) Lookup(key K) (Vector, bool) {
if len(h.layers) == 0 {
return nil, false
}

node, ok := h.layers[0].nodes[id]
node, ok := h.layers[0].nodes[key]
if !ok {
return nil, false
}
return node.Vec, ok
return node.Value, ok
}
52 changes: 26 additions & 26 deletions graph_test.go
Original file line number Diff line number Diff line change
@@ -22,38 +22,38 @@ func Test_maxLevel(t *testing.T) {
func Test_layerNode_search(t *testing.T) {
entry := &layerNode[int]{
Node: Node[int]{
Vec: Vector{0},
ID: 0,
Value: Vector{0},
Key: 0,
},
neighbors: map[int]*layerNode[int]{
1: {
Node: Node[int]{
Vec: Vector{1},
ID: 1,
Value: Vector{1},
Key: 1,
},
},
2: {
Node: Node[int]{
Vec: Vector{2},
ID: 2,
Value: Vector{2},
Key: 2,
},
},
3: {
Node: Node[int]{
Vec: Vector{3},
ID: 3,
Value: Vector{3},
Key: 3,
},
neighbors: map[int]*layerNode[int]{
4: {
Node: Node[int]{
Vec: Vector{4},
ID: 5,
Value: Vector{4},
Key: 5,
},
},
5: {
Node: Node[int]{
Vec: Vector{5},
ID: 5,
Value: Vector{5},
Key: 5,
},
},
},
@@ -63,8 +63,8 @@ func Test_layerNode_search(t *testing.T) {

best := entry.search(2, 4, []float32{4}, EuclideanDistance)

require.Equal(t, 5, best[0].node.ID)
require.Equal(t, 3, best[1].node.ID)
require.Equal(t, 5, best[0].node.Key)
require.Equal(t, 3, best[1].node.Key)
require.Len(t, best, 2)
}

@@ -86,8 +86,8 @@ func TestGraph_AddSearch(t *testing.T) {
for i := 0; i < 128; i++ {
g.Add(
Node[int]{
ID: i,
Vec: Vector{float32(i)},
Key: i,
Value: Vector{float32(i)},
},
)
}
@@ -131,8 +131,8 @@ func TestGraph_AddDelete(t *testing.T) {
g := newTestGraph[int]()
for i := 0; i < 128; i++ {
g.Add(Node[int]{
ID: i,
Vec: Vector{float32(i)},
Key: i,
Value: Vector{float32(i)},
})
}

@@ -176,8 +176,8 @@ func Benchmark_HSNW(b *testing.B) {
g.Distance = EuclideanDistance
for i := 0; i < size; i++ {
g.Add(Node[int]{
ID: i,
Vec: Vector{float32(i)},
Key: i,
Value: Vector{float32(i)},
})
}
b.ResetTimer()
@@ -210,8 +210,8 @@ func Benchmark_HNSW_1536(b *testing.B) {
points := make([]Node[int], size)
for i := 0; i < size; i++ {
points[i] = Node[int]{
ID: i,
Vec: Vector(randFloats(1536)),
Key: i,
Value: Vector(randFloats(1536)),
}
g.Add(points[i])
}
@@ -220,7 +220,7 @@ func Benchmark_HNSW_1536(b *testing.B) {
b.Run("Search", func(b *testing.B) {
for i := 0; i < b.N; i++ {
g.Search(
points[i%size].Vec,
points[i%size].Value,
4,
)
}
@@ -230,9 +230,9 @@ func Benchmark_HNSW_1536(b *testing.B) {
func TestGraph_DefaultCosine(t *testing.T) {
g := NewGraph[int]()
g.Add(
Node[int]{ID: 1, Vec: Vector{1, 1}},
Node[int]{ID: 2, Vec: Vector{0, 1}},
Node[int]{ID: 3, Vec: Vector{1, -1}},
Node[int]{Key: 1, Value: Vector{1, 1}},
Node[int]{Key: 2, Value: Vector{0, 1}},
Node[int]{Key: 3, Value: Vector{1, -1}},
)

neighbors := g.Search(