diff --git a/go.mod b/go.mod index 4bf6354a0fb..dc5f8189219 100644 --- a/go.mod +++ b/go.mod @@ -10,8 +10,8 @@ require ( github.com/IBM/sarama v1.43.3 github.com/Masterminds/semver/v3 v3.3.1 github.com/blevesearch/bleve/v2 v2.4.3 - github.com/dgraph-io/badger/v4 v4.4.1-0.20241128130124-d13fdcc74a8d - github.com/dgraph-io/dgo/v240 v240.0.2-0.20241128130808-b8caf57545c5 + github.com/dgraph-io/badger/v4 v4.5.0 + github.com/dgraph-io/dgo/v240 v240.1.0 github.com/dgraph-io/gqlgen v0.13.2 github.com/dgraph-io/gqlparser/v2 v2.2.2 github.com/dgraph-io/graphql-transport-ws v0.0.0-20210511143556-2cef522f1f15 diff --git a/go.sum b/go.sum index 15b0b80a0f1..570a13990f4 100644 --- a/go.sum +++ b/go.sum @@ -134,10 +134,10 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgraph-io/badger/v4 v4.4.1-0.20241128130124-d13fdcc74a8d h1:LFusMx3rDeqG7/gxafAdYnmVk4U8aAiJWYpd8zp1scU= -github.com/dgraph-io/badger/v4 v4.4.1-0.20241128130124-d13fdcc74a8d/go.mod h1:ysgYmIeG8dS/E8kwxT7xHyc7MkmwNYLRoYnFbr7387A= -github.com/dgraph-io/dgo/v240 v240.0.2-0.20241128130808-b8caf57545c5 h1:/DTuYmMcQCq9jM5S6gQteKRZue0vI9W4P3Fs9SFeSGo= -github.com/dgraph-io/dgo/v240 v240.0.2-0.20241128130808-b8caf57545c5/go.mod h1:r8WASETKfodzKqThSAhhTNIzcEMychArKKlZXQufWuA= +github.com/dgraph-io/badger/v4 v4.5.0 h1:TeJE3I1pIWLBjYhIYCA1+uxrjWEoJXImFBMEBVSm16g= +github.com/dgraph-io/badger/v4 v4.5.0/go.mod h1:ysgYmIeG8dS/E8kwxT7xHyc7MkmwNYLRoYnFbr7387A= +github.com/dgraph-io/dgo/v240 v240.1.0 h1:xd8z9kEXDWOAblaLJ2HLg2tXD6ngMQwq3ehLUS7GKNg= +github.com/dgraph-io/dgo/v240 v240.1.0/go.mod h1:r8WASETKfodzKqThSAhhTNIzcEMychArKKlZXQufWuA= github.com/dgraph-io/gqlgen v0.13.2 h1:TNhndk+eHKj5qE7BenKKSYdSIdOGhLqxR1rCiMso9KM= github.com/dgraph-io/gqlgen v0.13.2/go.mod h1:iCOrOv9lngN7KAo+jMgvUPVDlYHdf7qDwsTkQby2Sis= github.com/dgraph-io/gqlparser/v2 v2.1.1/go.mod h1:MYS4jppjyx8b9tuUtjV7jU1UFZK6P9fvO8TsIsQtRKU= diff --git a/posting/index.go b/posting/index.go index 8e7ffa5472a..e0ea946a5d0 100644 --- a/posting/index.go +++ b/posting/index.go @@ -168,10 +168,18 @@ func (txn *Txn) addIndexMutations(ctx context.Context, info *indexMutationInfo) // Similarly, the current assumption is that we have at most one // Vector Index, but this assumption may break later. if info.op != pb.DirectedEdge_DEL && - len(data) > 0 && data[0].Tid == types.VFloatID && - len(info.factorySpecs) > 0 { + len(data) > 0 && len(info.factorySpecs) > 0 { // retrieve vector from inUuid save as inVec - inVec := types.BytesAsFloatArray(data[0].Value.([]byte)) + + var inVec []float32 + // Check if it is an add index mutation on the vector predicate. + // If yes, we will extract inVec from info.edge.Value, which we + // converted earlier. Otherwise, we will extract it from data. + if data[0].Tid.Enum().Type() == types.DefaultID.Enum().Type() { + inVec = types.BytesAsFloatArray(info.val.Value.([]byte)) + } else { + inVec = types.BytesAsFloatArray(data[0].Value.([]byte)) + } tc := hnsw.NewTxnCache(NewViTxn(txn), txn.StartTs) indexer, err := info.factorySpecs[0].CreateIndex(attr) if err != nil { @@ -1392,6 +1400,19 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error { val := types.Val{ Value: p.Value, Tid: types.TypeID(p.ValType), + } // If this is an addIndex mutation for the vector predicate, + // convert the default value type to the VFloat type. + if runForVectors && types.TypeID(p.ValType) != types.VFloatID { + sv, err := types.Convert(val, types.VFloatID) + if err != nil { + return err + } + b := types.ValueForType(types.BinaryID) + if err = types.Marshal(sv, &b); err != nil { + return err + } + val.Value = b.Value + val.Tid = types.VFloatID } edge.Lang = string(p.LangTag) @@ -1419,7 +1440,7 @@ func rebuildTokIndex(ctx context.Context, rb *IndexRebuild) error { } return edges, err } - if len(factorySpecs) != 0 { + if runForVectors { return builder.RunWithoutTemp(ctx) } return builder.Run(ctx) diff --git a/query/vector/vector_graphql_test.go b/query/vector/vector_graphql_test.go index 4f496ea1db1..73241ff0b70 100644 --- a/query/vector/vector_graphql_test.go +++ b/query/vector/vector_graphql_test.go @@ -19,12 +19,15 @@ package query import ( + "context" "encoding/json" "fmt" "math/rand" "testing" + "github.com/dgraph-io/dgo/v240/protos/api" "github.com/dgraph-io/dgraph/v24/dgraphapi" + "github.com/dgraph-io/dgraph/v24/x" "github.com/stretchr/testify/require" ) @@ -256,3 +259,56 @@ func testVectorGraphQlMutationAndQuery(t *testing.T, hc *dgraphapi.HTTPClient) { } } } + +func TestVectorIndexDropPredicate(t *testing.T) { + gc, cleanup, err := dc.Client() + require.NoError(t, err) + defer cleanup() + require.NoError(t, gc.LoginIntoNamespace(context.Background(), + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + hc, err := dc.HTTPClient() + require.NoError(t, err) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + require.NoError(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "euclidean"))) + + var vectors [][]float32 + numProjects := 100 + projects := generateProjects(numProjects) + for _, project := range projects { + vectors = append(vectors, project.TitleV) + addProject(t, hc, project) + } + + schemaWithoutIndex := `type Project { + id: ID! + title: String! @search(by: [exact]) + title_v: [Float!] @embedding + } ` + + require.NoError(t, hc.UpdateGQLSchema(schemaWithoutIndex)) + + op := &api.Operation{ + DropAttr: "title_v", + } + require.NoError(t, gc.Alter(context.Background(), op)) + + numProjects = 100 + projects = generateProjects(numProjects) + for _, project := range projects { + vectors = append(vectors, project.TitleV) + addProject(t, hc, project) + } + + require.NoError(t, hc.UpdateGQLSchema(fmt.Sprintf(graphQLVectorSchema, "euclidean"))) + + // similar to query + for _, project := range projects { + similarProjects := queryProjectsSimilarByEmbedding(t, hc, project.TitleV, numProjects) + for _, similarVec := range similarProjects { + require.Contains(t, vectors, similarVec.TitleV) + } + } +} diff --git a/systest/vector/vector_test.go b/systest/vector/vector_test.go index 87e7b661a87..edd5be54889 100644 --- a/systest/vector/vector_test.go +++ b/systest/vector/vector_test.go @@ -315,3 +315,123 @@ func TestVectorIndexOnVectorPredWithoutData(t *testing.T) { _, err = gc.QueryMultipleVectorsUsingSimilarTo(vector, pred, 10) require.NoError(t, err) } + +func TestVectorIndexDropPredicate(t *testing.T) { + conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) + c, err := dgraphtest.NewLocalCluster(conf) + + require.NoError(t, err) + defer func() { c.Cleanup(t.Failed()) }() + require.NoError(t, c.Start()) + + gc, cleanup, err := c.Client() + defer cleanup() + require.NoError(t, err) + + require.NoError(t, gc.LoginIntoNamespace(context.Background(), + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + hc, err := c.HTTPClient() + require.NoError(t, err) + require.NoError(t, hc.LoginIntoNamespace(dgraphapi.DefaultUser, + dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + require.NoError(t, gc.SetupSchema(testSchema)) + pred := "project_discription_v" + numVectors := 10 + + // add vectors + rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred) + mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} + _, err = gc.Mutate(mu) + require.NoError(t, err) + + require.NoError(t, gc.SetupSchema(testSchema)) + + for _, vect := range vectors { + similarVects, err := gc.QueryMultipleVectorsUsingSimilarTo(vect, pred, 2) + require.NoError(t, err) + require.LessOrEqual(t, 1, len(similarVects)) + } + + query := `{ + vector(func: has(project_discription_v)) { + count(uid) + } + }` + + result, err := gc.Query(query) + require.NoError(t, err) + require.JSONEq(t, fmt.Sprintf(`{"vector":[{"count":%v}]}`, numVectors), string(result.GetJson())) + + // remove index from vector predicate + require.NoError(t, gc.SetupSchema(testSchemaWithoutIndex)) + + // drop predicate + op := &api.Operation{ + DropAttr: pred, + } + require.NoError(t, gc.Alter(context.Background(), op)) + + // generate random vectors + rdfs, vectors = dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred) + mu = &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} + _, err = gc.Mutate(mu) + require.NoError(t, err) + + // add index back + require.NoError(t, gc.SetupSchema(testSchema)) + + result, err = gc.Query(query) + require.NoError(t, err) + require.JSONEq(t, fmt.Sprintf(`{"vector":[{"count":%v}]}`, numVectors), string(result.GetJson())) + + for _, vect := range vectors { + similarVects, err := gc.QueryMultipleVectorsUsingSimilarTo(vect, pred, 2) + require.NoError(t, err) + require.Equal(t, 2, len(similarVects)) + } +} + +func TestVectorIndexWithoutSchema(t *testing.T) { + conf := dgraphtest.NewClusterConfig().WithNumAlphas(1).WithNumZeros(1).WithReplicas(1).WithACL(time.Hour) + c, err := dgraphtest.NewLocalCluster(conf) + + require.NoError(t, err) + defer func() { c.Cleanup(t.Failed()) }() + require.NoError(t, c.Start()) + + gc, cleanup, err := c.Client() + defer cleanup() + require.NoError(t, err) + + require.NoError(t, gc.LoginIntoNamespace(context.Background(), + dgraphapi.DefaultUser, dgraphapi.DefaultPassword, x.GalaxyNamespace)) + + pred := "project_discription_v" + numVectors := 10 + + // add vectors + rdfs, vectors := dgraphapi.GenerateRandomVectors(0, numVectors, 100, pred) + mu := &api.Mutation{SetNquads: []byte(rdfs), CommitNow: true} + _, err = gc.Mutate(mu) + require.NoError(t, err) + + require.NoError(t, gc.SetupSchema(testSchema)) + + for _, vect := range vectors { + similarVects, err := gc.QueryMultipleVectorsUsingSimilarTo(vect, pred, 2) + require.NoError(t, err) + require.Equal(t, 2, len(similarVects)) + } + + query := `{ + vector(func: has(project_discription_v)) { + count(uid) + } + }` + + result, err := gc.Query(query) + require.NoError(t, err) + require.JSONEq(t, fmt.Sprintf(`{"vector":[{"count":%v}]}`, numVectors), string(result.GetJson())) +}