Skip to content

Commit

Permalink
FIX paurkedal#26: Make the populate function use COPY mode when using…
Browse files Browse the repository at this point in the history
… the postgresql driver
  • Loading branch information
James Owen committed Oct 21, 2019
1 parent a7af1f8 commit 3d5fb93
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 16 deletions.
157 changes: 146 additions & 11 deletions lib-driver/caqti_driver_postgresql.ml
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,36 @@ let rec encode_field
let typ = Caqti_type.field field_type in
Error (Caqti_error.encode_rejected ~uri ~typ msg))))

let rec copy_encode_field
: type a. uri: Uri.t -> a Caqti_type.field -> a -> (string, _) result =
fun ~uri field_type x ->
let escape_and_quote s =
"\"" ^ (String.concat "\"\"" (String.split_on_char '"' s)) ^ "\""
in
(match field_type with
| Caqti_type.Bool -> Ok (Pg_ext.string_of_bool x)
| Caqti_type.Int -> Ok (string_of_int x)
| Caqti_type.Int32 -> Ok (Int32.to_string x)
| Caqti_type.Int64 -> Ok (Int64.to_string x)
| Caqti_type.Float -> Ok (sprintf "%.17g" x)
| Caqti_type.String -> Ok (escape_and_quote x)
| Caqti_type.Octets -> Ok (escape_and_quote x)
| Caqti_type.Pdate -> Ok (iso8601_of_pdate x)
| Caqti_type.Ptime ->
Ok (Ptime.to_rfc3339 ~space:true ~tz_offset_s:0 ~frac_s:6 x)
| Caqti_type.Ptime_span ->
Ok (Pg_ext.string_of_ptime_span x)
| _ ->
(match Caqti_type.Field.coding driver_info field_type with
| None -> Error (Caqti_error.encode_missing ~uri ~field_type ())
| Some (Caqti_type.Field.Coding {rep; encode; _}) ->
(match encode x with
| Ok y -> copy_encode_field ~uri rep y
| Error msg ->
let msg = Caqti_error.Msg msg in
let typ = Caqti_type.field field_type in
Error (Caqti_error.encode_rejected ~uri ~typ msg))))

let rec decode_field
: type a. uri: Uri.t -> a Caqti_type.field -> string -> (a, _) result =
fun ~uri field_type s ->
Expand Down Expand Up @@ -280,6 +310,33 @@ let rec encode_param
let msg = Caqti_error.Msg msg in
Error (Caqti_error.encode_rejected ~uri ~typ:t msg)))

let rec copy_encode_param
: type a. uri: Uri.t -> string array ->
a Caqti_type.t -> a -> int -> (int, _) result =
fun ~uri params t x ->
(match t, x with
| Caqti_type.Unit, () -> fun i -> Ok i
| Caqti_type.Field ft, fv -> fun i ->
(match copy_encode_field ~uri ft fv with
| Ok s -> params.(i) <- s; Ok (i + 1)
| Error _ as r -> r)
| Caqti_type.Option t, None -> fun i -> Ok (i + Caqti_type.length t)
| Caqti_type.Option t, Some x -> copy_encode_param ~uri params t x
| Caqti_type.Tup2 (t0, t1), (x0, x1) ->
copy_encode_param ~uri params t0 x0 %>? copy_encode_param ~uri params t1 x1
| Caqti_type.Tup3 (t0, t1, t2), (x0, x1, x2) ->
copy_encode_param ~uri params t0 x0 %>? copy_encode_param ~uri params t1 x1 %>?
copy_encode_param ~uri params t2 x2
| Caqti_type.Tup4 (t0, t1, t2, t3), (x0, x1, x2, x3) ->
copy_encode_param ~uri params t0 x0 %>? copy_encode_param ~uri params t1 x1 %>?
copy_encode_param ~uri params t2 x2 %>? copy_encode_param ~uri params t3 x3
| Caqti_type.Custom {rep; encode; _}, x -> fun i ->
(match encode x with
| Ok y -> copy_encode_param ~uri params rep y i
| Error msg ->
let msg = Caqti_error.Msg msg in
Error (Caqti_error.encode_rejected ~uri ~typ:t msg)))

let rec decode_row'
: type b. uri: Uri.t -> Pg.result * int ->
b Caqti_type.t -> int -> (int * b, _) result =
Expand Down Expand Up @@ -384,20 +441,23 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct
| exception Pg.Error msg -> return (pg_error msg)
| socket -> Unix.wrap_fd aux (Obj.magic socket))

