diff --git a/examples/eio/eio_connect_client.ml b/examples/eio/eio_connect_client.ml index 5ade820..d0e2f21 100644 --- a/examples/eio/eio_connect_client.ml +++ b/examples/eio/eio_connect_client.ml @@ -25,21 +25,24 @@ let proxy_handler _env ~sw ~headers flow ~on_eof response _response_body = We'll be boring and use HTTP/1.1 again. *) let connection = Httpun_eio.Client.create_connection ~sw flow in let exit_cond = Eio.Condition.create () in - let response_handler = - http_handler ~on_eof:(fun () -> - Stdlib.Format.eprintf "http eof@."; - Eio.Condition.broadcast exit_cond; - on_eof ()) - in - let request_body = - Httpun_eio.Client.request - ~flush_headers_immediately:true - ~error_handler:Httpun_examples.Client.error_handler - ~response_handler - connection - (Request.create ~headers `GET "/") - in - Body.Writer.close request_body + Eio.Fiber.fork ~sw (fun () -> + let response_handler = + http_handler ~on_eof:(fun () -> + Stdlib.Format.eprintf "http eof@."; + Eio.Condition.broadcast exit_cond; + on_eof ()) + in + let request_body = + Httpun_eio.Client.request + ~flush_headers_immediately:true + ~error_handler:Httpun_examples.Client.error_handler + ~response_handler + connection + (Request.create ~headers `GET "/") + in + Body.Writer.close request_body); + Eio.Condition.await_no_mutex exit_cond; + Httpun_eio.Client.shutdown connection |> Eio.Promise.await | _response -> Stdlib.exit 124 let main port proxy_host = @@ -61,10 +64,15 @@ let main port proxy_host = let connection = Httpun_eio.Client.create_connection ~sw socket in let exit_cond = Eio.Condition.create () in + Eio.Fiber.fork ~sw (fun ()-> let response_handler = - proxy_handler _env ~sw socket ~headers ~on_eof:(fun () -> - Stdlib.Format.eprintf "(connect) eof@."; - Eio.Condition.broadcast exit_cond) + fun response response_body -> + Eio.Fiber.fork ~sw @@ fun () -> + proxy_handler _env ~sw socket ~headers ~on_eof:(fun () -> + Stdlib.Format.eprintf "(connect) eof@."; + Eio.Condition.broadcast exit_cond) + response + response_body in let request_body = Httpun_eio.Client.request @@ -77,7 +85,7 @@ let main port proxy_host = Body.Writer.close request_body; Eio.Condition.await_no_mutex exit_cond; - Httpun_eio.Client.shutdown connection |> Eio.Promise.await)) + Httpun_eio.Client.shutdown connection |> Eio.Promise.await))) let () = let host = ref None in diff --git a/lib/respd.ml b/lib/respd.ml index 35c9c59..17389a2 100644 --- a/lib/respd.ml +++ b/lib/respd.ml @@ -104,12 +104,7 @@ let input_state t : Io_state.t = else if Body.Reader.is_read_scheduled response_body then Ready else Wait - | Upgraded _ -> - (* Upgraded is "Complete" because the descriptor doesn't wish to receive - * any more input. - * XXX(anmonteiro): not true for `CONNECT - *) - Wait + | Upgraded _ -> Wait | Closed -> Complete let output_state { request_body; state; writer; _ } : Io_state.t = diff --git a/lib/response_state.ml b/lib/response_state.ml index 3623991..80c3ba7 100644 --- a/lib/response_state.ml +++ b/lib/response_state.ml @@ -4,21 +4,23 @@ type t = | Streaming of Response.t * Body.Writer.t | Upgrade of Response.t * (unit -> unit) -let output_state t ~request_method ~writer : Io_state.t = - match request_method with - | `CONNECT -> Wait - | _ -> - match t with - | Fixed _ -> Complete +let output_state = + let response_sent_state = function + | `CONNECT -> Io_state.Wait + | _ -> Complete + in + fun t ~request_method ~writer : Io_state.t -> + match t with + | Upgrade _ -> Wait | Waiting -> if Serialize.Writer.is_closed writer then Complete else Wait + | Fixed _ -> response_sent_state request_method | Streaming(_, response_body) -> - if Serialize.Writer.is_closed writer then Complete + if Serialize.Writer.is_closed writer then response_sent_state request_method else if Body.Writer.requires_output response_body then Ready - else Complete - | Upgrade _ -> Wait + else response_sent_state request_method let flush_response_body t = match t with diff --git a/lib_test/test_client_connection.ml b/lib_test/test_client_connection.ml index 17e278c..f3e0a18 100644 --- a/lib_test/test_client_connection.ml +++ b/lib_test/test_client_connection.ml @@ -1866,6 +1866,41 @@ let test_read_response_before_shutdown () = connection_is_shutdown t; ;; +let test_client_connect () = + let writer_woken_up = ref false in + let reader_woken_up = ref false in + let request' = Request.create + ~headers:(Headers.of_list ["host", "example.com:80"]) + `CONNECT "/" + in + let t = create () in + let response = Response.create `OK in + let body = + request + t + request' + ~flush_headers_immediately:true + ~response_handler:(default_response_handler response) + ~error_handler:no_error_handler + in + write_request t request'; + writer_yielded t; + Body.Writer.close body; + reader_ready t; + read_response t response; + reader_yielded t; + yield_reader t (fun () -> reader_woken_up := true); + writer_yielded t; + yield_writer t (fun () -> writer_woken_up := true); + Alcotest.(check bool) "Reader hasn't woken up yet" false !reader_woken_up; + Alcotest.(check bool) "Writer hasn't woken up yet" false !writer_woken_up; + shutdown t; + Alcotest.(check bool) "Reader woken up" true !reader_woken_up; + Alcotest.(check bool) "Writer woken up" true !writer_woken_up; + connection_is_shutdown t; +;; + + let tests = [ "commit parse after every header line", `Quick, test_commit_parse_after_every_header ; "GET" , `Quick, test_get @@ -1914,4 +1949,5 @@ let tests = ; "shut down closes request body ", `Quick, test_read_response_before_shutdown ; "report exn during body read", `Quick, test_report_exn_during_body_read ; "read response after write eof", `Quick, test_can_read_response_after_write_eof + ; "Client support for CONNECT", `Quick, test_client_connect ] diff --git a/lib_test/test_server_connection.ml b/lib_test/test_server_connection.ml index d778749..22a0863 100644 --- a/lib_test/test_server_connection.ml +++ b/lib_test/test_server_connection.ml @@ -2438,6 +2438,24 @@ let test_write_response_after_read_eof () = connection_is_shutdown t; ;; +let test_connect_method () = + let upgraded = ref false in + let upgrade_handler reqd = + Reqd.respond_with_upgrade reqd Headers.empty (fun () -> + upgraded := true) + in + let t = create ~error_handler upgrade_handler in + read_request + t + (Request.create + ~headers:(Headers.of_list [ "host", "example.com:80" ]) + `CONNECT + "/"); + write_response ~msg:"Upgrade response written" t (Response.create `Switching_protocols); + Alcotest.(check bool) "Callback was called" true !upgraded; + reader_yielded t; +;; + let tests = [ "initial reader state" , `Quick, test_initial_reader_state ; "shutdown reader closed", `Quick, test_reader_is_closed_after_eof @@ -2520,4 +2538,5 @@ let tests = ; "can read more requests after write eof", `Quick, test_can_read_more_requests_after_write_eof ; "can read more requests after write eof (before response sent)", `Quick, test_can_read_more_requests_after_write_eof_before_send_response ; "write response after reader EOF", `Quick,test_write_response_after_read_eof + ; "CONNECT method", `Quick, test_connect_method ]