Skip to content

Enhanced Vector Store Capabilities with Full-Text/Hybrid Search and Reranking #1227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,22 @@ public class SearchRequest {

private Filter.Expression filterExpression;

/**
* Default value for search request is to use the vector search.
*/
private boolean vectorSearch = true;

/**
* Enables the full text search mode. If combined with the vector search, the hybrid
* search is done.
*/
private boolean fullTextSearch = false;

/**
* Enables the reranking of the results.
*/
private boolean reRank = false;

private SearchRequest(String query) {
this.query = query;
}
Expand Down Expand Up @@ -230,6 +246,36 @@ public SearchRequest withFilterExpression(String textExpression) {
return this;
}

/**
* Set the vector search mode.
* @param vectorSearch
* @return this.builder
*/
public SearchRequest withVectorSearch(boolean vectorSearch) {
this.vectorSearch = vectorSearch;
return this;
}

/**
* Set the full text search mode.
* @param fullTextSearch
* @return this.builder
*/
public SearchRequest withFullTextSearch(boolean fullTextSearch) {
this.fullTextSearch = fullTextSearch;
return this;
}

/**
* Set the rerank mode.
* @param rerank
* @return this.builder
*/
public SearchRequest withRerank(boolean rerank) {
this.reRank = rerank;
return this;
}

public String getQuery() {
return query;
}
Expand All @@ -250,10 +296,23 @@ public boolean hasFilterExpression() {
return this.filterExpression != null;
}

public boolean isVectorSearch() {
return this.vectorSearch;
}

public boolean isFullTextSearch() {
return this.fullTextSearch;
}

public boolean isReRank() {
return this.reRank;
}

@Override
public String toString() {
return "SearchRequest{" + "query='" + query + '\'' + ", topK=" + topK + ", similarityThreshold="
+ similarityThreshold + ", filterExpression=" + filterExpression + '}';
+ similarityThreshold + ", filterExpression=" + filterExpression + ", isVectorSearch=" + vectorSearch
+ ", isFullTextSearch=" + fullTextSearch + ", isRerank=" + reRank + '}';
}

@Override
Expand All @@ -264,12 +323,13 @@ public boolean equals(Object o) {
return false;
SearchRequest that = (SearchRequest) o;
return topK == that.topK && Double.compare(that.similarityThreshold, similarityThreshold) == 0
&& Objects.equals(query, that.query) && Objects.equals(filterExpression, that.filterExpression);
&& Objects.equals(query, that.query) && Objects.equals(filterExpression, that.filterExpression)
&& vectorSearch == that.vectorSearch && fullTextSearch == that.fullTextSearch && reRank == that.reRank;
}

