From 3d5fb93e883d0e2aa80e8fe5c2890abf154c52f4 Mon Sep 17 00:00:00 2001 From: James Owen Date: Tue, 15 Oct 2019 10:53:17 +0200 Subject: [PATCH] FIX #26: Make the populate function use COPY mode when using the postgresql driver --- lib-driver/caqti_driver_postgresql.ml | 157 ++++++++++++++++++++++++-- tests/test_sql.ml | 37 +++++- 2 files changed, 178 insertions(+), 16 deletions(-) diff --git a/lib-driver/caqti_driver_postgresql.ml b/lib-driver/caqti_driver_postgresql.ml index c132ff17..dbf8f123 100644 --- a/lib-driver/caqti_driver_postgresql.ml +++ b/lib-driver/caqti_driver_postgresql.ml @@ -201,6 +201,36 @@ let rec encode_field let typ = Caqti_type.field field_type in Error (Caqti_error.encode_rejected ~uri ~typ msg)))) +let rec copy_encode_field + : type a. uri: Uri.t -> a Caqti_type.field -> a -> (string, _) result = + fun ~uri field_type x -> + let escape_and_quote s = + "\"" ^ (String.concat "\"\"" (String.split_on_char '"' s)) ^ "\"" + in + (match field_type with + | Caqti_type.Bool -> Ok (Pg_ext.string_of_bool x) + | Caqti_type.Int -> Ok (string_of_int x) + | Caqti_type.Int32 -> Ok (Int32.to_string x) + | Caqti_type.Int64 -> Ok (Int64.to_string x) + | Caqti_type.Float -> Ok (sprintf "%.17g" x) + | Caqti_type.String -> Ok (escape_and_quote x) + | Caqti_type.Octets -> Ok (escape_and_quote x) + | Caqti_type.Pdate -> Ok (iso8601_of_pdate x) + | Caqti_type.Ptime -> + Ok (Ptime.to_rfc3339 ~space:true ~tz_offset_s:0 ~frac_s:6 x) + | Caqti_type.Ptime_span -> + Ok (Pg_ext.string_of_ptime_span x) + | _ -> + (match Caqti_type.Field.coding driver_info field_type with + | None -> Error (Caqti_error.encode_missing ~uri ~field_type ()) + | Some (Caqti_type.Field.Coding {rep; encode; _}) -> + (match encode x with + | Ok y -> copy_encode_field ~uri rep y + | Error msg -> + let msg = Caqti_error.Msg msg in + let typ = Caqti_type.field field_type in + Error (Caqti_error.encode_rejected ~uri ~typ msg)))) + let rec decode_field : type a. uri: Uri.t -> a Caqti_type.field -> string -> (a, _) result = fun ~uri field_type s -> @@ -280,6 +310,33 @@ let rec encode_param let msg = Caqti_error.Msg msg in Error (Caqti_error.encode_rejected ~uri ~typ:t msg))) +let rec copy_encode_param + : type a. uri: Uri.t -> string array -> + a Caqti_type.t -> a -> int -> (int, _) result = + fun ~uri params t x -> + (match t, x with + | Caqti_type.Unit, () -> fun i -> Ok i + | Caqti_type.Field ft, fv -> fun i -> + (match copy_encode_field ~uri ft fv with + | Ok s -> params.(i) <- s; Ok (i + 1) + | Error _ as r -> r) + | Caqti_type.Option t, None -> fun i -> Ok (i + Caqti_type.length t) + | Caqti_type.Option t, Some x -> copy_encode_param ~uri params t x + | Caqti_type.Tup2 (t0, t1), (x0, x1) -> + copy_encode_param ~uri params t0 x0 %>? copy_encode_param ~uri params t1 x1 + | Caqti_type.Tup3 (t0, t1, t2), (x0, x1, x2) -> + copy_encode_param ~uri params t0 x0 %>? copy_encode_param ~uri params t1 x1 %>? + copy_encode_param ~uri params t2 x2 + | Caqti_type.Tup4 (t0, t1, t2, t3), (x0, x1, x2, x3) -> + copy_encode_param ~uri params t0 x0 %>? copy_encode_param ~uri params t1 x1 %>? + copy_encode_param ~uri params t2 x2 %>? copy_encode_param ~uri params t3 x3 + | Caqti_type.Custom {rep; encode; _}, x -> fun i -> + (match encode x with + | Ok y -> copy_encode_param ~uri params rep y i + | Error msg -> + let msg = Caqti_error.Msg msg in + Error (Caqti_error.encode_rejected ~uri ~typ:t msg))) + let rec decode_row' : type b. uri: Uri.t -> Pg.result * int -> b Caqti_type.t -> int -> (int * b, _) result = @@ -384,20 +441,23 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct | exception Pg.Error msg -> return (pg_error msg) | socket -> Unix.wrap_fd aux (Obj.magic socket)) - let get_result ~uri ~query db = + let get_result ~uri ~query ?(ensure_single_result=true) db = get_next_result ~uri ~query db >>= (function | Ok None -> let msg = Caqti_error.Msg "No response received after send." in return (Error (Caqti_error.request_failed ~uri ~query msg)) | Ok (Some result) -> - get_next_result ~uri ~query db >>= - (function - | Ok None -> return (Ok result) - | Ok (Some _) -> - let msg = Caqti_error.Msg "More than one response received." in - return (Error (Caqti_error.response_rejected ~uri ~query msg)) - | Error _ as r -> return r) + if ensure_single_result then + get_next_result ~uri ~query db >>= + (function + | Ok None -> return (Ok result) + | Ok (Some _) -> + let msg = Caqti_error.Msg "More than one response received." in + return (Error (Caqti_error.response_rejected ~uri ~query msg)) + | Error _ as r -> return r) + else + return (Ok result) | Error _ as r -> return r) let check_query_result ~uri ~query result mult = @@ -440,6 +500,10 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct let check_command_result ~uri ~query result = check_query_result ~uri ~query result Caqti_mult.zero + + let get_and_check_result ~uri ~query db = + get_result ~uri ~query db + >|=? check_command_result ~uri ~query end (* Driver Interface *) @@ -508,12 +572,12 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct return r | Error _ as r -> return r) - let query_oneshot ?params ?binary_params query = + let query_oneshot ?params ?binary_params ?ensure_single_result query = retry_on_connection_error begin fun () -> return (send_query ?params ?binary_params query) end >>= function | Error _ as r -> return r - | Ok () -> Pg_io.get_result ~uri ~query db + | Ok () -> Pg_io.get_result ~uri ~query ?ensure_single_result db let query_prepared query_id prepared params = let {query; binary_params; _} = prepared in @@ -675,6 +739,78 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct let start () = exec start_req () >|=? fun () -> Ok (in_transaction := true) let commit () = in_transaction := false; exec commit_req () let rollback () = in_transaction := false; exec rollback_req () + + let populate ~table ~columns row_type data = + let query = + let copy_command = + let columns_tuple = String.concat "," columns in + let q = sprintf "COPY %s (%s) FROM STDIN WITH CSV" table columns_tuple in + Caqti_request.exec ~oneshot:true Caqti_type.unit q + in + let templ = Caqti_request.query copy_command driver_info in + Pg_ext.query_string templ + in + let param_length = Caqti_type.length row_type in + let fail msg = + return (Error (Caqti_error.request_failed ~uri ~query (Caqti_error.Msg msg))) + in + let put_copy_data data = + let pg_error msg = + Error (Caqti_error.request_failed ~uri ~query (Pg_msg msg)) + in + let rec loop fd = + match db#put_copy_data data with + | Pg.Put_copy_error -> + fail "Unable to put copy data" + | Pg.Put_copy_queued -> + return (Ok ()) + | Pg.Put_copy_not_queued -> + Unix.poll ~write:true fd >>= fun _ -> loop fd + in + (match db#socket with + | exception Pg.Error msg -> return (pg_error msg) + | socket -> Unix.wrap_fd loop (Obj.magic socket)) + in + let copy_row row = + let params = Array.make param_length Pg.null in + (match copy_encode_param ~uri params row_type row 0 with + | Ok n -> + begin + assert (n = param_length); + return (Ok (String.concat "," (Array.to_list params))) + end + | Error _ as r -> + return r) + >>=? fun param_string -> put_copy_data (param_string ^ "\n") + in + begin + (* Send the copy command to start the transfer. + * Skip checking that there is only a single result: while in copy mode we can repeatedly + * get the latest result and it will always be Copy_in, so checking for a single result + * would trigger an error. + *) + query_oneshot ~ensure_single_result:false query + >>=? fun result -> + (* We expect the Copy_in response only - turn other success responses into errors, and + * delegate error handling. + *) + (match result#status with + | Pg.Copy_in -> return (Ok ()) + | Pg.Command_ok -> fail "Received Command_ok when expecting Copy_in" + | _ -> return (Pg_io.check_command_result ~uri ~query result)) + >>=? fun () -> System.Stream.iter_s ~f:copy_row data + >>=? fun () -> + (* End the copy *) + (match db#put_copy_end () with + | Put_copy_error + | Put_copy_not_queued -> + fail "Unable to finalize copy" + | Put_copy_queued -> + return (Ok ())) + >>=? fun () -> + (* After ending the copy, there will be a new result for the initial query. *) + Pg_io.get_and_check_result ~uri ~query db + end end let connect uri = @@ -701,7 +837,6 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct let driver_info = driver_info include B include Caqti_connection.Make_convenience (System) (B) - include Caqti_connection.Make_populate (System) (B) end in Ok (module Connection : CONNECTION)))) end diff --git a/tests/test_sql.ml b/tests/test_sql.ml index e1fe3de1..9d7bc257 100644 --- a/tests/test_sql.ml +++ b/tests/test_sql.ml @@ -59,12 +59,22 @@ module Q = struct "CREATE TABLE test_sql \ (id INTEGER PRIMARY KEY, i INTEGER NOT NULL, s TEXT NOT NULL)" | _ -> failwith "Unimplemented." + let create_tmp_nullable = (unit -->! unit) @@ function + | `Mysql | `Pgsql -> + "CREATE TEMPORARY TABLE test_sql \ + (id SERIAL NOT NULL, i INTEGER NOT NULL, s TEXT)" + | `Sqlite -> + "CREATE TABLE test_sql \ + (id INTEGER PRIMARY KEY, i INTEGER NOT NULL, s TEXT)" + | _ -> failwith "Unimplemented." let drop_tmp = (unit -->! unit) @@ function | _ -> "DROP TABLE test_sql" let insert_into_tmp = Caqti_request.exec (tup2 int string) "INSERT INTO test_sql (i, s) VALUES (?, ?)" let select_from_tmp = Caqti_request.collect unit (tup2 int string) "SELECT i, s FROM test_sql" + let select_from_tmp_nullable = Caqti_request.collect unit (tup2 int (option string)) + "SELECT i, s FROM test_sql" let select_from_tmp_where_i_lt = Caqti_request.collect int (tup2 int string) "SELECT i, s FROM test_sql WHERE i < ?" @@ -261,20 +271,37 @@ struct let test_stream_both_ways (module Db : Caqti_sys.CONNECTION) = let assert_stream_both_ways expected = let input_stream = Caqti_sys.Stream.of_list expected in - Db.exec Q.create_tmp () >>= Sys.or_fail >>= fun () -> + Db.exec Q.create_tmp_nullable () >>= Sys.or_fail >>= fun () -> Db.populate ~table:"test_sql" ~columns:["i"; "s"] - Caqti_type.(tup2 int string) + Caqti_type.(tup2 int (option string)) input_stream >|= Caqti_error.uncongested >>= Sys.or_fail >>= fun () -> - Db.collect_list Q.select_from_tmp () >>= Sys.or_fail >>= fun actual -> + Db.collect_list Q.select_from_tmp_nullable () >>= Sys.or_fail >>= fun actual -> + if actual <> expected then + ( + let repr a = + a + |> List.map (fun (i, s) -> + let repr_s = match s with Some s -> "Some \"" ^ s ^ "\"" | None -> "None" in + "(" ^ (string_of_int i) ^ "," ^ repr_s ^ ")") + |> String.concat "; " + |> (fun s -> "[" ^ s ^ "]") + in + eprintf "Expected: %s\nActual: %s\n" (repr expected) (repr actual)); assert (actual = expected); Db.exec Q.drop_tmp () in assert_stream_both_ways [] >>= Sys.or_fail >>= fun () -> - assert_stream_both_ways [(1, "one")] >>= Sys.or_fail >>= fun () -> - assert_stream_both_ways [(1, "one"); (2, "two")] >>= Sys.or_fail + assert_stream_both_ways [(1, Some "one")] >>= Sys.or_fail >>= fun () -> + assert_stream_both_ways [(1, Some "one"); (2, Some "two")] >>= Sys.or_fail >>= fun () -> + assert_stream_both_ways + [ (1, Some "bad1\"\"") + ; (2, Some "bad2,\"\n") + ; (3, None) + ; (4, Some "") + ] >>= Sys.or_fail let run (module Db : Caqti_sys.CONNECTION) = test_expr (module Db) >>= fun () ->