Skip to content

Commit

Permalink
Implement single-row-mode for postgresql (#24).
Browse files Browse the repository at this point in the history
  • Loading branch information
paurkedal committed Nov 8, 2022
1 parent a8fab9e commit 60f2e01
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 67 deletions.
247 changes: 180 additions & 67 deletions caqti-driver-postgresql/lib/caqti_driver_postgresql.ml
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,15 @@ module Pg_ext = struct
let parse_uri uri =
pop_uri_param parse_notice_processing `Quiet "notice_processing" uri
|>? fun (notice_processing, uri) ->
pop_uri_param bool_of_string false "use_single_row_mode" uri
|>? fun (use_single_row_mode, uri) ->
let conninfo =
if Uri.host uri <> None then Uri.to_string ~pct_encoder uri else
let mkparam k v = k ^ " = '" ^ escaped_connvalue v ^ "'" in
let mkparams (k, vs) = List.map (mkparam k) vs in
String.concat " " (List.flatten (List.map mkparams (Uri.query uri)))
in
Ok (conninfo, notice_processing)
Ok (conninfo, notice_processing, use_single_row_mode)
end

let bool_oid = Pg.oid_of_ftype Pg.BOOL
Expand Down Expand Up @@ -343,6 +345,7 @@ type prepared = {
param_length: int;
param_types: Pg.oid array;
binary_params: bool array;
single_row_mode: bool;
}

module Connect_functor (System : Caqti_platform_unix.System_sig.S) = struct
Expand Down Expand Up @@ -408,7 +411,7 @@ module Connect_functor (System : Caqti_platform_unix.System_sig.S) = struct
let msg = Caqti_error.Msg "More than one response received." in
return (Error (Caqti_error.response_rejected ~uri ~query msg))

let check_query_result ~uri ~query result mult =
let check_query_result ~uri ~query ~row_mult ~single_row_mode result =
let reject msg =
let msg = Caqti_error.Msg msg in
Error (Caqti_error.response_rejected ~uri ~query msg)
Expand All @@ -419,12 +422,16 @@ module Connect_functor (System : Caqti_platform_unix.System_sig.S) = struct
in
(match result#status with
| Pg.Command_ok ->
(match Caqti_mult.expose mult with
(match Caqti_mult.expose row_mult with
| `Zero -> Ok ()
| (`One | `Zero_or_one | `Zero_or_more) ->
reject "Tuples expected for this query.")
| Pg.Tuples_ok ->
(match Caqti_mult.expose mult with
if single_row_mode then
if result#ntuples = 0 then Ok () else
reject "Tuples returned in single-row-mode."
else
(match Caqti_mult.expose row_mult with
| `Zero ->
if result#ntuples = 0 then Ok () else
reject "No tuples expected for this query."
Expand All @@ -448,10 +455,15 @@ module Connect_functor (System : Caqti_platform_unix.System_sig.S) = struct
| Pg.Copy_out | Pg.Copy_in | Pg.Copy_both ->
reject "Received unexpected copy response."
| Pg.Single_tuple ->
reject "Received unexpected single-tuple response.")
if not single_row_mode then
reject "Received unexpected single tuple response." else
if result#ntuples <> 1 then
reject "Expected a single row in single-row mode." else
Ok ())

let check_command_result ~uri ~query result =
check_query_result ~uri ~query result Caqti_mult.zero
check_query_result
~uri ~query ~row_mult:Caqti_mult.zero ~single_row_mode:false result
end

(* Driver Interface *)
Expand All @@ -465,6 +477,7 @@ module Connect_functor (System : Caqti_platform_unix.System_sig.S) = struct
val env : Caqti_driver_info.t -> string -> Caqti_query.t
val uri : Uri.t
val db : Pg.connection
val use_single_row_mode : bool
end) =
struct
open Connection_arg
Expand Down Expand Up @@ -533,17 +546,20 @@ module Connect_functor (System : Caqti_platform_unix.System_sig.S) = struct
| Error _ as r -> return r)

let send_oneshot_query
?params ?param_types ?binary_params
?params ?param_types ?binary_params ?(single_row_mode = false)
query =
retry_on_connection_error begin fun () ->
return @@ wrap_pg ~query begin fun () ->
db#send_query ?params ?param_types ?binary_params query;
if single_row_mode then db#set_single_row_mode;
db#consume_input
end
end

let send_prepared_query query_id prepared params =
let {query; param_types; binary_params; _} = prepared in
let {query; param_types; binary_params; single_row_mode; _} =
prepared
in
retry_on_connection_error begin fun () ->
begin
if Int_hashtbl.mem prepare_cache query_id then return (Ok ()) else
Expand All @@ -558,70 +574,149 @@ module Connect_functor (System : Caqti_platform_unix.System_sig.S) = struct
wrap_pg ~query begin fun () ->
db#send_query_prepared
~params ~binary_params (query_name_of_id query_id);
if single_row_mode then db#set_single_row_mode;
db#consume_input
end
end

let fetch_one_result ~query () = Pg_io.get_one_result ~uri ~query db
let fetch_final_result ~query () = Pg_io.get_final_result ~uri ~query db

