Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor API so that keys are generic #6

Merged
merged 7 commits into from
Jun 14, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
WIP: graph.go compiles!
ammario committed May 31, 2024
commit 4875617857535f53328b02e06e11189044c15477
106 changes: 62 additions & 44 deletions graph.go
Original file line number Diff line number Diff line change
@@ -14,10 +14,15 @@ import (

type Vector = []float32

// Node is a node in the graph.
type Node[K cmp.Ordered] struct {
ID K
Vec Vector
}

// layerNode is a node in a layer of the graph.
type layerNode[K cmp.Ordered] struct {
id K
vec Vector
Node[K]

// neighbors is map of neighbor IDs to neighbor nodes.
// It is a map and not a slice to allow for efficient deletes, esp.
@@ -32,7 +37,7 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu
n.neighbors = make(map[K]*layerNode[K], m)
}

n.neighbors[newNode.id] = newNode
n.neighbors[newNode.ID] = newNode
if len(n.neighbors) <= m {
return
}
@@ -43,7 +48,7 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu
worst *layerNode[K]
)
for _, neighbor := range n.neighbors {
d := dist(neighbor.vec, n.vec)
d := dist(neighbor.Vec, n.Vec)
// d > worstDist may always be false if the distance function
// returns NaN, e.g., when the embeddings are zero.
if d > worstDist || worst == nil {
@@ -52,9 +57,9 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu
}
}

delete(n.neighbors, worst.id)
delete(n.neighbors, worst.ID)
// Delete backlink from the worst neighbor.
delete(worst.neighbors, n.id)
delete(worst.neighbors, n.ID)
worst.replenish(m)
}

