@@ -143,14 +143,19 @@ impl WorkerServer {
143
143
let collection = scan
144
144
. collection
145
145
. ok_or ( Status :: invalid_argument ( "Invalid Collection" ) ) ?;
146
+
146
147
let collection_uuid = CollectionUuid :: from_str ( & collection. id )
147
- . map_err ( |err| Status :: invalid_argument ( err. to_string ( ) ) ) ?;
148
+ . map_err ( |_| Status :: invalid_argument ( "Invalid Collection UUID" ) ) ?;
149
+
148
150
let vector_uuid = SegmentUuid :: from_str ( & scan. knn_id )
149
- . map_err ( |err| Status :: invalid_argument ( err. to_string ( ) ) ) ?;
151
+ . map_err ( |_| Status :: invalid_argument ( "Invalid UUID for Vector segment" ) ) ?;
152
+
150
153
let metadata_uuid = SegmentUuid :: from_str ( & scan. metadata_id )
151
- . map_err ( |err| Status :: invalid_argument ( err. to_string ( ) ) ) ?;
154
+ . map_err ( |_| Status :: invalid_argument ( "Invalid UUID for Metadata segment" ) ) ?;
155
+
152
156
let record_uuid = SegmentUuid :: from_str ( & scan. record_id )
153
- . map_err ( |err| Status :: invalid_argument ( err. to_string ( ) ) ) ?;
157
+ . map_err ( |_| Status :: invalid_argument ( "Invalid UUID for Record segment" ) ) ?;
158
+
154
159
Ok ( (
155
160
FetchLogOperator {
156
161
log_client : self . log . clone ( ) ,
@@ -213,13 +218,17 @@ impl WorkerServer {
213
218
let scan = get_inner
214
219
. scan
215
220
. ok_or ( Status :: invalid_argument ( "Invalid Scan Operator" ) ) ?;
221
+
216
222
let ( fetch_log_operator, fetch_segment_operator) = self . decompose_proto_scan ( scan) ?;
223
+
217
224
let filter = get_inner
218
225
. filter
219
226
. ok_or ( Status :: invalid_argument ( "Invalid Filter Operator" ) ) ?;
227
+
220
228
let limit = get_inner
221
229
. limit
222
230
. ok_or ( Status :: invalid_argument ( "Invalid Scan Operator" ) ) ?;
231
+
223
232
let projection = get_inner
224
233
. projection
225
234
. ok_or ( Status :: invalid_argument ( "Invalid Projection Operator" ) ) ?;
@@ -248,17 +257,28 @@ impl WorkerServer {
248
257
) -> Result < Response < KnnBatchResult > , Status > {
249
258
let dispatcher = self . clone_dispatcher ( ) ?;
250
259
let system = self . clone_system ( ) ?;
260
+
251
261
let knn_inner = knn. into_inner ( ) ;
262
+
252
263
let scan = knn_inner
253
264
. scan
254
265
. ok_or ( Status :: invalid_argument ( "Invalid Scan Operator" ) ) ?;
266
+
255
267
let ( fetch_log_operator, fetch_segment_operator) = self . decompose_proto_scan ( scan) ?;
268
+
256
269
let filter = knn_inner
257
270
. filter
258
271
. ok_or ( Status :: invalid_argument ( "Invalid Filter Operator" ) ) ?;
272
+
259
273
let knn = knn_inner
260
274
. knn
261
- . ok_or ( Status :: invalid_argument ( "Invalid Scan Operator" ) ) ?;
275
+ . ok_or ( Status :: invalid_argument ( "Invalid Knn Operator" ) ) ?;
276
+
277
+ let projection = knn_inner
278
+ . projection
279
+ . ok_or ( Status :: invalid_argument ( "Invalid Projection Operator" ) ) ?;
280
+ let knn_projection = KnnProjectionOperator :: try_from ( projection)
281
+ . map_err ( |e| Status :: invalid_argument ( format ! ( "Invalid Projection Operator: {}" , e) ) ) ?;
262
282
263
283
if knn. embeddings . is_empty ( ) {
264
284
return Ok ( Response :: new ( to_proto_knn_batch_result ( Vec :: new ( ) ) ?) ) ;
@@ -290,11 +310,6 @@ impl WorkerServer {
290
310
}
291
311
} ;
292
312
293
- let projection = knn_inner
294
- . projection
295
- . ok_or ( Status :: invalid_argument ( "Invalid Projection Operator" ) ) ?;
296
- let knn_projection = KnnProjectionOperator :: try_from ( projection) ?;
297
-
298
313
let knn_orchestrator_futures = from_proto_knn ( knn) ?
299
314
. into_iter ( )
300
315
. map ( |knn| {
@@ -589,4 +604,180 @@ mod tests {
589
604
assert ! ( response. is_err( ) ) ;
590
605
assert_eq ! ( response. unwrap_err( ) . code( ) , tonic:: Code :: InvalidArgument ) ;
591
606
}
607
+
608
+ fn gen_knn_request ( mut scan_operator : Option < chroma_proto:: ScanOperator > ) -> chroma_proto:: KnnPlan {
609
+ if scan_operator. is_none ( ) {
610
+ scan_operator = Some ( scan ( ) ) ;
611
+ }
612
+ chroma_proto:: KnnPlan {
613
+ scan : scan_operator,
614
+ filter : Some ( chroma_proto:: FilterOperator {
615
+ ids : None ,
616
+ r#where : None ,
617
+ where_document : None ,
618
+ } ) ,
619
+ knn : Some ( chroma_proto:: KnnOperator {
620
+ embeddings : vec ! [ ] ,
621
+ fetch : 0 ,
622
+ } ) ,
623
+ projection : Some ( chroma_proto:: KnnProjectionOperator {
624
+ projection : Some ( chroma_proto:: ProjectionOperator {
625
+ document : false ,
626
+ embedding : false ,
627
+ metadata : false ,
628
+ } ) ,
629
+ distance : false ,
630
+ } ) ,
631
+ }
632
+ }
633
+
634
+ #[ tokio:: test]
635
+ async fn validate_knn_plan_empty_embeddings ( ) {
636
+ let mut executor = QueryExecutorClient :: connect ( run_server ( ) ) . await . unwrap ( ) ;
637
+ let response = executor. knn ( gen_knn_request ( None ) ) . await ;
638
+ assert ! ( response. is_ok( ) ) ;
639
+ assert_eq ! ( response. unwrap( ) . into_inner( ) . results. len( ) , 0 ) ;
640
+ }
641
+
642
+ #[ tokio:: test]
643
+ async fn validate_knn_plan_filter ( ) {
644
+ let mut executor = QueryExecutorClient :: connect ( run_server ( ) ) . await . unwrap ( ) ;
645
+ let mut request = gen_knn_request ( None ) ;
646
+ request. filter = None ;
647
+ let response = executor. knn ( request) . await ;
648
+ let err = response. unwrap_err ( ) ;
649
+ assert_eq ! ( err. code( ) , tonic:: Code :: InvalidArgument ) ;
650
+ assert ! (
651
+ err. message( ) . to_lowercase( ) . contains( "filter operator" ) ,
652
+ "{}" ,
653
+ err. message( )
654
+ ) ;
655
+ }
656
+
657
+ #[ tokio:: test]
658
+ async fn validate_knn_plan_knn ( ) {
659
+ let mut executor = QueryExecutorClient :: connect ( run_server ( ) ) . await . unwrap ( ) ;
660
+ let mut request = gen_knn_request ( None ) ;
661
+ request. knn = None ;
662
+ let response = executor. knn ( request) . await ;
663
+ assert ! ( response. is_err( ) ) ;
664
+ let err = response. unwrap_err ( ) ;
665
+ assert_eq ! ( err. code( ) , tonic:: Code :: InvalidArgument ) ;
666
+ assert ! (
667
+ err. message( ) . to_lowercase( ) . contains( "knn operator" ) ,
668
+ "{}" ,
669
+ err. message( )
670
+ ) ;
671
+ }
672
+
673
+ #[ tokio:: test]
674
+ async fn validate_knn_plan_projection ( ) {
675
+ let mut executor = QueryExecutorClient :: connect ( run_server ( ) ) . await . unwrap ( ) ;
676
+ let mut request = gen_knn_request ( None ) ;
677
+ request. projection = None ;
678
+ let response = executor. knn ( request) . await ;
679
+ let err = response. unwrap_err ( ) ;
680
+ assert_eq ! ( err. code( ) , tonic:: Code :: InvalidArgument ) ;
681
+ assert ! (
682
+ err. message( ) . to_lowercase( ) . contains( "projection operator" ) ,
683
+ "{}" ,
684
+ err. message( )
685
+ ) ;
686
+
687
+ let mut request = gen_knn_request ( None ) ;
688
+ request. projection = Some ( chroma_proto:: KnnProjectionOperator {
689
+ projection : None ,
690
+ distance : false ,
691
+ } ) ;
692
+ let response = executor. knn ( request) . await ;
693
+ let err = response. unwrap_err ( ) ;
694
+ assert_eq ! ( err. code( ) , tonic:: Code :: InvalidArgument ) ;
695
+ assert ! (
696
+ err. message( )
697
+ . to_lowercase( )
698
+ . contains( "projection operator: " ) ,
699
+ "{}" ,
700
+ err. message( )
701
+ ) ;
702
+ }
703
+
704
+ #[ tokio:: test]
705
+ async fn validate_knn_plan_scan ( ) {
706
+ let mut executor = QueryExecutorClient :: connect ( run_server ( ) ) . await . unwrap ( ) ;
707
+ let mut request = gen_knn_request ( None ) ;
708
+ request. scan = None ;
709
+ let response = executor. knn ( request) . await ;
710
+ let err = response. unwrap_err ( ) ;
711
+ assert_eq ! ( err. code( ) , tonic:: Code :: InvalidArgument ) ;
712
+ assert ! (
713
+ err. message( ) . to_lowercase( ) . contains( "scan operator" ) ,
714
+ "{}" ,
715
+ err. message( )
716
+ ) ;
717
+ }
718
+
719
+ #[ tokio:: test]
720
+ async fn validate_knn_plan_scan_collection ( ) {
721
+ let mut executor = QueryExecutorClient :: connect ( run_server ( ) ) . await . unwrap ( ) ;
722
+ let mut scan = scan ( ) ;
723
+ scan. collection . as_mut ( ) . unwrap ( ) . id = "Invalid-Collection-ID" . to_string ( ) ;
724
+ let response = executor. knn ( gen_knn_request ( Some ( scan) ) ) . await ;
725
+ assert ! ( response. is_err( ) ) ;
726
+ let err = response. unwrap_err ( ) ;
727
+ assert_eq ! ( err. code( ) , tonic:: Code :: InvalidArgument ) ;
728
+ assert ! (
729
+ err. message( ) . to_lowercase( ) . contains( "collection uuid" ) ,
730
+ "{}" ,
731
+ err. message( )
732
+ ) ;
733
+ }
734
+
735
+ #[ tokio:: test]
736
+ async fn validate_knn_plan_scan_vector ( ) {
737
+ let mut executor = QueryExecutorClient :: connect ( run_server ( ) ) . await . unwrap ( ) ;
738
+ // invalid vector uuid
739
+ let mut scan_operator = scan ( ) ;
740
+ scan_operator. knn_id = "invalid_segment_id" . to_string ( ) ;
741
+ let response = executor. knn ( gen_knn_request ( Some ( scan_operator) ) ) . await ;
742
+ assert ! ( response. is_err( ) ) ;
743
+ let err = response. unwrap_err ( ) ;
744
+ assert_eq ! ( err. code( ) , tonic:: Code :: InvalidArgument ) ;
745
+ assert ! (
746
+ err. message( ) . to_lowercase( ) . contains( "vector" ) ,
747
+ "{}" ,
748
+ err. message( )
749
+ ) ;
750
+ }
751
+
752
+ #[ tokio:: test]
753
+ async fn validate_knn_plan_scan_record ( ) {
754
+ let mut executor = QueryExecutorClient :: connect ( run_server ( ) ) . await . unwrap ( ) ;
755
+ let mut scan_operator = scan ( ) ;
756
+ scan_operator. record_id = "invalid_record_id" . to_string ( ) ;
757
+ let response = executor. knn ( gen_knn_request ( Some ( scan_operator) ) ) . await ;
758
+ assert ! ( response. is_err( ) ) ;
759
+ let err = response. unwrap_err ( ) ;
760
+ assert_eq ! ( err. code( ) , tonic:: Code :: InvalidArgument ) ;
761
+ assert ! (
762
+ err. message( ) . to_lowercase( ) . contains( "record" ) ,
763
+ "{}" ,
764
+ err. message( )
765
+ ) ;
766
+ }
767
+
768
+ #[ tokio:: test]
769
+ async fn validate_knn_plan_scan_metadata ( ) {
770
+ let mut executor = QueryExecutorClient :: connect ( run_server ( ) ) . await . unwrap ( ) ;
771
+ let mut scan_operator = scan ( ) ;
772
+ scan_operator. metadata_id = "invalid_metadata_id" . to_string ( ) ;
773
+ let response = executor. knn ( gen_knn_request ( Some ( scan_operator) ) ) . await ;
774
+ assert ! ( response. is_err( ) ) ;
775
+ let err = response. unwrap_err ( ) ;
776
+ assert_eq ! ( err. code( ) , tonic:: Code :: InvalidArgument ) ;
777
+ assert ! (
778
+ err. message( ) . to_lowercase( ) . contains( "metadata" ) ,
779
+ "{}" ,
780
+ err. message( )
781
+ ) ;
782
+ }
592
783
}
0 commit comments