Skip to content

Commit

Permalink
Merge pull request #5 from postgresml/silas-patch-js-rerank
Browse files Browse the repository at this point in the history
Patched re-ranking for JavaScript
  • Loading branch information
SilasMarvin authored Jul 1, 2024
2 parents 17a2f7f + 616ead8 commit acfed87
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 10 deletions.
2 changes: 1 addition & 1 deletion korvus/javascript/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "korvus",
"version": "1.1.2",
"version": "1.1.3",
"description": "Open Source Alternative for Building End-to-End Vector Search Applications without OpenAI & Pinecone",
"keywords": [
"postgres",
Expand Down
40 changes: 40 additions & 0 deletions korvus/javascript/tests/typescript-tests/test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,46 @@ it("can vector search with query builder", async () => {
await collection.archive();
});

it("can vector search with re-ranking", async () => {
let pipeline = korvus.newPipeline("1", {
title: {
semantic_search: { model: "intfloat/e5-small-v2", parameters: { prompt: "passage: " } },
full_text_search: { configuration: "english" },
},
body: {
splitter: { model: "recursive_character" },
semantic_search: {
model: "text-embedding-ada-002",
source: "openai",
},
},
});
let collection = korvus.newCollection("test_j_c_cvswr_0")
await collection.add_pipeline(pipeline)
await collection.upsert_documents(generate_dummy_documents(5))
let results = await collection.vector_search(
{
query: {
fields: {
title: { query: "Test document: 2", parameters: { prompt: "query: " }, full_text_filter: "test" },
body: { query: "Test document: 2" },
},
filter: { id: { "$gt": 2 } },
},
rerank: {
model: "mixedbread-ai/mxbai-rerank-base-v1",
query: "Test query",
num_documents_to_rerank: 100
},
limit: 5,
},
pipeline,
);
let ids = results.map(r => r["document"]["id"]);
expect(ids).toEqual([4, 3, 3, 4]);
await collection.archive();
});

///////////////////////////////////////////////////
// Test rag ///////////////////////////////////////
///////////////////////////////////////////////////
Expand Down
54 changes: 54 additions & 0 deletions korvus/python/tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,60 @@ async def test_can_vector_search_with_query_builder():
await collection.archive()


@pytest.mark.asyncio
async def test_can_vector_search_with_rerank():
pipeline = korvus.Pipeline(
"test_p_p_tcvswr_0",
{
"title": {
"semantic_search": {
"model": "intfloat/e5-small-v2",
"parameters": {"prompt": "passage: "},
},
"full_text_search": {"configuration": "english"},
},
"text": {
"splitter": {"model": "recursive_character"},
"semantic_search": {
"model": "intfloat/e5-small-v2",
"parameters": {"prompt": "passage: "},
},
},
},
)
collection = korvus.Collection("test_p_c_tcvs_3")
await collection.add_pipeline(pipeline)
await collection.upsert_documents(generate_dummy_documents(5))
results = await collection.vector_search(
{
"query": {
"fields": {
"title": {
"query": "Test document: 2",
"parameters": {"prompt": "passage: "},
"full_text_filter": "test",
},
"text": {
"query": "Test document: 2",
"parameters": {"prompt": "passage: "},
},
},
"filter": {"id": {"$gt": 2}},
},
"rerank": {
"model": "mixedbread-ai/mxbai-rerank-base-v1",
"query": "Test query",
"num_documents_to_rerank": 100,
},
"limit": 5,
},
pipeline,
)
ids = [result["document"]["id"] for result in results]
assert ids == [3, 3, 4, 4]
await collection.archive()


###################################################
## Test RAG #######################################
###################################################
Expand Down
2 changes: 1 addition & 1 deletion korvus/src/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ impl Collection {
let pool = get_or_initialize_pool(&self.database_url).await?;
let pipelines_table_name = format!("{}.pipelines", project_info.name);
let exists: bool = sqlx::query_scalar(&query_builder!(
"SELECT EXISTS (SELECT id FROM %s WHERE name = $1 AND active = TRUE)",
"SELECT EXISTS (SELECT id FROM %s WHERE name = $1)",
pipelines_table_name
))
.bind(&pipeline.name)
Expand Down
14 changes: 7 additions & 7 deletions korvus/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ mod tests {
#[tokio::test]
async fn can_upsert_documents_and_enable_pipeline() -> anyhow::Result<()> {
internal_init_logger(None, None).ok();
let collection_name = "test_r_c_cudaep_43";
let collection_name = "test_r_c_cudaep_44";
let mut collection = Collection::new(collection_name, None)?;
let pipeline_name = "0";
let mut pipeline = Pipeline::new(
Expand Down Expand Up @@ -654,7 +654,7 @@ mod tests {
#[tokio::test]
async fn random_pipelines_documents_test() -> anyhow::Result<()> {
internal_init_logger(None, None).ok();
let collection_name = "test_r_c_rpdt_3";
let collection_name = "test_r_c_rpdt_4";
let mut collection = Collection::new(collection_name, None)?;
let documents = generate_dummy_documents(6);
collection
Expand Down Expand Up @@ -818,7 +818,7 @@ mod tests {
#[tokio::test]
async fn pipeline_sync_status() -> anyhow::Result<()> {
internal_init_logger(None, None).ok();
let collection_name = "test_r_c_pss_6";
let collection_name = "test_r_c_pss_7";
let mut collection = Collection::new(collection_name, None)?;
let pipeline_name = "0";
let mut pipeline = Pipeline::new(
Expand Down Expand Up @@ -1140,7 +1140,7 @@ mod tests {
#[tokio::test]
async fn can_search_with_remote_embeddings() -> anyhow::Result<()> {
internal_init_logger(None, None).ok();
let collection_name = "test r_c_cswre_66";
let collection_name = "test r_c_cswre_67";
let mut collection = Collection::new(collection_name, None)?;
let documents = generate_dummy_documents(10);
collection.upsert_documents(documents.clone(), None).await?;
Expand Down Expand Up @@ -1314,7 +1314,7 @@ mod tests {
#[tokio::test]
async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> {
internal_init_logger(None, None).ok();
let collection_name = "test r_c_cvswre_7";
let collection_name = "test r_c_cvswre_8";
let mut collection = Collection::new(collection_name, None)?;
let documents = generate_dummy_documents(10);
collection.upsert_documents(documents.clone(), None).await?;
Expand Down Expand Up @@ -1455,7 +1455,7 @@ mod tests {
async fn can_vector_search_with_local_embeddings_and_specify_document_keys(
) -> anyhow::Result<()> {
internal_init_logger(None, None).ok();
let collection_name = "test r_c_cvswleasdk_1";
let collection_name = "test r_c_cvswleasdk_2";
let mut collection = Collection::new(collection_name, None)?;
let documents = generate_dummy_documents(2);
collection.upsert_documents(documents.clone(), None).await?;
Expand Down Expand Up @@ -1556,7 +1556,7 @@ mod tests {
#[tokio::test]
async fn can_vector_search_with_local_embeddings_and_rerank() -> anyhow::Result<()> {
internal_init_logger(None, None).ok();
let collection_name = "test r_c_cvswlear_1";
let collection_name = "test r_c_cvswlear_2";
let mut collection = Collection::new(collection_name, None)?;
let documents = generate_dummy_documents(10);
collection.upsert_documents(documents.clone(), None).await?;
Expand Down
4 changes: 3 additions & 1 deletion korvus/src/vector_search_query_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ const fn default_num_documents_to_rerank() -> u64 {
10
}

#[serde_as]
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(deny_unknown_fields)]
struct ValidRerank {
query: String,
model: String,
#[serde(default = "default_num_documents_to_rerank")]
#[serde_as(as = "FromInto<CustomU64Convertor>")]
num_documents_to_rerank: u64,
parameters: Option<Json>,
}
Expand All @@ -61,7 +63,7 @@ const fn default_limit() -> u64 {

#[serde_as]
#[derive(Debug, Deserialize, Serialize, Clone)]
// #[serde(deny_unknown_fields)]
#[serde(deny_unknown_fields)]
pub struct ValidQuery {
query: ValidQueryActions,
// Need this when coming from JavaScript as everything is an f64 from JS
Expand Down

0 comments on commit acfed87

Please sign in to comment.