Skip to content

Commit

Permalink
add index after adding data without vector schema for vector predicate
Browse files Browse the repository at this point in the history
  • Loading branch information
shivaji-kharse authored and mangalaman93 committed Dec 6, 2024
1 parent 889008f commit 0cdd0d7
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 4 deletions.
35 changes: 31 additions & 4 deletions posting/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,23 @@ func (txn *Txn) addIndexMutations(ctx context.Context, info *indexMutationInfo)
// Similarly, the current assumption is that we have at most one
// Vector Index, but this assumption may break later.
if info.op != pb.DirectedEdge_DEL &&
len(data) > 0 && data[0].Tid == types.VFloatID &&
len(info.factorySpecs) > 0 {
len(data) > 0 && len(info.factorySpecs) > 0 {
// retrieve vector from inUuid save as inVec
inVec := types.BytesAsFloatArray(data[0].Value.([]byte))

var inVec []float32
// Check if it is an add index mutation on the vector predicate.
// If yes, we will extract inVec from info.edge.Value, which we
// converted earlier. Otherwise, we will extract it from data.
if data[0].Tid.Enum().Type() == types.DefaultID.Enum().Type() {
info.edge.Value = info.val.Value.([]byte)
info.edge.ValueType = types.VFloatID.Enum()
if err := pl.addMutation(ctx, txn, info.edge); err != nil {
return []*pb.DirectedEdge{}, err
}
inVec = types.BytesAsFloatArray(info.val.Value.([]byte))
} else {
inVec = types.BytesAsFloatArray(data[0].Value.([]byte))
}
tc := hnsw.NewTxnCache(NewViTxn(txn), txn.StartTs)
indexer, err := info.factorySpecs[0].CreateIndex(attr)
if err != nil {
Expand Down Expand Up @@ -1393,6 +1406,20 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error {
Value: p.Value,
Tid: types.TypeID(p.ValType),
}
// If this is an addIndex mutation for the vector predicate,
// convert the default value type to the VFloat type.
if runForVectors && types.TypeID(p.ValType) != types.VFloatID {
sv, err := types.Convert(val, types.VFloatID)
if err != nil {
return err
}
b := types.ValueForType(types.BinaryID)
if err = types.Marshal(sv, &b); err != nil {
return err
}
val.Value = b.Value
val.Tid = types.VFloatID
}
edge.Lang = string(p.LangTag)

for {
Expand All @@ -1419,7 +1446,7 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error {
}
return edges, err
}
if len(factorySpecs) != 0 {
if runForVectors {
return builder.RunWithoutTemp(ctx)
}
return builder.Run(ctx)
Expand Down
56 changes: 56 additions & 0 deletions query/vector/vector_graphql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
package query

import (
"context"
"encoding/json"
"fmt"
"math/rand"
"testing"

"github.com/dgraph-io/dgo/v240/protos/api"
"github.com/dgraph-io/dgraph/v24/dgraphapi"
"github.com/dgraph-io/dgraph/v24/x"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -256,3 +259,56 @@ func testVectorGraphQlMutationAndQuery(t *testing.T, hc *dgraphapi.HTTPClient) {
}
}
}

func TestVectorIndexDropPredicate(t *testing.T) {
gc, cleanup, err := dc.Client()
require.NoError(t, err)
defer cleanup()
require.NoError(t, gc.LoginIntoNamespace(context.Background(),
dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace))

hc, err := dc.HTTPClient()
require.NoError(t, err)
require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser,
dgraphapi.DefaultPassword, x.GalaxyNamespace))

require.NoError(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "euclidean")))

var vectors [][]float32
numProjects := 100
projects := generateProjects(numProjects)
for _, project := range projects {
vectors = append(vectors, project.TitleV)
addProject(t, hc, project)
}

schemaWithoutIndex := `type Project {
id: ID!
title: String! @search(by: [exact])
title_v: [Float!] @embedding
} `

require.NoError(t, hc.UpdateGQLSchema(schemaWithoutIndex))

op := &api.Operation{
DropAttr: "title_v",
}
require.NoError(t, gc.Alter(context.Background(), op))

