From 4d90e6ece9eda52225220d92a507836898a10ed2 Mon Sep 17 00:00:00 2001 From: Fred Carle Date: Wed, 19 Feb 2025 17:13:36 -0500 Subject: [PATCH] feat: Add cosine similarity query (#3464) ## Relevant issue(s) Resolves #3349 ## Description This PR adds the possibility to calculate the cosine similarity between a vector field and a given vector. To achieve this we added the `_similarity` system field which take a target field (part of the parent object) and vector as parameter. ```gql query { User{ _similarity(pointsList: {vector: [1, 2, 0]}) } } ``` Note that the added code to mapper and planner is more of a "bolt on" addition given the current state of that part of the code base. A refactor is expected in the future. Future work will allow giving a `content` parameter instead of the `vector` if the target field has embedding generation configured. This will enable out-of-the-box RAG queries. --- client/request/consts.go | 42 +- client/request/similarity.go | 26 + internal/planner/errors.go | 9 + internal/planner/mapper/mapper.go | 20 + internal/planner/mapper/similarity.go | 26 + internal/planner/operations.go | 1 + internal/planner/planner.go | 9 + internal/planner/scan.go | 2 + internal/planner/select.go | 39 +- internal/planner/similarity.go | 204 ++++++++ internal/request/graphql/parser/query.go | 26 + internal/request/graphql/parser/request.go | 6 + internal/request/graphql/schema/generate.go | 59 +++ .../request/graphql/schema/types/types.go | 2 + .../explain/execute/with_similarity_test.go | 74 +++ .../query/simple/with_similarity_test.go | 477 ++++++++++++++++++ tests/integration/schema/default_fields.go | 10 + tests/integration/schema/similarity_test.go | 322 ++++++++++++ 18 files changed, 1320 insertions(+), 34 deletions(-) create mode 100644 client/request/similarity.go create mode 100644 internal/planner/mapper/similarity.go create mode 100644 internal/planner/similarity.go create mode 100644 tests/integration/explain/execute/with_similarity_test.go create mode 100644 tests/integration/query/simple/with_similarity_test.go create mode 100644 tests/integration/schema/similarity_test.go diff --git a/client/request/consts.go b/client/request/consts.go index ce31c95133..3b6edeae61 100644 --- a/client/request/consts.go +++ b/client/request/consts.go @@ -39,16 +39,17 @@ const ( DocIDArgName = "docID" - AverageFieldName = "_avg" - CountFieldName = "_count" - DocIDFieldName = "_docID" - GroupFieldName = "_group" - DeletedFieldName = "_deleted" - SumFieldName = "_sum" - VersionFieldName = "_version" - MaxFieldName = "_max" - MinFieldName = "_min" - AliasFieldName = "_alias" + AverageFieldName = "_avg" + CountFieldName = "_count" + DocIDFieldName = "_docID" + GroupFieldName = "_group" + DeletedFieldName = "_deleted" + SumFieldName = "_sum" + VersionFieldName = "_version" + MaxFieldName = "_max" + MinFieldName = "_min" + AliasFieldName = "_alias" + SimilarityFieldName = "_similarity" // New generated document id from a backed up document, // which might have a different _docID originally. @@ -104,16 +105,17 @@ var ( } ReservedFields = map[string]struct{}{ - TypeNameFieldName: {}, - VersionFieldName: {}, - GroupFieldName: {}, - CountFieldName: {}, - SumFieldName: {}, - AverageFieldName: {}, - DocIDFieldName: {}, - DeletedFieldName: {}, - MaxFieldName: {}, - MinFieldName: {}, + TypeNameFieldName: {}, + VersionFieldName: {}, + GroupFieldName: {}, + CountFieldName: {}, + SumFieldName: {}, + AverageFieldName: {}, + DocIDFieldName: {}, + DeletedFieldName: {}, + MaxFieldName: {}, + MinFieldName: {}, + SimilarityFieldName: {}, } Aggregates = map[string]struct{}{ diff --git a/client/request/similarity.go b/client/request/similarity.go new file mode 100644 index 0000000000..04d2698dd4 --- /dev/null +++ b/client/request/similarity.go @@ -0,0 +1,26 @@ +// Copyright 2025 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package request + +// Similarity is a functional field that defines the +// parameters to calculate the cosine similarity between two vectors. +type Similarity struct { + Field + // Vector contains the vector to compare the target field to. + // + // It will be of type Int, Float32 or Float64. It must be the same type and length as Target. + Vector any + + // Target is the field in the host object that we will compare the the vector to. + // + // It must be a field of type Int, Float32 or Float64. It must be the same type and length as Vector. + Target string +} diff --git a/internal/planner/errors.go b/internal/planner/errors.go index 295196d4d5..671e1db524 100644 --- a/internal/planner/errors.go +++ b/internal/planner/errors.go @@ -35,6 +35,7 @@ var ( ErrUnknownRelationType = errors.New("failed sub selection, unknown relation type") ErrUnknownExplainRequestType = errors.New("can not explain request of unknown type") ErrUpsertMultipleDocuments = errors.New("cannot upsert multiple matching documents") + ErrMismatchLengthOnSimilarity = errors.New("source and vector must be of the same length") ) func NewErrUnknownDependency(name string) error { @@ -52,3 +53,11 @@ func NewErrFailedToCollectExecExplainInfo(inner error) error { func NewErrSubTypeInit(inner error) error { return errors.Wrap(errSubTypeInit, inner) } + +func NewErrMismatchLengthOnSimilarity(source, vector int) error { + return errors.WithStack( + ErrMismatchLengthOnSimilarity, + errors.NewKV("Source", source), + errors.NewKV("Vector", vector), + ) +} diff --git a/internal/planner/mapper/mapper.go b/internal/planner/mapper/mapper.go index 826c29ffbc..fc80e6b245 100644 --- a/internal/planner/mapper/mapper.go +++ b/internal/planner/mapper/mapper.go @@ -803,6 +803,26 @@ func getRequestables( Key: getRenderKey(&f.Field), }) + mapping.Add(index, f.Name) + case *request.Similarity: + index := mapping.GetNextIndex() + fields = append(fields, &Similarity{ + Field: Field{ + Index: index, + Name: f.Name, + }, + Vector: f.Vector, + SimilarityTarget: Targetable{ + Field: Field{ + Index: mapping.FirstIndexOfName(f.Target), + Name: f.Target, + }, + }, + }) + mapping.RenderKeys = append(mapping.RenderKeys, core.RenderKey{ + Index: index, + Key: getRenderKey(&f.Field), + }) mapping.Add(index, f.Name) default: return nil, nil, client.NewErrUnhandledType("field", field) diff --git a/internal/planner/mapper/similarity.go b/internal/planner/mapper/similarity.go new file mode 100644 index 0000000000..21fd6758ff --- /dev/null +++ b/internal/planner/mapper/similarity.go @@ -0,0 +1,26 @@ +// Copyright 2025 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package mapper + +import "github.com/sourcenetwork/defradb/internal/core" + +// Similarity represents an cosine similarity operation definition. +type Similarity struct { + Field + // The mapping of this aggregate's parent/host. + *core.DocumentMapping + + // The targetted field for the cosine similarity + SimilarityTarget Targetable + + // The vector to compare the target field to. + Vector any +} diff --git a/internal/planner/operations.go b/internal/planner/operations.go index 73fe1450bb..4bbca72844 100644 --- a/internal/planner/operations.go +++ b/internal/planner/operations.go @@ -37,6 +37,7 @@ var ( _ planNode = (*valuesNode)(nil) _ planNode = (*viewNode)(nil) _ planNode = (*lensNode)(nil) + _ planNode = (*similarityNode)(nil) _ MultiNode = (*parallelNode)(nil) _ MultiNode = (*topLevelNode)(nil) diff --git a/internal/planner/planner.go b/internal/planner/planner.go index 8d5c7cf052..f9e0a6c765 100644 --- a/internal/planner/planner.go +++ b/internal/planner/planner.go @@ -231,6 +231,8 @@ func (p *Planner) expandSelectTopNodePlan(plan *selectTopNode, parentPlan *selec p.expandLimitPlan(plan, parentPlan) } + p.expandSimilarityPlans(plan) + return nil } @@ -249,6 +251,13 @@ func (p *Planner) expandAggregatePlans(plan *selectTopNode) { } } +func (p *Planner) expandSimilarityPlans(plan *selectTopNode) { + for _, sim := range plan.similarity { + sim.SetPlan(plan.planNode) + plan.planNode = sim + } +} + func (p *Planner) expandMultiNode(multiNode MultiNode, parentPlan *selectTopNode) error { for _, child := range multiNode.Children() { if err := p.expandPlan(child, parentPlan); err != nil { diff --git a/internal/planner/scan.go b/internal/planner/scan.go index 161a28f2f4..816074b9f4 100644 --- a/internal/planner/scan.go +++ b/internal/planner/scan.go @@ -121,6 +121,8 @@ func (n *scanNode) initFields(fields []mapper.Requestable) error { n.tryAddFieldWithName(target.Field.Name) } } + case *mapper.Similarity: + n.tryAddFieldWithName(requestable.SimilarityTarget.Name) } } return nil diff --git a/internal/planner/select.go b/internal/planner/select.go index 8d41d1d251..c2419c8d7e 100644 --- a/internal/planner/select.go +++ b/internal/planner/select.go @@ -61,6 +61,10 @@ type selectTopNode struct { // selectNode is used pre-wiring of the plan (before expansion and all). selectNode *selectNode + // This is added temporarity until Planner is refactored + // https://github.com/sourcenetwork/defradb/issues/3467 + similarity []*similarityNode + // plan is the top of the plan graph (the wired and finalized plan graph). planNode planNode } @@ -236,14 +240,14 @@ func (n *selectNode) Explain(explainType request.ExplainType) (map[string]any, e // creating scanNodes, typeIndexJoinNodes, and splitting // the necessary filters. Its designed to work with the // planner.Select construction call. -func (n *selectNode) initSource() ([]aggregateNode, error) { +func (n *selectNode) initSource() ([]aggregateNode, []*similarityNode, error) { if n.selectReq.CollectionName == "" { n.selectReq.CollectionName = n.selectReq.Name } sourcePlan, err := n.planner.getSource(n.selectReq) if err != nil { - return nil, err + return nil, nil, err } n.source = sourcePlan.plan n.origSource = sourcePlan.plan @@ -264,7 +268,7 @@ func (n *selectNode) initSource() ([]aggregateNode, error) { if n.selectReq.Cid.HasValue() { c, err := cid.Decode(n.selectReq.Cid.Value()) if err != nil { - return nil, err + return nil, nil, err } // This exists because the fetcher interface demands a []Prefixes, yet the versioned @@ -293,9 +297,9 @@ func (n *selectNode) initSource() ([]aggregateNode, error) { } } - aggregates, err := n.initFields(n.selectReq) + aggregates, similarity, err := n.initFields(n.selectReq) if err != nil { - return nil, err + return nil, nil, err } if isScanNode { @@ -303,7 +307,7 @@ func (n *selectNode) initSource() ([]aggregateNode, error) { origScan.initFetcher(n.selectReq.Cid) } - return aggregates, nil + return aggregates, similarity, nil } func findIndexByFilteringField(scanNode *scanNode) immutable.Option[client.IndexDescription] { @@ -354,8 +358,9 @@ func findIndexByFieldName(col client.Collection, fieldName string) immutable.Opt return immutable.None[client.IndexDescription]() } -func (n *selectNode) initFields(selectReq *mapper.Select) ([]aggregateNode, error) { +func (n *selectNode) initFields(selectReq *mapper.Select) ([]aggregateNode, []*similarityNode, error) { aggregates := []aggregateNode{} + similarity := []*similarityNode{} // loop over the sub type // at the moment, we're only testing a single sub selection for _, field := range selectReq.Fields { @@ -381,7 +386,7 @@ func (n *selectNode) initFields(selectReq *mapper.Select) ([]aggregateNode, erro } if aggregateError != nil { - return nil, aggregateError + return nil, nil, aggregateError } if plan != nil { @@ -408,11 +413,11 @@ func (n *selectNode) initFields(selectReq *mapper.Select) ([]aggregateNode, erro commitPlan := n.planner.DAGScan(commitSlct) if err := n.addSubPlan(f.Index, commitPlan); err != nil { - return nil, err + return nil, nil, err } } else if f.Name == request.GroupFieldName { if selectReq.GroupBy == nil { - return nil, ErrGroupOutsideOfGroupBy + return nil, nil, ErrGroupOutsideOfGroupBy } n.groupSelects = append(n.groupSelects, f) } else if f.Name == request.LinksFieldName && @@ -427,13 +432,17 @@ func (n *selectNode) initFields(selectReq *mapper.Select) ([]aggregateNode, erro // a traditional join here err := n.addTypeIndexJoin(f) if err != nil { - return nil, err + return nil, nil, err } } + case *mapper.Similarity: + var simFilter *mapper.Filter + selectReq.Filter, simFilter = filter.SplitByFields(selectReq.Filter, f.Field) + similarity = append(similarity, n.planner.Similarity(f, simFilter)) } } - return aggregates, nil + return aggregates, similarity, nil } func (n *selectNode) addTypeIndexJoin(subSelect *mapper.Select) error { @@ -482,7 +491,7 @@ func (p *Planner) SelectFromSource( s.collection = col } - aggregates, err := s.initFields(selectReq) + aggregates, similarity, err := s.initFields(selectReq) if err != nil { return nil, err } @@ -508,6 +517,7 @@ func (p *Planner) SelectFromSource( order: orderPlan, group: groupPlan, aggregates: aggregates, + similarity: similarity, docMapper: docMapper{selectReq.DocumentMapping}, } return top, nil @@ -526,7 +536,7 @@ func (p *Planner) Select(selectReq *mapper.Select) (planNode, error) { orderBy := selectReq.OrderBy groupBy := selectReq.GroupBy - aggregates, err := s.initSource() + aggregates, similarity, err := s.initSource() if err != nil { return nil, err } @@ -552,6 +562,7 @@ func (p *Planner) Select(selectReq *mapper.Select) (planNode, error) { order: orderPlan, group: groupPlan, aggregates: aggregates, + similarity: similarity, docMapper: docMapper{selectReq.DocumentMapping}, } return top, nil diff --git a/internal/planner/similarity.go b/internal/planner/similarity.go new file mode 100644 index 0000000000..33b23bfc1b --- /dev/null +++ b/internal/planner/similarity.go @@ -0,0 +1,204 @@ +// Copyright 2025 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package planner + +import ( + "github.com/sourcenetwork/defradb/client/request" + "github.com/sourcenetwork/defradb/internal/keys" + "github.com/sourcenetwork/defradb/internal/planner/mapper" +) + +type similarityNode struct { + documentIterator + docMapper + + p *Planner + plan planNode + + virtualFieldIndex int + target mapper.Targetable + vector any + execInfo similarityExecInfo + simFilter *mapper.Filter +} + +type similarityExecInfo struct { + // Total number of times similarityNode was executed. + iterations uint64 +} + +func (p *Planner) Similarity( + field *mapper.Similarity, + filter *mapper.Filter, +) *similarityNode { + return &similarityNode{ + p: p, + virtualFieldIndex: field.Index, + vector: field.Vector, + target: field.SimilarityTarget, + simFilter: filter, + docMapper: docMapper{field.DocumentMapping}, + } +} + +func (n *similarityNode) Kind() string { + return "similarityNode" +} + +func (n *similarityNode) Init() error { + return n.plan.Init() +} + +func (n *similarityNode) Start() error { return n.plan.Start() } + +func (n *similarityNode) Prefixes(prefixes []keys.Walkable) { n.plan.Prefixes(prefixes) } + +func (n *similarityNode) Close() error { return n.plan.Close() } + +func (n *similarityNode) Source() planNode { return n.plan } + +func (n *similarityNode) simpleExplain() (map[string]any, error) { + simpleExplainMap := map[string]any{} + + simpleExplainMap["vector"] = n.vector + simpleExplainMap["target"] = n.target.Field.Name + + return map[string]any{ + sourcesLabel: simpleExplainMap, + }, nil +} + +// Explain method returns a map containing all attributes of this node that +// are to be explained, subscribes / opts-in this node to be an explainablePlanNode. +func (n *similarityNode) Explain(explainType request.ExplainType) (map[string]any, error) { + switch explainType { + case request.SimpleExplain: + return n.simpleExplain() + + case request.ExecuteExplain: + return map[string]any{ + "iterations": n.execInfo.iterations, + }, nil + + default: + return nil, ErrUnknownExplainRequestType + } +} + +func (n *similarityNode) Next() (bool, error) { + for { + n.execInfo.iterations++ + + hasNext, err := n.plan.Next() + if err != nil || !hasNext { + return hasNext, err + } + + n.currentValue = n.plan.Value() + + similarity := float64(0) + + child := n.currentValue.Fields[n.target.Index] + switch childCollection := child.(type) { + case []int64: + vector := convertArray[int64](n.vector) + result, err := cosineSimilarity(childCollection, vector) + if err != nil { + return false, err + } + similarity = float64(result) + case []float32: + vector := convertArray[float32](n.vector) + result, err := cosineSimilarity(childCollection, vector) + if err != nil { + return false, err + } + similarity = float64(result) + case []float64: + vector := convertArray[float64](n.vector) + result, err := cosineSimilarity(childCollection, vector) + if err != nil { + return false, err + } + similarity = float64(result) + } + + n.currentValue.Fields[n.virtualFieldIndex] = similarity + + passes, err := mapper.RunFilter(n.currentValue, n.simFilter) + if err != nil { + return false, err + } + if !passes { + continue + } + return true, nil + } +} + +func (n *similarityNode) SetPlan(p planNode) { n.plan = p } + +func cosineSimilarity[T number]( + source []T, + vector []T, +) (T, error) { + var value T + if len(source) != len(vector) { + return value, NewErrMismatchLengthOnSimilarity(len(source), len(vector)) + } + for i := range source { + value += vector[i] * source[i] + } + return value, nil +} + +func convertArray[T int64 | float32 | float64](val any) []T { + switch typedVal := val.(type) { + case []any: + newArr := make([]T, len(typedVal)) + for i, v := range typedVal { + newArr[i] = convertToType[T](v) + } + return newArr + } + return nil +} + +func convertToType[T int64 | float32 | float64](val any) T { + switch v := val.(type) { + case int64: + return T(v) + case float64: + return T(v) + case float32: + return T(v) + case int8: + return T(v) + case int16: + return T(v) + case int32: + return T(v) + case int: + return T(v) + case uint8: + return T(v) + case uint16: + return T(v) + case uint32: + return T(v) + case uint64: + return T(v) + case uint: + return T(v) + } + var t T + return t +} diff --git a/internal/request/graphql/parser/query.go b/internal/request/graphql/parser/query.go index a15cff0e1f..334389f5d7 100644 --- a/internal/request/graphql/parser/query.go +++ b/internal/request/graphql/parser/query.go @@ -16,6 +16,7 @@ import ( "github.com/sourcenetwork/immutable" "github.com/sourcenetwork/defradb/client/request" + "github.com/sourcenetwork/defradb/internal/request/graphql/schema/types" ) // parseQueryOperationDefinition parses the individual GraphQL @@ -223,6 +224,31 @@ func parseAggregate( }, nil } +func parseSimilarity( + exe *gql.ExecutionContext, + parent *gql.Object, + field *ast.Field, +) (*request.Similarity, error) { + fieldDef := gql.GetFieldDef(exe.Schema, parent, field.Name.Value) + arguments := gql.GetArgumentValues(fieldDef.Args, field.Arguments, exe.VariableValues) + var target string + var vector any + for _, argument := range field.Arguments { + target = argument.Name.Value + v := arguments[target].(map[string]any) + vector = v[types.SimilarityArgVector] + } + + return &request.Similarity{ + Field: request.Field{ + Name: field.Name.Value, + Alias: getFieldAlias(field), + }, + Target: target, + Vector: vector, + }, nil +} + func parseAggregateTarget( hostName string, arguments map[string]any, diff --git a/internal/request/graphql/parser/request.go b/internal/request/graphql/parser/request.go index 7545957ae3..72e84f3594 100644 --- a/internal/request/graphql/parser/request.go +++ b/internal/request/graphql/parser/request.go @@ -207,6 +207,12 @@ func parseSelectFields( return nil, err } selection = s + } else if node.Name.Value == request.SimilarityFieldName { + s, err := parseSimilarity(exe, parent, node) + if err != nil { + return nil, err + } + selection = s } else if node.SelectionSet == nil { // regular field selection = parseField(node) } else { // sub type with extra fields diff --git a/internal/request/graphql/schema/generate.go b/internal/request/graphql/schema/generate.go index 09a268ba56..85c2bfff73 100644 --- a/internal/request/graphql/schema/generate.go +++ b/internal/request/graphql/schema/generate.go @@ -145,6 +145,16 @@ func (g *Generator) generate(ctx context.Context, collections []client.Collectio if err := g.genAggregateFields(); err != nil { return nil, err } + + // resolve types + if err := g.manager.ResolveTypes(); err != nil { + return nil, err + } + + if err := g.genVectorOpsFields(); err != nil { + return nil, err + } + // resolve types if err := g.manager.ResolveTypes(); err != nil { return nil, err @@ -846,6 +856,40 @@ func (g *Generator) genAverageFieldConfig(obj *gql.Object) (gql.Field, error) { return field, nil } +func (g *Generator) genSimilarityFieldConfig(obj *gql.Object) (gql.Field, error) { + field := gql.Field{ + Name: request.SimilarityFieldName, + Description: "Returns the cosine similarity between the specified field and the provided vector.", + Type: gql.Float, + Args: gql.FieldConfigArgument{}, + } + + for _, objectField := range obj.Fields() { + listType, isList := objectField.Type.(*gql.List) + if !isList || !isNumericArray(listType) { + continue + } + + inputObject := gql.NewInputObject(gql.InputObjectConfig{ + Name: genSimilaritySelectorName(obj.Name(), objectField.Name), + Description: objectField.Description, + Fields: gql.InputObjectConfigFieldMap{ + schemaTypes.SimilarityArgVector: &gql.InputObjectFieldConfig{ + Type: gql.NewNonNull(gql.NewList(listType.OfType)), + Description: "A vector of the same type as the field to compute the cosine similarity with.", + }, + }, + }) + err := g.appendIfNotExists(inputObject) + if err != nil { + return gql.Field{}, err + } + field.Args[objectField.Name] = schemaTypes.NewArgConfig(inputObject, objectField.Description) + } + + return field, nil +} + func (g *Generator) getNumericFields(obj *gql.Object) map[string]gql.Type { fieldTypes := map[string]gql.Type{} for _, field := range obj.Fields() { @@ -914,6 +958,10 @@ func genNumericInlineArraySelectorName(hostName string, fieldName string) string return fmt.Sprintf("%s__%s__%s", hostName, fieldName, "NumericSelector") } +func genSimilaritySelectorName(hostName string, fieldName string) string { + return fmt.Sprintf("%s__%s__%s", hostName, fieldName, "SimilaritySelector") +} + func (g *Generator) genCountBaseArgInputs(obj *gql.Object) *gql.InputObject { countableObject := gql.NewInputObject(gql.InputObjectConfig{ Name: genObjectCountName(obj.Name()), @@ -1360,6 +1408,17 @@ func (g *Generator) genTypeQueryableFieldList( return field } +func (g *Generator) genVectorOpsFields() error { + for _, t := range g.typeDefs { + similarityField, err := g.genSimilarityFieldConfig(t) + if err != nil { + return err + } + t.AddFieldConfig(similarityField.Name, &similarityField) + } + return nil +} + func (g *Generator) appendIfNotExists(obj gql.Type) error { if _, typeExists := g.manager.schema.TypeMap()[obj.Name()]; !typeExists { err := g.manager.schema.AppendType(obj) diff --git a/internal/request/graphql/schema/types/types.go b/internal/request/graphql/schema/types/types.go index 4fd5bb4efc..c98d3b1c08 100644 --- a/internal/request/graphql/schema/types/types.go +++ b/internal/request/graphql/schema/types/types.go @@ -71,6 +71,8 @@ const ( FieldOrderASC = "ASC" FieldOrderDESC = "DESC" + + SimilarityArgVector = "vector" ) // OrderingEnum is an enum for the Ordering argument. diff --git a/tests/integration/explain/execute/with_similarity_test.go b/tests/integration/explain/execute/with_similarity_test.go new file mode 100644 index 0000000000..b6bd62c550 --- /dev/null +++ b/tests/integration/explain/execute/with_similarity_test.go @@ -0,0 +1,74 @@ +// Copyright 2025 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package test_explain_execute + +import ( + "testing" + + testUtils "github.com/sourcenetwork/defradb/tests/integration" + explainUtils "github.com/sourcenetwork/defradb/tests/integration/explain" +) + +func TestExecuteExplainRequest_WithSimilarity(t *testing.T) { + test := testUtils.TestCase{ + Description: "Explain (execute) request with similarity.", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: `type User { + name: String + pointsList: [Float64!] + }`, + }, + testUtils.CreateDoc{ + DocMap: map[string]any{ + "name": "John", + "pointsList": []float64{2, 4, 1}, + }, + }, + testUtils.ExplainRequest{ + Request: `query @explain(type: execute) { + User { + name + _similarity(pointsList: {vector: [1, 2, 0]}) + } + }`, + ExpectedFullGraph: dataMap{ + "explain": dataMap{ + "executionSuccess": true, + "sizeOfResult": 1, + "planExecutions": uint64(2), + "operationNode": []dataMap{ + { + "selectTopNode": dataMap{ + "similarityNode": dataMap{ + "iterations": uint64(2), + "selectNode": dataMap{ + "iterations": uint64(2), + "filterMatches": uint64(1), + "scanNode": dataMap{ + "iterations": uint64(2), + "docFetches": uint64(1), + "fieldFetches": uint64(2), + "indexFetches": uint64(0), + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + explainUtils.ExecuteTestCase(t, test) +} diff --git a/tests/integration/query/simple/with_similarity_test.go b/tests/integration/query/simple/with_similarity_test.go new file mode 100644 index 0000000000..46b955f62c --- /dev/null +++ b/tests/integration/query/simple/with_similarity_test.go @@ -0,0 +1,477 @@ +// Copyright 2025 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package simple + +import ( + "testing" + + testUtils "github.com/sourcenetwork/defradb/tests/integration" +) + +func TestQuerySimple_WithSimilarityOnQuery_ShouldError(t *testing.T) { + test := testUtils.TestCase{ + Description: "Simple query, similarity on undefined object", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: `type User { + name: String + vector: [Int!] + }`, + }, + testUtils.Request{ + Request: `query { + _similarity + }`, + ExpectedError: "Cannot query field \"_similarity\" on type \"Query\".", + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQuerySimple_WithSimilarityOnUndefinedField_ShouldError(t *testing.T) { + test := testUtils.TestCase{ + Description: "Simple query, similarity on undefined field", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: `type User { + name: String + }`, + }, + testUtils.Request{ + Request: `query { + User{ + _similarity(pointsList: {vector: [1, 2, 3]}) + } + }`, + ExpectedError: "Unknown argument \"pointsList\" on field \"_similarity\" of type \"User\".", + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQuerySimple_WithSimilarityAndWrongVectorValueType_ShouldError(t *testing.T) { + test := testUtils.TestCase{ + Description: "Simple query, similarity on empty", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: `type User { + name: String + pointsList: [Int!] + }`, + }, + testUtils.Request{ + Request: `query { + User{ + _similarity(pointsList: {vector: [1.1, 1.2, 0.9]}) + } + }`, + ExpectedError: "Argument \"pointsList\" has invalid value {vector: [1.1, 1.2, 0.9]}.\nIn field " + + "\"vector\": In element #1: Expected type \"Int\", found 1.1.\nIn field \"vector\": In element #1: " + + "Expected type \"Int\", found 1.2.\nIn field \"vector\": In element #1: Expected type \"Int\", found 0.9.", + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQuerySimple_WithSimilarityAndWrongFieldType_ShouldError(t *testing.T) { + test := testUtils.TestCase{ + Description: "Simple query, similarity on empty", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: `type User { + name: String + pets: [String!] + }`, + }, + testUtils.Request{ + Request: `query { + User{ + _similarity(pets: {vector: [1.1, 1.2, 0.9]}) + } + }`, + // Not found on _similarity because it's not a supported type. + ExpectedError: "Unknown argument \"pets\" on field \"_similarity\" of type \"User\".", + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQuerySimple_WithSimilarityOnEmptyCollection_ShouldSucceed(t *testing.T) { + test := testUtils.TestCase{ + Description: "Simple query, similarity on empty", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: `type User { + name: String + pointsList: [Int!] + }`, + }, + testUtils.Request{ + Request: `query { + User{ + _similarity(pointsList: {vector: [1, 2, 3]}) + } + }`, + Results: map[string]any{ + "User": []map[string]any{}, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQuerySimple_WithIntSimilarity_ShouldSucceed(t *testing.T) { + test := testUtils.TestCase{ + Description: "Simple query, similarity on empty", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: `type User { + name: String + pointsList: [Int!] + }`, + }, + testUtils.CreateDoc{ + DocMap: map[string]any{ + "name": "John", + "pointsList": []int64{2, 4, 1}, + }, + }, + testUtils.Request{ + Request: `query { + User{ + name + _similarity(pointsList: {vector: [1, 2, 0]}) + } + }`, + Results: map[string]any{ + "User": []map[string]any{ + { + "name": "John", + "_similarity": float64(10), + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQuerySimple_WithIntSimilarityDifferentVectorLength_ShouldError(t *testing.T) { + test := testUtils.TestCase{ + Description: "Simple query, similarity on empty", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: `type User { + name: String + pointsList: [Int!] + }`, + }, + testUtils.CreateDoc{ + DocMap: map[string]any{ + "name": "John", + "pointsList": []int64{2, 4, 1}, + }, + }, + testUtils.Request{ + Request: `query { + User{ + name + _similarity(pointsList: {vector: [1, 2, 0, 1]}) + } + }`, + ExpectedError: "source and vector must be of the same length. Source: 3, Vector: 4", + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQuerySimple_WithFloat32Similarity_ShouldSucceed(t *testing.T) { + test := testUtils.TestCase{ + Description: "Simple query, similarity on empty", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: `type User { + name: String + pointsList: [Float32!] + }`, + }, + testUtils.CreateDoc{ + DocMap: map[string]any{ + "name": "John", + "pointsList": []float32{2, 4, 1}, + }, + }, + testUtils.Request{ + Request: `query { + User{ + name + _similarity(pointsList: {vector: [1, 2, 0]}) + } + }`, + Results: map[string]any{ + "User": []map[string]any{ + { + "name": "John", + "_similarity": float64(10), + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQuerySimple_WithFloat64Similarity_ShouldSucceed(t *testing.T) { + test := testUtils.TestCase{ + Description: "Simple query, similarity on empty", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: `type User { + name: String + pointsList: [Float64!] + }`, + }, + testUtils.CreateDoc{ + DocMap: map[string]any{ + "name": "John", + "pointsList": []float64{2, 4, 1}, + }, + }, + testUtils.Request{ + Request: `query { + User{ + name + _similarity(pointsList: {vector: [1, 2, 0]}) + } + }`, + Results: map[string]any{ + "User": []map[string]any{ + { + "name": "John", + "_similarity": float64(10), + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQuerySimple_WithJSONDocCreationSimilarity_ShouldSucceed(t *testing.T) { + test := testUtils.TestCase{ + Description: "Simple query, similarity on empty", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: `type User { + name: String + pointsList: [Float64!] + }`, + }, + testUtils.CreateDoc{ + Doc: `{ + "name": "John", + "pointsList": [2, 4, 1] + }`, + }, + testUtils.Request{ + Request: `query { + User{ + name + _similarity(pointsList: {vector: [1, 2, 0]}) + } + }`, + Results: map[string]any{ + "User": []map[string]any{ + { + "name": "John", + "_similarity": float64(10), + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQuerySimple_WithSimilarityAndFilteringOnSimilarityResult_ShouldSucceed(t *testing.T) { + test := testUtils.TestCase{ + Description: "Simple query, similarity on empty", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: `type User { + name: String + pointsList: [Int!] + }`, + }, + testUtils.CreateDoc{ + DocMap: map[string]any{ + "name": "John", + "pointsList": []int64{2, 4, 1}, + }, + }, + testUtils.CreateDoc{ + DocMap: map[string]any{ + "name": "Bob", + "pointsList": []int64{1, 1, 1}, + }, + }, + testUtils.CreateDoc{ + DocMap: map[string]any{ + "name": "Alice", + "pointsList": []int64{4, 5, 3}, + }, + }, + testUtils.Request{ + Request: `query { + User(filter: {_alias: {sim: {_lt: 11}}}){ + name + sim: _similarity(pointsList: {vector: [1, 2, 0]}) + } + }`, + Results: map[string]any{ + "User": []map[string]any{ + { + "name": "Bob", + "sim": float64(3), + }, + { + "name": "John", + "sim": float64(10), + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestQuerySimple_WithTwoSimilarityAndFilteringOnSecond_ShouldSucceed(t *testing.T) { + test := testUtils.TestCase{ + Description: "Simple query, similarity on empty", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: `type User { + name: String + pointsList: [Int!] + }`, + }, + testUtils.CreateDoc{ + DocMap: map[string]any{ + "name": "John", + "pointsList": []int64{2, 4, 1}, + }, + }, + testUtils.CreateDoc{ + DocMap: map[string]any{ + "name": "Bob", + "pointsList": []int64{1, 1, 1}, + }, + }, + testUtils.CreateDoc{ + DocMap: map[string]any{ + "name": "Alice", + "pointsList": []int64{4, 5, 3}, + }, + }, + testUtils.Request{ + Request: `query { + User(filter: {_alias: {sim2: {_gt: 20}}}){ + name + sim: _similarity(pointsList: {vector: [1, 2, 0]}) + sim2: _similarity(pointsList: {vector: [2, 3, 0]}) + } + }`, + Results: map[string]any{ + "User": []map[string]any{ + { + "name": "Alice", + "sim": float64(14), + "sim2": float64(23), + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +// This test documents a bug where having two aliases in a logical _or operator +// return no results even though in the tests bellow 1 should be returned. +// https://github.com/sourcenetwork/defradb/issues/3468 +func TestQuerySimple_WithTwoSimilarityAndFilteringOnBoth_ShouldSucceed(t *testing.T) { + test := testUtils.TestCase{ + Description: "Simple query, similarity on empty", + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: `type User { + name: String + pointsList: [Int!] + }`, + }, + testUtils.CreateDoc{ + DocMap: map[string]any{ + "name": "John", + "pointsList": []int64{2, 4, 1}, + }, + }, + testUtils.CreateDoc{ + DocMap: map[string]any{ + "name": "Bob", + "pointsList": []int64{1, 1, 1}, + }, + }, + testUtils.CreateDoc{ + DocMap: map[string]any{ + "name": "Alice", + "pointsList": []int64{4, 5, 3}, + }, + }, + testUtils.Request{ + Request: `query { + User(filter: {_or: [{_alias: {sim2: {_gt: 20}}}, {_alias: {sim: {_lt: 10}}}]}){ + name + sim: _similarity(pointsList: {vector: [1, 2, 0]}) + sim2: _similarity(pointsList: {vector: [2, 3, 0]}) + } + }`, + Results: map[string]any{ + "User": []map[string]any{}, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} diff --git a/tests/integration/schema/default_fields.go b/tests/integration/schema/default_fields.go index 1f71f6bc2f..9254805eae 100644 --- a/tests/integration/schema/default_fields.go +++ b/tests/integration/schema/default_fields.go @@ -66,6 +66,7 @@ var DefaultFields = concat( versionField, groupField, deletedField, + similarityField, }, aggregateFields, ) @@ -75,6 +76,7 @@ var DefaultFields = concat( var DefaultViewObjFields = concat( fields{ groupField, + similarityField, }, aggregateFields, ) @@ -149,6 +151,14 @@ var aggregateFields = fields{ }, } +var similarityField = Field{ + "name": "_similarity", + "type": map[string]any{ + "kind": "SCALAR", + "name": "Float", + }, +} + var cidArg = Field{ "name": "cid", "type": map[string]any{ diff --git a/tests/integration/schema/similarity_test.go b/tests/integration/schema/similarity_test.go new file mode 100644 index 0000000000..17a93d9a37 --- /dev/null +++ b/tests/integration/schema/similarity_test.go @@ -0,0 +1,322 @@ +// Copyright 2025 Democratized Data Foundation +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package schema + +import ( + "testing" + + testUtils "github.com/sourcenetwork/defradb/tests/integration" +) + +func TestSchemaInstrospection_SimilarityCapableFieldIntArray(t *testing.T) { + test := testUtils.TestCase{ + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type Users { + someVector: [Int!] + } + `, + }, + testUtils.IntrospectionRequest{ + Request: ` + query { + __type (name: "Users") { + name + fields { + name + args { + name + type { + name + inputFields { + name + type { + name + kind + ofType { + name + kind + ofType { + name + kind + ofType { + name + kind + } + } + } + } + } + } + } + } + } + } + `, + ContainsData: map[string]any{ + "__type": map[string]any{ + "name": "Users", + "fields": []any{ + map[string]any{ + "name": "_similarity", + "args": []any{ + map[string]any{ + "name": "someVector", + "type": map[string]any{ + "inputFields": []any{ + map[string]any{ + "name": "vector", + "type": map[string]any{ + "kind": "NON_NULL", + "name": nil, + "ofType": map[string]any{ + "kind": "LIST", + "name": nil, + "ofType": map[string]any{ + "name": nil, + "kind": "NON_NULL", + "ofType": map[string]any{ + "name": "Int", + "kind": "SCALAR", + }, + }, + }, + }, + }, + }, + "name": "Users__someVector__SimilaritySelector", + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestSchemaInstrospection_SimilarityCapableFieldFloat32Array(t *testing.T) { + test := testUtils.TestCase{ + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type Users { + someVector: [Float32!] + } + `, + }, + testUtils.IntrospectionRequest{ + Request: ` + query { + __type (name: "Users") { + name + fields { + name + args { + name + type { + name + inputFields { + name + type { + name + kind + ofType { + name + kind + ofType { + name + kind + ofType { + name + kind + } + } + } + } + } + } + } + } + } + } + `, + ContainsData: map[string]any{ + "__type": map[string]any{ + "name": "Users", + "fields": []any{ + map[string]any{ + "name": "_similarity", + "args": []any{ + map[string]any{ + "name": "someVector", + "type": map[string]any{ + "inputFields": []any{ + map[string]any{ + "name": "vector", + "type": map[string]any{ + "kind": "NON_NULL", + "name": nil, + "ofType": map[string]any{ + "kind": "LIST", + "name": nil, + "ofType": map[string]any{ + "name": nil, + "kind": "NON_NULL", + "ofType": map[string]any{ + "name": "Float32", + "kind": "SCALAR", + }, + }, + }, + }, + }, + }, + "name": "Users__someVector__SimilaritySelector", + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +} + +func TestSchemaInstrospection_SimilarityCapableFieldsIntArrayAndFloat32Array(t *testing.T) { + test := testUtils.TestCase{ + Actions: []any{ + testUtils.SchemaUpdate{ + Schema: ` + type Users { + someVectorInt: [Int!] + someVectorFloat32: [Float32!] + someOtherNumber: Int + } + `, + }, + testUtils.IntrospectionRequest{ + Request: ` + query { + __type (name: "Users") { + name + fields { + name + args { + name + type { + name + inputFields { + name + type { + name + kind + ofType { + name + kind + ofType { + name + kind + ofType { + name + kind + } + } + } + } + } + } + } + } + } + } + `, + ContainsData: map[string]any{ + "__type": map[string]any{ + "name": "Users", + "fields": []any{ + map[string]any{ + "name": "_similarity", + "args": []any{ + map[string]any{ + "name": "someVectorFloat32", + "type": map[string]any{ + "inputFields": []any{ + map[string]any{ + "name": "vector", + "type": map[string]any{ + "kind": "NON_NULL", + "name": nil, + "ofType": map[string]any{ + "kind": "LIST", + "name": nil, + "ofType": map[string]any{ + "name": nil, + "kind": "NON_NULL", + "ofType": map[string]any{ + "name": "Float32", + "kind": "SCALAR", + }, + }, + }, + }, + }, + }, + "name": "Users__someVectorFloat32__SimilaritySelector", + }, + }, + map[string]any{ + "name": "someVectorInt", + "type": map[string]any{ + "inputFields": []any{ + map[string]any{ + "name": "vector", + "type": map[string]any{ + "kind": "NON_NULL", + "name": nil, + "ofType": map[string]any{ + "kind": "LIST", + "name": nil, + "ofType": map[string]any{ + "name": nil, + "kind": "NON_NULL", + "ofType": map[string]any{ + "name": "Int", + "kind": "SCALAR", + }, + }, + }, + }, + }, + }, + "name": "Users__someVectorInt__SimilaritySelector", + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + testUtils.ExecuteTestCase(t, test) +}