let get_result ~uri ~query db =
let get_result ~uri ~query ?(ensure_single_result=true) db =
get_next_result ~uri ~query db >>=
(function
| Ok None ->
let msg = Caqti_error.Msg "No response received after send." in
return (Error (Caqti_error.request_failed ~uri ~query msg))
| Ok (Some result) ->
get_next_result ~uri ~query db >>=
(function
| Ok None -> return (Ok result)
| Ok (Some _) ->
let msg = Caqti_error.Msg "More than one response received." in
return (Error (Caqti_error.response_rejected ~uri ~query msg))
| Error _ as r -> return r)
if ensure_single_result then
get_next_result ~uri ~query db >>=
(function
| Ok None -> return (Ok result)
| Ok (Some _) ->
let msg = Caqti_error.Msg "More than one response received." in
return (Error (Caqti_error.response_rejected ~uri ~query msg))
| Error _ as r -> return r)
else
return (Ok result)
| Error _ as r -> return r)

let check_query_result ~uri ~query result mult =
Expand Down Expand Up @@ -440,6 +500,10 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct

let check_command_result ~uri ~query result =
check_query_result ~uri ~query result Caqti_mult.zero

let get_and_check_result ~uri ~query db =
get_result ~uri ~query db
>|=? check_command_result ~uri ~query
end

(* Driver Interface *)
Expand Down Expand Up @@ -508,12 +572,12 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct
return r
| Error _ as r -> return r)

let query_oneshot ?params ?binary_params query =
let query_oneshot ?params ?binary_params ?ensure_single_result query =
retry_on_connection_error begin fun () ->
return (send_query ?params ?binary_params query)
end >>= function
| Error _ as r -> return r
| Ok () -> Pg_io.get_result ~uri ~query db
| Ok () -> Pg_io.get_result ~uri ~query ?ensure_single_result db

let query_prepared query_id prepared params =
let {query; binary_params; _} = prepared in
Expand Down Expand Up @@ -675,6 +739,78 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct
let start () = exec start_req () >|=? fun () -> Ok (in_transaction := true)
let commit () = in_transaction := false; exec commit_req ()
let rollback () = in_transaction := false; exec rollback_req ()

let populate ~table ~columns row_type data =
let query =
let copy_command =
let columns_tuple = String.concat "," columns in
let q = sprintf "COPY %s (%s) FROM STDIN WITH CSV" table columns_tuple in
Caqti_request.exec ~oneshot:true Caqti_type.unit q
in
let templ = Caqti_request.query copy_command driver_info in
Pg_ext.query_string templ
in
let param_length = Caqti_type.length row_type in
let fail msg =
return (Error (Caqti_error.request_failed ~uri ~query (Caqti_error.Msg msg)))
in
let put_copy_data data =
let pg_error msg =
Error (Caqti_error.request_failed ~uri ~query (Pg_msg msg))
in
let rec loop fd =
match db#put_copy_data data with
| Pg.Put_copy_error ->
fail "Unable to put copy data"
| Pg.Put_copy_queued ->
return (Ok ())
| Pg.Put_copy_not_queued ->
Unix.poll ~write:true fd >>= fun _ -> loop fd
in
(match db#socket with
| exception Pg.Error msg -> return (pg_error msg)
| socket -> Unix.wrap_fd loop (Obj.magic socket))
in
let copy_row row =
let params = Array.make param_length Pg.null in
(match copy_encode_param ~uri params row_type row 0 with
| Ok n ->
begin
assert (n = param_length);
return (Ok (String.concat "," (Array.to_list params)))
end
| Error _ as r ->
return r)
>>=? fun param_string -> put_copy_data (param_string ^ "\n")
in
begin
(* Send the copy command to start the transfer.
* Skip checking that there is only a single result: while in copy mode we can repeatedly
* get the latest result and it will always be Copy_in, so checking for a single result
* would trigger an error.
*)
query_oneshot ~ensure_single_result:false query
>>=? fun result ->
(* We expect the Copy_in response only - turn other success responses into errors, and
* delegate error handling.
*)
(match result#status with
| Pg.Copy_in -> return (Ok ())
| Pg.Command_ok -> fail "Received Command_ok when expecting Copy_in"
| _ -> return (Pg_io.check_command_result ~uri ~query result))
>>=? fun () -> System.Stream.iter_s ~f:copy_row data
>>=? fun () ->
(* End the copy *)
(match db#put_copy_end () with
| Put_copy_error
| Put_copy_not_queued ->
fail "Unable to finalize copy"
| Put_copy_queued ->
return (Ok ()))
>>=? fun () ->
(* After ending the copy, there will be a new result for the initial query. *)
Pg_io.get_and_check_result ~uri ~query db
end
end

