Skip to content

Commit

Permalink
feat: Add cosine similarity query (#3464)
Browse files Browse the repository at this point in the history
## Relevant issue(s)

Resolves #3349 

## Description

This PR adds the possibility to calculate the cosine similarity between
a vector field and a given vector.

To achieve this we added the `_similarity` system field which take a
target field (part of the parent object) and vector as parameter.
```gql
query {
  User{
    _similarity(pointsList: {vector: [1, 2, 0]})
  }
}
```

Note that the added code to mapper and planner is more of a "bolt on"
addition given the current state of that part of the code base. A
refactor is expected in the future.

Future work will allow giving a `content` parameter instead of the
`vector` if the target field has embedding generation configured. This
will enable out-of-the-box RAG queries.
  • Loading branch information
fredcarle authored Feb 19, 2025
1 parent 7107187 commit 4d90e6e
Show file tree
Hide file tree
Showing 18 changed files with 1,320 additions and 34 deletions.
42 changes: 22 additions & 20 deletions client/request/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,17 @@ const (

DocIDArgName = "docID"

AverageFieldName = "_avg"
CountFieldName = "_count"
DocIDFieldName = "_docID"
GroupFieldName = "_group"
DeletedFieldName = "_deleted"
SumFieldName = "_sum"
VersionFieldName = "_version"
MaxFieldName = "_max"
MinFieldName = "_min"
AliasFieldName = "_alias"
AverageFieldName = "_avg"
CountFieldName = "_count"
DocIDFieldName = "_docID"
GroupFieldName = "_group"
DeletedFieldName = "_deleted"
SumFieldName = "_sum"
VersionFieldName = "_version"
MaxFieldName = "_max"
MinFieldName = "_min"
AliasFieldName = "_alias"
SimilarityFieldName = "_similarity"

// New generated document id from a backed up document,
// which might have a different _docID originally.
Expand Down Expand Up @@ -104,16 +105,17 @@ var (
}

ReservedFields = map[string]struct{}{
TypeNameFieldName: {},
VersionFieldName: {},
GroupFieldName: {},
CountFieldName: {},
SumFieldName: {},
AverageFieldName: {},
DocIDFieldName: {},
DeletedFieldName: {},
MaxFieldName: {},
MinFieldName: {},
TypeNameFieldName: {},
VersionFieldName: {},
GroupFieldName: {},
CountFieldName: {},
SumFieldName: {},
AverageFieldName: {},
DocIDFieldName: {},
DeletedFieldName: {},
MaxFieldName: {},
MinFieldName: {},
SimilarityFieldName: {},
}

Aggregates = map[string]struct{}{
Expand Down
26 changes: 26 additions & 0 deletions client/request/similarity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright 2025 Democratized Data Foundation
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package request

// Similarity is a functional field that defines the
// parameters to calculate the cosine similarity between two vectors.
type Similarity struct {
Field
// Vector contains the vector to compare the target field to.
//
// It will be of type Int, Float32 or Float64. It must be the same type and length as Target.
Vector any

// Target is the field in the host object that we will compare the the vector to.
//
// It must be a field of type Int, Float32 or Float64. It must be the same type and length as Vector.
Target string
}
9 changes: 9 additions & 0 deletions internal/planner/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ var (
ErrUnknownRelationType = errors.New("failed sub selection, unknown relation type")
ErrUnknownExplainRequestType = errors.New("can not explain request of unknown type")
ErrUpsertMultipleDocuments = errors.New("cannot upsert multiple matching documents")
ErrMismatchLengthOnSimilarity = errors.New("source and vector must be of the same length")
)

func NewErrUnknownDependency(name string) error {
Expand All @@ -52,3 +53,11 @@ func NewErrFailedToCollectExecExplainInfo(inner error) error {
func NewErrSubTypeInit(inner error) error {
return errors.Wrap(errSubTypeInit, inner)
}

func NewErrMismatchLengthOnSimilarity(source, vector int) error {
return errors.WithStack(
ErrMismatchLengthOnSimilarity,
errors.NewKV("Source", source),
errors.NewKV("Vector", vector),
)
}
20 changes: 20 additions & 0 deletions internal/planner/mapper/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,26 @@ func getRequestables(
Key: getRenderKey(&f.Field),
})

mapping.Add(index, f.Name)
case *request.Similarity:
index := mapping.GetNextIndex()
fields = append(fields, &Similarity{
Field: Field{
Index: index,
Name: f.Name,
},
Vector: f.Vector,
SimilarityTarget: Targetable{
Field: Field{
Index: mapping.FirstIndexOfName(f.Target),
Name: f.Target,
},
},
})
mapping.RenderKeys = append(mapping.RenderKeys, core.RenderKey{
Index: index,
Key: getRenderKey(&f.Field),
})
mapping.Add(index, f.Name)
default:
return nil, nil, client.NewErrUnhandledType("field", field)
Expand Down
26 changes: 26 additions & 0 deletions internal/planner/mapper/similarity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright 2025 Democratized Data Foundation
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package mapper

import "github.com/sourcenetwork/defradb/internal/core"

// Similarity represents an cosine similarity operation definition.
type Similarity struct {
Field
// The mapping of this aggregate's parent/host.
*core.DocumentMapping

// The targetted field for the cosine similarity
SimilarityTarget Targetable

// The vector to compare the target field to.
Vector any
}
1 change: 1 addition & 0 deletions internal/planner/operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ var (
_ planNode = (*valuesNode)(nil)
_ planNode = (*viewNode)(nil)
_ planNode = (*lensNode)(nil)
_ planNode = (*similarityNode)(nil)

_ MultiNode = (*parallelNode)(nil)
_ MultiNode = (*topLevelNode)(nil)
Expand Down
9 changes: 9 additions & 0 deletions internal/planner/planner.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ func (p *Planner) expandSelectTopNodePlan(plan *selectTopNode, parentPlan *selec
p.expandLimitPlan(plan, parentPlan)
}

p.expandSimilarityPlans(plan)

return nil
}

Expand All @@ -249,6 +251,13 @@ func (p *Planner) expandAggregatePlans(plan *selectTopNode) {
}
}

func (p *Planner) expandSimilarityPlans(plan *selectTopNode) {
for _, sim := range plan.similarity {
sim.SetPlan(plan.planNode)
plan.planNode = sim
}
}

func (p *Planner) expandMultiNode(multiNode MultiNode, parentPlan *selectTopNode) error {
for _, child := range multiNode.Children() {
if err := p.expandPlan(child, parentPlan); err != nil {
Expand Down
2 changes: 2 additions & 0 deletions internal/planner/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ func (n *scanNode) initFields(fields []mapper.Requestable) error {
n.tryAddFieldWithName(target.Field.Name)
}
}
case *mapper.Similarity:
n.tryAddFieldWithName(requestable.SimilarityTarget.Name)
}
}
return nil
Expand Down
39 changes: 25 additions & 14 deletions internal/planner/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ type selectTopNode struct {
// selectNode is used pre-wiring of the plan (before expansion and all).
selectNode *selectNode

// This is added temporarity until Planner is refactored
// https://github.com/sourcenetwork/defradb/issues/3467
similarity []*similarityNode

// plan is the top of the plan graph (the wired and finalized plan graph).
planNode planNode
}
Expand Down Expand Up @@ -236,14 +240,14 @@ func (n *selectNode) Explain(explainType request.ExplainType) (map[string]any, e
// creating scanNodes, typeIndexJoinNodes, and splitting
// the necessary filters. Its designed to work with the
// planner.Select construction call.
func (n *selectNode) initSource() ([]aggregateNode, error) {
func (n *selectNode) initSource() ([]aggregateNode, []*similarityNode, error) {
if n.selectReq.CollectionName == "" {
n.selectReq.CollectionName = n.selectReq.Name
}

sourcePlan, err := n.planner.getSource(n.selectReq)
if err != nil {
return nil, err
return nil, nil, err
}
n.source = sourcePlan.plan
n.origSource = sourcePlan.plan
Expand All @@ -264,7 +268,7 @@ func (n *selectNode) initSource() ([]aggregateNode, error) {
if n.selectReq.Cid.HasValue() {
c, err := cid.Decode(n.selectReq.Cid.Value())
if err != nil {
return nil, err
return nil, nil, err
}

// This exists because the fetcher interface demands a []Prefixes, yet the versioned
Expand Down Expand Up @@ -293,17 +297,17 @@ func (n *selectNode) initSource() ([]aggregateNode, error) {
}
}

aggregates, err := n.initFields(n.selectReq)
aggregates, similarity, err := n.initFields(n.selectReq)
if err != nil {
return nil, err
return nil, nil, err
}

if isScanNode {
origScan.index = findIndexByFilteringField(origScan)
origScan.initFetcher(n.selectReq.Cid)
}

return aggregates, nil
return aggregates, similarity, nil
}

func findIndexByFilteringField(scanNode *scanNode) immutable.Option[client.IndexDescription] {
Expand Down Expand Up @@ -354,8 +358,9 @@ func findIndexByFieldName(col client.Collection, fieldName string) immutable.Opt
return immutable.None[client.IndexDescription]()
}

func (n *selectNode) initFields(selectReq *mapper.Select) ([]aggregateNode, error) {
func (n *selectNode) initFields(selectReq *mapper.Select) ([]aggregateNode, []*similarityNode, error) {
aggregates := []aggregateNode{}
similarity := []*similarityNode{}
// loop over the sub type
// at the moment, we're only testing a single sub selection
for _, field := range selectReq.Fields {
Expand All @@ -381,7 +386,7 @@ func (n *selectNode) initFields(selectReq *mapper.Select) ([]aggregateNode, erro
}

if aggregateError != nil {
return nil, aggregateError
return nil, nil, aggregateError
}

if plan != nil {
Expand All @@ -408,11 +413,11 @@ func (n *selectNode) initFields(selectReq *mapper.Select) ([]aggregateNode, erro
commitPlan := n.planner.DAGScan(commitSlct)

if err := n.addSubPlan(f.Index, commitPlan); err != nil {
return nil, err
return nil, nil, err
}
} else if f.Name == request.GroupFieldName {
if selectReq.GroupBy == nil {
return nil, ErrGroupOutsideOfGroupBy
return nil, nil, ErrGroupOutsideOfGroupBy
}
n.groupSelects = append(n.groupSelects, f)
} else if f.Name == request.LinksFieldName &&
Expand All @@ -427,13 +432,17 @@ func (n *selectNode) initFields(selectReq *mapper.Select) ([]aggregateNode, erro
// a traditional join here
err := n.addTypeIndexJoin(f)
if err != nil {
return nil, err
return nil, nil, err
}
}
case *mapper.Similarity:
var simFilter *mapper.Filter
selectReq.Filter, simFilter = filter.SplitByFields(selectReq.Filter, f.Field)
similarity = append(similarity, n.planner.Similarity(f, simFilter))
}
}

return aggregates, nil
return aggregates, similarity, nil
}

func (n *selectNode) addTypeIndexJoin(subSelect *mapper.Select) error {
Expand Down Expand Up @@ -482,7 +491,7 @@ func (p *Planner) SelectFromSource(
s.collection = col
}

aggregates, err := s.initFields(selectReq)
aggregates, similarity, err := s.initFields(selectReq)
if err != nil {
return nil, err
}
Expand All @@ -508,6 +517,7 @@ func (p *Planner) SelectFromSource(
order: orderPlan,
group: groupPlan,
aggregates: aggregates,
similarity: similarity,
docMapper: docMapper{selectReq.DocumentMapping},
}
return top, nil
Expand All @@ -526,7 +536,7 @@ func (p *Planner) Select(selectReq *mapper.Select) (planNode, error) {
orderBy := selectReq.OrderBy
groupBy := selectReq.GroupBy

aggregates, err := s.initSource()
aggregates, similarity, err := s.initSource()
if err != nil {
return nil, err
}
Expand All @@ -552,6 +562,7 @@ func (p *Planner) Select(selectReq *mapper.Select) (planNode, error) {
order: orderPlan,
group: groupPlan,
aggregates: aggregates,
similarity: similarity,
docMapper: docMapper{selectReq.DocumentMapping},
}
return top, nil
Expand Down
Loading

0 comments on commit 4d90e6e

Please sign in to comment.