From 0ece57234a1918e22cb4e4b2c4bb0136d8e97088 Mon Sep 17 00:00:00 2001 From: Xin Liu Date: Fri, 8 Mar 2024 11:25:35 +0900 Subject: [PATCH] feat!: new `score_threshold` arg for `search_points` method Signed-off-by: Xin Liu --- examples/src/main.rs | 58 ++++++++----- src/lib.rs | 197 ++++++++++++++++++++++++++++++------------- 2 files changed, 176 insertions(+), 79 deletions(-) diff --git a/examples/src/main.rs b/examples/src/main.rs index 9ed4d68..db9cc89 100644 --- a/examples/src/main.rs +++ b/examples/src/main.rs @@ -1,5 +1,5 @@ -use serde_json::{json}; use qdrant::*; +use serde_json::json; #[tokio::main(flavor = "current_thread")] async fn main() -> Result<(), Box> { @@ -9,47 +9,67 @@ async fn main() -> Result<(), Box> { println!("Create collection result is {:?}", r); let mut points = Vec::::new(); - points.push(Point{ - id: PointId::Num(1), vector: vec!(0.05, 0.61, 0.76, 0.74), payload: json!({"city": "Berlin"}).as_object().map(|m| m.to_owned()) + points.push(Point { + id: PointId::Num(1), + vector: vec![0.05, 0.61, 0.76, 0.74], + payload: json!({"city": "Berlin"}).as_object().map(|m| m.to_owned()), }); - points.push(Point{ - id: PointId::Num(2), vector: vec!(0.19, 0.81, 0.75, 0.11), payload: json!({"city": "London"}).as_object().map(|m| m.to_owned()) + points.push(Point { + id: PointId::Num(2), + vector: vec![0.19, 0.81, 0.75, 0.11], + payload: json!({"city": "London"}).as_object().map(|m| m.to_owned()), }); - points.push(Point{ - id: PointId::Num(3), vector: vec!(0.36, 0.55, 0.47, 0.94), payload: json!({"city": "Moscow"}).as_object().map(|m| m.to_owned()) + points.push(Point { + id: PointId::Num(3), + vector: vec![0.36, 0.55, 0.47, 0.94], + payload: json!({"city": "Moscow"}).as_object().map(|m| m.to_owned()), }); - points.push(Point{ - id: PointId::Num(4), vector: vec!(0.18, 0.01, 0.85, 0.80), payload: json!({"city": "New York"}).as_object().map(|m| m.to_owned()) + points.push(Point { + id: PointId::Num(4), + vector: vec![0.18, 0.01, 0.85, 0.80], + payload: json!({"city": "New York"}) + .as_object() + .map(|m| m.to_owned()), }); - points.push(Point{ - id: PointId::Num(5), vector: vec!(0.24, 0.18, 0.22, 0.44), payload: json!({"city": "Beijing"}).as_object().map(|m| m.to_owned()) + points.push(Point { + id: PointId::Num(5), + vector: vec![0.24, 0.18, 0.22, 0.44], + payload: json!({"city": "Beijing"}).as_object().map(|m| m.to_owned()), }); - points.push(Point{ - id: PointId::Num(6), vector: vec!(0.35, 0.08, 0.11, 0.44), payload: json!({"city": "Mumbai"}).as_object().map(|m| m.to_owned()) + points.push(Point { + id: PointId::Num(6), + vector: vec![0.35, 0.08, 0.11, 0.44], + payload: json!({"city": "Mumbai"}).as_object().map(|m| m.to_owned()), }); let r = client.upsert_points("my_test", points).await; println!("Upsert points result is {:?}", r); - println!("The collection size is {}", client.collection_info("my_test").await); + println!( + "The collection size is {}", + client.collection_info("my_test").await + ); let p = client.get_point("my_test", 2).await; println!("The second point is {:?}", p); - let ps = client.get_points("my_test", vec!(1, 2, 3, 4, 5, 6)).await; + let ps = client.get_points("my_test", vec![1, 2, 3, 4, 5, 6]).await; println!("The 1-6 points are {:?}", ps); let q = vec![0.2, 0.1, 0.9, 0.7]; - let r = client.search_points("my_test", q, 2).await; + let r = client.search_points("my_test", q, 2, None).await; println!("Search result points are {:?}", r); - let r = client.delete_points("my_test", vec!(1, 4)).await; + let r = client.delete_points("my_test", vec![1, 4]).await; println!("Delete points result is {:?}", r); - println!("The collection size is {}", client.collection_info("my_test").await); + println!( + "The collection size is {}", + client.collection_info("my_test").await + ); let q = vec![0.2, 0.1, 0.9, 0.7]; - let r = client.search_points("my_test", q, 2).await; + let r = client.search_points("my_test", q, 2, None).await; println!("Search result points are {:?}", r); Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index 7f6ac71..13b3bcc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,7 @@ use anyhow::{anyhow, Error}; use serde::{Deserialize, Serialize}; -use serde_json::{Map, Value}; use serde_json::json; +use serde_json::{Map, Value}; #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] @@ -62,7 +62,12 @@ impl Qdrant { /// Shortcut functions pub async fn collection_info(&self, collection_name: &str) -> u64 { let v = self.collection_info_api(collection_name).await.unwrap(); - v.get("result").unwrap().get("points_count").unwrap().as_u64().unwrap() + v.get("result") + .unwrap() + .get("points_count") + .unwrap() + .as_u64() + .unwrap() } pub async fn create_collection(&self, collection_name: &str, size: u32) -> Result<(), Error> { @@ -76,26 +81,45 @@ impl Qdrant { self.create_collection_api(collection_name, ¶ms).await } - pub async fn upsert_points(&self, collection_name: &str, points: Vec) -> Result<(), Error> { + pub async fn upsert_points( + &self, + collection_name: &str, + points: Vec, + ) -> Result<(), Error> { let params = json!({ "points": points, }); self.upsert_points_api(collection_name, ¶ms).await } - pub async fn search_points(&self, collection_name: &str, point: Vec, limit: u64) -> Vec { + pub async fn search_points( + &self, + collection_name: &str, + point: Vec, + limit: u64, + score_threshold: Option, + ) -> Vec { + let score_threshold = match score_threshold { + Some(v) => v, + None => 0.0, + }; + let params = json!({ "vector": point, "limit": limit, "with_payload": true, "with_vector": true, + "score_threshold": score_threshold, }); - let v = self.search_points_api(collection_name, ¶ms).await.unwrap(); - let rs : &Vec = v.get("result").unwrap().as_array().unwrap(); - let mut sps : Vec = Vec::::new(); + let v = self + .search_points_api(collection_name, ¶ms) + .await + .unwrap(); + let rs: &Vec = v.get("result").unwrap().as_array().unwrap(); + let mut sps: Vec = Vec::::new(); for r in rs { - let sp : ScoredPoint = serde_json::from_value(r.clone()).unwrap(); + let sp: ScoredPoint = serde_json::from_value(r.clone()).unwrap(); sps.push(sp); } sps @@ -109,10 +133,10 @@ impl Qdrant { }); let v = self.get_points_api(collection_name, ¶ms).await.unwrap(); - let rs : &Vec = v.get("result").unwrap().as_array().unwrap(); - let mut ps : Vec = Vec::::new(); + let rs: &Vec = v.get("result").unwrap().as_array().unwrap(); + let mut ps: Vec = Vec::::new(); for r in rs { - let p : Point = serde_json::from_value(r.clone()).unwrap(); + let p: Point = serde_json::from_value(r.clone()).unwrap(); ps.push(p); } ps @@ -133,61 +157,81 @@ impl Qdrant { /// REST API functions pub async fn collection_info_api(&self, collection_name: &str) -> Result { - let url = format!( - "{}/collections/{}", - self.url_base, - collection_name, - ); + let url = format!("{}/collections/{}", self.url_base, collection_name,); let client = reqwest::Client::new(); - let ci = client.get(&url).header("Content-Type", "application/json").send().await?.json().await?; + let ci = client + .get(&url) + .header("Content-Type", "application/json") + .send() + .await? + .json() + .await?; Ok(ci) } - - pub async fn create_collection_api(&self, collection_name: &str, params: &Value) -> Result<(), Error> { - let url = format!( - "{}/collections/{}", - self.url_base, - collection_name, - ); + pub async fn create_collection_api( + &self, + collection_name: &str, + params: &Value, + ) -> Result<(), Error> { + let url = format!("{}/collections/{}", self.url_base, collection_name,); let body = serde_json::to_vec(params).unwrap_or_default(); let client = reqwest::Client::new(); - let res = client.put(&url).header("Content-Type", "application/json").body(body).send().await?; + let res = client + .put(&url) + .header("Content-Type", "application/json") + .body(body) + .send() + .await?; if res.status().is_success() { Ok(()) } else { - Err(anyhow!("Failed to create collection: {}", res.status().as_str())) + Err(anyhow!( + "Failed to create collection: {}", + res.status().as_str() + )) } } pub async fn delete_collection_api(&self, collection_name: &str) -> Result<(), Error> { - let url = format!( - "{}/collections/{}", - self.url_base, - collection_name, - ); + let url = format!("{}/collections/{}", self.url_base, collection_name,); let client = reqwest::Client::new(); - let res = client.delete(&url).header("Content-Type", "application/json").send().await?; + let res = client + .delete(&url) + .header("Content-Type", "application/json") + .send() + .await?; if res.status().is_success() { Ok(()) } else { - Err(anyhow!("Failed to delete collection: {}", res.status().as_str())) + Err(anyhow!( + "Failed to delete collection: {}", + res.status().as_str() + )) } } - pub async fn upsert_points_api(&self, collection_name: &str, params: &Value) -> Result<(), Error> { + pub async fn upsert_points_api( + &self, + collection_name: &str, + params: &Value, + ) -> Result<(), Error> { let url = format!( "{}/collections/{}/points?wait=true", - self.url_base, - collection_name, + self.url_base, collection_name, ); let body = serde_json::to_vec(params).unwrap_or_default(); let client = reqwest::Client::new(); - let res = client.put(&url).header("Content-Type", "application/json").body(body).send().await?; + let res = client + .put(&url) + .header("Content-Type", "application/json") + .body(body) + .send() + .await?; if res.status().is_success() { let v = res.json::().await?; let status = v.get("status").unwrap().as_str().unwrap(); @@ -197,65 +241,98 @@ impl Qdrant { Err(anyhow!("Failed to upsert points. Status = {}", status)) } } else { - Err(anyhow!("Failed to upsert points: {}", res.status().as_str())) + Err(anyhow!( + "Failed to upsert points: {}", + res.status().as_str() + )) } } - - pub async fn search_points_api(&self, collection_name: &str, params: &Value) -> Result { + pub async fn search_points_api( + &self, + collection_name: &str, + params: &Value, + ) -> Result { let url = format!( "{}/collections/{}/points/search", - self.url_base, - collection_name, + self.url_base, collection_name, ); let body = serde_json::to_vec(params).unwrap_or_default(); let client = reqwest::Client::new(); - let json = client.post(&url).header("Content-Type", "application/json").body(body).send().await?.json().await?; + let json = client + .post(&url) + .header("Content-Type", "application/json") + .body(body) + .send() + .await? + .json() + .await?; Ok(json) } - pub async fn get_points_api(&self, collection_name: &str, params: &Value) -> Result { - let url = format!( - "{}/collections/{}/points", - self.url_base, - collection_name, - ); + pub async fn get_points_api( + &self, + collection_name: &str, + params: &Value, + ) -> Result { + let url = format!("{}/collections/{}/points", self.url_base, collection_name,); let body = serde_json::to_vec(params).unwrap_or_default(); let client = reqwest::Client::new(); - let json = client.post(&url).header("Content-Type", "application/json").body(body).send().await?.json().await?; + let json = client + .post(&url) + .header("Content-Type", "application/json") + .body(body) + .send() + .await? + .json() + .await?; Ok(json) } pub async fn get_point_api(&self, collection_name: &str, id: u64) -> Result { let url = format!( "{}/collections/{}/points/{}", - self.url_base, - collection_name, - id, + self.url_base, collection_name, id, ); let client = reqwest::Client::new(); - let json = client.get(&url).header("Content-Type", "application/json").send().await?.json().await?; + let json = client + .get(&url) + .header("Content-Type", "application/json") + .send() + .await? + .json() + .await?; Ok(json) } - pub async fn delete_points_api(&self, collection_name: &str, params: &Value) -> Result<(), Error> { + pub async fn delete_points_api( + &self, + collection_name: &str, + params: &Value, + ) -> Result<(), Error> { let url = format!( "{}/collections/{}/points/delete?wait=true", - self.url_base, - collection_name, + self.url_base, collection_name, ); let body = serde_json::to_vec(params).unwrap_or_default(); let client = reqwest::Client::new(); - let res = client.post(&url).header("Content-Type", "application/json").body(body).send().await?; + let res = client + .post(&url) + .header("Content-Type", "application/json") + .body(body) + .send() + .await?; if res.status().is_success() { Ok(()) } else { - Err(anyhow!("Failed to delete points: {}", res.status().as_str())) + Err(anyhow!( + "Failed to delete points: {}", + res.status().as_str() + )) } } - }