From 7c66137bc764bb040c211a65b5a829cb351bdb2a Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Thu, 8 Feb 2024 12:45:29 +0100 Subject: [PATCH] Adapt to mirage-flow 4 API (#70) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * mirage: avoid an assert false, properly return an error * provide close and shutdown in Awa_mirage * simplify - a shutdown \`read_write is a close * mirage: preserve half-closed connections, and deal with them properly * mirage: avoid assertions * address @reynir review - and use inject_state * mirage: revise close and shutdown first to the ssh teardown, then do the underlying flow teardown * shutdown: don't shutdown the flow unless closed If we are in `Read_closed we may still want to read channel-close and when we are in `Write_closed we may still want to write channel-close. * mirage: set closed earlier in close(); also remove TODO comment * mirage: add comment about states and why errors may occur that we ignore (thanks to @dinosaure) * minor tweaks * shutdown: if in closed/error state, call close on the underlying flow nevertheless --------- Co-authored-by: Reynir Björnsson Co-authored-by: Romain Calascibetta --- awa-mirage.opam | 2 +- lib/client.ml | 20 ++++++ lib/client.mli | 4 ++ mirage/awa_mirage.ml | 145 +++++++++++++++++++++++++++++++----------- mirage/awa_mirage.mli | 6 +- 5 files changed, 135 insertions(+), 42 deletions(-) diff --git a/awa-mirage.opam b/awa-mirage.opam index 772db0a..eeecb67 100644 --- a/awa-mirage.opam +++ b/awa-mirage.opam @@ -22,7 +22,7 @@ depends: [ "lwt" {>= "5.3.0"} "mirage-time" {>= "2.0.0"} "duration" {>= "0.2.0"} - "mirage-flow" {>= "2.0.0"} + "mirage-flow" {>= "4.0.0"} "mirage-clock" {>= "3.0.0"} "logs" ] diff --git a/lib/client.ml b/lib/client.ml index ab2f44c..f8f3ce4 100644 --- a/lib/client.ml +++ b/lib/client.ml @@ -523,3 +523,23 @@ let outgoing_data t ?(id = 0l) data = let* c, frags = Channel.output_data c data in let t' = { t with channels = Channel.update c t.channels } in Ok (output_msgs t' frags) + +let eof ?(id = 0l) t = + match + let* () = guard (established t) "not yet established" in + let* c = guard_some (Channel.lookup id t.channels) "no such channel" in + let msg = Ssh.Msg_channel_eof c.them.id in + Ok (output_msg t msg) + with + | Error _ -> t, None + | Ok (t, msg) -> t, Some msg + +let close ?(id = 0l) t = + match + let* () = guard (established t) "not yet established" in + let* c = guard_some (Channel.lookup id t.channels) "no such channel" in + let msg = Ssh.Msg_channel_close c.them.id in + Ok (output_msg t msg) + with + | Error _ -> t, None + | Ok (t, msg) -> t, Some msg diff --git a/lib/client.mli b/lib/client.mli index 5e9ee5d..3cf73ef 100644 --- a/lib/client.mli +++ b/lib/client.mli @@ -38,3 +38,7 @@ val outgoing_request : t -> ?id:int32 -> ?want_reply:bool -> val outgoing_data : t -> ?id:int32 -> Cstruct.t -> (t * Cstruct.t list, string) result + +val eof : ?id:int32 -> t -> t * Cstruct.t option + +val close : ?id:int32 -> t -> t * Cstruct.t option diff --git a/mirage/awa_mirage.ml b/mirage/awa_mirage.ml index b0bad82..b5e44b5 100644 --- a/mirage/awa_mirage.ml +++ b/mirage/awa_mirage.ml @@ -5,7 +5,6 @@ module Log = (val Logs.src_log src : Logs.LOG) module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = struct - module FLOW = F module MCLOCK = M type error = [ `Msg of string @@ -22,22 +21,63 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = | #Mirage_flow.write_error as e -> Mirage_flow.pp_write_error ppf e | #error as e -> pp_error ppf e + (* this is the flow of a ssh-client. be aware that we're only using a single + channel. + + the state `Read_closed is set (a) when a TCP.read returned `Eof, + and (b) when the application did a shutdown `read (or `read_write). + the state `Write_closed is set (a) when a TCP.write returned `Closed, + and (b) when the application did a shutdown `write (or `read_write). + + If we're in `Write_closed, and do a shutdown `read, we'll end up in + `Closed, and attempt to (a) send a SSH_MSG_CHANNEL_CLOSE and (b) TCP.close. + This may fail, since on the TCP layer, the connection may have already be + half-closed (or fully closed) in the write direction. We ignore this error + from writev below in close. + *) type flow = { - flow : FLOW.flow ; - mutable state : [ `Active of Awa.Client.t | `Eof | `Error of error ] + flow : F.flow ; + mutable state : [ + | `Active of Awa.Client.t + | `Read_closed of Awa.Client.t + | `Write_closed of Awa.Client.t + | `Closed + | `Error of error ] } + let half_close state mode = + match state, mode with + | `Active ssh, `read -> `Read_closed ssh + | `Active ssh, `write -> `Write_closed ssh + | `Active _, `read_write -> `Closed + | `Read_closed ssh, `read -> `Read_closed ssh + | `Read_closed _, (`write | `read_write) -> `Closed + | `Write_closed ssh, `write -> `Write_closed ssh + | `Write_closed _, (`read | `read_write) -> `Closed + | (`Closed | `Error _) as e, (`read | `write | `read_write) -> e + + let inject_state ssh = function + | `Active _ -> `Active ssh + | `Read_closed _ -> `Read_closed ssh + | `Write_closed _ -> `Write_closed ssh + | (`Closed | `Error _) as e -> e + let write_flow t buf = - FLOW.write t.flow buf >>= function - | Ok () -> Lwt.return (Ok ()) + F.write t.flow buf >>= function + | Ok _ as o -> Lwt.return o + | Error `Closed -> + Log.warn (fun m -> m "error closed while writing"); + t.state <- half_close t.state `write; + Lwt.return (Error (`Write `Closed)) | Error w -> Log.warn (fun m -> m "error %a while writing" F.pp_write_error w); - t.state <- `Error (`Write w) ; Lwt.return (Error (`Write w)) + t.state <- `Error (`Write w); + Lwt.return (Error (`Write w)) let writev_flow t bufs = Lwt_list.fold_left_s (fun r d -> match r with - | Error e -> Lwt.return (Error e) + | Error _ as e -> Lwt.return e | Ok () -> write_flow t d) (Ok ()) bufs @@ -46,25 +86,27 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = let read_react t = match t.state with - | `Eof | `Error _ -> Lwt.return (Error ()) - | `Active _ -> - FLOW.read t.flow >>= function + | `Read_closed _ | `Closed | `Error _ -> Lwt.return (Error ()) + | `Active _ | `Write_closed _ -> + F.read t.flow >>= function | Error e -> Log.warn (fun m -> m "error %a while reading" F.pp_error e); t.state <- `Error (`Read e); Lwt.return (Error ()) - | Ok `Eof -> t.state <- `Eof ; Lwt.return (Error ()) + | Ok `Eof -> + t.state <- half_close t.state `read; + Lwt.return (Error ()) | Ok (`Data data) -> match t.state with - | `Active ssh -> + | `Active ssh | `Write_closed ssh -> begin match Awa.Client.incoming ssh (now ()) data with | Error msg -> Log.warn (fun m -> m "error %s while processing data" msg); t.state <- `Error (`Msg msg); Lwt.return (Error ()) | Ok (ssh', out, events) -> - let state' = if List.mem `Disconnected events then `Eof else `Active ssh' in - t.state <- state'; + t.state <- + inject_state ssh' (if List.mem `Disconnected events then half_close t.state `read else t.state); writev_flow t out >>= fun _ -> Lwt.return (Ok events) end @@ -74,15 +116,14 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = read_react t >>= function | Ok es -> begin match t.state, List.filter (function `Established _ -> true | _ -> false) es with - | `Eof, _ -> Lwt.return (Error (`Msg "disconnected")) + | (`Read_closed _ | `Closed), _ -> Lwt.return (Error (`Msg "disconnected")) | `Error e, _ -> Lwt.return (Error e) - | `Active _, [ `Established id ] -> Lwt.return (Ok id) - | `Active _, _ -> drain_handshake t + | (`Active _ | `Write_closed _), [ `Established id ] -> Lwt.return (Ok id) + | (`Active _ | `Write_closed _), _ -> drain_handshake t end | Error () -> match t.state with | `Error e -> Lwt.return (Error e) - | `Eof -> Lwt.return (Error (`Msg "disconnected")) - | `Active _ -> assert false + | `Closed | `Read_closed _ | `Active _ | `Write_closed _ -> Lwt.return (Error (`Msg "disconnected")) let rec read t = read_react t >>= function @@ -107,32 +148,57 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = end | Error () -> match t.state with | `Error e -> Lwt.return (Error e) - | `Eof -> Lwt.return (Ok `Eof) - | `Active _ -> assert false + | `Closed | `Read_closed _ | `Active _ | `Write_closed _ -> Lwt.return (Ok `Eof) let close t = - (* TODO ssh session teardown (send some protocol messages) *) - FLOW.close t.flow >|= fun () -> - t.state <- `Eof + (match t.state with + | `Active ssh | `Read_closed ssh | `Write_closed ssh -> + let ssh, msg = Awa.Client.close ssh in + t.state <- inject_state ssh t.state; + t.state <- `Closed; + (* as outlined above, this may fail since the TCP flow may already be (half-)closed *) + writev_flow t (Option.to_list msg) >|= ignore + | `Error _ | `Closed -> Lwt.return_unit) >>= fun () -> + F.close t.flow + + let shutdown t mode = + match t.state with + | `Active ssh | `Read_closed ssh | `Write_closed ssh -> + let ssh, msg = + match t.state, mode with + | (`Active ssh | `Read_closed ssh), `write -> Awa.Client.eof ssh + | _, `read_write -> Awa.Client.close ssh + | _ -> ssh, None + in + t.state <- inject_state ssh (half_close t.state mode); + (* as outlined above, this may fail since the TCP flow may already be (half-)closed *) + writev_flow t (Option.to_list msg) >>= fun _ -> + (* we don't [FLOW.shutdown _ mode] because we still need to read/write + channel_eof/channel_close unless both directions are closed *) + (match t.state with + | `Closed -> F.close t.flow + | _ -> Lwt.return_unit) + | `Error _ | `Closed -> + F.close t.flow let writev t bufs = let open Lwt_result.Infix in match t.state with - | `Active ssh -> + | `Active ssh | `Read_closed ssh -> Lwt_list.fold_left_s (fun r data -> match r with | Error e -> Lwt.return (Error e) | Ok ssh -> match Awa.Client.outgoing_data ssh data with | Ok (ssh', datas) -> - t.state <- `Active ssh'; + t.state <- inject_state ssh' t.state; writev_flow t datas >|= fun () -> ssh' | Error msg -> t.state <- `Error (`Msg msg) ; Lwt.return (Error (`Msg msg))) (Ok ssh) bufs >|= fun _ -> () - | `Eof -> Lwt.return (Error `Closed) + | `Write_closed _ | `Closed -> Lwt.return (Error `Closed) | `Error e -> Lwt.return (Error (e :> write_error)) let write t buf = writev t [buf] @@ -146,12 +212,17 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = } in writev_flow t msgs >>= fun () -> drain_handshake t >>= fun id -> - (* TODO that's a bit hardcoded... *) - let ssh = match t.state with `Active t -> t | _ -> assert false in - (match Awa.Client.outgoing_request ssh ~id req with - | Error msg -> t.state <- `Error (`Msg msg) ; Lwt.return (Error (`Msg msg)) - | Ok (ssh', data) -> t.state <- `Active ssh' ; write_flow t data) >|= fun () -> - t + match t.state with + | `Active ssh -> + (match Awa.Client.outgoing_request ssh ~id req with + | Error msg -> t.state <- `Error (`Msg msg) ; Lwt.return (Error (`Msg msg)) + | Ok (ssh', data) -> t.state <- `Active ssh' ; write_flow t data) >|= fun () -> + t + | `Read_closed _ -> Lwt.return (Error (`Msg "read closed")) + | `Write_closed _ -> Lwt.return (Error (`Msg "write closed")) + | `Closed -> Lwt.return (Error (`Msg "closed")) + | `Error e -> Lwt.return (Error e) + (* copy from awa_lwt.ml and unix references removed in favor to FLOW *) type nexus_msg = @@ -195,10 +266,10 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = let send_msg flow server msg = wrapr (Awa.Server.output_msg server msg) >>= fun (server, msg_buf) -> - FLOW.write flow msg_buf >>= function + F.write flow msg_buf >>= function | Ok () -> Lwt.return server | Error w -> - Log.err (fun m -> m "error %a while writing" FLOW.pp_write_error w); + Log.err (fun m -> m "error %a while writing" F.pp_write_error w); Lwt.return server let rec send_msgs fd server = function @@ -209,9 +280,9 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = | [] -> Lwt.return server let net_read flow = - FLOW.read flow >>= function + F.read flow >>= function | Error e -> - Log.err (fun m -> m "read error %a" FLOW.pp_error e); + Log.err (fun m -> m "read error %a" F.pp_error e); Lwt.return Net_eof | Ok `Eof -> Lwt.return Net_eof diff --git a/mirage/awa_mirage.mli b/mirage/awa_mirage.mli index 7ada427..51aef1d 100644 --- a/mirage/awa_mirage.mli +++ b/mirage/awa_mirage.mli @@ -3,8 +3,6 @@ (** SSH module given a flow *) module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) : sig - module FLOW : Mirage_flow.S - (** possible errors: incoming alert, processing failure, or a problem in the underlying flow. *) type error = [ `Msg of string @@ -24,7 +22,7 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) : sends the channel request. *) val client_of_flow : ?authenticator:Awa.Keys.authenticator -> user:string -> [ `Pubkey of Awa.Hostkey.priv | `Password of string ] -> - Awa.Ssh.channel_request -> FLOW.flow -> (flow, error) result Lwt.t + Awa.Ssh.channel_request -> F.flow -> (flow, error) result Lwt.t type t @@ -64,4 +62,4 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) : {b NOTE}: Even if the [ssh_channel_handler] is fulfilled, [spawn_server] continues to handle SSH channels. Only [stop] can really stop the internal SSH channels handler. *) -end with module FLOW = F +end