@@ -83,7 +88,7 @@ func (n *layerNode[K]) search(
candidates.Push(
searchCandidate[K]{
node: n,
dist: distance(n.vec, target),
dist: distance(n.Vec, target),
},
)
var (
@@ -94,7 +99,7 @@ func (n *layerNode[K]) search(

// Begin with the entry node in the result set.
result.Push(candidates.Min())
visited[n.id] = true
visited[n.ID] = true

for candidates.Len() > 0 {
var (
@@ -113,7 +118,7 @@ func (n *layerNode[K]) search(
}
visited[neighborID] = true

dist := distance(neighbor.vec, target)
dist := distance(neighbor.Vec, target)
improved = improved || dist < result.Min().dist
if result.Len() < k {
result.Push(searchCandidate[K]{node: neighbor, dist: dist})
@@ -168,7 +173,7 @@ func (n *layerNode[K]) replenish(m int) {
// to neighbors.
func (n *layerNode[K]) isolate(m int) {
for _, neighbor := range n.neighbors {
delete(neighbor.neighbors, n.id)
delete(neighbor.neighbors, n.ID)
neighbor.replenish(m)
}
}
@@ -179,7 +184,7 @@ type layer[K cmp.Ordered] struct {
// property of the graph.
//
// nodes is exported for interop with encoding/gob.
nodes map[string]*layerNode[K]
nodes map[K]*layerNode[K]
}

// entry returns the entry node of the layer.
@@ -237,8 +242,8 @@ func defaultRand() *rand.Rand {

// NewGraph returns a new graph with default parameters, roughly designed for
// storing OpenAI embeddings.
func NewGraph[K cmp.Ordered, V Embeddable[K]]() *Graph[K, V] {
return &Graph[K, V]{
func NewGraph[K cmp.Ordered]() *Graph[K] {
return &Graph[K]{
M: 16,
Ml: 0.25,
Distance: CosineDistance,
@@ -307,38 +312,48 @@ func (g *Graph[T]) Dims() int {
if len(g.layers) == 0 {
return 0
}
return len(g.layers[0].entry().Point.Embedding())
return len(g.layers[0].entry().Vec)
}

func ptr[T any](v T) *T {
return &v
}

// Add inserts nodes into the graph.
// If another node with the same ID exists, it is replaced.
func (g *Graph[T]) Add(nodes ...T) {
for _, n := range nodes {
g.assertDims(n.Embedding())
func (g *Graph[K]) Add(nodes ...Node[K]) {
for _, node := range nodes {
id := node.ID
vec := node.Vec

g.assertDims(vec)
insertLevel := g.randomLevel()
// Create layers that don't exist yet.
for insertLevel >= len(g.layers) {
g.layers = append(g.layers, &layer[T]{})
g.layers = append(g.layers, &layer[K]{})
}

if insertLevel < 0 {
panic("invalid level")
}

var elevator string
var elevator *K

preLen := g.Len()

// Insert node at each layer, beginning with the highest.
for i := len(g.layers) - 1; i >= 0; i-- {
layer := g.layers[i]
newNode := &layerNode[T]{
vec: n,
newNode := &layerNode[K]{
Node: Node[K]{
ID: id,
Vec: vec,
},
}

// Insert the new node into the layer.
if layer.entry() == nil {
layer.Nodes = map[string]*layerNode[T]{n.ID(): newNode}
layer.nodes = map[K]*layerNode[K]{id: newNode}
continue
}

@@ -348,30 +363,30 @@ func (g *Graph[T]) Add(nodes ...T) {

// On subsequent layers, we use the elevator node to enter the graph
// at the best point.
if elevator != "" {
searchPoint = layer.Nodes[elevator]
if elevator != nil {
searchPoint = layer.nodes[*elevator]
}

if g.Distance == nil {
panic("(*Graph).Distance must be set")
}

neighborhood := searchPoint.search(g.M, g.EfSearch, n.Embedding(), g.Distance)
neighborhood := searchPoint.search(g.M, g.EfSearch, vec, g.Distance)
if len(neighborhood) == 0 {
// This should never happen because the searchPoint itself
// should be in the result set.
panic("no nodes found")
}

// Re-set the elevator node for the next layer.
elevator = neighborhood[0].node.Point.ID()
elevator = ptr(neighborhood[0].node.ID)

if insertLevel >= i {
if _, ok := layer.Nodes[n.ID()]; ok {
g.Delete(n.ID())
if _, ok := layer.nodes[id]; ok {
g.Delete(id)
}
// Insert the new node into the layer.
layer.Nodes[n.ID()] = newNode
layer.nodes[id] = newNode
for _, node := range neighborhood {
// Create a bi-directional edge between the new node and the best node.
node.node.addNeighbor(newNode, g.M, g.Distance)
@@ -388,7 +403,7 @@ func (g *Graph[T]) Add(nodes ...T) {
}

// Search finds the k nearest neighbors from the target node.
func (h *Graph[T]) Search(near Vector, k int) []T {
func (h *Graph[K]) Search(near Vector, k int) []Node[K] {
h.assertDims(near)
if len(h.layers) == 0 {
return nil
@@ -397,27 +412,27 @@ func (h *Graph[T]) Search(near Vector, k int) []T {
var (
efSearch = h.EfSearch

elevator string
elevator *K
)

for layer := len(h.layers) - 1; layer >= 0; layer-- {
searchPoint := h.layers[layer].entry()
if elevator != "" {
searchPoint = h.layers[layer].Nodes[elevator]
if elevator != nil {
searchPoint = h.layers[layer].nodes[*elevator]
}

// Descending hierarchies
if layer > 0 {
nodes := searchPoint.search(1, efSearch, near, h.Distance)
elevator = nodes[0].node.Point.ID()
elevator = ptr(nodes[0].node.ID)
continue
}

nodes := searchPoint.search(k, efSearch, near, h.Distance)
out := make([]T, 0, len(nodes))
out := make([]Node[K], 0, len(nodes))

for _, node := range nodes {
out = append(out, node.node.Point.(T))
out = append(out, node.node.Node)
}

return out
@@ -437,31 +452,34 @@ func (h *Graph[T]) Len() int {
// Delete removes a node from the graph by ID.
// It tries to preserve the clustering properties of the graph by
// replenishing connectivity in the affected neighborhoods.
func (h *Graph[T]) Delete(id string) bool {
func (h *Graph[K]) Delete(id K) bool {
if len(h.layers) == 0 {
return false
}

var deleted bool
for _, layer := range h.layers {
node, ok := layer.Nodes[id]
node, ok := layer.nodes[id]
if !ok {
continue
}
delete(layer.Nodes, id)
delete(layer.nodes, id)
node.isolate(h.M)
deleted = true
}

return deleted
}

// Lookup returns the node with the given ID.
func (h *Graph[T]) Lookup(id string) (T, bool) {
var zero T
// Lookup returns the vector with the given ID.
func (h *Graph[K]) Lookup(id K) (Vector, bool) {
if len(h.layers) == 0 {
return zero, false
return nil, false
}

return h.layers[0].Nodes[id].Point.(T), true
node, ok := h.layers[0].nodes[id]
if !ok {
return nil, false
}
return node.Vec, ok
}

Unchanged files with check annotations Beta

// methods for analyzing it. It offers no compatibility guarantee
// as the methods of measuring the graph's health with change
// with the implementation.
type Analyzer[T Embeddable] struct {

Check failure on line 7 in analyzer.go

GitHub Actions / test

undefined: Embeddable
Graph *Graph[T]

Check failure on line 8 in analyzer.go

GitHub Actions / test

T does not satisfy cmp.Ordered
}
func (a *Analyzer[T]) Height() int {
func (a *Analyzer[T]) Connectivity() []float64 {
var layerConnectivity []float64
for _, layer := range a.Graph.layers {
if len(layer.Nodes) == 0 {

Check failure on line 20 in analyzer.go

GitHub Actions / test

layer.Nodes undefined (type *layer[T] has no field or method Nodes, but does have nodes)
continue
}
var sum float64
for _, node := range layer.Nodes {

Check failure on line 25 in analyzer.go

GitHub Actions / test

layer.Nodes undefined (type *layer[T] has no field or method Nodes, but does have nodes)
sum += float64(len(node.neighbors))
}
layerConnectivity = append(layerConnectivity, sum/float64(len(layer.Nodes)))

Check failure on line 29 in analyzer.go

GitHub Actions / test

layer.Nodes undefined (type *layer[T] has no field or method Nodes, but does have nodes)
}
return layerConnectivity
func (a *Analyzer[T]) Topography() []int {
var topography []int
for _, layer := range a.Graph.layers {
topography = append(topography, len(layer.Nodes))

Check failure on line 39 in analyzer.go

GitHub Actions / test

layer.Nodes undefined (type *layer[T] has no field or method Nodes, but does have nodes)
}
return topography
}
return fmt.Errorf("encode number of layers: %w", err)
}
for _, layer := range h.layers {
_, err = binaryWrite(w, len(layer.Nodes))

Check failure on line 137 in encode.go

GitHub Actions / test

layer.Nodes undefined (type *layer[T] has no field or method Nodes, but does have nodes)
if err != nil {
return fmt.Errorf("encode number of nodes: %w", err)
}
// changes to a file upon calls to Save. It is more convenient
// but less powerful than calling Graph.Export and Graph.Import
// directly.
type SavedGraph[T Embeddable] struct {

Check failure on line 256 in encode.go

GitHub Actions / test

undefined: Embeddable
*Graph[T]

Check failure on line 257 in encode.go

GitHub Actions / test

T does not satisfy cmp.Ordered
Path string
}
//
// It does not hold open a file descriptor, so SavedGraph can be forgotten
// without ever calling Save.
func LoadSavedGraph[T Embeddable](path string) (*SavedGraph[T], error) {

Check failure on line 268 in encode.go

GitHub Actions / test

undefined: Embeddable
f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600)
if err != nil {
return nil, err