Skip to content

Commit

Permalink
add index after adding data without vector schema for vector predicate
Browse files Browse the repository at this point in the history
  • Loading branch information
shivaji-kharse committed Dec 4, 2024
1 parent a184d8d commit dacfc78
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 10 deletions.
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ require (
github.com/IBM/sarama v1.43.3
github.com/Masterminds/semver/v3 v3.3.1
github.com/blevesearch/bleve/v2 v2.4.3
github.com/dgraph-io/badger/v4 v4.4.1-0.20241128130124-d13fdcc74a8d
github.com/dgraph-io/dgo/v240 v240.0.2-0.20241128130808-b8caf57545c5
github.com/dgraph-io/badger/v4 v4.5.0
github.com/dgraph-io/dgo/v240 v240.1.0
github.com/dgraph-io/gqlgen v0.13.2
github.com/dgraph-io/gqlparser/v2 v2.2.2
github.com/dgraph-io/graphql-transport-ws v0.0.0-20210511143556-2cef522f1f15
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgraph-io/badger/v4 v4.4.1-0.20241128130124-d13fdcc74a8d h1:LFusMx3rDeqG7/gxafAdYnmVk4U8aAiJWYpd8zp1scU=
github.com/dgraph-io/badger/v4 v4.4.1-0.20241128130124-d13fdcc74a8d/go.mod h1:ysgYmIeG8dS/E8kwxT7xHyc7MkmwNYLRoYnFbr7387A=
github.com/dgraph-io/dgo/v240 v240.0.2-0.20241128130808-b8caf57545c5 h1:/DTuYmMcQCq9jM5S6gQteKRZue0vI9W4P3Fs9SFeSGo=
github.com/dgraph-io/dgo/v240 v240.0.2-0.20241128130808-b8caf57545c5/go.mod h1:r8WASETKfodzKqThSAhhTNIzcEMychArKKlZXQufWuA=
github.com/dgraph-io/badger/v4 v4.5.0 h1:TeJE3I1pIWLBjYhIYCA1+uxrjWEoJXImFBMEBVSm16g=
github.com/dgraph-io/badger/v4 v4.5.0/go.mod h1:ysgYmIeG8dS/E8kwxT7xHyc7MkmwNYLRoYnFbr7387A=
github.com/dgraph-io/dgo/v240 v240.1.0 h1:xd8z9kEXDWOAblaLJ2HLg2tXD6ngMQwq3ehLUS7GKNg=
github.com/dgraph-io/dgo/v240 v240.1.0/go.mod h1:r8WASETKfodzKqThSAhhTNIzcEMychArKKlZXQufWuA=
github.com/dgraph-io/gqlgen v0.13.2 h1:TNhndk+eHKj5qE7BenKKSYdSIdOGhLqxR1rCiMso9KM=
github.com/dgraph-io/gqlgen v0.13.2/go.mod h1:iCOrOv9lngN7KAo+jMgvUPVDlYHdf7qDwsTkQby2Sis=
github.com/dgraph-io/gqlparser/v2 v2.1.1/go.mod h1:MYS4jppjyx8b9tuUtjV7jU1UFZK6P9fvO8TsIsQtRKU=
Expand Down
35 changes: 31 additions & 4 deletions posting/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,23 @@ func (txn *Txn) addIndexMutations(ctx context.Context, info *indexMutationInfo)
// Similarly, the current assumption is that we have at most one
// Vector Index, but this assumption may break later.
if info.op != pb.DirectedEdge_DEL &&
len(data) > 0 && data[0].Tid == types.VFloatID &&
len(info.factorySpecs) > 0 {
len(data) > 0 && len(info.factorySpecs) > 0 {
// retrieve vector from inUuid save as inVec
inVec := types.BytesAsFloatArray(data[0].Value.([]byte))

var inVec []float32
// Check if it is an add index mutation on the vector predicate.
// If yes, we will extract inVec from info.edge.Value, which we
// converted earlier. Otherwise, we will extract it from data.
if data[0].Tid.Enum().Type() == types.DefaultID.Enum().Type() {
info.edge.Value = info.val.Value.([]byte)
info.edge.ValueType = types.VFloatID.Enum()
if err := pl.addMutation(ctx, txn, info.edge); err != nil {
return []*pb.DirectedEdge{}, err
}
inVec = types.BytesAsFloatArray(info.val.Value.([]byte))
} else {
inVec = types.BytesAsFloatArray(data[0].Value.([]byte))
}
tc := hnsw.NewTxnCache(NewViTxn(txn), txn.StartTs)
indexer, err := info.factorySpecs[0].CreateIndex(attr)
if err != nil {
Expand Down Expand Up @@ -1393,6 +1406,20 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error {
Value: p.Value,
Tid: types.TypeID(p.ValType),
}
// If this is an addIndex mutation for the vector predicate,
// convert the default value type to the VFloat type.
if runForVectors && types.TypeID(p.ValType) != types.VFloatID {
sv, err := types.Convert(val, types.VFloatID)
if err != nil {
return err
}
b := types.ValueForType(types.BinaryID)
if err = types.Marshal(sv, &b); err != nil {
return err
}
val.Value = b.Value
val.Tid = types.VFloatID
}
edge.Lang = string(p.LangTag)

for {
Expand All @@ -1419,7 +1446,7 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error {
}
return edges, err
}
if len(factorySpecs) != 0 {
if runForVectors {
return builder.RunWithoutTemp(ctx)
}
return builder.Run(ctx)
Expand Down
56 changes: 56 additions & 0 deletions query/vector/vector_graphql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
package query

import (
"context"
"encoding/json"
"fmt"
"math/rand"
"testing"

"github.com/dgraph-io/dgo/v240/protos/api"
"github.com/dgraph-io/dgraph/v24/dgraphapi"
"github.com/dgraph-io/dgraph/v24/x"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -256,3 +259,56 @@ func testVectorGraphQlMutationAndQuery(t *testing.T, hc *dgraphapi.HTTPClient) {
}
}
}

func TestVectorIndexDropPredicate(t *testing.T) {
gc, cleanup, err := dc.Client()
require.NoError(t, err)
defer cleanup()
require.NoError(t, gc.LoginIntoNamespace(context.Background(),
dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace))

hc, err := dc.HTTPClient()
require.NoError(t, err)
require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser,
dgraphapi.DefaultPassword, x.GalaxyNamespace))

require.NoError(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "euclidean")))

var vectors [][]float32
numProjects := 100
projects := generateProjects(numProjects)
for _, project := range projects {
vectors = append(vectors, project.TitleV)
addProject(t, hc, project)
}

schemaWithoutIndex := `type Project {
id: ID!
title: String! @search(by: [exact])
title_v: [Float!] @embedding
} `

require.NoError(t, hc.UpdateGQLSchema(schemaWithoutIndex))

op := &api.Operation{
DropAttr: "title_v",
}
require.NoError(t, gc.Alter(context.Background(), op))

numProjects = 100
projects = generateProjects(numProjects)
for _, project := range projects {
vectors = append(vectors, project.TitleV)
addProject(t, hc, project)
}

require.NoError(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "euclidean")))

