@@ -24,7 +24,9 @@ use once_cell::sync::Lazy;
24
24
use prost:: Message ;
25
25
use std:: collections:: HashSet ;
26
26
use std:: pin:: Pin ;
27
+ use std:: str:: FromStr ;
27
28
use std:: sync:: Arc ;
29
+ use tonic:: metadata:: MetadataValue ;
28
30
use tonic:: transport:: Server ;
29
31
use tonic:: transport:: { Certificate , Identity , ServerTlsConfig } ;
30
32
use tonic:: { Request , Response , Status , Streaming } ;
@@ -52,7 +54,7 @@ use arrow_flight::utils::batches_to_flight_data;
52
54
use arrow_flight:: {
53
55
flight_service_server:: FlightService , flight_service_server:: FlightServiceServer , Action ,
54
56
FlightData , FlightDescriptor , FlightEndpoint , FlightInfo , HandshakeRequest , HandshakeResponse ,
55
- IpcMessage , Location , SchemaAsIpc , Ticket ,
57
+ IpcMessage , SchemaAsIpc , Ticket ,
56
58
} ;
57
59
use arrow_ipc:: writer:: IpcWriteOptions ;
58
60
use arrow_schema:: { ArrowError , DataType , Field , Schema } ;
@@ -184,7 +186,15 @@ impl FlightSqlService for FlightSqlServiceImpl {
184
186
} ;
185
187
let result = Ok ( result) ;
186
188
let output = futures:: stream:: iter ( vec ! [ result] ) ;
187
- return Ok ( Response :: new ( Box :: pin ( output) ) ) ;
189
+
190
+ let token = format ! ( "Bearer {}" , FAKE_TOKEN ) ;
191
+ let mut response: Response < Pin < Box < dyn Stream < Item = _ > + Send > > > =
192
+ Response :: new ( Box :: pin ( output) ) ;
193
+ response. metadata_mut ( ) . append (
194
+ "authorization" ,
195
+ MetadataValue :: from_str ( token. as_str ( ) ) . unwrap ( ) ,
196
+ ) ;
197
+ return Ok ( response) ;
188
198
}
189
199
190
200
async fn do_get_fallback (
@@ -235,21 +245,20 @@ impl FlightSqlService for FlightSqlServiceImpl {
235
245
self . check_token ( & request) ?;
236
246
let handle = std:: str:: from_utf8 ( & cmd. prepared_statement_handle )
237
247
. map_err ( |e| status ! ( "Unable to parse handle" , e) ) ?;
248
+
238
249
let batch = Self :: fake_result ( ) . map_err ( |e| status ! ( "Could not fake a result" , e) ) ?;
239
250
let schema = ( * batch. schema ( ) ) . clone ( ) ;
240
251
let num_rows = batch. num_rows ( ) ;
241
252
let num_bytes = batch. get_array_memory_size ( ) ;
242
- let loc = Location {
243
- uri : "grpc+tcp://127.0.0.1" . to_string ( ) ,
244
- } ;
253
+
245
254
let fetch = FetchResults {
246
255
handle : handle. to_string ( ) ,
247
256
} ;
248
257
let buf = fetch. as_any ( ) . encode_to_vec ( ) . into ( ) ;
249
258
let ticket = Ticket { ticket : buf } ;
250
259
let endpoint = FlightEndpoint {
251
260
ticket : Some ( ticket) ,
252
- location : vec ! [ loc ] ,
261
+ location : vec ! [ ] ,
253
262
expiration_time : None ,
254
263
app_metadata : vec ! [ ] . into ( ) ,
255
264
} ;
@@ -662,9 +671,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
662
671
_query : ActionClosePreparedStatementRequest ,
663
672
_request : Request < Action > ,
664
673
) -> Result < ( ) , Status > {
665
- Err ( Status :: unimplemented (
666
- "Implement do_action_close_prepared_statement" ,
667
- ) )
674
+ Ok ( ( ) )
668
675
}
669
676
670
677
async fn do_action_create_prepared_substrait_plan (
@@ -725,9 +732,8 @@ impl FlightSqlService for FlightSqlServiceImpl {
725
732
/// This example shows how to run a FlightSql server
726
733
#[ tokio:: main]
727
734
async fn main ( ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
728
- let addr = "0.0.0.0:50051" . parse ( ) ?;
729
-
730
- let svc = FlightServiceServer :: new ( FlightSqlServiceImpl { } ) ;
735
+ let addr_str = "0.0.0.0:50051" ;
736
+ let addr = addr_str. parse ( ) ?;
731
737
732
738
println ! ( "Listening on {:?}" , addr) ;
733
739
@@ -736,6 +742,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
736
742
let key = std:: fs:: read_to_string ( "arrow-flight/examples/data/server.key" ) ?;
737
743
let client_ca = std:: fs:: read_to_string ( "arrow-flight/examples/data/client_ca.pem" ) ?;
738
744
745
+ let svc = FlightServiceServer :: new ( FlightSqlServiceImpl { } ) ;
739
746
let tls_config = ServerTlsConfig :: new ( )
740
747
. identity ( Identity :: from_pem ( & cert, & key) )
741
748
. client_ca_root ( Certificate :: from_pem ( & client_ca) ) ;
@@ -746,6 +753,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
746
753
. serve ( addr)
747
754
. await ?;
748
755
} else {
756
+ let svc = FlightServiceServer :: new ( FlightSqlServiceImpl { } ) ;
757
+
749
758
Server :: builder ( ) . add_service ( svc) . serve ( addr) . await ?;
750
759
}
751
760
@@ -999,15 +1008,6 @@ mod tests {
999
1008
. to_string( )
1000
1009
. contains( "Invalid credentials" ) ) ;
1001
1010
1002
- // forget to set_token
1003
- client. handshake ( "admin" , "password" ) . await . unwrap ( ) ;
1004
- assert ! ( client
1005
- . prepare( "select 1;" . to_string( ) , None )
1006
- . await
1007
- . unwrap_err( )
1008
- . to_string( )
1009
- . contains( "No authorization header" ) ) ;
1010
-
1011
1011
// Invalid Tokens
1012
1012
client. handshake ( "admin" , "password" ) . await . unwrap ( ) ;
1013
1013
client. set_token ( "wrong token" . to_string ( ) ) ;
@@ -1017,6 +1017,12 @@ mod tests {
1017
1017
. unwrap_err( )
1018
1018
. to_string( )
1019
1019
. contains( "invalid token" ) ) ;
1020
+
1021
+ client. clear_token ( ) ;
1022
+
1023
+ // Successful call (token is automatically set by handshake)
1024
+ client. handshake ( "admin" , "password" ) . await . unwrap ( ) ;
1025
+ client. prepare ( "select 1;" . to_string ( ) , None ) . await . unwrap ( ) ;
1020
1026
} )
1021
1027
. await
1022
1028
}
0 commit comments