forked from paurkedal/ocaml-caqti
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FIX paurkedal#26: Make the populate function use COPY mode when using…
… the postgresql driver
- Loading branch information
James Owen
committed
Oct 15, 2019
1 parent
a204499
commit 6964422
Showing
1 changed file
with
86 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -384,20 +384,23 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct | |
| exception Pg.Error msg -> return (pg_error msg) | ||
| socket -> Unix.wrap_fd aux (Obj.magic socket)) | ||
|
||
let get_result ~uri ~query db = | ||
let get_result ~uri ~query ?(ensure_single_result=true) db = | ||
get_next_result ~uri ~query db >>= | ||
(function | ||
| Ok None -> | ||
let msg = Caqti_error.Msg "No response received after send." in | ||
return (Error (Caqti_error.request_failed ~uri ~query msg)) | ||
| Ok (Some result) -> | ||
get_next_result ~uri ~query db >>= | ||
(function | ||
| Ok None -> return (Ok result) | ||
| Ok (Some _) -> | ||
let msg = Caqti_error.Msg "More than one response received." in | ||
return (Error (Caqti_error.response_rejected ~uri ~query msg)) | ||
| Error _ as r -> return r) | ||
if ensure_single_result then | ||
get_next_result ~uri ~query db >>= | ||
(function | ||
| Ok None -> return (Ok result) | ||
| Ok (Some _) -> | ||
let msg = Caqti_error.Msg "More than one response received." in | ||
return (Error (Caqti_error.response_rejected ~uri ~query msg)) | ||
| Error _ as r -> return r) | ||
else | ||
return (Ok result) | ||
| Error _ as r -> return r) | ||
|
||
let check_query_result ~uri ~query result mult = | ||
|
@@ -440,6 +443,10 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct | |
|
||
let check_command_result ~uri ~query result = | ||
check_query_result ~uri ~query result Caqti_mult.zero | ||
|
||
let get_and_check_result ~uri ~query db = | ||
get_result ~uri ~query db | ||
>|=? check_command_result ~uri ~query | ||
end | ||
|
||
(* Driver Interface *) | ||
|
@@ -508,12 +515,12 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct | |
return r | ||
| Error _ as r -> return r) | ||
|
||
let query_oneshot ?params ?binary_params query = | ||
let query_oneshot ?params ?binary_params ?ensure_single_result query = | ||
retry_on_connection_error begin fun () -> | ||
return (send_query ?params ?binary_params query) | ||
end >>= function | ||
| Error _ as r -> return r | ||
| Ok () -> Pg_io.get_result ~uri ~query db | ||
| Ok () -> Pg_io.get_result ~uri ~query ?ensure_single_result db | ||
|
||
let query_prepared query_id prepared params = | ||
let {query; binary_params; _} = prepared in | ||
|
@@ -675,6 +682,75 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct | |
let start () = exec start_req () >|=? fun () -> Ok (in_transaction := true) | ||
let commit () = in_transaction := false; exec commit_req () | ||
let rollback () = in_transaction := false; exec rollback_req () | ||
|
||
let populate ~table ~columns row_type data = | ||
let query = | ||
let copy_command = | ||
let columns_tuple = String.concat "," columns in | ||
let q = sprintf "COPY %s (%s) FROM STDIN WITH CSV" table columns_tuple in | ||
Caqti_request.exec ~oneshot:true Caqti_type.unit q | ||
in | ||
let templ = Caqti_request.query copy_command driver_info in | ||
Pg_ext.query_string templ | ||
in | ||
let param_length = Caqti_type.length row_type in | ||
let binary_params = Array.make param_length Pg.null in | ||
let copy_row row = | ||
( match encode_param ~uri binary_params row_type row 0 with | ||
| Ok n -> | ||
assert (n = param_length); | ||
return (Ok (String.concat "," (Array.to_list binary_params))) | ||
| Error _ as r -> | ||
return r | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
leamingrad
Owner
|
||
) | ||
>>=? fun param_string -> | ||
match db#put_copy_data (param_string ^ "\n") with | ||
| Pg.Put_copy_error | ||
| Pg.Put_copy_not_queued -> | ||
return (Error ( | ||
Caqti_error.encode_rejected | ||
~typ:row_type | ||
~uri | ||
(Caqti_error.Msg "Unable to encode row for copy") | ||
)) | ||
| Pg.Put_copy_queued -> | ||
return (Ok ()) | ||
in | ||
let fail msg = | ||
return (Error (Caqti_error.request_failed ~uri ~query (Caqti_error.Msg msg))) | ||
in | ||
begin | ||
(* Send the copy command to start the transfer. | ||
* Skip checking that there is only a single result: while in copy mode we can repeatedly | ||
* get the latest result and it will always be Copy_in, so checking for a single result | ||
* would trigger an error. | ||
*) | ||
query_oneshot ~ensure_single_result:false query | ||
>>=? fun result -> | ||
(* We expect the Copy_in response only - turn other success responses into errors, and | ||
* delegate error handling. | ||
*) | ||
( match result#status with | ||
| Pg.Copy_in -> return (Ok ()) | ||
| Pg.Command_ok -> fail "Recieved Command_ok when expecting Copy_in" | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong. |
||
| _ -> return (Pg_io.check_command_result ~uri ~query result) | ||
) | ||
>>=? fun () -> System.Stream.iter_s ~f:copy_row data | ||
>>=? fun () -> | ||
(* End the copy *) | ||
( match db#put_copy_end () with | ||
| Put_copy_error | ||
| Put_copy_not_queued | ||
-> | ||
fail "Unable to finalize copy" | ||
| Put_copy_queued | ||
-> | ||
return (Ok ()) | ||
) | ||
>>=? fun () -> | ||
(* After ending the copy, there will be a new result for the initial query. *) | ||
Pg_io.get_and_check_result ~uri ~query db | ||
end | ||
end | ||
|
||
let connect uri = | ||
|
@@ -701,7 +777,6 @@ module Connect_functor (System : Caqti_driver_sig.System_unix) = struct | |
let driver_info = driver_info | ||
include B | ||
include Caqti_connection.Make_convenience (System) (B) | ||
include Caqti_connection.Make_populate (System) (B) | ||
end in | ||
Ok (module Connection : CONNECTION)))) | ||
end | ||
|
This error branch looks like a pattern for a monad, no?