let fetch_single_row ~row_type ~query () =
Pg_io.get_one_result ~uri ~query db >>=? fun result ->
(match result#status with
| Pg.Single_tuple ->
return @@ Result.map Option.some @@
decode_row ~uri row_type (result, 0)
| Pg.Tuples_ok ->
Pg_io.get_next_result ~uri ~query db >|=?
(function
| None -> Ok None
| Some _ ->
let msg =
Caqti_error.Msg "Extra result after final single-row result." in
Error (Caqti_error.response_rejected ~uri ~query msg))
| _ ->
return @@ Result.map (fun () -> None) @@
Pg_io.check_query_result
~uri ~query ~row_mult:Caqti_mult.zero_or_more ~single_row_mode:true
result)

module Response = struct
type ('b, 'm) t = {row_type: 'b Caqti_type.t; result: Pg.result}

let returned_count {result; _} =
return (Ok result#ntuples)
type source =
| Complete of Pg.result
| Single_row

let affected_count {result; _} =
return (Ok (int_of_string result#cmd_tuples))
type ('b, 'm) t = {
row_type: 'b Caqti_type.t;
source: source;
query: string;
}

let exec _ = return (Ok ())
let returned_count = function
| {source = Complete result; _} ->
return (Ok result#ntuples)
| {source = Single_row; _} ->
return (Error `Unsupported)

let find {row_type; result} =
return (decode_row ~uri row_type (result, 0))
let affected_count = function
| {source = Complete result; _} ->
return (Ok (int_of_string result#cmd_tuples))
| {source = Single_row; _} ->
return (Error `Unsupported)

let find_opt {row_type; result} =
return begin
if result#ntuples = 0 then Ok None else
(match decode_row ~uri row_type (result, 0) with
| Ok y -> Ok (Some y)
| Error _ as r -> r)
end
let exec _ = return (Ok ())

let fold f {row_type; result} acc =
let n = result#ntuples in
let rec loop i acc =
if i = n then Ok acc else
(match decode_row ~uri row_type (result, i) with
| Ok y -> loop (i + 1) (f y acc)
| Error _ as r -> r) in
return (loop 0 acc)

let fold_s f {row_type; result} acc =
let n = result#ntuples in
let rec loop i acc =
if i = n then return (Ok acc) else
(match decode_row ~uri row_type (result, i) with
| Ok y -> f y acc >>=? loop (i + 1)
| Error _ as r -> return r) in
loop 0 acc

let iter_s f {row_type; result} =
let n = result#ntuples in
let rec loop i () =
if i = n then return (Ok ()) else
(match decode_row ~uri row_type (result, i) with
| Ok y -> f y >>=? loop (i + 1)
| Error _ as r -> return r) in
loop 0 ()

let to_stream {row_type; result} =
let n = result#ntuples in
let rec f i () =
if i = n then return Stream.Nil else
(match decode_row ~uri row_type (result, i) with
| Ok y -> return @@ Stream.Cons (y, f (i + 1))
| Error err -> return @@ Stream.Error err) in
f 0
let find = function
| {row_type; source = Complete result; _} ->
return (decode_row ~uri row_type (result, 0))
| {source = Single_row; _} ->
assert false

let find_opt = function
| {row_type; source = Complete result; _} ->
return begin
if result#ntuples = 0 then Ok None else
(match decode_row ~uri row_type (result, 0) with
| Ok y -> Ok (Some y)
| Error _ as r -> r)
end
| {source = Single_row; _} ->
assert false

let fold f = function
| {row_type; source = Complete result; _} ->
let n = result#ntuples in
let rec loop i acc =
if i = n then Ok acc else
(match decode_row ~uri row_type (result, i) with
| Ok y -> loop (i + 1) (f y acc)
| Error _ as r -> r)
in
fun acc -> return (loop 0 acc)
| {row_type; source = Single_row; query} ->
let rec loop acc =
fetch_single_row ~row_type ~query () >>=? function
| None -> return (Ok acc)
| Some y -> loop (f y acc)
in
loop

let fold_s f = function
| {row_type; source = Complete result; _} ->
let n = result#ntuples in
let rec loop i acc =
if i = n then return (Ok acc) else
(match decode_row ~uri row_type (result, i) with
| Ok y -> f y acc >>=? loop (i + 1)
| Error _ as r -> return r)
in
loop 0
| {row_type; source = Single_row; query} ->
let rec loop acc =
fetch_single_row ~row_type ~query () >>=? function
| None -> return (Ok acc)
| Some y -> f y acc >>=? loop
in
loop

let iter_s f = function
| {row_type; source = Complete result; _} ->
let n = result#ntuples in
let rec loop i =
if i = n then return (Ok ()) else
(match decode_row ~uri row_type (result, i) with
| Ok y -> f y >>=? fun () -> loop (i + 1)
| Error _ as r -> return r)
in
loop 0
| {row_type; source = Single_row; query} ->
let rec loop () =
fetch_single_row ~row_type ~query () >>=? function
| None -> return (Ok ())
| Some y -> f y >>=? fun () -> loop ()
in
loop ()

let to_stream = function
| {row_type; source = Complete result; _} ->
let n = result#ntuples in
let rec f i () =
if i = n then return Stream.Nil else
(match decode_row ~uri row_type (result, i) with
| Ok y -> return (Stream.Cons (y, f (i + 1)))
| Error err -> return (Stream.Error err))
in
f 0
| {row_type; source = Single_row; query} ->
let rec f () =
fetch_single_row ~row_type ~query () >|= function
| Ok None -> Stream.Nil
| Ok (Some y) -> Stream.Cons (y, f)
| Error err -> Stream.Error err
in
f
end

let type_oid_cache = Hashtbl.create 19
Expand All @@ -633,6 +728,11 @@ module Connect_functor (System : Caqti_platform_unix.System_sig.S) = struct
Log.debug ~src:Logging.request_log_src (fun f ->
f "Sending %a" pp_request_with_param (req, param)) >>= fun () ->

let single_row_mode =
use_single_row_mode
&& Caqti_mult.can_be_many (Caqti_request.row_mult req)
in

(* Prepare, if requested, and send the query. *)
let param_type = Caqti_request.param_type req in
(match Caqti_request.query_id req with
Expand All @@ -648,7 +748,8 @@ module Connect_functor (System : Caqti_platform_unix.System_sig.S) = struct
let params = Array.make param_length Pg.null in
(match Param_encoder.encode ~uri params param_type param with
| Ok () ->
send_oneshot_query ~params ~binary_params query >|=? fun () ->
send_oneshot_query ~params ~binary_params ~single_row_mode query
>|=? fun () ->
Ok query
| Error _ as r ->
return r)
Expand All @@ -664,7 +765,8 @@ module Connect_functor (System : Caqti_platform_unix.System_sig.S) = struct
init_param_types
~uri ~type_oid_cache param_types binary_params param_type
|>? fun () ->
Ok {query; param_length; param_types; binary_params}
Ok {query; param_length; param_types; binary_params;
single_row_mode}
end |> return >>=? fun prepared ->
let params = Array.make prepared.param_length Pg.null in
(match Param_encoder.encode ~uri params param_type param with
Expand All @@ -676,12 +778,19 @@ module Connect_functor (System : Caqti_platform_unix.System_sig.S) = struct
>>=? fun query ->

(* Fetch and process the result. *)
fetch_final_result ~query () >>=? fun result ->
let row_type = Caqti_request.row_type req in
let row_mult = Caqti_request.row_mult req in
(match Pg_io.check_query_result ~uri ~query result row_mult with
| Ok () -> f Response.{row_type; result}
| Error _ as r -> return r)
if single_row_mode then
f Response.{row_type; query; source = Single_row}
else begin
let row_mult = Caqti_request.row_mult req in
fetch_final_result ~query () >>=? fun result ->
(match
Pg_io.check_query_result
~uri ~query ~row_mult ~single_row_mode result
with
| Ok () -> f Response.{row_type; query; source = Complete result}
| Error _ as r -> return r)
end

let rec fetch_type_oids : type a. a Caqti_type.t -> _ = function
| Caqti_type.Unit -> return (Ok ())
Expand Down Expand Up @@ -741,7 +850,9 @@ module Connect_functor (System : Caqti_platform_unix.System_sig.S) = struct
send_oneshot_query query >>=? fun () ->
fetch_final_result ~query () >|=? fun result ->
Int_hashtbl.remove prepare_cache query_id;
Pg_io.check_query_result ~uri ~query result Caqti_mult.zero
Pg_io.check_query_result
~uri ~query ~row_mult:Caqti_mult.zero ~single_row_mode:false
result
end
else
return (Ok ())
Expand Down Expand Up @@ -852,7 +963,8 @@ module Connect_functor (System : Caqti_platform_unix.System_sig.S) = struct
end

let connect ?(env = no_env) ~tweaks_version:_ uri =
return (Pg_ext.parse_uri uri) >>=? fun (conninfo, notice_processing) ->
return (Pg_ext.parse_uri uri)
>>=? fun (conninfo, notice_processing, use_single_row_mode) ->
(match new Pg.connection ~conninfo () with
| exception Pg.Error err ->
let msg = extract_connect_error err in
Expand All @@ -878,6 +990,7 @@ module Connect_functor (System : Caqti_platform_unix.System_sig.S) = struct
let env = env
let uri = uri
let db = db
let use_single_row_mode = use_single_row_mode
end)
in
let module Connection = struct
Expand Down
8 changes: 8 additions & 0 deletions caqti/lib/caqti_mult.ml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,11 @@ let expose = function
| One -> `One
| Zero_or_one -> `Zero_or_one
| Zero_or_more -> `Zero_or_more

let can_be_zero = function
| One -> false
| Zero | Zero_or_one | Zero_or_more -> true

let can_be_many = function
| Zero | One | Zero_or_one -> false
| Zero_or_more -> true
Loading

0 comments on commit 60f2e01

Please sign in to comment.