let connect uri =
Expand All @@ -701,7 +837,6 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct
let driver_info = driver_info
include B
include Caqti_connection.Make_convenience (System) (B)
include Caqti_connection.Make_populate (System) (B)
end in
Ok (module Connection : CONNECTION))))
end
Expand Down
37 changes: 32 additions & 5 deletions tests/test_sql.ml
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,22 @@ module Q = struct
"CREATE TABLE test_sql \
(id INTEGER PRIMARY KEY, i INTEGER NOT NULL, s TEXT NOT NULL)"
| _ -> failwith "Unimplemented."
let create_tmp_nullable = (unit -->! unit) @@ function
| `Mysql | `Pgsql ->
"CREATE TEMPORARY TABLE test_sql \
(id SERIAL NOT NULL, i INTEGER NOT NULL, s TEXT)"
| `Sqlite ->
"CREATE TABLE test_sql \
(id INTEGER PRIMARY KEY, i INTEGER NOT NULL, s TEXT)"
| _ -> failwith "Unimplemented."
let drop_tmp = (unit -->! unit) @@ function
| _ -> "DROP TABLE test_sql"
let insert_into_tmp = Caqti_request.exec (tup2 int string)
"INSERT INTO test_sql (i, s) VALUES (?, ?)"
let select_from_tmp = Caqti_request.collect unit (tup2 int string)
"SELECT i, s FROM test_sql"
let select_from_tmp_nullable = Caqti_request.collect unit (tup2 int (option string))
"SELECT i, s FROM test_sql"
let select_from_tmp_where_i_lt = Caqti_request.collect int (tup2 int string)
"SELECT i, s FROM test_sql WHERE i < ?"

Expand Down Expand Up @@ -261,20 +271,37 @@ struct
let test_stream_both_ways (module Db : Caqti_sys.CONNECTION) =
let assert_stream_both_ways expected =
let input_stream = Caqti_sys.Stream.of_list expected in
Db.exec Q.create_tmp () >>= Sys.or_fail >>= fun () ->
Db.exec Q.create_tmp_nullable () >>= Sys.or_fail >>= fun () ->
Db.populate
~table:"test_sql"
~columns:["i"; "s"]
Caqti_type.(tup2 int string)
Caqti_type.(tup2 int (option string))
input_stream
>|= Caqti_error.uncongested >>= Sys.or_fail >>= fun () ->
Db.collect_list Q.select_from_tmp () >>= Sys.or_fail >>= fun actual ->
Db.collect_list Q.select_from_tmp_nullable () >>= Sys.or_fail >>= fun actual ->
if actual <> expected then
(
let repr a =
a
|> List.map (fun (i, s) ->
let repr_s = match s with Some s -> "Some \"" ^ s ^ "\"" | None -> "None" in
"(" ^ (string_of_int i) ^ "," ^ repr_s ^ ")")
|> String.concat "; "
|> (fun s -> "[" ^ s ^ "]")
in
eprintf "Expected: %s\nActual: %s\n" (repr expected) (repr actual));
assert (actual = expected);
Db.exec Q.drop_tmp ()
in
assert_stream_both_ways [] >>= Sys.or_fail >>= fun () ->
assert_stream_both_ways [(1, "one")] >>= Sys.or_fail >>= fun () ->
assert_stream_both_ways [(1, "one"); (2, "two")] >>= Sys.or_fail
assert_stream_both_ways [(1, Some "one")] >>= Sys.or_fail >>= fun () ->
assert_stream_both_ways [(1, Some "one"); (2, Some "two")] >>= Sys.or_fail >>= fun () ->
assert_stream_both_ways
[ (1, Some "bad1\"\"")
; (2, Some "bad2,\"\n")
; (3, None)
; (4, Some "")
] >>= Sys.or_fail

let run (module Db : Caqti_sys.CONNECTION) =
test_expr (module Db) >>= fun () ->
Expand Down

0 comments on commit 3d5fb93

Please sign in to comment.