// similar to query
for _, project := range projects {
similarProjects := queryProjectsSimilarByEmbedding(t, hc, project.TitleV, numProjects)
for _, similarVec := range similarProjects {
require.Contains(t, vectors, similarVec.TitleV)
}
}
}
120 changes: 120 additions & 0 deletions systest/vector/vector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,123 @@ func TestVectorIndexOnVectorPredWithoutData(t *testing.T) {
_, err = gc.QueryMultipleVectorsUsingSimilarTo(vector, pred, 10)
require.NoError(t, err)
}

func TestVectorIndexDropPredicate(t *testing.T) {
conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour)
c, err := dgraphtest.NewLocalCluster(conf)

require.NoError(t, err)
defer func() { c.Cleanup(t.Failed()) }()
require.NoError(t, c.Start())

gc, cleanup, err := c.Client()
defer cleanup()
require.NoError(t, err)

require.NoError(t, gc.LoginIntoNamespace(context.Background(),
dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace))

hc, err := c.HTTPClient()
require.NoError(t, err)
require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser,
dgraphapi.DefaultPassword, x.GalaxyNamespace))

require.NoError(t, gc.SetupSchema(testSchema))
pred := "project_discription_v"
numVectors := 10

// add vectors
rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred)
mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true}
_, err = gc.Mutate(mu)
require.NoError(t, err)

require.NoError(t, gc.SetupSchema(testSchema))

for _, vect := range vectors {
similarVects, err := gc.QueryMultipleVectorsUsingSimilarTo(vect, pred, 2)
require.NoError(t, err)
require.LessOrEqual(t, 1, len(similarVects))
}

query := `{
vector(func: has(project_discription_v)) {
count(uid)
}
}`

result, err := gc.Query(query)
require.NoError(t, err)
require.JSONEq(t, fmt.Sprintf(`{"vector":[{"count":%v}]}`, numVectors), string(result.GetJson()))

// remove index from vector predicate
require.NoError(t, gc.SetupSchema(testSchemaWithoutIndex))

// drop predicate
op := &api.Operation{
DropAttr: pred,
}
require.NoError(t, gc.Alter(context.Background(), op))

// generate random vectors
rdfs, vectors = dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred)
mu = &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true}
_, err = gc.Mutate(mu)
require.NoError(t, err)

// add index back
require.NoError(t, gc.SetupSchema(testSchema))

result, err = gc.Query(query)
require.NoError(t, err)
require.JSONEq(t, fmt.Sprintf(`{"vector":[{"count":%v}]}`, numVectors), string(result.GetJson()))

for _, vect := range vectors {
similarVects, err := gc.QueryMultipleVectorsUsingSimilarTo(vect, pred, 2)
require.NoError(t, err)
require.Equal(t, 2, len(similarVects))
}
}

func TestVectorIndexWithoutSchema(t *testing.T) {
conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour)
c, err := dgraphtest.NewLocalCluster(conf)

require.NoError(t, err)
defer func() { c.Cleanup(t.Failed()) }()
require.NoError(t, c.Start())

gc, cleanup, err := c.Client()
defer cleanup()
require.NoError(t, err)

require.NoError(t, gc.LoginIntoNamespace(context.Background(),
dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace))

pred := "project_discription_v"
numVectors := 10

// add vectors
rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred)
mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true}
_, err = gc.Mutate(mu)
require.NoError(t, err)

require.NoError(t, gc.SetupSchema(testSchema))

for _, vect := range vectors {
similarVects, err := gc.QueryMultipleVectorsUsingSimilarTo(vect, pred, 2)
require.NoError(t, err)
require.Equal(t, 2, len(similarVects))
}

query := `{
vector(func: has(project_discription_v)) {
count(uid)
}
}`

result, err := gc.Query(query)
require.NoError(t, err)
require.JSONEq(t, fmt.Sprintf(`{"vector":[{"count":%v}]}`, numVectors), string(result.GetJson()))
}

0 comments on commit dacfc78

Please sign in to comment.