Skip to content

Commit 2568b3d

Browse files
added gpu_bindings.go to enable transferring index to GPU
In addition added a test for the bindings which show how to use the functionality.
1 parent 4faa3a5 commit 2568b3d

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

gpu_bindings.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package faiss
2+
3+
/*
4+
#include <faiss/c_api/gpu/StandardGpuResources_c.h>
5+
#include <faiss/c_api/gpu/GpuAutoTune_c.h>
6+
*/
7+
import "C"
8+
import (
9+
"errors"
10+
)
11+
12+
func TransferToGpu(index Index) (Index, error) {
13+
var gpuResources *C.FaissStandardGpuResources
14+
var gpuIndex *C.FaissGpuIndex
15+
c := C.faiss_StandardGpuResources_new(&gpuResources)
16+
if c != 0 {
17+
return nil, errors.New("error on init gpu")
18+
}
19+
20+
exitCode := C.faiss_index_cpu_to_gpu(gpuResources, 0, index.cPtr(), &gpuIndex)
21+
if exitCode != 0 {
22+
return nil, errors.New("error gpu blabla")
23+
}
24+
25+
return &faissIndex{idx: gpuIndex}, nil
26+
}

gpu_bindings_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package faiss
2+
3+
import (
4+
"fmt"
5+
"github.com/stretchr/testify/require"
6+
"log"
7+
"testing"
8+
)
9+
10+
func TestFlatIndexOnGpu(t *testing.T) {
11+
index, err := NewIndexFlatL2(1)
12+
if err != nil {
13+
log.Fatal(err)
14+
}
15+
16+
idx, err := TransferToGpu(index)
17+
if err != nil {
18+
log.Fatal(err)
19+
}
20+
vectorsToAdd := []float32{1,2,3,4,5}
21+
err = idx.Add(vectorsToAdd)
22+
if err != nil {
23+
fmt.Println(err.Error())
24+
}
25+
distances, resultIds, err := idx.Search(vectorsToAdd, 5)
26+
fmt.Println(distances, resultIds, err)
27+
for i := range vectorsToAdd {
28+
require.Equal(t, int64(i), resultIds[len(vectorsToAdd)*i])
29+
require.Equal(t, float32(0), distances[len(vectorsToAdd)*i])
30+
}
31+
}
32+
33+
func TestIndexIDMapOnGPU(t *testing.T) {
34+
index, err := NewIndexFlatL2(1)
35+
if err != nil {
36+
log.Fatal(err)
37+
}
38+
39+
indexMap, err := NewIndexIDMap(index)
40+
if err != nil {
41+
fmt.Println(err.Error())
42+
}
43+
idx, err := TransferToGpu(indexMap)
44+
if err != nil {
45+
log.Fatal(err)
46+
}
47+
vectorsToAdd := []float32{1,2,3,4,5}
48+
ids := make([]int64, len(vectorsToAdd))
49+
for i := 0; i < len(vectorsToAdd); i++ {
50+
ids[i] = int64(i)
51+
}
52+
53+
err = idx.AddWithIDs(vectorsToAdd, ids)
54+
if err != nil {
55+
fmt.Println(err.Error())
56+
}
57+
distances, resultIds, err := idx.Search(vectorsToAdd, 5)
58+
fmt.Println(idx.D(), idx.Ntotal())
59+
fmt.Println(distances, resultIds, err)
60+
for i := range vectorsToAdd {
61+
require.Equal(t, ids[i], resultIds[len(vectorsToAdd)*i])
62+
require.Equal(t, float32(0), distances[len(vectorsToAdd)*i])
63+
}
64+
}

0 commit comments

Comments
 (0)