Skip to content

Commit

Permalink
Update variable name ID to key
Browse files Browse the repository at this point in the history
  • Loading branch information
ammario committed Jun 14, 2024
1 parent f00e907 commit ac9deaf
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 82 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,18 @@ $$

where:
* $n$ is the number of vectors in the graph
* $\text{size(id)}$ is the average size of the ID in bytes
* $\text{size(key)}$ is the average size of the key in bytes
* $M$ is the maximum number of neighbors each node can have
* $d$ is the dimensionality of the vectors
* $mem_{graph}$ is the memory used by the graph structure across all layers
* $mem_{base}$ is the memory used by the vectors themselves in the base or 0th layer

You can infer that:
* Connectivity ($M$) is very expensive if IDs are large
* If $d \cdot 4$ is far larger than $M \cdot \text{size(id)}$, you should expect linear memory usage spent on representing vector data
* If $d \cdot 4$ is far smaller than $M \cdot \text{size(id)}$, you should expect $n \cdot \log(n)$ memory usage spent on representing graph structure
* Connectivity ($M$) is very expensive if keys are large
* If $d \cdot 4$ is far larger than $M \cdot \text{size(key)}$, you should expect linear memory usage spent on representing vector data
* If $d \cdot 4$ is far smaller than $M \cdot \text{size(key)}$, you should expect $n \cdot \log(n)$ memory usage spent on representing graph structure

In the example of a graph with 256 dimensions, and $M = 16$, with 8 byte IDs, you would see that each vector takes:
In the example of a graph with 256 dimensions, and $M = 16$, with 8 byte keys, you would see that each vector takes:

* $256 \cdot 4 = 1024$ data bytes
* $16 \cdot 8 = 128$ metadata bytes
Expand Down
16 changes: 8 additions & 8 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func (h *Graph[K]) Export(w io.Writer) error {
return fmt.Errorf("encode number of nodes: %w", err)
}
for _, node := range layer.nodes {
_, err = multiBinaryWrite(w, node.ID, node.Vec, len(node.neighbors))
_, err = multiBinaryWrite(w, node.Key, node.Value, len(node.neighbors))
if err != nil {
return fmt.Errorf("encode node data: %w", err)
}
Expand Down Expand Up @@ -218,10 +218,10 @@ func (h *Graph[K]) Import(r io.Reader) error {

nodes := make(map[K]*layerNode[K], nNodes)
for j := 0; j < nNodes; j++ {
var id K
var key K
var vec Vector
var nNeighbors int
_, err = multiBinaryRead(r, &id, &vec, &nNeighbors)
_, err = multiBinaryRead(r, &key, &vec, &nNeighbors)
if err != nil {
return fmt.Errorf("decoding node %d: %w", j, err)
}
Expand All @@ -238,21 +238,21 @@ func (h *Graph[K]) Import(r io.Reader) error {

node := &layerNode[K]{
Node: Node[K]{
ID: id,
Vec: vec,
Key: key,
Value: vec,
},
neighbors: make(map[K]*layerNode[K]),
}

nodes[id] = node
nodes[key] = node
for _, neighbor := range neighbors {
node.neighbors[neighbor] = nil
}
}
// Fill in neighbor pointers
for _, node := range nodes {
for id := range node.neighbors {
node.neighbors[id] = nodes[id]
for key := range node.neighbors {
node.neighbors[key] = nodes[key]
}
}
h.layers[i] = &layer[K]{nodes: nodes}
Expand Down
10 changes: 5 additions & 5 deletions encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,17 @@ func verifyGraphNodes[K cmp.Ordered](t *testing.T, g *Graph[K]) {
for _, layer := range g.layers {
for _, node := range layer.nodes {
for neighborKey, neighbor := range node.neighbors {
_, ok := layer.nodes[neighbor.ID]
_, ok := layer.nodes[neighbor.Key]
if !ok {
t.Errorf(
"node %v has neighbor %v, but neighbor does not exist",
node.ID, neighbor.ID,
node.Key, neighbor.Key,
)
}

if neighborKey != neighbor.ID {
t.Errorf("node %v has neighbor %v, but neighbor key is %v", node.ID,
neighbor.ID,
if neighborKey != neighbor.Key {
t.Errorf("node %v has neighbor %v, but neighbor key is %v", node.Key,
neighbor.Key,
neighborKey,
)
}
Expand Down
2 changes: 1 addition & 1 deletion example/readme/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ func main() {
[]float32{0.5, 0.5, 0.5},
1,
)
fmt.Printf("best friend: %v\n", neighbors[0].Vec)
fmt.Printf("best friend: %v\n", neighbors[0].Value)
}
75 changes: 38 additions & 37 deletions graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ type Vector = []float32

// Node is a node in the graph.
type Node[K cmp.Ordered] struct {
ID K
Vec Vector
Key K
Value Vector
}

func MakeNode[K cmp.Ordered](id K, vec Vector) Node[K] {
return Node[K]{ID: id, Vec: vec}
func MakeNode[K cmp.Ordered](key K, vec Vector) Node[K] {
return Node[K]{Key: key, Value: vec}
}

// layerNode is a node in a layer of the graph.
type layerNode[K cmp.Ordered] struct {
Node[K]

// neighbors is map of neighbor IDs to neighbor nodes.
// neighbors is map of neighbor keys to neighbor nodes.
// It is a map and not a slice to allow for efficient deletes, esp.
// when M is high.
neighbors map[K]*layerNode[K]
Expand All @@ -41,7 +41,7 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu
n.neighbors = make(map[K]*layerNode[K], m)
}

n.neighbors[newNode.ID] = newNode
n.neighbors[newNode.Key] = newNode
if len(n.neighbors) <= m {
return
}
Expand All @@ -52,7 +52,7 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu
worst *layerNode[K]
)
for _, neighbor := range n.neighbors {
d := dist(neighbor.Vec, n.Vec)
d := dist(neighbor.Value, n.Value)
// d > worstDist may always be false if the distance function
// returns NaN, e.g., when the embeddings are zero.
if d > worstDist || worst == nil {
Expand All @@ -61,9 +61,9 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu
}
}

delete(n.neighbors, worst.ID)
delete(n.neighbors, worst.Key)
// Delete backlink from the worst neighbor.
delete(worst.neighbors, n.ID)
delete(worst.neighbors, n.Key)
worst.replenish(m)
}

Expand Down Expand Up @@ -92,7 +92,7 @@ func (n *layerNode[K]) search(
candidates.Push(
searchCandidate[K]{
node: n,
dist: distance(n.Vec, target),
dist: distance(n.Value, target),
},
)
var (
Expand All @@ -103,7 +103,7 @@ func (n *layerNode[K]) search(

// Begin with the entry node in the result set.
result.Push(candidates.Min())
visited[n.ID] = true
visited[n.Key] = true

for candidates.Len() > 0 {
var (
Expand All @@ -113,16 +113,16 @@ func (n *layerNode[K]) search(

// We iterate the map in a sorted, deterministic fashion for
// tests.
neighborIDs := maps.Keys(current.neighbors)
slices.Sort(neighborIDs)
for _, neighborID := range neighborIDs {
neighborKeys := maps.Keys(current.neighbors)
slices.Sort(neighborKeys)
for _, neighborID := range neighborKeys {
neighbor := current.neighbors[neighborID]
if visited[neighborID] {
continue
}
visited[neighborID] = true

dist := distance(neighbor.Vec, target)
dist := distance(neighbor.Value, target)
improved = improved || dist < result.Min().dist
if result.Len() < k {
result.Push(searchCandidate[K]{node: neighbor, dist: dist})
Expand Down Expand Up @@ -157,8 +157,8 @@ func (n *layerNode[K]) replenish(m int) {
// This is a naive implementation that could be improved by
// using a priority queue to find the best candidates.
for _, neighbor := range n.neighbors {
for id, candidate := range neighbor.neighbors {
if _, ok := n.neighbors[id]; ok {
for key, candidate := range neighbor.neighbors {
if _, ok := n.neighbors[key]; ok {
// do not add duplicates
continue
}
Expand All @@ -177,7 +177,7 @@ func (n *layerNode[K]) replenish(m int) {
// to neighbors.
func (n *layerNode[K]) isolate(m int) {
for _, neighbor := range n.neighbors {
delete(neighbor.neighbors, n.ID)
delete(neighbor.neighbors, n.Key)
neighbor.replenish(m)
}
}
Expand Down Expand Up @@ -214,6 +214,7 @@ func (l *layer[K]) size() int {

// Graph is a Hierarchical Navigable Small World graph.
// All public parameters must be set before adding nodes to the graph.
// K is cmp.Ordered instead of of comparable so that they can be sorted.
type Graph[K cmp.Ordered] struct {
// Distance is the distance function used to compare embeddings.
Distance DistanceFunc
Expand Down Expand Up @@ -316,7 +317,7 @@ func (g *Graph[T]) Dims() int {
if len(g.layers) == 0 {
return 0
}
return len(g.layers[0].entry().Vec)
return len(g.layers[0].entry().Value)
}

func ptr[T any](v T) *T {
Expand All @@ -327,8 +328,8 @@ func ptr[T any](v T) *T {
// If another node with the same ID exists, it is replaced.
func (g *Graph[K]) Add(nodes ...Node[K]) {
for _, node := range nodes {
id := node.ID
vec := node.Vec
key := node.Key
vec := node.Value

g.assertDims(vec)
insertLevel := g.randomLevel()
Expand All @@ -350,14 +351,14 @@ func (g *Graph[K]) Add(nodes ...Node[K]) {
layer := g.layers[i]
newNode := &layerNode[K]{
Node: Node[K]{
ID: id,
Vec: vec,
Key: key,
Value: vec,
},
}

// Insert the new node into the layer.
if layer.entry() == nil {
layer.nodes = map[K]*layerNode[K]{id: newNode}
layer.nodes = map[K]*layerNode[K]{key: newNode}
continue
}

Expand All @@ -383,14 +384,14 @@ func (g *Graph[K]) Add(nodes ...Node[K]) {
}

// Re-set the elevator node for the next layer.
elevator = ptr(neighborhood[0].node.ID)
elevator = ptr(neighborhood[0].node.Key)

if insertLevel >= i {
if _, ok := layer.nodes[id]; ok {
g.Delete(id)
if _, ok := layer.nodes[key]; ok {
g.Delete(key)
}
// Insert the new node into the layer.
layer.nodes[id] = newNode
layer.nodes[key] = newNode
for _, node := range neighborhood {
// Create a bi-directional edge between the new node and the best node.
node.node.addNeighbor(newNode, g.M, g.Distance)
Expand Down Expand Up @@ -428,7 +429,7 @@ func (h *Graph[K]) Search(near Vector, k int) []Node[K] {
// Descending hierarchies
if layer > 0 {
nodes := searchPoint.search(1, efSearch, near, h.Distance)
elevator = ptr(nodes[0].node.ID)
elevator = ptr(nodes[0].node.Key)
continue
}

Expand All @@ -453,37 +454,37 @@ func (h *Graph[T]) Len() int {
return h.layers[0].size()
}

// Delete removes a node from the graph by ID.
// Delete removes a node from the graph by key.
// It tries to preserve the clustering properties of the graph by
// replenishing connectivity in the affected neighborhoods.
func (h *Graph[K]) Delete(id K) bool {
func (h *Graph[K]) Delete(key K) bool {
if len(h.layers) == 0 {
return false
}

var deleted bool
for _, layer := range h.layers {
node, ok := layer.nodes[id]
node, ok := layer.nodes[key]
if !ok {
continue
}
delete(layer.nodes, id)
delete(layer.nodes, key)
node.isolate(h.M)
deleted = true
}

return deleted
}

// Lookup returns the vector with the given ID.
func (h *Graph[K]) Lookup(id K) (Vector, bool) {
// Lookup returns the vector with the given key.
func (h *Graph[K]) Lookup(key K) (Vector, bool) {
if len(h.layers) == 0 {
return nil, false
}

node, ok := h.layers[0].nodes[id]
node, ok := h.layers[0].nodes[key]
if !ok {
return nil, false
}
return node.Vec, ok
return node.Value, ok
}
Loading

0 comments on commit ac9deaf

Please sign in to comment.