Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various optimizations #280

Merged
merged 6 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bench/bench_engine.ml
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ let test_receive_data cipher =
Staged.stage @@ fun () ->
match Miragevpn.handle established_client (`Data pkt) with
| Ok _ -> ()
| Error _ -> assert false
| Error err -> Format.kasprintf failwith "%a" Miragevpn.pp_error err
in
Test.make ~name:"decode data" staged

Expand Down
153 changes: 82 additions & 71 deletions src/engine.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1108,11 +1108,11 @@ let[@coverage off] pp_error ppf = function
actual
| `Msg msg -> Fmt.string ppf msg

let unpad block_size cs =
let l = String.length cs in
let amount = String.get_uint8 cs (pred l) in
let unpad block_size cs off =
let l = String.length cs - off in
let amount = String.get_uint8 cs (off + pred l) in
let len = l - amount in
if len >= 0 && amount <= block_size then Ok (String.sub cs 0 len)
if len >= 0 && amount <= block_size then Ok (String.sub cs off len)
else Error (`Msg "bad padding")

let out ?add_timestamp prefix_len (ctx : keys) hmac_algorithm compress rng data
Expand All @@ -1121,10 +1121,18 @@ let out ?add_timestamp prefix_len (ctx : keys) hmac_algorithm compress rng data
the ~add_timestamp argument is only used in static key mode
*)
let set_replay_id dest off = Bytes.set_int32_be dest off ctx.my_replay_id in
let aead (type key)
(authenticate_encrypt_tag :
key:key -> nonce:string -> ?adata:string -> string -> string * string)
(my_key : key) my_implicit_iv =
let aead (type key) tag_size
(authenticate_encrypt_into :
key:key ->
nonce:string ->
?adata:string ->
string ->
src_off:int ->
bytes ->
dst_off:int ->
tag_off:int ->
int ->
unit) (my_key : key) my_implicit_iv =
let nonce, replay_id =
let b = Bytes.create (Packet.id_len + String.length my_implicit_iv) in
set_replay_id b 0;
Expand All @@ -1135,21 +1143,16 @@ let out ?add_timestamp prefix_len (ctx : keys) hmac_algorithm compress rng data
if compress then (* 0xFA is "no compression" *)
"\xfa" ^ data else data
in
let enc, tag =
authenticate_encrypt_tag ~key:my_key ~nonce ~adata:replay_id data
in
let b =
Bytes.create
(prefix_len + String.length replay_id + String.length tag
+ String.length enc)
(prefix_len + String.length replay_id + tag_size + String.length data)
in
Bytes.blit_string replay_id 0 b prefix_len (String.length replay_id);
Bytes.blit_string tag 0 b
(prefix_len + String.length replay_id)
(String.length tag);
Bytes.blit_string enc 0 b
(prefix_len + String.length replay_id + String.length tag)
(String.length enc);
set_replay_id b prefix_len;
authenticate_encrypt_into ~key:my_key ~nonce ~adata:replay_id data
~src_off:0 b
~dst_off:(prefix_len + String.length replay_id + tag_size)
~tag_off:(prefix_len + String.length replay_id)
(String.length data);
b
in
( { ctx with my_replay_id = Int32.succ ctx.my_replay_id },
Expand All @@ -1162,6 +1165,7 @@ let out ?add_timestamp prefix_len (ctx : keys) hmac_algorithm compress rng data
- timestamp only used in static key mode (32bit, seconds since unix epoch)
*)
let open Mirage_crypto in
let module H = (val Digestif.module_of_hash' hmac_algorithm) in
let hdr_len = 4 + if Option.is_some add_timestamp then 4 else 0 in
let data =
let unpad_len = hdr_len + Bool.to_int compress + String.length data in
Expand All @@ -1181,31 +1185,29 @@ let out ?add_timestamp prefix_len (ctx : keys) hmac_algorithm compress rng data
Bytes.fill b unpad_len pad_len (char_of_int pad_len);
Bytes.unsafe_to_string b
in
(* FIXME: rng_into *)
let iv = rng AES.CBC.block_size in
let enc = AES.CBC.encrypt ~key:my_key ~iv data in
let hmac =
let module H = (val Digestif.module_of_hash' hmac_algorithm) in
H.(to_raw_string (hmacv_string ~key:my_hmac [ iv; enc ]))
in
let b =
Bytes.create
(prefix_len + String.length hmac + String.length iv
+ String.length enc)
(prefix_len + H.digest_size + String.length iv + String.length data)
in
Bytes.blit_string hmac 0 b prefix_len (String.length hmac);
Bytes.blit_string iv 0 b
(prefix_len + String.length hmac)
(String.length iv);
Bytes.blit_string enc 0 b
(prefix_len + String.length hmac + String.length iv)
(String.length enc);
Bytes.blit_string iv 0 b (prefix_len + H.digest_size) (String.length iv);
AES.CBC.encrypt_into ~key:my_key ~iv data ~src_off:0 b
~dst_off:(prefix_len + H.digest_size + String.length iv)
(String.length data);
let hmac =
H.hmac_bytes ~key:my_hmac ~off:(prefix_len + H.digest_size) b
in
(* H.get_into_bytes hmac ~off:prefix_len b; *)
Bytes.blit_string (H.to_raw_string hmac) 0 b prefix_len H.digest_size;
b
| AES_GCM { my_key; my_implicit_iv; _ } ->
aead Mirage_crypto.AES.GCM.authenticate_encrypt_tag my_key
my_implicit_iv
aead Mirage_crypto.AES.GCM.tag_size
Mirage_crypto.AES.GCM.authenticate_encrypt_into my_key my_implicit_iv
| CHACHA20_POLY1305 { my_key; my_implicit_iv; _ } ->
aead Mirage_crypto.Chacha20.authenticate_encrypt_tag my_key
my_implicit_iv )
aead Mirage_crypto.Chacha20.tag_size
Mirage_crypto.Chacha20.authenticate_encrypt_into my_key my_implicit_iv
)

let data_out ?add_timestamp (ctx : keys) hmac_algorithm compress protocol rng
key data =
Expand Down Expand Up @@ -1370,23 +1372,23 @@ let incoming_data ?(add_timestamp = false) err (ctx : keys) hmac_algorithm
*)
let open Mirage_crypto in
let module H = (val Digestif.module_of_hash' hmac_algorithm) in
let hmac, data =
( String.sub data 0 H.digest_size,
String.sub data H.digest_size (String.length data - H.digest_size)
)
in
let computed_hmac =
H.(to_raw_string (hmac_string ~key:their_hmac data))
let hmac, off =
(H.of_raw_string (String.sub data 0 H.digest_size), H.digest_size)
in
let computed_hmac = H.(hmac_string ~off ~key:their_hmac data) in
let* () =
guard (String.equal hmac computed_hmac) (err hmac computed_hmac)
guard
(H.equal hmac computed_hmac)
(err (H.to_raw_string hmac) (H.to_raw_string computed_hmac))
in
let iv, data =
( String.sub data 0 AES.CBC.block_size,
String.sub data AES.CBC.block_size
(String.length data - AES.CBC.block_size) )
let iv, off =
(String.sub data off AES.CBC.block_size, off + AES.CBC.block_size)
in
let dec = AES.CBC.decrypt ~key:their_key ~iv data in
let l = String.length data - off in
let dec = Bytes.create l in
AES.CBC.decrypt_into ~key:their_key ~iv data ~src_off:off dec ~dst_off:0
l;
let dec = Bytes.unsafe_to_string dec in
(* dec is: uint32 replay packet id followed by (lzo-compressed) data and padding *)
let hdr_len = Packet.id_len + if add_timestamp then 4 else 0 in
let* () =
Expand All @@ -1398,54 +1400,57 @@ let incoming_data ?(add_timestamp = false) err (ctx : keys) hmac_algorithm
Log.debug (fun m ->
m "received replay packet id is %lu" (String.get_int32_be dec 0));
(* TODO validate ts if provided (avoid replay) *)
unpad AES.CBC.block_size
(String.sub dec hdr_len (String.length dec - hdr_len))
unpad AES.CBC.block_size dec hdr_len
| AES_GCM { their_key; their_implicit_iv; _ } ->
let tag_len = Mirage_crypto.AES.GCM.tag_size in
let* () =
guard
(String.length data >= Packet.id_len + tag_len)
(`Payload_too_short (Packet.id_len + tag_len, String.length data))
in
let replay_id, tag, payload =
let replay_id, tag_off, off =
( String.sub data 0 Packet.id_len,
String.sub data Packet.id_len tag_len,
String.sub data (Packet.id_len + tag_len)
(String.length data - Packet.id_len - tag_len) )
Packet.id_len,
Packet.id_len + tag_len )
in
let nonce = replay_id ^ their_implicit_iv in
let plain =
Mirage_crypto.AES.GCM.authenticate_decrypt_tag ~key:their_key ~nonce
~adata:replay_id ~tag payload
let plain = Bytes.create (String.length data - off) in
let valid =
Mirage_crypto.AES.GCM.authenticate_decrypt_into ~key:their_key ~nonce
~adata:replay_id data ~src_off:off ~tag_off plain ~dst_off:0
(String.length data - off)
in
(* TODO validate replay packet id and ordering *)
Log.debug (fun m ->
m "received replay packet id is %lu"
(String.get_int32_be replay_id 0));
Option.to_result ~none:(`Msg "AEAD decrypt failed") plain
if valid then Ok (Bytes.unsafe_to_string plain)
else Error (`Msg "AEAD decrypt failed")
| CHACHA20_POLY1305 { their_key; their_implicit_iv; _ } ->
let tag_len = Mirage_crypto.Chacha20.tag_size in
let* () =
guard
(String.length data >= Packet.id_len + tag_len)
(`Payload_too_short (Packet.id_len + tag_len, String.length data))
in
let replay_id, tag, payload =
let replay_id, tag_off, off =
( String.sub data 0 Packet.id_len,
String.sub data Packet.id_len tag_len,
String.sub data (Packet.id_len + tag_len)
(String.length data - Packet.id_len - tag_len) )
Packet.id_len,
Packet.id_len + tag_len )
in
let nonce = replay_id ^ their_implicit_iv in
let plain =
Mirage_crypto.Chacha20.authenticate_decrypt_tag ~key:their_key ~nonce
~adata:replay_id ~tag payload
let plain = Bytes.create (String.length data - off) in
let valid =
Mirage_crypto.Chacha20.authenticate_decrypt_into ~key:their_key ~nonce
~adata:replay_id data ~src_off:off ~tag_off plain ~dst_off:0
(String.length data - off)
in
(* TODO validate replay packet id and ordering *)
Log.debug (fun m ->
m "received replay packet id is %lu"
(String.get_int32_be replay_id 0));
Option.to_result ~none:(`Msg "AEAD decrypt failed") plain
if valid then Ok (Bytes.unsafe_to_string plain)
else Error (`Msg "AEAD decrypt failed")
in
let+ data' =
if compress then
Expand Down Expand Up @@ -1876,7 +1881,8 @@ let incoming state control_crypto buf =
if linger = "" then Ok (state, out, payloads, act_opt)
else multi linger (state, out, payloads, act_opt)
in
let r = multi (state.linger ^ buf) (state, [], [], None) in
let buf = if state.linger = "" then buf else state.linger ^ buf in
let r = multi buf (state, [], [], None) in
let+ s', out, payloads, act_opt = udp_ignore r in
Log.debug (fun m -> m "out state is %a" State.pp s');
Log.debug (fun m ->
Expand Down Expand Up @@ -2246,7 +2252,12 @@ let handle_static_client t s keys ev =
| Error `Tcp_partial ->
(* we don't need to check protocol as [`Tcp_partial] is only ever returned for tcp *)
Ok ({ t with linger }, acc)
| Ok (cs, linger) ->
| Ok (poff, plen) ->
let cs, linger =
( String.sub linger poff plen,
String.sub linger (poff + plen)
(String.length linger - poff - plen) )
in
let bad_mac computed rcv = `Bad_mac (t, computed, rcv, cs) in
let* d =
incoming_data ~add_timestamp bad_mac keys hmac_algorithm
Expand Down
17 changes: 10 additions & 7 deletions src/packet.ml
Original file line number Diff line number Diff line change
Expand Up @@ -174,18 +174,21 @@ let decode_protocol proto buf =
let* () = guard (String.length buf >= 2) `Tcp_partial in
let plen = String.get_uint16_be buf 0 in
let+ () = guard (String.length buf - 2 >= plen) `Tcp_partial in
( String.sub buf 2 plen,
String.sub buf (plen + 2) (String.length buf - plen - 2) )
| `Udp -> Ok (buf, "")
(2, plen)
| `Udp -> Ok (0, String.length buf)

let decode_key_op proto buf =
let open Result.Syntax in
let* buf, linger = decode_protocol proto buf in
let* () = guard (String.length buf >= 1) `Partial in
let opkey = String.get_uint8 buf 0 in
let* poff, plen = decode_protocol proto buf in
let* () = guard (plen >= 1) `Partial in
let opkey = String.get_uint8 buf poff in
let op, key = (opkey lsr 3, opkey land 0x07) in
let+ op = int_to_operation op in
(op, key, String.sub buf 1 (String.length buf - 1), linger)
let buf, linger =
( String.sub buf (poff + 1) (plen - 1),
String.sub buf (poff + plen) (String.length buf - poff - plen) )
in
(op, key, buf, linger)

let operation = function
| `Ack _ -> Ack
Expand Down
Loading