numProjects = 100
projects = generateProjects(numProjects)
for _, project := range projects {
vectors = append(vectors, project.TitleV)
addProject(t, hc, project)
}

require.NoError(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "euclidean")))

// similar to query
for _, project := range projects {
similarProjects := queryProjectsSimilarByEmbedding(t, hc, project.TitleV, numProjects)
for _, similarVec := range similarProjects {
require.Contains(t, vectors, similarVec.TitleV)
}
}
}
120 changes: 120 additions & 0 deletions systest/vector/vector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,123 @@ func TestVectorIndexOnVectorPredWithoutData(t *testing.T) {
_, err = gc.QueryMultipleVectorsUsingSimilarTo(vector, pred, 10)
require.NoError(t, err)
}

func TestVectorIndexDropPredicate(t *testing.T) {
conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour)
c, err := dgraphtest.NewLocalCluster(conf)

require.NoError(t, err)
defer func() { c.Cleanup(t.Failed()) }()
require.NoError(t, c.Start())

gc, cleanup, err := c.Client()
defer cleanup()
require.NoError(t, err)

require.NoError(t, gc.LoginIntoNamespace(context.Background(),
dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace))

hc, err := c.HTTPClient()
require.NoError(t, err)
require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser,
dgraphapi.DefaultPassword, x.GalaxyNamespace))

require.NoError(t, gc.SetupSchema(testSchema))
pred := "project_discription_v"
numVectors := 10

// add vectors
rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred)
mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true}
_, err = gc.Mutate(mu)
require.NoError(t, err)

require.NoError(t, gc.SetupSchema(testSchema))

for _, vect := range vectors {
similarVects, err := gc.QueryMultipleVectorsUsingSimilarTo(vect, pred, 2)
require.NoError(t, err)
require.LessOrEqual(t, 1, len(similarVects))
}

query := `{
vector(func: has(project_discription_v)) {
count(uid)
}
}`

result, err := gc.Query(query)
require.NoError(t, err)
require.JSONEq(t, fmt.Sprintf(`{"vector":[{"count":%v}]}`, numVectors), string(result.GetJson()))

// remove index from vector predicate
require.NoError(t, gc.SetupSchema(testSchemaWithoutIndex))

// drop predicate
op := &api.Operation{
DropAttr: pred,
}
require.NoError(t, gc.Alter(context.Background(), op))

// generate random vectors
rdfs, vectors = dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred)
mu = &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true}
_, err = gc.Mutate(mu)
require.NoError(t, err)

// add index back
require.NoError(t, gc.SetupSchema(testSchema))

result, err = gc.Query(query)
require.NoError(t, err)
require.JSONEq(t, fmt.Sprintf(`{"vector":[{"count":%v}]}`, numVectors), string(result.GetJson()))

for _, vect := range vectors {
similarVects, err := gc.QueryMultipleVectorsUsingSimilarTo(vect, pred, 2)
require.NoError(t, err)
require.Equal(t, 2, len(similarVects))
}
}

func TestVectorIndexWithoutSchema(t *testing.T) {
conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour)
c, err := dgraphtest.NewLocalCluster(conf)

require.NoError(t, err)
defer func() { c.Cleanup(t.Failed()) }()
require.NoError(t, c.Start())

gc, cleanup, err := c.Client()
defer cleanup()
require.NoError(t, err)

require.NoError(t, gc.LoginIntoNamespace(context.Background(),
dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace))

pred := "project_discription_v"
numVectors := 10

// add vectors
rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred)
mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true}
_, err = gc.Mutate(mu)
require.NoError(t, err)

require.NoError(t, gc.SetupSchema(testSchema))

for _, vect := range vectors {
similarVects, err := gc.QueryMultipleVectorsUsingSimilarTo(vect, pred, 2)
require.NoError(t, err)
require.Equal(t, 2, len(similarVects))
}

query := `{
vector(func: has(project_discription_v)) {
count(uid)
}
}`

result, err := gc.Query(query)
require.NoError(t, err)
require.JSONEq(t, fmt.Sprintf(`{"vector":[{"count":%v}]}`, numVectors), string(result.GetJson()))
}

0 comments on commit 0cdd0d7

Please sign in to comment.