Skip to content

Commit

Permalink
Adapt to mirage-flow 4 API (#70)
Browse files Browse the repository at this point in the history
* mirage: avoid an assert false, properly return an error

* provide close and shutdown in Awa_mirage

* simplify - a shutdown \`read_write is a close

* mirage: preserve half-closed connections, and deal with them properly

* mirage: avoid assertions

* address @reynir review - and use inject_state

* mirage: revise close and shutdown

first to the ssh teardown, then do the underlying flow teardown

* shutdown: don't shutdown the flow unless closed

If we are in `Read_closed we may still want to read channel-close and
when we are in `Write_closed we may still want to write channel-close.

* mirage: set closed earlier in close(); also remove TODO comment

* mirage: add comment about states and why errors may occur that we ignore (thanks to @dinosaure)

* minor tweaks

* shutdown: if in closed/error state, call close on the underlying flow nevertheless

---------

Co-authored-by: Reynir Björnsson <[email protected]>
Co-authored-by: Romain Calascibetta <[email protected]>
  • Loading branch information
3 people authored Feb 8, 2024
1 parent 389c1f3 commit 7c66137
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 42 deletions.
2 changes: 1 addition & 1 deletion awa-mirage.opam
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ depends: [
"lwt" {>= "5.3.0"}
"mirage-time" {>= "2.0.0"}
"duration" {>= "0.2.0"}
"mirage-flow" {>= "2.0.0"}
"mirage-flow" {>= "4.0.0"}
"mirage-clock" {>= "3.0.0"}
"logs"
]
Expand Down
20 changes: 20 additions & 0 deletions lib/client.ml
Original file line number Diff line number Diff line change
Expand Up @@ -523,3 +523,23 @@ let outgoing_data t ?(id = 0l) data =
let* c, frags = Channel.output_data c data in
let t' = { t with channels = Channel.update c t.channels } in
Ok (output_msgs t' frags)

let eof ?(id = 0l) t =
match
let* () = guard (established t) "not yet established" in
let* c = guard_some (Channel.lookup id t.channels) "no such channel" in
let msg = Ssh.Msg_channel_eof c.them.id in
Ok (output_msg t msg)
with
| Error _ -> t, None
| Ok (t, msg) -> t, Some msg

let close ?(id = 0l) t =
match
let* () = guard (established t) "not yet established" in
let* c = guard_some (Channel.lookup id t.channels) "no such channel" in
let msg = Ssh.Msg_channel_close c.them.id in
Ok (output_msg t msg)
with
| Error _ -> t, None
| Ok (t, msg) -> t, Some msg
4 changes: 4 additions & 0 deletions lib/client.mli
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,7 @@ val outgoing_request : t -> ?id:int32 -> ?want_reply:bool ->

val outgoing_data : t -> ?id:int32 -> Cstruct.t ->
(t * Cstruct.t list, string) result

val eof : ?id:int32 -> t -> t * Cstruct.t option

val close : ?id:int32 -> t -> t * Cstruct.t option
145 changes: 108 additions & 37 deletions mirage/awa_mirage.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ module Log = (val Logs.src_log src : Logs.LOG)

module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) = struct

module FLOW = F
module MCLOCK = M

type error = [ `Msg of string
Expand All @@ -22,22 +21,63 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
| #Mirage_flow.write_error as e -> Mirage_flow.pp_write_error ppf e
| #error as e -> pp_error ppf e

(* this is the flow of a ssh-client. be aware that we're only using a single
channel.
the state `Read_closed is set (a) when a TCP.read returned `Eof,
and (b) when the application did a shutdown `read (or `read_write).
the state `Write_closed is set (a) when a TCP.write returned `Closed,
and (b) when the application did a shutdown `write (or `read_write).
If we're in `Write_closed, and do a shutdown `read, we'll end up in
`Closed, and attempt to (a) send a SSH_MSG_CHANNEL_CLOSE and (b) TCP.close.
This may fail, since on the TCP layer, the connection may have already be
half-closed (or fully closed) in the write direction. We ignore this error
from writev below in close.
*)
type flow = {
flow : FLOW.flow ;
mutable state : [ `Active of Awa.Client.t | `Eof | `Error of error ]
flow : F.flow ;
mutable state : [
| `Active of Awa.Client.t
| `Read_closed of Awa.Client.t
| `Write_closed of Awa.Client.t
| `Closed
| `Error of error ]
}

let half_close state mode =
match state, mode with
| `Active ssh, `read -> `Read_closed ssh
| `Active ssh, `write -> `Write_closed ssh
| `Active _, `read_write -> `Closed
| `Read_closed ssh, `read -> `Read_closed ssh
| `Read_closed _, (`write | `read_write) -> `Closed
| `Write_closed ssh, `write -> `Write_closed ssh
| `Write_closed _, (`read | `read_write) -> `Closed
| (`Closed | `Error _) as e, (`read | `write | `read_write) -> e

let inject_state ssh = function
| `Active _ -> `Active ssh
| `Read_closed _ -> `Read_closed ssh
| `Write_closed _ -> `Write_closed ssh
| (`Closed | `Error _) as e -> e

let write_flow t buf =
FLOW.write t.flow buf >>= function
| Ok () -> Lwt.return (Ok ())
F.write t.flow buf >>= function
| Ok _ as o -> Lwt.return o
| Error `Closed ->
Log.warn (fun m -> m "error closed while writing");
t.state <- half_close t.state `write;
Lwt.return (Error (`Write `Closed))
| Error w ->
Log.warn (fun m -> m "error %a while writing" F.pp_write_error w);
t.state <- `Error (`Write w) ; Lwt.return (Error (`Write w))
t.state <- `Error (`Write w);
Lwt.return (Error (`Write w))

let writev_flow t bufs =
Lwt_list.fold_left_s (fun r d ->
match r with
| Error e -> Lwt.return (Error e)
| Error _ as e -> Lwt.return e
| Ok () -> write_flow t d)
(Ok ()) bufs

Expand All @@ -46,25 +86,27 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =

let read_react t =
match t.state with
| `Eof | `Error _ -> Lwt.return (Error ())
| `Active _ ->
FLOW.read t.flow >>= function
| `Read_closed _ | `Closed | `Error _ -> Lwt.return (Error ())
| `Active _ | `Write_closed _ ->
F.read t.flow >>= function
| Error e ->
Log.warn (fun m -> m "error %a while reading" F.pp_error e);
t.state <- `Error (`Read e);
Lwt.return (Error ())
| Ok `Eof -> t.state <- `Eof ; Lwt.return (Error ())
| Ok `Eof ->
t.state <- half_close t.state `read;
Lwt.return (Error ())
| Ok (`Data data) ->
match t.state with
| `Active ssh ->
| `Active ssh | `Write_closed ssh ->
begin match Awa.Client.incoming ssh (now ()) data with
| Error msg ->
Log.warn (fun m -> m "error %s while processing data" msg);
t.state <- `Error (`Msg msg);
Lwt.return (Error ())
| Ok (ssh', out, events) ->
let state' = if List.mem `Disconnected events then `Eof else `Active ssh' in
t.state <- state';
t.state <-
inject_state ssh' (if List.mem `Disconnected events then half_close t.state `read else t.state);
writev_flow t out >>= fun _ ->
Lwt.return (Ok events)
end
Expand All @@ -74,15 +116,14 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
read_react t >>= function
| Ok es ->
begin match t.state, List.filter (function `Established _ -> true | _ -> false) es with
| `Eof, _ -> Lwt.return (Error (`Msg "disconnected"))
| (`Read_closed _ | `Closed), _ -> Lwt.return (Error (`Msg "disconnected"))
| `Error e, _ -> Lwt.return (Error e)
| `Active _, [ `Established id ] -> Lwt.return (Ok id)
| `Active _, _ -> drain_handshake t
| (`Active _ | `Write_closed _), [ `Established id ] -> Lwt.return (Ok id)
| (`Active _ | `Write_closed _), _ -> drain_handshake t
end
| Error () -> match t.state with
| `Error e -> Lwt.return (Error e)
| `Eof -> Lwt.return (Error (`Msg "disconnected"))
| `Active _ -> assert false
| `Closed | `Read_closed _ | `Active _ | `Write_closed _ -> Lwt.return (Error (`Msg "disconnected"))

let rec read t =
read_react t >>= function
Expand All @@ -107,32 +148,57 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
end
| Error () -> match t.state with
| `Error e -> Lwt.return (Error e)
| `Eof -> Lwt.return (Ok `Eof)
| `Active _ -> assert false
| `Closed | `Read_closed _ | `Active _ | `Write_closed _ -> Lwt.return (Ok `Eof)

let close t =
(* TODO ssh session teardown (send some protocol messages) *)
FLOW.close t.flow >|= fun () ->
t.state <- `Eof
(match t.state with
| `Active ssh | `Read_closed ssh | `Write_closed ssh ->
let ssh, msg = Awa.Client.close ssh in
t.state <- inject_state ssh t.state;
t.state <- `Closed;
(* as outlined above, this may fail since the TCP flow may already be (half-)closed *)
writev_flow t (Option.to_list msg) >|= ignore
| `Error _ | `Closed -> Lwt.return_unit) >>= fun () ->
F.close t.flow

let shutdown t mode =
match t.state with
| `Active ssh | `Read_closed ssh | `Write_closed ssh ->
let ssh, msg =
match t.state, mode with
| (`Active ssh | `Read_closed ssh), `write -> Awa.Client.eof ssh
| _, `read_write -> Awa.Client.close ssh
| _ -> ssh, None
in
t.state <- inject_state ssh (half_close t.state mode);
(* as outlined above, this may fail since the TCP flow may already be (half-)closed *)
writev_flow t (Option.to_list msg) >>= fun _ ->
(* we don't [FLOW.shutdown _ mode] because we still need to read/write
channel_eof/channel_close unless both directions are closed *)
(match t.state with
| `Closed -> F.close t.flow
| _ -> Lwt.return_unit)
| `Error _ | `Closed ->
F.close t.flow

let writev t bufs =
let open Lwt_result.Infix in
match t.state with
| `Active ssh ->
| `Active ssh | `Read_closed ssh ->
Lwt_list.fold_left_s (fun r data ->
match r with
| Error e -> Lwt.return (Error e)
| Ok ssh ->
match Awa.Client.outgoing_data ssh data with
| Ok (ssh', datas) ->
t.state <- `Active ssh';
t.state <- inject_state ssh' t.state;
writev_flow t datas >|= fun () ->
ssh'
| Error msg ->
t.state <- `Error (`Msg msg) ;
Lwt.return (Error (`Msg msg)))
(Ok ssh) bufs >|= fun _ -> ()
| `Eof -> Lwt.return (Error `Closed)
| `Write_closed _ | `Closed -> Lwt.return (Error `Closed)
| `Error e -> Lwt.return (Error (e :> write_error))

let write t buf = writev t [buf]
Expand All @@ -146,12 +212,17 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
} in
writev_flow t msgs >>= fun () ->
drain_handshake t >>= fun id ->
(* TODO that's a bit hardcoded... *)
let ssh = match t.state with `Active t -> t | _ -> assert false in
(match Awa.Client.outgoing_request ssh ~id req with
| Error msg -> t.state <- `Error (`Msg msg) ; Lwt.return (Error (`Msg msg))
| Ok (ssh', data) -> t.state <- `Active ssh' ; write_flow t data) >|= fun () ->
t
match t.state with
| `Active ssh ->
(match Awa.Client.outgoing_request ssh ~id req with
| Error msg -> t.state <- `Error (`Msg msg) ; Lwt.return (Error (`Msg msg))
| Ok (ssh', data) -> t.state <- `Active ssh' ; write_flow t data) >|= fun () ->
t
| `Read_closed _ -> Lwt.return (Error (`Msg "read closed"))
| `Write_closed _ -> Lwt.return (Error (`Msg "write closed"))
| `Closed -> Lwt.return (Error (`Msg "closed"))
| `Error e -> Lwt.return (Error e)


(* copy from awa_lwt.ml and unix references removed in favor to FLOW *)
type nexus_msg =
Expand Down Expand Up @@ -195,10 +266,10 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
let send_msg flow server msg =
wrapr (Awa.Server.output_msg server msg)
>>= fun (server, msg_buf) ->
FLOW.write flow msg_buf >>= function
F.write flow msg_buf >>= function
| Ok () -> Lwt.return server
| Error w ->
Log.err (fun m -> m "error %a while writing" FLOW.pp_write_error w);
Log.err (fun m -> m "error %a while writing" F.pp_write_error w);
Lwt.return server

let rec send_msgs fd server = function
Expand All @@ -209,9 +280,9 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) =
| [] -> Lwt.return server

let net_read flow =
FLOW.read flow >>= function
F.read flow >>= function
| Error e ->
Log.err (fun m -> m "read error %a" FLOW.pp_error e);
Log.err (fun m -> m "read error %a" F.pp_error e);
Lwt.return Net_eof
| Ok `Eof ->
Lwt.return Net_eof
Expand Down
6 changes: 2 additions & 4 deletions mirage/awa_mirage.mli
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
(** SSH module given a flow *)
module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) : sig

module FLOW : Mirage_flow.S

(** possible errors: incoming alert, processing failure, or a
problem in the underlying flow. *)
type error = [ `Msg of string
Expand All @@ -24,7 +22,7 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) :
sends the channel request. *)
val client_of_flow : ?authenticator:Awa.Keys.authenticator -> user:string ->
[ `Pubkey of Awa.Hostkey.priv | `Password of string ] ->
Awa.Ssh.channel_request -> FLOW.flow -> (flow, error) result Lwt.t
Awa.Ssh.channel_request -> F.flow -> (flow, error) result Lwt.t

type t

Expand Down Expand Up @@ -64,4 +62,4 @@ module Make (F : Mirage_flow.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) :
{b NOTE}: Even if the [ssh_channel_handler] is fulfilled, [spawn_server]
continues to handle SSH channels. Only [stop] can really stop the internal
SSH channels handler. *)
end with module FLOW = F
end

0 comments on commit 7c66137

Please sign in to comment.