From af012d0ee4a50fc20fc24b8bf320893f52a9dbe6 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 1 Jul 2024 15:03:20 -0700 Subject: [PATCH 1/2] Patched re-ranking for JavaScript --- .../javascript/tests/typescript-tests/test.ts | 40 ++++++++++++++ korvus/python/tests/test.py | 54 +++++++++++++++++++ korvus/src/collection.rs | 2 +- korvus/src/lib.rs | 14 ++--- korvus/src/vector_search_query_builder.rs | 4 +- 5 files changed, 105 insertions(+), 9 deletions(-) diff --git a/korvus/javascript/tests/typescript-tests/test.ts b/korvus/javascript/tests/typescript-tests/test.ts index b9c8b04..f991840 100644 --- a/korvus/javascript/tests/typescript-tests/test.ts +++ b/korvus/javascript/tests/typescript-tests/test.ts @@ -164,6 +164,46 @@ it("can vector search with query builder", async () => { await collection.archive(); }); +it("can vector search with re-ranking", async () => { + let pipeline = korvus.newPipeline("1", { + title: { + semantic_search: { model: "intfloat/e5-small-v2", parameters: { prompt: "passage: " } }, + full_text_search: { configuration: "english" }, + }, + body: { + splitter: { model: "recursive_character" }, + semantic_search: { + model: "text-embedding-ada-002", + source: "openai", + }, + }, + }); + let collection = korvus.newCollection("test_j_c_cvswr_0") + await collection.add_pipeline(pipeline) + await collection.upsert_documents(generate_dummy_documents(5)) + let results = await collection.vector_search( + { + query: { + fields: { + title: { query: "Test document: 2", parameters: { prompt: "query: " }, full_text_filter: "test" }, + body: { query: "Test document: 2" }, + }, + filter: { id: { "$gt": 2 } }, + }, + rerank: { + model: "mixedbread-ai/mxbai-rerank-base-v1", + query: "Test query", + num_documents_to_rerank: 100 + }, + limit: 5, + }, + pipeline, + ); + let ids = results.map(r => r["document"]["id"]); + expect(ids).toEqual([4, 3, 3, 4]); + await collection.archive(); +}); + /////////////////////////////////////////////////// // Test rag /////////////////////////////////////// /////////////////////////////////////////////////// diff --git a/korvus/python/tests/test.py b/korvus/python/tests/test.py index 51a61d4..029477a 100644 --- a/korvus/python/tests/test.py +++ b/korvus/python/tests/test.py @@ -212,6 +212,60 @@ async def test_can_vector_search_with_query_builder(): await collection.archive() +@pytest.mark.asyncio +async def test_can_vector_search_with_rerank(): + pipeline = korvus.Pipeline( + "test_p_p_tcvswr_0", + { + "title": { + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": {"prompt": "passage: "}, + }, + "full_text_search": {"configuration": "english"}, + }, + "text": { + "splitter": {"model": "recursive_character"}, + "semantic_search": { + "model": "intfloat/e5-small-v2", + "parameters": {"prompt": "passage: "}, + }, + }, + }, + ) + collection = korvus.Collection("test_p_c_tcvs_3") + await collection.add_pipeline(pipeline) + await collection.upsert_documents(generate_dummy_documents(5)) + results = await collection.vector_search( + { + "query": { + "fields": { + "title": { + "query": "Test document: 2", + "parameters": {"prompt": "passage: "}, + "full_text_filter": "test", + }, + "text": { + "query": "Test document: 2", + "parameters": {"prompt": "passage: "}, + }, + }, + "filter": {"id": {"$gt": 2}}, + }, + "rerank": { + "model": "mixedbread-ai/mxbai-rerank-base-v1", + "query": "Test query", + "num_documents_to_rerank": 100, + }, + "limit": 5, + }, + pipeline, + ) + ids = [result["document"]["id"] for result in results] + assert ids == [3, 3, 4, 4] + await collection.archive() + + ################################################### ## Test RAG ####################################### ################################################### diff --git a/korvus/src/collection.rs b/korvus/src/collection.rs index d8c9c33..fbcd50d 100644 --- a/korvus/src/collection.rs +++ b/korvus/src/collection.rs @@ -345,7 +345,7 @@ impl Collection { let pool = get_or_initialize_pool(&self.database_url).await?; let pipelines_table_name = format!("{}.pipelines", project_info.name); let exists: bool = sqlx::query_scalar(&query_builder!( - "SELECT EXISTS (SELECT id FROM %s WHERE name = $1 AND active = TRUE)", + "SELECT EXISTS (SELECT id FROM %s WHERE name = $1)", pipelines_table_name )) .bind(&pipeline.name) diff --git a/korvus/src/lib.rs b/korvus/src/lib.rs index 6f1f19b..d26de60 100644 --- a/korvus/src/lib.rs +++ b/korvus/src/lib.rs @@ -610,7 +610,7 @@ mod tests { #[tokio::test] async fn can_upsert_documents_and_enable_pipeline() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_cudaep_43"; + let collection_name = "test_r_c_cudaep_44"; let mut collection = Collection::new(collection_name, None)?; let pipeline_name = "0"; let mut pipeline = Pipeline::new( @@ -654,7 +654,7 @@ mod tests { #[tokio::test] async fn random_pipelines_documents_test() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_rpdt_3"; + let collection_name = "test_r_c_rpdt_4"; let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(6); collection @@ -818,7 +818,7 @@ mod tests { #[tokio::test] async fn pipeline_sync_status() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test_r_c_pss_6"; + let collection_name = "test_r_c_pss_7"; let mut collection = Collection::new(collection_name, None)?; let pipeline_name = "0"; let mut pipeline = Pipeline::new( @@ -1140,7 +1140,7 @@ mod tests { #[tokio::test] async fn can_search_with_remote_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test r_c_cswre_66"; + let collection_name = "test r_c_cswre_67"; let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; @@ -1314,7 +1314,7 @@ mod tests { #[tokio::test] async fn can_vector_search_with_remote_embeddings() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test r_c_cvswre_7"; + let collection_name = "test r_c_cvswre_8"; let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; @@ -1455,7 +1455,7 @@ mod tests { async fn can_vector_search_with_local_embeddings_and_specify_document_keys( ) -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test r_c_cvswleasdk_1"; + let collection_name = "test r_c_cvswleasdk_2"; let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(2); collection.upsert_documents(documents.clone(), None).await?; @@ -1556,7 +1556,7 @@ mod tests { #[tokio::test] async fn can_vector_search_with_local_embeddings_and_rerank() -> anyhow::Result<()> { internal_init_logger(None, None).ok(); - let collection_name = "test r_c_cvswlear_1"; + let collection_name = "test r_c_cvswlear_2"; let mut collection = Collection::new(collection_name, None)?; let documents = generate_dummy_documents(10); collection.upsert_documents(documents.clone(), None).await?; diff --git a/korvus/src/vector_search_query_builder.rs b/korvus/src/vector_search_query_builder.rs index 24de38a..35884f3 100644 --- a/korvus/src/vector_search_query_builder.rs +++ b/korvus/src/vector_search_query_builder.rs @@ -45,12 +45,14 @@ const fn default_num_documents_to_rerank() -> u64 { 10 } +#[serde_as] #[derive(Debug, Deserialize, Serialize, Clone)] #[serde(deny_unknown_fields)] struct ValidRerank { query: String, model: String, #[serde(default = "default_num_documents_to_rerank")] + #[serde_as(as = "FromInto")] num_documents_to_rerank: u64, parameters: Option, } @@ -61,7 +63,7 @@ const fn default_limit() -> u64 { #[serde_as] #[derive(Debug, Deserialize, Serialize, Clone)] -// #[serde(deny_unknown_fields)] +#[serde(deny_unknown_fields)] pub struct ValidQuery { query: ValidQueryActions, // Need this when coming from JavaScript as everything is an f64 from JS From 616ead83d94a7b9d79196a15cc0467edda1daf4f Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 1 Jul 2024 15:04:14 -0700 Subject: [PATCH 2/2] Bump version --- korvus/javascript/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/korvus/javascript/package.json b/korvus/javascript/package.json index af68f97..d9702ad 100644 --- a/korvus/javascript/package.json +++ b/korvus/javascript/package.json @@ -1,6 +1,6 @@ { "name": "korvus", - "version": "1.1.2", + "version": "1.1.3", "description": "Open Source Alternative for Building End-to-End Vector Search Applications without OpenAI & Pinecone", "keywords": [ "postgres",