Skip to content

Commit 337fe73

Browse files
[PERF] Validate KNN Projection Input Parameters Before Applying knn_filter (#3205)
## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Ensure input parameters for projections are validated before applying `knn_filter` to avoid unnecessary processing - Provide precise error message when scan operator inputs are invalid - Add unit tests to verify KNN queries and error scenarios - New functionality - ... ## Test plan *How are these changes tested?* Added new unit tests - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?*
1 parent 8597bc7 commit 337fe73

File tree

1 file changed

+201
-10
lines changed

1 file changed

+201
-10
lines changed

rust/worker/src/server.rs

Lines changed: 201 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,19 @@ impl WorkerServer {
143143
let collection = scan
144144
.collection
145145
.ok_or(Status::invalid_argument("Invalid Collection"))?;
146+
146147
let collection_uuid = CollectionUuid::from_str(&collection.id)
147-
.map_err(|err| Status::invalid_argument(err.to_string()))?;
148+
.map_err(|_| Status::invalid_argument("Invalid Collection UUID"))?;
149+
148150
let vector_uuid = SegmentUuid::from_str(&scan.knn_id)
149-
.map_err(|err| Status::invalid_argument(err.to_string()))?;
151+
.map_err(|_| Status::invalid_argument("Invalid UUID for Vector segment"))?;
152+
150153
let metadata_uuid = SegmentUuid::from_str(&scan.metadata_id)
151-
.map_err(|err| Status::invalid_argument(err.to_string()))?;
154+
.map_err(|_| Status::invalid_argument("Invalid UUID for Metadata segment"))?;
155+
152156
let record_uuid = SegmentUuid::from_str(&scan.record_id)
153-
.map_err(|err| Status::invalid_argument(err.to_string()))?;
157+
.map_err(|_| Status::invalid_argument("Invalid UUID for Record segment"))?;
158+
154159
Ok((
155160
FetchLogOperator {
156161
log_client: self.log.clone(),
@@ -213,13 +218,17 @@ impl WorkerServer {
213218
let scan = get_inner
214219
.scan
215220
.ok_or(Status::invalid_argument("Invalid Scan Operator"))?;
221+
216222
let (fetch_log_operator, fetch_segment_operator) = self.decompose_proto_scan(scan)?;
223+
217224
let filter = get_inner
218225
.filter
219226
.ok_or(Status::invalid_argument("Invalid Filter Operator"))?;
227+
220228
let limit = get_inner
221229
.limit
222230
.ok_or(Status::invalid_argument("Invalid Scan Operator"))?;
231+
223232
let projection = get_inner
224233
.projection
225234
.ok_or(Status::invalid_argument("Invalid Projection Operator"))?;
@@ -248,17 +257,28 @@ impl WorkerServer {
248257
) -> Result<Response<KnnBatchResult>, Status> {
249258
let dispatcher = self.clone_dispatcher()?;
250259
let system = self.clone_system()?;
260+
251261
let knn_inner = knn.into_inner();
262+
252263
let scan = knn_inner
253264
.scan
254265
.ok_or(Status::invalid_argument("Invalid Scan Operator"))?;
266+
255267
let (fetch_log_operator, fetch_segment_operator) = self.decompose_proto_scan(scan)?;
268+
256269
let filter = knn_inner
257270
.filter
258271
.ok_or(Status::invalid_argument("Invalid Filter Operator"))?;
272+
259273
let knn = knn_inner
260274
.knn
261-
.ok_or(Status::invalid_argument("Invalid Scan Operator"))?;
275+
.ok_or(Status::invalid_argument("Invalid Knn Operator"))?;
276+
277+
let projection = knn_inner
278+
.projection
279+
.ok_or(Status::invalid_argument("Invalid Projection Operator"))?;
280+
let knn_projection = KnnProjectionOperator::try_from(projection)
281+
.map_err(|e| Status::invalid_argument(format!("Invalid Projection Operator: {}", e)))?;
262282

263283
if knn.embeddings.is_empty() {
264284
return Ok(Response::new(to_proto_knn_batch_result(Vec::new())?));
@@ -290,11 +310,6 @@ impl WorkerServer {
290310
}
291311
};
292312

293-
let projection = knn_inner
294-
.projection
295-
.ok_or(Status::invalid_argument("Invalid Projection Operator"))?;
296-
let knn_projection = KnnProjectionOperator::try_from(projection)?;
297-
298313
let knn_orchestrator_futures = from_proto_knn(knn)?
299314
.into_iter()
300315
.map(|knn| {
@@ -589,4 +604,180 @@ mod tests {
589604
assert!(response.is_err());
590605
assert_eq!(response.unwrap_err().code(), tonic::Code::InvalidArgument);
591606
}
607+
608+
fn gen_knn_request(mut scan_operator: Option<chroma_proto::ScanOperator>) -> chroma_proto::KnnPlan {
609+
if scan_operator.is_none() {
610+
scan_operator = Some(scan());
611+
}
612+
chroma_proto::KnnPlan {
613+
scan: scan_operator,
614+
filter: Some(chroma_proto::FilterOperator {
615+
ids: None,
616+
r#where: None,
617+
where_document: None,
618+
}),
619+
knn: Some(chroma_proto::KnnOperator {
620+
embeddings: vec![],
621+
fetch: 0,
622+
}),
623+
projection: Some(chroma_proto::KnnProjectionOperator {
624+
projection: Some(chroma_proto::ProjectionOperator {
625+
document: false,
626+
embedding: false,
627+
metadata: false,
628+
}),
629+
distance: false,
630+
}),
631+
}
632+
}
633+
634+
#[tokio::test]
635+
async fn validate_knn_plan_empty_embeddings() {
636+
let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap();
637+
let response = executor.knn(gen_knn_request(None)).await;
638+
assert!(response.is_ok());
639+
assert_eq!(response.unwrap().into_inner().results.len(), 0);
640+
}
641+
642+
#[tokio::test]
643+
async fn validate_knn_plan_filter() {
644+
let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap();
645+
let mut request = gen_knn_request(None);
646+
request.filter = None;
647+
let response = executor.knn(request).await;
648+
let err = response.unwrap_err();
649+
assert_eq!(err.code(), tonic::Code::InvalidArgument);
650+
assert!(
651+
err.message().to_lowercase().contains("filter operator"),
652+
"{}",
653+
err.message()
654+
);
655+
}
656+
657+
#[tokio::test]
658+
async fn validate_knn_plan_knn() {
659+
let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap();
660+
let mut request = gen_knn_request(None);
661+
request.knn = None;
662+
let response = executor.knn(request).await;
663+
assert!(response.is_err());
664+
let err = response.unwrap_err();
665+
assert_eq!(err.code(), tonic::Code::InvalidArgument);
666+
assert!(
667+
err.message().to_lowercase().contains("knn operator"),
668+
"{}",
669+
err.message()
670+
);
671+
}
672+
673+
#[tokio::test]
674+
async fn validate_knn_plan_projection() {
675+
let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap();
676+
let mut request = gen_knn_request(None);
677+
request.projection = None;
678+
let response = executor.knn(request).await;
679+
let err = response.unwrap_err();
680+
assert_eq!(err.code(), tonic::Code::InvalidArgument);
681+
assert!(
682+
err.message().to_lowercase().contains("projection operator"),
683+
"{}",
684+
err.message()
685+
);
686+
687+
let mut request = gen_knn_request(None);
688+
request.projection = Some(chroma_proto::KnnProjectionOperator {
689+
projection: None,
690+
distance: false,
691+
});
692+
let response = executor.knn(request).await;
693+
let err = response.unwrap_err();
694+
assert_eq!(err.code(), tonic::Code::InvalidArgument);
695+
assert!(
696+
err.message()
697+
.to_lowercase()
698+
.contains("projection operator: "),
699+
"{}",
700+
err.message()
701+
);
702+
}
703+
704+
#[tokio::test]
705+
async fn validate_knn_plan_scan() {
706+
let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap();
707+
let mut request = gen_knn_request(None);
708+
request.scan = None;
709+
let response = executor.knn(request).await;
710+
let err = response.unwrap_err();
711+
assert_eq!(err.code(), tonic::Code::InvalidArgument);
712+
assert!(
713+
err.message().to_lowercase().contains("scan operator"),
714+
"{}",
715+
err.message()
716+
);
717+
}
718+
719+
#[tokio::test]
720+
async fn validate_knn_plan_scan_collection() {
721+
let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap();
722+
let mut scan = scan();
723+
scan.collection.as_mut().unwrap().id = "Invalid-Collection-ID".to_string();
724+
let response = executor.knn(gen_knn_request(Some(scan))).await;
725+
assert!(response.is_err());
726+
let err = response.unwrap_err();
727+
assert_eq!(err.code(), tonic::Code::InvalidArgument);
728+
assert!(
729+
err.message().to_lowercase().contains("collection uuid"),
730+
"{}",
731+
err.message()
732+
);
733+
}
734+
735+
#[tokio::test]
736+
async fn validate_knn_plan_scan_vector() {
737+
let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap();
738+
// invalid vector uuid
739+
let mut scan_operator = scan();
740+
scan_operator.knn_id = "invalid_segment_id".to_string();
741+
let response = executor.knn(gen_knn_request(Some(scan_operator))).await;
742+
assert!(response.is_err());
743+
let err = response.unwrap_err();
744+
assert_eq!(err.code(), tonic::Code::InvalidArgument);
745+
assert!(
746+
err.message().to_lowercase().contains("vector"),
747+
"{}",
748+
err.message()
749+
);
750+
}
751+
752+
#[tokio::test]
753+
async fn validate_knn_plan_scan_record() {
754+
let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap();
755+
let mut scan_operator = scan();
756+
scan_operator.record_id = "invalid_record_id".to_string();
757+
let response = executor.knn(gen_knn_request(Some(scan_operator))).await;
758+
assert!(response.is_err());
759+
let err = response.unwrap_err();
760+
assert_eq!(err.code(), tonic::Code::InvalidArgument);
761+
assert!(
762+
err.message().to_lowercase().contains("record"),
763+
"{}",
764+
err.message()
765+
);
766+
}
767+
768+
#[tokio::test]
769+
async fn validate_knn_plan_scan_metadata() {
770+
let mut executor = QueryExecutorClient::connect(run_server()).await.unwrap();
771+
let mut scan_operator = scan();
772+
scan_operator.metadata_id = "invalid_metadata_id".to_string();
773+
let response = executor.knn(gen_knn_request(Some(scan_operator))).await;
774+
assert!(response.is_err());
775+
let err = response.unwrap_err();
776+
assert_eq!(err.code(), tonic::Code::InvalidArgument);
777+
assert!(
778+
err.message().to_lowercase().contains("metadata"),
779+
"{}",
780+
err.message()
781+
);
782+
}
592783
}

0 commit comments

Comments
 (0)