Skip to content

Query: Adds support for weighted RRF in Hybrid Search #45328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
* Added API to allow customers to wrap/extend `CosmosAsyncContainer` - [PR 43724](https://github.com/Azure/azure-sdk-for-java/pull/43724) and [PR 45087](https://github.com/Azure/azure-sdk-for-java/pull/45087)
* Added Per-Partition Automatic Failover which enables failover for writes at per-partition level for Single-Write Multi-Region accounts. - [PR 44099](https://github.com/Azure/azure-sdk-for-java/pull/44099)
* Added Beta public API to allow defining the consistency behavior for read / query / change feed operations independent of the chosen account-level consistency level. **NOTE: This API is still in preview mode and can only be used when using DIRECT connection mode.** - See [PR 45161](https://github.com/Azure/azure-sdk-for-java/pull/45161)
* Added Weighted RRF for Hybrid and Full Text Search queries - [PR 45328](https://github.com/Azure/azure-sdk-for-java/pull/45328)

#### Bugs Fixed
* Fixed the fail back flow where not all partitions were failing back to original first preferred region for Per-Partition Circuit Breaker. - [PR 44099](https://github.com/Azure/azure-sdk-for-java/pull/44099)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ public static final class Properties {
// Hybrid Search Query
public static final String GLOBAL_STATISTICS_QUERY = "globalStatisticsQuery";
public static final String COMPONENT_QUERY_INFOS = "componentQueryInfos";
public static final String COMPONENT_WEIGHTS = "componentWeights";
public static final String PROJECTION_QUERY_INFO = "projectionQueryInfo";
public static final String SKIP = "skip";
public static final String TAKE = "take";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,23 @@ private Flux<HybridSearchQueryResult<Document>> hybridSearch(List<FeedRangeEpkIm
// Retrieve and format rewritten query infos using the global statistics.
Flux<QueryInfo> rewrittenQueryInfos = retrieveRewrittenQueryInfos(hybridSearchQueryInfo.getComponentQueryInfoList());

// Retrieve component weights used to sort component queries and compute correct ranks later
List<ComponentWeight> componentWeights = retrieveComponentWeights(hybridSearchQueryInfo.getComponentWeights(), hybridSearchQueryInfo.getComponentQueryInfoList());

// Run component queries, and retrieve component query results.
Flux<Document> componentQueryResults = getComponentQueryResults(targetFeedRanges, initialPageSize, collection, rewrittenQueryInfos);

// Coalesce the results on unique _rids, and sort it based on the _rid
Mono<List<HybridSearchQueryResult<Document>>> coalescedAndSortedResults = coalesceAndSortResults(componentQueryResults);

// Compose component scores matrix, where each tuple is (score, index)
Mono<List<List<ScoreTuple>>> componentScoresList = retrieveComponentScores(coalescedAndSortedResults);
Mono<List<List<ScoreTuple>>> componentScoresList = retrieveComponentScores(coalescedAndSortedResults, componentWeights);

// Compute Ranks
Mono<List<List<Integer>>> ranks = computeRanks(componentScoresList);

// Compute the RRF scores
return computeRRFScores(ranks, coalescedAndSortedResults);
return computeRRFScores(ranks, coalescedAndSortedResults, componentWeights);
}

@Override
Expand Down Expand Up @@ -293,18 +296,19 @@ public Flux<FeedResponse<Document>> apply(Flux<HybridSearchQueryResult<Document>
}
}

private static Flux<HybridSearchQueryResult<Document>> computeRRFScores(Mono<List<List<Integer>>> ranks, Mono<List<HybridSearchQueryResult<Document>>> coalescedAndSortedResults) {
private static Flux<HybridSearchQueryResult<Document>> computeRRFScores(Mono<List<List<Integer>>> ranks,
Mono<List<HybridSearchQueryResult<Document>>> coalescedAndSortedResults,
List<ComponentWeight> componentWeights) {
return ranks.zipWith(coalescedAndSortedResults)
.map(tuple -> {
List<List<Integer>> ranksInternal = tuple.getT1();
List<HybridSearchQueryResult<Document>> results = tuple.getT2();

for (int index = 0; index < results.size(); ++index) {
double rrfScore = 0.0;
for (List<Integer> integers : ranksInternal) {
rrfScore += 1.0 / (RRF_CONSTANT + integers.get(index));
for (int componentIndex = 0; componentIndex < ranksInternal.size(); ++componentIndex) {
rrfScore += componentWeights.get(componentIndex).getWeight() / (RRF_CONSTANT + ranksInternal.get(componentIndex).get(index));
}

results.get(index).setScore(rrfScore);
}
// Sort on the RRF scores to build the final result
Expand All @@ -329,7 +333,7 @@ private static Mono<List<List<Integer>>> computeRanks(Mono<List<List<ScoreTuple>
int rank = 1; // ranks are 1 based
for (int index = 0; index < componentScores.get(componentIndex).size(); index++) {
// Identical scores should have the same rank
if ((index > 0) && (componentScores.get(componentIndex).get(index).getScore() < componentScores.get(componentIndex).get(index - 1).getScore())) {
if ((index > 0) && (componentScores.get(componentIndex).get(index).getScore() != componentScores.get(componentIndex).get(index - 1).getScore())) {
rank += 1;
}
int rankIndex = componentScores.get(componentIndex).get(index).getIndex();
Expand All @@ -340,7 +344,7 @@ private static Mono<List<List<Integer>>> computeRanks(Mono<List<List<ScoreTuple>
});
}

private static Mono<List<List<ScoreTuple>>> retrieveComponentScores(Mono<List<HybridSearchQueryResult<Document>>> coalescedAndSortedResults) {
private static Mono<List<List<ScoreTuple>>> retrieveComponentScores(Mono<List<HybridSearchQueryResult<Document>>> coalescedAndSortedResults, List<ComponentWeight> componentWeights) {
return coalescedAndSortedResults.map(results -> {
List<List<ScoreTuple>> componentScoresInternal = new ArrayList<>();
for (int i = 0; i < results.get(0).getComponentScores().size(); i++) {
Expand All @@ -361,9 +365,14 @@ private static Mono<List<List<ScoreTuple>>> retrieveComponentScores(Mono<List<Hy
componentScoresInternal.get(j).add(scoreTuple);
}
}
//Sort scores in descending order
for (List<ScoreTuple> scoreTuples : componentScoresInternal) {
scoreTuples.sort(Comparator.comparing(ScoreTuple::getScore, Comparator.reverseOrder()));
// //Sort scores in descending order
// for (List<ScoreTuple> scoreTuples : componentScoresInternal) {
// scoreTuples.sort(Comparator.comparing(ScoreTuple::getScore, Comparator.reverseOrder()));
// }
for (int i = 0; i < componentScoresInternal.size(); i++) {
final int componentIndex = i;
componentScoresInternal.get(i).sort((x,y) ->
componentWeights.get(componentIndex).getComparator().compare(x.getScore(), y.getScore()));
}
return componentScoresInternal;
});
Expand Down Expand Up @@ -492,6 +501,38 @@ private GlobalFullTextSearchQueryStatistics aggregateStatistics(List<GlobalFullT
return aggregatedStats;
}

private List<ComponentWeight> retrieveComponentWeights(List<Double> componentWeightList, List<QueryInfo> componentQueryInfos) {
boolean useDefaultComponentWeight = componentWeightList == null || componentWeightList.isEmpty();
List<ComponentWeight> componentWeights = new ArrayList<>();
for (int i=0;i<componentQueryInfos.size();i++) {
QueryInfo queryInfo = componentQueryInfos.get(i);

double componentWeight = useDefaultComponentWeight ? 1.0 : componentWeightList.get(i);
componentWeights.add(new ComponentWeight(componentWeight, queryInfo.getOrderBy().get(0)));
}
return componentWeights;
}

private static class ComponentWeight {
private final Double weight;
private final Comparator<Double> comparator;

public ComponentWeight(Double weight, SortOrder sortOrder) {
this.weight = weight;

int comparisonFactor = (sortOrder == SortOrder.Ascending) ? 1 : -1;
this.comparator = (x, y) -> comparisonFactor * Double.compare(x, y);
}

public Double getWeight() {
return weight;
}

public Comparator<Double> getComparator() {
return comparator;
}
}

public static class ScoreTuple {
private final Double score;
private final Integer index;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ public enum QueryFeature {
DCount,
NonStreamingOrderBy,
HybridSearch,
CountIf
CountIf,
WeightedRankFusion
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ class QueryPlanRetriever {
QueryFeature.DCount.name() + ", " +
QueryFeature.NonValueAggregate.name() + ", " +
QueryFeature.NonStreamingOrderBy.name() + ", " +
QueryFeature.HybridSearch.name();
QueryFeature.HybridSearch.name() + ", " +
QueryFeature.WeightedRankFusion.name();

private static final String OLD_SUPPORTED_QUERY_FEATURES = QueryFeature.Aggregate.name() + ", " +
QueryFeature.CompositeAggregate.name() + ", " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ public class HybridSearchQueryInfo extends JsonSerializable {
private String globalStatisticsQuery;
@JsonProperty(Constants.Properties.COMPONENT_QUERY_INFOS)
private List<QueryInfo> componentQueryInfoList;
@JsonProperty(Constants.Properties.COMPONENT_WEIGHTS)
private List<Double> componentWeights;
@JsonProperty(Constants.Properties.PROJECTION_QUERY_INFO)
private QueryInfo projectionQueryInfo;
@JsonProperty(Constants.Properties.SKIP)
Expand Down Expand Up @@ -61,6 +63,15 @@ public List<QueryInfo> getComponentQueryInfoList() {
return componentQueryInfoList != null ? this.componentQueryInfoList : (this.componentQueryInfoList = super.getList(Constants.Properties.COMPONENT_QUERY_INFOS, QueryInfo.class));
}

/**
* Gets the list for componentWeights for hybrid search
*
* @return componentWeights
*/
public List<Double> getComponentWeights() {
return componentWeights != null ? this.componentWeights : (this.componentWeights = super.getList(Constants.Properties.COMPONENT_WEIGHTS, Double.class));
}

/**
* Gets the projectionQueryInfo for hybrid search
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,16 @@

import com.azure.cosmos.implementation.Constants;
import com.azure.cosmos.implementation.JsonSerializable;
import com.azure.cosmos.util.Beta;

/**
* Represents cosmos full text index of the IndexingPolicy in the Azure Cosmos DB database service.
*/
@Beta(value = Beta.SinceVersion.V4_65_0, warningText = Beta.PREVIEW_SUBJECT_TO_CHANGE_WARNING)
public final class CosmosFullTextIndex {
private final JsonSerializable jsonSerializable;

/**
* Constructor
*/
@Beta(value = Beta.SinceVersion.V4_65_0, warningText = Beta.PREVIEW_SUBJECT_TO_CHANGE_WARNING)
public CosmosFullTextIndex() {
this.jsonSerializable = new JsonSerializable();
}
Expand All @@ -26,15 +23,13 @@ public CosmosFullTextIndex() {
* Gets path.
* @return the path.
*/
@Beta(value = Beta.SinceVersion.V4_65_0, warningText = Beta.PREVIEW_SUBJECT_TO_CHANGE_WARNING)
public String getPath() { return this.jsonSerializable.getString(Constants.Properties.PATH); }

/**
* Sets the path.
* @param path the path.
* @return the CosmosFullTextIndex.
*/
@Beta(value = Beta.SinceVersion.V4_65_0, warningText = Beta.PREVIEW_SUBJECT_TO_CHANGE_WARNING)
public CosmosFullTextIndex setPath(String path) {
this.jsonSerializable.set(
Constants.Properties.PATH,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@

import com.azure.cosmos.implementation.Constants;
import com.azure.cosmos.implementation.apachecommons.lang.StringUtils;
import com.azure.cosmos.util.Beta;
import com.fasterxml.jackson.annotation.JsonProperty;

/**
* Path settings within {@link CosmosFullTextPolicy}
*/
@Beta(value = Beta.SinceVersion.V4_65_0, warningText = Beta.PREVIEW_SUBJECT_TO_CHANGE_WARNING)
public final class CosmosFullTextPath {
@JsonProperty(Constants.Properties.PATH)
private String path;
Expand All @@ -21,15 +19,13 @@ public final class CosmosFullTextPath {
/**
* Constructor
*/
@Beta(value = Beta.SinceVersion.V4_65_0, warningText = Beta.PREVIEW_SUBJECT_TO_CHANGE_WARNING)
public CosmosFullTextPath() {}

/**
* Gets the path for the cosmosFullText.
*
* @return path
*/
@Beta(value = Beta.SinceVersion.V4_65_0, warningText = Beta.PREVIEW_SUBJECT_TO_CHANGE_WARNING)
public String getPath() {
return path;
}
Expand All @@ -40,7 +36,6 @@ public String getPath() {
* @param path the path for the cosmosFullText.
* @return CosmosFullTextPath
*/
@Beta(value = Beta.SinceVersion.V4_65_0, warningText = Beta.PREVIEW_SUBJECT_TO_CHANGE_WARNING)
public CosmosFullTextPath setPath(String path) {
if (StringUtils.isEmpty(path)) {
throw new NullPointerException("Full text search path is either null or empty");
Expand All @@ -58,7 +53,6 @@ public CosmosFullTextPath setPath(String path) {
* Gets the language for the cosmosFullText path.
* @return language
*/
@Beta(value = Beta.SinceVersion.V4_65_0, warningText = Beta.PREVIEW_SUBJECT_TO_CHANGE_WARNING)
public String getLanguage() {
return language;
}
Expand All @@ -68,7 +62,6 @@ public String getLanguage() {
* @param language the language for the cosmosFullText path.
* @return CosmosFullTextPath
*/
@Beta(value = Beta.SinceVersion.V4_65_0, warningText = Beta.PREVIEW_SUBJECT_TO_CHANGE_WARNING)
public CosmosFullTextPath setLanguage(String language) {
this.language = language;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package com.azure.cosmos.models;

import com.azure.cosmos.implementation.Constants;
import com.azure.cosmos.util.Beta;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;

Expand All @@ -13,7 +12,6 @@
/**
* Full Text Search Policy
*/
@Beta(value = Beta.SinceVersion.V4_65_0, warningText = Beta.PREVIEW_SUBJECT_TO_CHANGE_WARNING)
@JsonInclude(JsonInclude.Include.NON_NULL)
public final class CosmosFullTextPolicy {
@JsonProperty(Constants.Properties.DEFAULT_LANGUAGE)
Expand All @@ -24,7 +22,6 @@ public final class CosmosFullTextPolicy {
/**
* Constructor
*/
@Beta(value = Beta.SinceVersion.V4_65_0, warningText = Beta.PREVIEW_SUBJECT_TO_CHANGE_WARNING)
public CosmosFullTextPolicy() {
}

Expand All @@ -33,7 +30,6 @@ public CosmosFullTextPolicy() {
*
* @return the default language for cosmosFullText.
*/
@Beta(value = Beta.SinceVersion.V4_65_0, warningText = Beta.PREVIEW_SUBJECT_TO_CHANGE_WARNING)
public String getDefaultLanguage() {
return defaultLanguage;
}
Expand All @@ -43,7 +39,6 @@ public String getDefaultLanguage() {
* @param defaultLanguage the default language for cosmosFullText.
* @return CosmosFullTextPolicy
*/
@Beta(value = Beta.SinceVersion.V4_65_0, warningText = Beta.PREVIEW_SUBJECT_TO_CHANGE_WARNING)
public CosmosFullTextPolicy setDefaultLanguage(String defaultLanguage) {
this.defaultLanguage = defaultLanguage;
return this;
Expand All @@ -53,7 +48,6 @@ public CosmosFullTextPolicy setDefaultLanguage(String defaultLanguage) {
* Gets the paths for cosmosFulltext.
* @return the paths for cosmosFulltext.
*/
@Beta(value = Beta.SinceVersion.V4_65_0, warningText = Beta.PREVIEW_SUBJECT_TO_CHANGE_WARNING)
public List<CosmosFullTextPath> getPaths() {
return paths;
}
Expand All @@ -63,7 +57,6 @@ public List<CosmosFullTextPath> getPaths() {
* @param paths the paths for cosmosFulltext.
* @return CosmosFullTextPolicy
*/
@Beta(value = Beta.SinceVersion.V4_65_0, warningText = Beta.PREVIEW_SUBJECT_TO_CHANGE_WARNING)
public CosmosFullTextPolicy setPaths(List<CosmosFullTextPath> paths) {
for (CosmosFullTextPath cosmosFullTextPath : paths) {
if (cosmosFullTextPath.getLanguage().isEmpty()) {
Expand Down