@Override
public int hashCode() {
return Objects.hash(query, topK, similarityThreshold, filterExpression);
return Objects.hash(query, topK, similarityThreshold, filterExpression, vectorSearch, fullTextSearch);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,71 @@ default List<Document> similaritySearch(String query) {
return this.similaritySearch(SearchRequest.query(query));
}

/**
* Retrieves documents by query full text content and metadata filters to retrieve
* exactly the number of nearest-neighbor results that match the request criteria.
* @param request Search request for set search parameters, such as the query text,
* topK, similarity threshold and metadata filter expressions.
* @return a list of {@link Document} objects representing the retrieved documents
* that match the search criteria.
* @throws UnsupportedOperationException if the method is not supported by the current
* implementation. Subclasses should override this method to provide a specific
* implementation.
*/
default List<Document> fullTextSearch(SearchRequest request) {
throw new UnsupportedOperationException("The [" + this.getClass() + "] doesn't support full text search!");
}

/**
* Retrieves documents by query full text content using the default
* {@link SearchRequest}'s' search criteria.
* @param query Text to use for full text search.
* @return a list of {@link Document} objects representing the retrieved documents
* that match the search criteria.
*/
default List<Document> fullTextSearch(String query) {
return this.fullTextSearch(SearchRequest.query(query));
}

/**
* Performs a hybrid search by combining semantic and keyword-based search techniques
* to retrieve a list of relevant documents based on the provided
* {@link SearchRequest}.
* <p>
* This method is intended to retrieve documents that match the query both
* semantically (using vector embeddings) and via keyword matching. The hybrid
* approach aims to enhance retrieval accuracy by leveraging the strengths of both
* search methods.
* </p>
* @param request the {@link SearchRequest} object containing the query and search
* parameters.
* @return a list of {@link Document} objects representing the retrieved documents
* that match the search criteria.
* @throws UnsupportedOperationException if the method is not supported by the current
* implementation. Subclasses should override this method to provide a specific
* implementation.
*/
default List<Document> hybridSearch(SearchRequest request) {
throw new UnsupportedOperationException(
"The [" + this.getClass() + "] doesn't support hybrid (vector + text) search!");
}

/**
* Performs a hybrid search by combining semantic and keyword-based search techniques
* to retrieve a list of relevant documents based on the provided
* {@link SearchRequest}.
* <p>
* This method is intended to retrieve documents that match the query both
* semantically (using vector embeddings) and via keyword matching. The hybrid
* approach aims to enhance retrieval accuracy by leveraging the strengths of both
* search methods.
* </p>
* @param query Text to use for embedding similarity comparison.
* @return a list of {@link Document} objects representing the retrieved documents
* that match the search criteria.
*/
default List<Document> hybridSearch(String query) {
return this.hybridSearch(SearchRequest.query(query));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public void createFrom() {
assertThat(newRequest.getTopK()).isEqualTo(originalRequest.getTopK());
assertThat(newRequest.getFilterExpression()).isEqualTo(originalRequest.getFilterExpression());
assertThat(newRequest.getSimilarityThreshold()).isEqualTo(originalRequest.getSimilarityThreshold());
assertThat(newRequest.isVectorSearch() == originalRequest.isVectorSearch());
}

@Test
Expand Down Expand Up @@ -135,10 +136,20 @@ public void withFilterExpression() {

}

@Test()
public void withHybridSearchWithRerank() {

var request = SearchRequest.query("Test").withFullTextSearch(true).withRerank(true);
assertThat(request.isVectorSearch()).isTrue();
assertThat(request.isFullTextSearch()).isTrue();
assertThat(request.isReRank()).isTrue();
}

private void checkDefaults(SearchRequest request) {
assertThat(request.getFilterExpression()).isNull();
assertThat(request.getSimilarityThreshold()).isEqualTo(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL);
assertThat(request.getTopK()).isEqualTo(SearchRequest.DEFAULT_TOP_K);
assertThat(request.isVectorSearch()).isTrue();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,7 @@
import com.azure.search.documents.indexes.models.VectorSearch;
import com.azure.search.documents.indexes.models.VectorSearchAlgorithmMetric;
import com.azure.search.documents.indexes.models.VectorSearchProfile;
import com.azure.search.documents.models.IndexDocumentsResult;
import com.azure.search.documents.models.IndexingResult;
import com.azure.search.documents.models.SearchOptions;
import com.azure.search.documents.models.VectorSearchOptions;
import com.azure.search.documents.models.VectorizedQuery;
import com.azure.search.documents.models.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
Expand Down Expand Up @@ -64,6 +60,7 @@
* @author Xiangyang Yu
* @author Christian Tzolov
* @author Josh Long
* @author Alessio Bertazzo
*/
public class AzureVectorStore implements VectorStore, InitializingBean {

Expand Down Expand Up @@ -91,6 +88,8 @@ public class AzureVectorStore implements VectorStore, InitializingBean {

private static final String METADATA_FIELD_PREFIX = "meta_";

private static final String SEMANTIC_SEARCH_CONFIG_NAME = "default";

private final SearchIndexClient searchIndexClient;

private final EmbeddingModel embeddingModel;
Expand Down Expand Up @@ -281,24 +280,84 @@ public List<Document> similaritySearch(String query) {
public List<Document> similaritySearch(SearchRequest request) {

Assert.notNull(request, "The search request must not be null.");
Assert.isTrue(request.isVectorSearch() && !request.isFullTextSearch(),
"The search request must be a vector search.");

return this.search(request);
}

@Override
public List<Document> fullTextSearch(String query) {
return this.search(SearchRequest.query(query)
.withVectorSearch(false)
.withFullTextSearch(true)
.withTopK(this.defaultTopK)
.withSimilarityThreshold(this.defaultSimilarityThreshold));
}

@Override
public List<Document> fullTextSearch(SearchRequest request) {

Assert.notNull(request, "The search request must not be null.");
Assert.isTrue(!request.isVectorSearch() && request.isFullTextSearch(),
"The search request must be a full text search.");

return this.search(request);
}

@Override
public List<Document> hybridSearch(String query) {
return this.hybridSearch(SearchRequest.query(query)
.withVectorSearch(true)
.withFullTextSearch(true)
.withTopK(this.defaultTopK)
.withSimilarityThreshold(this.defaultSimilarityThreshold));
}

@Override
public List<Document> hybridSearch(SearchRequest request) {

Assert.notNull(request, "The search request must not be null.");
Assert.isTrue(request.isVectorSearch() && request.isFullTextSearch(),
"The search request must be a hybrid (vector + full text) search.");

return this.search(request);
}

private List<Document> search(SearchRequest request) {

var searchOptions = new SearchOptions().setTop(request.getTopK());

if (request.isVectorSearch()) {
var searchEmbedding = embeddingModel.embed(request.getQuery());

var searchEmbedding = embeddingModel.embed(request.getQuery());
final var vectorQuery = new VectorizedQuery(EmbeddingUtils.toList(searchEmbedding))
.setKNearestNeighborsCount(request.getTopK())
// Set the fields to compare the vector against. This is a comma-delimited
// list of field names.
.setFields(EMBEDDING_FIELD_NAME);

final var vectorQuery = new VectorizedQuery(EmbeddingUtils.toList(searchEmbedding))
.setKNearestNeighborsCount(request.getTopK())
// Set the fields to compare the vector against. This is a comma-delimited
// list of field names.
.setFields(EMBEDDING_FIELD_NAME);
searchOptions.setVectorSearchOptions(new VectorSearchOptions().setQueries(vectorQuery));
}

String searchText = null;
if (request.isFullTextSearch()) {
searchText = request.getQuery();

var searchOptions = new SearchOptions()
.setVectorSearchOptions(new VectorSearchOptions().setQueries(vectorQuery));
if (request.isReRank()) {
searchOptions
.setSemanticSearchOptions(
new SemanticSearchOptions().setSemanticConfigurationName(SEMANTIC_SEARCH_CONFIG_NAME))
.setQueryType(QueryType.SEMANTIC);
}
}

if (request.hasFilterExpression()) {
String oDataFilter = this.filterExpressionConverter.convertExpression(request.getFilterExpression());
searchOptions.setFilter(oDataFilter);
}

final var searchResults = searchClient.search(null, searchOptions, Context.NONE);
final var searchResults = searchClient.search(searchText, searchOptions, Context.NONE);

return searchResults.stream()
.filter(result -> result.getScore() >= request.getSimilarityThreshold())
Expand Down
Loading