diff --git a/Cargo.lock b/Cargo.lock index c38745eed..c12e87bba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -32,18 +32,18 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" dependencies = [ "memchr", ] [[package]] name = "anstream" -version = "0.6.12" +version = "0.6.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96b09b5178381e0874812a9b157f7fe84982617e48f71f4e3235482775e5b540" +checksum = "d96bd03f33fe50a863e394ee9718a706f988b9079b20c3784fb726e7678b62fb" dependencies = [ "anstyle", "anstyle-parse", @@ -93,6 +93,12 @@ version = "1.0.81" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247" +[[package]] +name = "arc-swap" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b3d0060af21e8d11a926981cc00c6c1541aa91dd64b9f881985c3da1094425f" + [[package]] name = "argh" version = "0.1.12" @@ -112,7 +118,7 @@ dependencies = [ "argh_shared", "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.53", ] [[package]] @@ -171,9 +177,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "backtrace" -version = "0.3.69" +version = "0.3.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" dependencies = [ "addr2line", "cc", @@ -190,12 +196,6 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" -[[package]] -name = "base64" -version = "0.22.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" - [[package]] name = "base64ct" version = "1.6.0" @@ -229,7 +229,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.50", + "syn 2.0.53", ] [[package]] @@ -240,9 +240,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.2" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" [[package]] name = "block-buffer" @@ -255,9 +255,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.15.1" +version = "3.15.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c764d619ca78fccbf3069b37bd7af92577f044bb15236036662d79b6559f25b7" +checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa" [[package]] name = "bytecount" @@ -267,9 +267,9 @@ checksum = "e1e5f035d16fc623ae5f74981db80a439803888314e3a555fd6f04acd51a3205" [[package]] name = "bytes" -version = "1.5.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" +checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" dependencies = [ "serde", ] @@ -296,9 +296,9 @@ dependencies = [ [[package]] name = "cargo-platform" -version = "0.1.7" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "694c8807f2ae16faecc43dc17d74b3eb042482789fd0eb64b39a2e04e087053f" +checksum = "24b1f0365a6c6bb4020cd05806fd0d33c44d38046b8bd7f0e40814b9763cabfc" dependencies = [ "serde", ] @@ -326,10 +326,11 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.86" +version = "1.0.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f9fa1897e4325be0d68d48df6aa1a71ac2ed4d27723887e7754192705350730" +checksum = "8cd6604a82acf3039f1144f54b8eb34e91ffba622051189e71b781822d5ee1f5" dependencies = [ + "jobserver", "libc", ] @@ -390,7 +391,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.53", ] [[package]] @@ -446,9 +447,9 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.11" +version = "0.5.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "176dc175b78f56c0f321911d9c8eb2b77a78a4860b9c19db83835fea1a46649b" +checksum = "ab3db02a9c5b5121e1e42fbdb1aeb65f5e02624cc58c43f2884c6ccac0b82f95" dependencies = [ "crossbeam-utils", ] @@ -502,7 +503,7 @@ checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.53", ] [[package]] @@ -575,7 +576,7 @@ checksum = "487585f4d0c6655fe74905e2504d8ad6908e4db67f744eb140876906c2f3175d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.53", ] [[package]] @@ -628,8 +629,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff3c058b07bdb5414da10bc8a2489715e31b0c3f4274a213c1a23831e9d94e91" dependencies = [ "ahash", - "base64 0.21.7", - "bitflags 2.4.2", + "base64", + "bitflags 2.5.0", "crc32c", "everscale-crypto", "everscale-types-proc", @@ -650,7 +651,16 @@ checksum = "323d8b61c76be2c16eb2d72d007f1542fdeb3760fdf2e2cae219fc0da3db0c09" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.53", +] + +[[package]] +name = "exponential-backoff" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47f78d87d930eee4b5686a2ab032de499c72bd1e954b84262bb03492a0f932cd" +dependencies = [ + "rand", ] [[package]] @@ -661,9 +671,9 @@ checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" [[package]] name = "fiat-crypto" -version = "0.2.6" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1676f435fc1dadde4d03e43f5d62b259e1ce5f40bd4ffb21db2b42ebe59c1382" +checksum = "c007b1ae3abe1cb6f85a16305acd418b7ca6343b953633fee2b76d8f108b830f" [[package]] name = "futures-core" @@ -679,7 +689,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.53", ] [[package]] @@ -756,9 +766,9 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "hermit-abi" -version = "0.3.6" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd5256b483761cd23699d0da46cc6fd2ee3be420bbe6d020ae4a091e70b7e9fd" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" [[package]] name = "hex" @@ -778,11 +788,20 @@ version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" +[[package]] +name = "jobserver" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab46a6e9526ddef3ae7f787c06f0f2600639ba80ea3eade3d8e670a2230f51d6" +dependencies = [ + "libc", +] + [[package]] name = "js-sys" -version = "0.3.68" +version = "0.3.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "406cda4b368d531c842222cf9d2600a9a4acce8d29423695379c6868a143a9ee" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" dependencies = [ "wasm-bindgen", ] @@ -807,12 +826,12 @@ checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "libloading" -version = "0.8.1" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c571b676ddfc9a8c12f1f3d3085a7b163966a8fd8098a90640953ce5f6170161" +checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" dependencies = [ "cfg-if", - "windows-sys 0.48.0", + "windows-targets 0.52.4", ] [[package]] @@ -834,9 +853,9 @@ dependencies = [ [[package]] name = "libz-sys" -version = "1.1.15" +version = "1.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "037731f5d3aaa87a5675e895b63ddff1a87624bc29f77004ea829809654e48f6" +checksum = "5e143b5e666b2695d28f6bca6497720813f699c9602dd7f5cac91008b8ada7f9" dependencies = [ "cc", "pkg-config", @@ -861,9 +880,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.20" +version = "0.4.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" [[package]] name = "lz4-sys" @@ -907,9 +926,9 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.10" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" dependencies = [ "libc", "wasi", @@ -1067,15 +1086,15 @@ version = "3.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b8fcc794035347fb64beda2d3b462595dd2753e3f268d89c5aae77e8cf2c310" dependencies = [ - "base64 0.21.7", + "base64", "serde", ] [[package]] name = "pest" -version = "2.7.7" +version = "2.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "219c0dcc30b6a27553f9cc242972b67f75b60eb0db71f0b5462f38b058c41546" +checksum = "56f8023d0fb78c8e03784ea1c7f3fa36e68a723138990b8d5a47d916b651e7a8" dependencies = [ "memchr", "thiserror", @@ -1084,9 +1103,9 @@ dependencies = [ [[package]] name = "pest_derive" -version = "2.7.7" +version = "2.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22e1288dbd7786462961e69bfd4df7848c1e37e8b74303dbdab82c3a9cdd2809" +checksum = "b0d24f72393fd16ab6ac5738bc33cdb6a9aa73f8b902e8fe29cf4e67d7dd1026" dependencies = [ "pest", "pest_generator", @@ -1094,22 +1113,22 @@ dependencies = [ [[package]] name = "pest_generator" -version = "2.7.7" +version = "2.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1381c29a877c6d34b8c176e734f35d7f7f5b3adaefe940cb4d1bb7af94678e2e" +checksum = "fdc17e2a6c7d0a492f0158d7a4bd66cc17280308bbaff78d5bef566dca35ab80" dependencies = [ "pest", "pest_meta", "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.53", ] [[package]] name = "pest_meta" -version = "2.7.7" +version = "2.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0934d6907f148c22a3acbda520c7eed243ad7487a30f51f6ce52b58b7077a8a" +checksum = "934cd7631c050f4674352a6e835d5f6711ffbfb9345c2fc0107155ac495ae293" dependencies = [ "once_cell", "pest", @@ -1169,14 +1188,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a41cf62165e97c7f814d2221421dbb9afcbcdb0a88068e5ea206e19951c2cbb5" dependencies = [ "proc-macro2", - "syn 2.0.50", + "syn 2.0.53", ] [[package]] name = "proc-macro2" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" +checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" dependencies = [ "unicode-ident", ] @@ -1187,7 +1206,7 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57206b407293d2bcd3af849ce869d52068623f19e1b5ff8e8778e3309439682b" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", "memchr", "unicase", ] @@ -1299,7 +1318,7 @@ version = "11.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d86a7c4638d42c44551f4791a20e687dbb4c3de1f33c43dd71e355cd429def1" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", ] [[package]] @@ -1325,13 +1344,13 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.3" +version = "1.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.5", + "regex-automata 0.4.6", "regex-syntax 0.8.2", ] @@ -1346,9 +1365,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" dependencies = [ "aho-corasick", "memchr", @@ -1439,11 +1458,11 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.31" +version = "0.38.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949" +checksum = "65e04861e65f21776e67888bfbea442b3642beaa0138fdb1dd7a84a52dffdb89" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", "errno", "libc", "linux-raw-sys", @@ -1535,7 +1554,7 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.53", ] [[package]] @@ -1610,18 +1629,18 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.13.1" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "socket2" -version = "0.5.5" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" +checksum = "05ffd9c0a93b7543e062e759284fcf5f5e3b098501104bfbdde4d404db792871" dependencies = [ "libc", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -1671,9 +1690,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.50" +version = "2.0.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74f1bdc9872430ce9b75da68329d1c1746faf50ffac5f19e02b71e37ff881ffb" +checksum = "7383cd0e49fff4b6b90ca5670bfd3e9d6a733b3f90c686605aa7eec8c4996032" dependencies = [ "proc-macro2", "quote", @@ -1700,9 +1719,9 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" [[package]] name = "tempfile" -version = "3.10.0" +version = "3.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a365e8cd18e44762ef95d87f284f4b5cd04107fec2ff3052bd6a3e6069669e67" +checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" dependencies = [ "cfg-if", "fastrand", @@ -1727,7 +1746,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.53", ] [[package]] @@ -1819,7 +1838,7 @@ dependencies = [ "proc-macro2", "quote", "rustc-hash", - "syn 2.0.50", + "syn 2.0.53", "tl-scheme", ] @@ -1861,7 +1880,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.53", ] [[package]] @@ -1909,7 +1928,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.53", ] [[package]] @@ -2039,13 +2058,15 @@ version = "0.0.1" dependencies = [ "ahash", "anyhow", + "arc-swap", "argh", - "base64 0.22.0", + "base64", "bytes", "castaway", "dashmap", "ed25519", "everscale-crypto", + "exponential-backoff", "futures-util", "hex", "moka", @@ -2104,6 +2125,7 @@ dependencies = [ "humantime", "rand", "serde", + "thiserror", "tokio", ] @@ -2160,9 +2182,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "uuid" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a" +checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0" dependencies = [ "getrandom", ] @@ -2187,9 +2209,9 @@ checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] name = "walkdir" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" dependencies = [ "same-file", "winapi-util", @@ -2203,9 +2225,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.91" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1e124130aee3fb58c5bdd6b639a0509486b0338acaaae0c84a5124b0f588b7f" +checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -2213,24 +2235,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.91" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9e7e1900c352b609c8488ad12639a311045f40a35491fb69ba8c12f758af70b" +checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.53", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.91" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b30af9e2d358182b5c7449424f017eba305ed32a7010509ede96cdc4696c46ed" +checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2238,28 +2260,28 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.91" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "642f325be6301eb8107a83d12a8ac6c1e1c54345a7ef1a9261962dfefda09e66" +checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.53", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.91" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f186bd2dcf04330886ce82d6f33dd75a7bfcf69ecf5763b89fcde53b6ac9838" +checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" [[package]] name = "web-sys" -version = "0.3.68" +version = "0.3.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96565907687f7aceb35bc5fc03770a8a0471d82e479f25832f54a0e3f4b28446" +checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" dependencies = [ "js-sys", "wasm-bindgen", @@ -2323,7 +2345,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.0", + "windows-targets 0.52.4", ] [[package]] @@ -2343,17 +2365,17 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" dependencies = [ - "windows_aarch64_gnullvm 0.52.0", - "windows_aarch64_msvc 0.52.0", - "windows_i686_gnu 0.52.0", - "windows_i686_msvc 0.52.0", - "windows_x86_64_gnu 0.52.0", - "windows_x86_64_gnullvm 0.52.0", - "windows_x86_64_msvc 0.52.0", + "windows_aarch64_gnullvm 0.52.4", + "windows_aarch64_msvc 0.52.4", + "windows_i686_gnu 0.52.4", + "windows_i686_msvc 0.52.4", + "windows_x86_64_gnu 0.52.4", + "windows_x86_64_gnullvm 0.52.4", + "windows_x86_64_msvc 0.52.4", ] [[package]] @@ -2364,9 +2386,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" +checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" [[package]] name = "windows_aarch64_msvc" @@ -2376,9 +2398,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" +checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" [[package]] name = "windows_i686_gnu" @@ -2388,9 +2410,9 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" +checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" [[package]] name = "windows_i686_msvc" @@ -2400,9 +2422,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" +checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" [[package]] name = "windows_x86_64_gnu" @@ -2412,9 +2434,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" +checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" [[package]] name = "windows_x86_64_gnullvm" @@ -2424,9 +2446,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" +checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" [[package]] name = "windows_x86_64_msvc" @@ -2436,9 +2458,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" +checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" [[package]] name = "x509-parser" @@ -2483,7 +2505,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.53", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 188941283..50ee60439 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -72,7 +72,6 @@ ref_option_ref = "warn" rest_pat_in_fully_bound_structs = "warn" same_functions_in_if_condition = "warn" semicolon_if_nothing_returned = "warn" -single_match_else = "warn" string_add_assign = "warn" string_add = "warn" string_lit_as_bytes = "warn" diff --git a/network/Cargo.toml b/network/Cargo.toml index 4c7cb72c0..e72ae9221 100644 --- a/network/Cargo.toml +++ b/network/Cargo.toml @@ -13,12 +13,14 @@ path = "examples/network_node.rs" # crates.io deps ahash = "0.8" anyhow = "1.0" -base64 = "0.22" +arc-swap = "1.6" +base64 = "0.21" bytes = { version = "1.0", features = ["serde"] } castaway = "0.2" dashmap = "5.4" ed25519 = { version = "2.0", features = ["alloc", "pkcs8"] } everscale-crypto = { version = "0.2", features = ["tl-proto"] } +exponential-backoff = "1" futures-util = { version = "0.3", features = ["sink"] } hex = "0.4" moka = { version = "0.12", features = ["sync"] } diff --git a/network/examples/network_node.rs b/network/examples/network_node.rs index aa2a3523f..ee6c6c51c 100644 --- a/network/examples/network_node.rs +++ b/network/examples/network_node.rs @@ -224,11 +224,11 @@ impl Node { fn new(key: ed25519::SecretKey, address: Address, config: NodeConfig) -> Result { let keypair = everscale_crypto::ed25519::KeyPair::from(&key); - let (dht_client, dht) = DhtService::builder(keypair.public_key.into()) + let (dht_tasks, dht_service) = DhtService::builder(keypair.public_key.into()) .with_config(config.dht) .build(); - let router = Router::builder().route(dht).build(); + let router = Router::builder().route(dht_service.clone()).build(); let network = Network::builder() .with_config(config.network) @@ -236,7 +236,8 @@ impl Node { .with_service_name("test-service") .build(address, router)?; - let dht = dht_client.build(network.clone()); + dht_tasks.spawn(&network); + let dht = dht_service.make_client(network.clone()); Ok(Self { network, dht }) } diff --git a/network/src/dht/mod.rs b/network/src/dht/mod.rs index ceba799bb..56d2396cf 100644 --- a/network/src/dht/mod.rs +++ b/network/src/dht/mod.rs @@ -12,51 +12,27 @@ use tokio::task::JoinHandle; use tycho_util::realloc_box_enum; use tycho_util::time::{now_sec, shifted_interval}; -use self::query::{Query, StoreValue}; -use self::routing::{RoutingTable, RoutingTableSource}; +use self::query::{Query, QueryCache, StoreValue}; +use self::routing::HandlesRoutingTable; use self::storage::Storage; use crate::network::{Network, WeakNetwork}; use crate::proto::dht::{ rpc, NodeInfoResponse, NodeResponse, PeerValue, PeerValueKey, PeerValueKeyName, PeerValueKeyRef, PeerValueRef, Value, ValueRef, ValueResponseRaw, }; -use crate::types::{ - Address, PeerAffinity, PeerId, PeerInfo, Request, Response, Service, ServiceRequest, -}; +use crate::types::{PeerId, PeerInfo, Request, Response, Service, ServiceRequest}; use crate::util::{NetworkExt, Routable}; pub use self::config::DhtConfig; +pub use self::peer_resolver::{PeerResolver, PeerResolverBuilder, PeerResolverHandle}; pub use self::storage::{OverlayValueMerger, StorageError}; mod config; +mod peer_resolver; mod query; mod routing; mod storage; -pub struct DhtClientBuilder { - inner: Arc, - disable_background_tasks: bool, -} - -impl DhtClientBuilder { - pub fn disable_background_tasks(mut self) -> Self { - self.disable_background_tasks = true; - self - } - - pub fn build(self, network: Network) -> DhtClient { - if !self.disable_background_tasks { - self.inner - .start_background_tasks(Network::downgrade(&network)); - } - - DhtClient { - inner: self.inner, - network, - } - } -} - #[derive(Clone)] pub struct DhtClient { inner: Arc, @@ -64,13 +40,18 @@ pub struct DhtClient { } impl DhtClient { + #[inline] pub fn network(&self) -> &Network { &self.network } + #[inline] + pub fn service(&self) -> &DhtService { + DhtService::wrap(&self.inner) + } + pub fn add_peer(&self, peer: Arc) -> Result { - self.inner - .add_peer_info(&self.network, peer, RoutingTableSource::Trusted) + self.inner.add_peer_info(&self.network, peer) } pub async fn get_node_info(&self, peer_id: &PeerId) -> Result { @@ -243,6 +224,17 @@ impl<'a> std::ops::DerefMut for DhtQueryWithDataBuilder<'a> { } } +pub struct DhtServiceBackgroundTasks { + inner: Arc, +} + +impl DhtServiceBackgroundTasks { + pub fn spawn(self, network: &Network) { + self.inner + .start_background_tasks(Network::downgrade(network)); + } +} + pub struct DhtServiceBuilder { local_id: PeerId, config: Option, @@ -260,7 +252,7 @@ impl DhtServiceBuilder { self } - pub fn build(self) -> (DhtClientBuilder, DhtService) { + pub fn build(self) -> (DhtServiceBackgroundTasks, DhtService) { let config = self.config.unwrap_or_default(); let storage = { @@ -283,26 +275,33 @@ impl DhtServiceBuilder { let inner = Arc::new(DhtInner { local_id: self.local_id, - routing_table: Mutex::new(RoutingTable::new(self.local_id)), + routing_table: Mutex::new(HandlesRoutingTable::new(self.local_id)), storage, local_peer_info: Mutex::new(None), config, announced_peers, + find_value_queries: Default::default(), }); - let client_builder = DhtClientBuilder { + let background_tasks = DhtServiceBackgroundTasks { inner: inner.clone(), - disable_background_tasks: false, }; - (client_builder, DhtService(inner)) + (background_tasks, DhtService(inner)) } } #[derive(Clone)] +#[repr(transparent)] pub struct DhtService(Arc); impl DhtService { + #[inline] + fn wrap(inner: &Arc) -> &Self { + // SAFETY: `DhtService` has the same memory layout as `Arc`. + unsafe { &*(inner as *const Arc).cast::() } + } + pub fn builder(local_id: PeerId) -> DhtServiceBuilder { DhtServiceBuilder { local_id, @@ -310,6 +309,17 @@ impl DhtService { overlay_merger: None, } } + + pub fn make_client(&self, network: Network) -> DhtClient { + DhtClient { + inner: self.0.clone(), + network, + } + } + + pub fn make_peer_resolver(&self) -> PeerResolverBuilder { + PeerResolver::builder(self.clone()) + } } impl Service for DhtService { @@ -415,11 +425,12 @@ impl Routable for DhtService { struct DhtInner { local_id: PeerId, - routing_table: Mutex, + routing_table: Mutex, storage: Storage, local_peer_info: Mutex>, config: DhtConfig, announced_peers: broadcast::Sender>, + find_value_queries: QueryCache>>, } impl DhtInner { @@ -469,8 +480,7 @@ impl DhtInner { this.refresh_local_peer_info(&network); } Action::AnnounceLocalPeerInfo => { - // Always refresh peer info before announcing - this.refresh_local_peer_info(&network); + // Peer info is always refreshed before announcing refresh_peer_info_interval.reset(); if let Err(e) = this.announce_local_peer_info(&network).await { @@ -492,9 +502,7 @@ impl DhtInner { } Action::AddPeer(peer_info) => { tracing::info!(peer_id = %peer_info.id, "received peer info"); - if let Err(e) = - this.add_peer_info(&network, peer_info, RoutingTableSource::Untrusted) - { + if let Err(e) = this.add_peer_info(&network, peer_info) { tracing::error!("failed to add peer to the routing table: {e:?}"); } } @@ -511,12 +519,18 @@ impl DhtInner { #[tracing::instrument(level = "debug", skip_all, fields(local_id = %self.local_id))] async fn announce_local_peer_info(&self, network: &Network) -> Result<()> { - let data = tl_proto::serialize(&[network.local_addr().into()] as &[Address]); + let now = now_sec(); + let data = { + let peer_info = self.make_local_peer_info(network, now); + let data = tl_proto::serialize(&peer_info); + *self.local_peer_info.lock().unwrap() = Some(peer_info); + data + }; let mut value = self.make_unsigned_peer_value( PeerValueKeyName::NodeInfo, &data, - now_sec() + self.config.max_peer_info_ttl.as_secs() as u32, + now + self.config.max_peer_info_ttl.as_secs() as u32, ); let signature = network.sign_tl(&value); value.signature = &signature; @@ -603,28 +617,28 @@ impl DhtInner { peer.clone(), self.config.max_k, &self.config.max_peer_info_ttl, - RoutingTableSource::Trusted, + |peer_info| network.known_peers().insert(peer_info, false).ok(), ); - if is_new { - network.known_peers().insert(peer, PeerAffinity::Allowed); - count += 1; - } + count += is_new as usize; } tracing::debug!(count, "found new peers"); } async fn find_value(&self, network: &Network, key_hash: &[u8; 32]) -> Option> { - // TODO: deduplicate shared futures - let query = Query::new( - network.clone(), - &self.routing_table.lock().unwrap(), - key_hash, - self.config.max_k, - ); + self.find_value_queries + .run(key_hash, || { + let query = Query::new( + network.clone(), + &self.routing_table.lock().unwrap(), + key_hash, + self.config.max_k, + ); - // NOTE: expression is intentionally split to drop the routing table guard - query.find_value().await + // NOTE: expression is intentionally split to drop the routing table guard + Box::pin(query.find_value()) + }) + .await } async fn store_value( @@ -659,12 +673,7 @@ impl DhtInner { Ok(()) } - fn add_peer_info( - &self, - network: &Network, - peer_info: Arc, - source: RoutingTableSource, - ) -> Result { + fn add_peer_info(&self, network: &Network, peer_info: Arc) -> Result { anyhow::ensure!(peer_info.is_valid(now_sec()), "invalid peer info"); if peer_info.id == self.local_id { @@ -672,18 +681,12 @@ impl DhtInner { } let mut routing_table = self.routing_table.lock().unwrap(); - let is_new = routing_table.add( + Ok(routing_table.add( peer_info.clone(), self.config.max_k, &self.config.max_peer_info_ttl, - source, - ); - if is_new { - network - .known_peers() - .insert(peer_info, PeerAffinity::Allowed); - } - Ok(is_new) + |peer_info| network.known_peers().insert(peer_info, false).ok(), + )) } fn make_unsigned_peer_value<'a>( diff --git a/network/src/dht/peer_resolver.rs b/network/src/dht/peer_resolver.rs new file mode 100644 index 000000000..d7d927dc5 --- /dev/null +++ b/network/src/dht/peer_resolver.rs @@ -0,0 +1,462 @@ +use std::mem::ManuallyDrop; +use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::{Arc, Mutex, Weak}; +use std::time::Duration; + +use exponential_backoff::Backoff; +use tokio::sync::{Notify, Semaphore}; +use tycho_util::futures::JoinTask; +use tycho_util::time::now_sec; +use tycho_util::FastDashMap; + +use crate::dht::DhtService; +use crate::network::{KnownPeerHandle, KnownPeersError, Network, PeerBannedError, WeakNetwork}; +use crate::proto::dht; +use crate::types::{PeerId, PeerInfo}; + +pub struct PeerResolverBuilder { + inner: PeerResolverConfig, + dht_service: DhtService, +} + +impl PeerResolverBuilder { + /// Minimal time-to-live for the resolved peer info. + /// + /// Default: 600 seconds. + pub fn with_min_ttl_sec(mut self, ttl_sec: u32) -> Self { + self.inner.min_ttl_sec = ttl_sec; + self + } + + /// Time before the expiration when the peer info should be updated. + /// + /// Default: 1200 seconds. + pub fn with_update_before_sec(mut self, update_before_sec: u32) -> Self { + self.inner.update_before_sec = update_before_sec; + self + } + + /// Number of fast retries before switching to the stale retry interval. + /// + /// Default: 10. + pub fn with_fast_retry_count(mut self, fast_retry_count: u32) -> Self { + self.inner.fast_retry_count = fast_retry_count; + self + } + + /// Minimal interval between the fast retries. + /// + /// Default: 1 second. + pub fn with_min_retry_interval(mut self, min_retry_interval: Duration) -> Self { + self.inner.min_retry_interval = min_retry_interval; + self + } + + /// Maximal interval between the fast retries. + /// + /// Default: 120 seconds. + pub fn with_max_retry_interval(mut self, max_retry_interval: Duration) -> Self { + self.inner.max_retry_interval = max_retry_interval; + self + } + + /// Interval between the stale retries. + /// + /// Default: 600 seconds. + pub fn with_stale_retry_interval(mut self, stale_retry_interval: Duration) -> Self { + self.inner.stale_retry_interval = stale_retry_interval; + self + } + + pub fn build(self, network: &Network) -> PeerResolver { + let semaphore = Semaphore::new(self.inner.max_parallel_resolve_requests); + + PeerResolver { + inner: Arc::new(PeerResolverInner { + weak_network: Network::downgrade(network), + dht_service: self.dht_service, + config: Default::default(), + tasks: Default::default(), + semaphore, + }), + } + } +} + +struct PeerResolverConfig { + max_parallel_resolve_requests: usize, + min_ttl_sec: u32, + update_before_sec: u32, + fast_retry_count: u32, + min_retry_interval: Duration, + max_retry_interval: Duration, + stale_retry_interval: Duration, +} + +impl Default for PeerResolverConfig { + fn default() -> Self { + Self { + max_parallel_resolve_requests: 100, + min_ttl_sec: 600, + update_before_sec: 1200, + fast_retry_count: 10, + min_retry_interval: Duration::from_secs(1), + max_retry_interval: Duration::from_secs(120), + stale_retry_interval: Duration::from_secs(600), + } + } +} + +#[derive(Clone)] +pub struct PeerResolver { + inner: Arc, +} + +impl PeerResolver { + pub(crate) fn builder(dht_service: DhtService) -> PeerResolverBuilder { + PeerResolverBuilder { + inner: Default::default(), + dht_service, + } + } + + // TODO: Use affinity flag to increase the handle affinity. + pub fn insert(&self, peer_id: &PeerId, _with_affinity: bool) -> PeerResolverHandle { + use dashmap::mapref::entry::Entry; + + match self.inner.tasks.entry(*peer_id) { + Entry::Vacant(entry) => { + let handle = self.inner.make_resolver_handle(peer_id); + entry.insert(Arc::downgrade(&handle.inner)); + handle + } + Entry::Occupied(mut entry) => match entry.get().upgrade() { + Some(inner) => PeerResolverHandle { + inner: ManuallyDrop::new(inner), + }, + None => { + let handle = self.inner.make_resolver_handle(peer_id); + entry.insert(Arc::downgrade(&handle.inner)); + handle + } + }, + } + } +} + +struct PeerResolverInner { + weak_network: WeakNetwork, + dht_service: DhtService, + config: PeerResolverConfig, + tasks: FastDashMap>, + semaphore: Semaphore, +} + +impl PeerResolverInner { + fn make_resolver_handle(self: &Arc, peer_id: &PeerId) -> PeerResolverHandle { + let handle = match self.weak_network.upgrade() { + Some(handle) => handle.known_peers().make_handle(peer_id, false), + None => { + return PeerResolverHandle::new_noop(peer_id); + } + }; + let next_update_at = handle + .as_ref() + .map(|handle| self.compute_update_at(&handle.peer_info())); + + let data = Arc::new(PeerResolverHandleData::new(peer_id, handle)); + + PeerResolverHandle::new( + JoinTask::new(self.clone().run_task(data.clone(), next_update_at)), + data, + self, + ) + } + + async fn run_task( + self: Arc, + data: Arc, + mut next_update_at: Option, + ) { + tracing::trace!(peer_id = %data.peer_id, "peer resolver task started"); + + // TODO: Select between the loop body and `KnownPeers` update event. + loop { + // Wait if needed. + if let Some(update_at) = next_update_at { + let update_at = std::time::UNIX_EPOCH + Duration::from_secs(update_at as u64); + let now = std::time::SystemTime::now(); + if let Ok(remaining) = update_at.duration_since(now) { + tokio::time::sleep(remaining).await; + } + } + + // Start resolving peer. + match self.resolve_peer(&data).await { + Some((network, peer_info)) => { + let mut handle = data.handle.lock().unwrap(); + + let peer_info_guard; + let peer_info = match &*handle { + // TODO: Force write into known peers to keep the handle in it? + Some(handle) => match handle.update_peer_info(&peer_info) { + Ok(()) => peer_info.as_ref(), + Err(KnownPeersError::OutdatedInfo) => { + peer_info_guard = handle.peer_info(); + peer_info_guard.as_ref() + } + // TODO: Allow resuming task after ban? + Err(KnownPeersError::PeerBanned(PeerBannedError)) => break, + }, + None => match network + .known_peers() + .insert_allow_outdated(peer_info, false) + { + Ok(new_handle) => { + peer_info_guard = handle.insert(new_handle).peer_info(); + data.mark_resolved(); + peer_info_guard.as_ref() + } + // TODO: Allow resuming task after ban? + Err(PeerBannedError) => break, + }, + }; + + next_update_at = Some(self.compute_update_at(peer_info)); + } + None => break, + } + } + + tracing::trace!(peer_id = %data.peer_id, "peer resolver task finished"); + } + + /// Returns a verified peer info with the strong reference to the network. + /// Or `None` if network no longer exists. + async fn resolve_peer( + &self, + data: &PeerResolverHandleData, + ) -> Option<(Network, Arc)> { + struct Iter<'a> { + backoff: Option>, + data: &'a PeerResolverHandleData, + stale_retry_interval: &'a Duration, + } + + impl Iterator for Iter<'_> { + type Item = Duration; + + fn next(&mut self) -> Option { + Some(loop { + match self.backoff.as_mut() { + // Get next duration from the backoff iterator. + Some(backoff) => match backoff.next() { + // Use it for the first attempts. + Some(duration) => break duration, + // Set `is_stale` flag on last attempt and continue wih only + // the `stale_retry_interval` for all subsequent iterations. + None => { + self.data.set_stale(true); + self.backoff = None; + } + }, + // Use `stale_retry_interval` after the max retry count is reached. + None => break *self.stale_retry_interval, + } + }) + } + } + + let backoff = Backoff::new( + self.config.fast_retry_count, + self.config.min_retry_interval, + Some(self.config.max_retry_interval), + ); + let mut iter = Iter { + backoff: Some(backoff.iter()), + data, + stale_retry_interval: &self.config.stale_retry_interval, + }; + + // "Fast" path + let mut attempts = 0usize; + loop { + attempts += 1; + let is_stale = attempts > self.config.fast_retry_count as usize; + + // NOTE: Acquire network ref only during the operation. + { + let network = self.weak_network.upgrade()?; + let dht_client = self.dht_service.make_client(network.clone()); + + let res = { + let _permit = self.semaphore.acquire().await.unwrap(); + dht_client + .entry(dht::PeerValueKeyName::NodeInfo) + .find_value::(&data.peer_id) + .await + }; + + let now = now_sec(); + match res { + // TODO: Should we move signature check into the `spawn_blocking`? + Ok(peer_info) if peer_info.id == data.peer_id && peer_info.is_valid(now) => { + return Some((network, Arc::new(peer_info))); + } + Ok(_) => { + tracing::trace!( + peer_id = %data.peer_id, + attempts, + is_stale, + "received an invalid peer info", + ); + } + Err(e) => { + tracing::trace!( + peer_id = %data.peer_id, + attempts, + is_stale, + "failed to resolve a peer info: {e:?}", + ); + } + } + } + + let interval = iter.next().expect("retries iterator must be infinite"); + tokio::time::sleep(interval).await; + } + } + + fn compute_update_at(&self, peer_info: &PeerInfo) -> u32 { + let real_ttl = peer_info + .expires_at + .saturating_sub(self.config.update_before_sec) + .saturating_sub(peer_info.created_at); + + let adjusted_ttl = std::cmp::max(real_ttl, self.config.min_ttl_sec); + peer_info.created_at.saturating_add(adjusted_ttl) + } +} + +#[derive(Clone)] +#[repr(transparent)] +pub struct PeerResolverHandle { + inner: ManuallyDrop>, +} + +impl PeerResolverHandle { + fn new( + task: JoinTask<()>, + data: Arc, + resolver: &Arc, + ) -> Self { + Self { + inner: ManuallyDrop::new(Arc::new(PeerResolverHandleInner { + _task: Some(task), + data, + resolver: Arc::downgrade(resolver), + })), + } + } + + pub fn new_noop(peer_id: &PeerId) -> Self { + Self { + inner: ManuallyDrop::new(Arc::new(PeerResolverHandleInner { + _task: None, + data: Arc::new(PeerResolverHandleData::new(peer_id, None)), + resolver: Weak::new(), + })), + } + } + + pub fn load_handle(&self) -> Option { + self.inner.data.handle.lock().unwrap().clone() + } + + pub fn is_stale(&self) -> bool { + self.inner.data.is_stale() + } + + pub fn is_resolved(&self) -> bool { + self.inner.data.is_resolved() + } + + pub async fn wait_resolved(&self) -> KnownPeerHandle { + loop { + let resolved = self.inner.data.notify_resolved.notified(); + if let Some(load_handle) = self.load_handle() { + break load_handle; + } + resolved.await; + } + } +} + +impl Drop for PeerResolverHandle { + fn drop(&mut self) { + // SAFETY: inner value is dropped only once + let inner = unsafe { ManuallyDrop::take(&mut self.inner) }; + + // Remove this entry from the resolver if it was the last strong reference. + if let Some(inner) = Arc::into_inner(inner) { + // NOTE: At this point an `Arc` was dropped, so the `Weak` in the resolver + // addresses only the remaining references. + + if let Some(resolver) = inner.resolver.upgrade() { + resolver + .tasks + .remove_if(&inner.data.peer_id, |_, value| value.strong_count() == 0); + } + } + } +} + +struct PeerResolverHandleInner { + _task: Option>, + data: Arc, + resolver: Weak, +} + +struct PeerResolverHandleData { + peer_id: PeerId, + handle: Mutex>, + flags: AtomicU32, + notify_resolved: Notify, +} + +impl PeerResolverHandleData { + fn new(peer_id: &PeerId, handle: Option) -> Self { + let flags = AtomicU32::new(if handle.is_some() { RESOLVED_FLAG } else { 0 }); + + Self { + peer_id: *peer_id, + handle: Mutex::new(handle), + flags, + notify_resolved: Notify::new(), + } + } + + fn mark_resolved(&self) { + self.flags.fetch_or(RESOLVED_FLAG, Ordering::Release); + self.notify_resolved.notify_waiters(); + } + + fn is_resolved(&self) -> bool { + self.flags.load(Ordering::Acquire) & RESOLVED_FLAG != 0 + } + + fn set_stale(&self, stale: bool) { + if stale { + self.flags.fetch_or(STALE_FLAG, Ordering::Release); + } else { + self.flags.fetch_and(!STALE_FLAG, Ordering::Release); + } + } + + fn is_stale(&self) -> bool { + self.flags.load(Ordering::Acquire) & STALE_FLAG != 0 + } +} + +const STALE_FLAG: u32 = 0b1; +const RESOLVED_FLAG: u32 = 0b10; diff --git a/network/src/dht/query.rs b/network/src/dht/query.rs index 5355ba6d1..c5c01bf00 100644 --- a/network/src/dht/query.rs +++ b/network/src/dht/query.rs @@ -8,36 +8,122 @@ use bytes::Bytes; use futures_util::stream::FuturesUnordered; use futures_util::{Future, StreamExt}; use tokio::sync::Semaphore; +use tycho_util::futures::{JoinTask, Shared, WeakShared}; use tycho_util::time::now_sec; -use tycho_util::{FastHashMap, FastHashSet}; +use tycho_util::{FastDashMap, FastHashMap, FastHashSet}; -use crate::dht::routing::{RoutingTable, RoutingTableSource}; +use crate::dht::routing::{HandlesRoutingTable, SimpleRoutingTable}; use crate::network::Network; use crate::proto::dht::{rpc, NodeResponse, Value, ValueRef, ValueResponse}; use crate::types::{PeerId, PeerInfo, Request}; use crate::util::NetworkExt; +pub struct QueryCache { + cache: FastDashMap<[u8; 32], WeakSpawnedFut>, +} + +impl QueryCache { + pub async fn run(&self, target_id: &[u8; 32], f: F) -> R + where + R: Clone + Send + 'static, + F: FnOnce() -> Fut, + Fut: Future + Send + 'static, + { + use dashmap::mapref::entry::Entry; + + let fut = match self.cache.entry(*target_id) { + Entry::Vacant(entry) => { + let fut = Shared::new(JoinTask::new(f())); + if let Some(weak) = fut.downgrade() { + entry.insert(weak); + } + fut + } + Entry::Occupied(mut entry) => { + if let Some(fut) = entry.get().upgrade() { + fut + } else { + let fut = Shared::new(JoinTask::new(f())); + match fut.downgrade() { + Some(weak) => entry.insert(weak), + None => entry.remove(), + }; + fut + } + } + }; + + fn on_drop(_key: &[u8; 32], value: &WeakSpawnedFut) -> bool { + value.strong_count() == 0 + } + + let (output, is_last) = { + struct Guard<'a, R> { + target_id: &'a [u8; 32], + cache: &'a FastDashMap<[u8; 32], WeakSpawnedFut>, + fut: Option>>, + } + + impl Drop for Guard<'_, R> { + fn drop(&mut self) { + // Remove value from cache if we consumed the last future instance + if self.fut.take().map(Shared::consume).unwrap_or_default() { + self.cache.remove_if(self.target_id, on_drop); + } + } + } + + // Wrap future into guard to remove it from cache event it was cancelled + let mut guard = Guard { + target_id, + cache: &self.cache, + fut: None, + }; + let fut = guard.fut.insert(fut); + + // Await future. + // If `Shared` future is not polled to `Complete` state, + // the guard will try to consume it and remove from cache + // if it was the last instance. + fut.await + }; + + // TODO: add ttl and force others to make a request for a fresh data + if is_last { + // Remove value from cache if we consumed the last future instance + self.cache.remove_if(target_id, on_drop); + } + + output + } +} + +impl Default for QueryCache { + fn default() -> Self { + Self { + cache: Default::default(), + } + } +} + +type WeakSpawnedFut = WeakShared>; + pub struct Query { network: Network, - candidates: RoutingTable, + candidates: SimpleRoutingTable, max_k: usize, } impl Query { pub fn new( network: Network, - routing_table: &RoutingTable, + routing_table: &HandlesRoutingTable, target_id: &[u8; 32], max_k: usize, ) -> Self { - let mut candidates = RoutingTable::new(PeerId(*target_id)); + let mut candidates = SimpleRoutingTable::new(PeerId(*target_id)); routing_table.visit_closest(target_id, max_k, |node| { - candidates.add( - node.clone(), - max_k, - &Duration::MAX, - RoutingTableSource::Trusted, - ); + candidates.add(node.load_peer_info(), max_k, &Duration::MAX, Some); }); Self { @@ -217,8 +303,7 @@ impl Query { // Insert a new entry if visited.insert(node.id) { - self.candidates - .add(node, max_k, &Duration::MAX, RoutingTableSource::Trusted); + self.candidates.add(node, max_k, &Duration::MAX, Some); has_new = true; } } @@ -244,8 +329,7 @@ impl Query { // Insert a new entry hash_map::Entry::Vacant(entry) => { let node = entry.insert(node).clone(); - self.candidates - .add(node, max_k, &Duration::MAX, RoutingTableSource::Trusted); + self.candidates.add(node, max_k, &Duration::MAX, Some); has_new = true; } // Try to replace an old entry @@ -299,7 +383,7 @@ pub struct StoreValue { impl StoreValue<()> { pub fn new( network: Network, - routing_table: &RoutingTable, + routing_table: &HandlesRoutingTable, value: ValueRef<'_>, max_k: usize, local_peer_info: Option<&PeerInfo>, @@ -321,7 +405,7 @@ impl StoreValue<()> { routing_table.visit_closest(&key_hash, max_k, |node| { futures.push(Self::visit( network.clone(), - node.clone(), + node.load_peer_info(), request_body.clone(), semaphore.clone(), )); diff --git a/network/src/dht/routing.rs b/network/src/dht/routing.rs index 9278b6db5..d78c66b06 100644 --- a/network/src/dht/routing.rs +++ b/network/src/dht/routing.rs @@ -5,20 +5,18 @@ use std::time::{Duration, Instant}; use tycho_util::time::now_sec; use crate::dht::{xor_distance, MAX_XOR_DISTANCE}; +use crate::network::KnownPeerHandle; use crate::types::{PeerId, PeerInfo}; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(crate) enum RoutingTableSource { - Untrusted, - Trusted, -} +pub(crate) type SimpleRoutingTable = RoutingTable>; +pub(crate) type HandlesRoutingTable = RoutingTable; -pub(crate) struct RoutingTable { +pub(crate) struct RoutingTable { pub local_id: PeerId, - pub buckets: BTreeMap, + pub buckets: BTreeMap>, } -impl RoutingTable { +impl RoutingTable { pub fn new(local_id: PeerId) -> Self { Self { local_id, @@ -35,14 +33,13 @@ impl RoutingTable { pub fn len(&self) -> usize { self.buckets.values().map(|bucket| bucket.nodes.len()).sum() } +} - pub fn add( - &mut self, - peer: Arc, - max_k: usize, - node_ttl: &Duration, - source: RoutingTableSource, - ) -> bool { +impl RoutingTable { + pub fn add(&mut self, peer: Arc, max_k: usize, node_ttl: &Duration, f: F) -> bool + where + F: FnOnce(Arc) -> Option, + { let distance = xor_distance(&self.local_id, &peer.id); if distance == 0 { return false; @@ -51,7 +48,7 @@ impl RoutingTable { self.buckets .entry(distance) .or_insert_with(|| Bucket::with_capacity(max_k)) - .insert(peer, max_k, node_ttl, source) + .insert(peer, max_k, node_ttl, f) } pub fn closest(&self, key: &[u8; 32], count: usize) -> Vec> { @@ -72,7 +69,7 @@ impl RoutingTable { if let Some(bucket) = self.buckets.get(&i) { for node in bucket.nodes.iter().take(remaining) { - result.push(node.data.clone()); + result.push(node.data.load_peer_info()); } } } @@ -82,7 +79,7 @@ impl RoutingTable { pub fn visit_closest(&self, key: &[u8; 32], count: usize, mut f: F) where - F: FnMut(&Arc), + F: FnMut(&T), { if count == 0 { return; @@ -109,53 +106,20 @@ impl RoutingTable { } } -pub(crate) struct Bucket { - nodes: VecDeque, +pub(crate) struct Bucket { + nodes: VecDeque>, } -impl Bucket { +impl Bucket { fn with_capacity(capacity: usize) -> Self { Self { nodes: VecDeque::with_capacity(capacity), } } - fn insert( - &mut self, - node: Arc, - max_k: usize, - timeout: &Duration, - source: RoutingTableSource, - ) -> bool { - if let Some(index) = self - .nodes - .iter_mut() - .position(|item| item.data.id == node.id) - { - if source == RoutingTableSource::Untrusted { - let slot = &mut self.nodes[index]; - // Do nothing if node info was not updated (by created_at field) - if node.created_at <= slot.data.created_at { - return false; - } - } - - self.nodes.remove(index); - } else if self.nodes.len() >= max_k { - if matches!(self.nodes.front(), Some(node) if node.is_expired(now_sec(), timeout)) { - self.nodes.pop_front(); - } else { - return false; - } - } - - self.nodes.push_back(Node::new(node)); - true - } - pub fn retain_nodes(&mut self, f: F) where - F: FnMut(&Node) -> bool, + F: FnMut(&Node) -> bool, { self.nodes.retain(f); } @@ -165,21 +129,107 @@ impl Bucket { } } -pub(crate) struct Node { - pub data: Arc, +impl Bucket { + fn insert( + &mut self, + peer_info: Arc, + max_k: usize, + timeout: &Duration, + f: F, + ) -> bool + where + F: FnOnce(Arc) -> Option, + { + let data = 'data: { + if let Some(index) = self + .nodes + .iter_mut() + .position(|item| item.data.as_peer_info().id == peer_info.id) + { + if let Some(data) = f(peer_info) { + // Found node info with the same id, update it + self.nodes.remove(index); + break 'data data; + } + } else if self.nodes.len() >= max_k { + if matches!(self.nodes.front(), Some(node) if node.is_expired(now_sec(), timeout)) { + if let Some(data) = f(peer_info) { + // Found an expired node, replace it + self.nodes.pop_front(); + break 'data data; + } + } + } else if let Some(data) = f(peer_info) { + // Found an empty slot, insert the new node + break 'data data; + } + + // No action was taken + return false; + }; + + self.nodes.push_back(Node::new(data)); + true + } +} + +pub(crate) struct Node { + pub data: T, pub last_updated_at: Instant, } -impl Node { - fn new(data: Arc) -> Self { +impl Node { + fn new(data: T) -> Self { Self { data, last_updated_at: Instant::now(), } } +} +impl Node { pub fn is_expired(&self, at: u32, timeout: &Duration) -> bool { - self.data.is_expired(at) || &self.last_updated_at.elapsed() >= timeout + self.data.as_peer_info().is_expired(at) || &self.last_updated_at.elapsed() >= timeout + } +} + +pub(crate) trait AsPeerInfo { + type Guard<'a>: std::ops::Deref> + where + Self: 'a; + + fn as_peer_info(&self) -> Self::Guard<'_>; + + fn load_peer_info(&self) -> Arc; +} + +impl AsPeerInfo for Arc { + type Guard<'a> = &'a Arc; + + #[inline] + fn as_peer_info(&self) -> Self::Guard<'_> { + self + } + + #[inline] + fn load_peer_info(&self) -> Arc { + self.clone() + } +} + +impl AsPeerInfo for KnownPeerHandle { + type Guard<'a> = arc_swap::Guard, arc_swap::DefaultStrategy> + where + Self: 'a; + + #[inline] + fn as_peer_info(&self) -> Self::Guard<'_> { + self.peer_info() + } + + #[inline] + fn load_peer_info(&self) -> Arc { + KnownPeerHandle::load_peer_info(self) } } @@ -188,50 +238,26 @@ mod tests { use std::str::FromStr; use super::*; + use crate::util::make_peer_info_stub; const MAX_K: usize = 20; - fn make_node(id: PeerId) -> Arc { - Arc::new(PeerInfo { - id, - address_list: Default::default(), - created_at: 0, - expires_at: u32::MAX, - signature: Box::new([0; 64]), - }) - } - #[test] fn buckets_are_sets() { let mut table = RoutingTable::new(rand::random()); let peer = rand::random(); - assert!(table.add( - make_node(peer), - MAX_K, - &Duration::MAX, - RoutingTableSource::Trusted - )); - assert!(table.add( - make_node(peer), - MAX_K, - &Duration::MAX, - RoutingTableSource::Trusted - )); // returns true because the node was updated + assert!(table.add(make_peer_info_stub(peer), MAX_K, &Duration::MAX, Some)); + assert!(table.add(make_peer_info_stub(peer), MAX_K, &Duration::MAX, Some)); // returns true because the node was updated assert_eq!(table.len(), 1); } #[test] - fn sould_not_add_seld() { + fn should_not_add_self() { let local_id = rand::random(); let mut table = RoutingTable::new(local_id); - assert!(!table.add( - make_node(local_id), - MAX_K, - &Duration::MAX, - RoutingTableSource::Trusted - )); + assert!(!table.add(make_peer_info_stub(local_id), MAX_K, &Duration::MAX, Some)); assert!(table.is_empty()); } @@ -242,19 +268,9 @@ mod tests { let mut bucket = Bucket::with_capacity(k); for _ in 0..k { - assert!(bucket.insert( - make_node(rand::random()), - k, - &timeout, - RoutingTableSource::Trusted - )); + assert!(bucket.insert(make_peer_info_stub(rand::random()), k, &timeout, Some)); } - assert!(!bucket.insert( - make_node(rand::random()), - k, - &timeout, - RoutingTableSource::Trusted - )); + assert!(!bucket.insert(make_peer_info_stub(rand::random()), k, &timeout, Some)); } #[test] @@ -373,12 +389,7 @@ mod tests { let mut table = RoutingTable::new(local_id); for id in ids { - table.add( - make_node(id), - MAX_K, - &Duration::MAX, - RoutingTableSource::Trusted, - ); + table.add(make_peer_info_stub(id), MAX_K, &Duration::MAX, Some); } { diff --git a/network/src/lib.rs b/network/src/lib.rs index 2f01ec87a..f701fc434 100644 --- a/network/src/lib.rs +++ b/network/src/lib.rs @@ -1,17 +1,19 @@ pub use self::overlay::{ OverlayConfig, OverlayId, OverlayService, OverlayServiceBackgroundTasks, OverlayServiceBuilder, - PrivateOverlay, PrivateOverlayBuilder, PrivateOverlayEntries, PrivateOverlayEntriesReadGuard, - PrivateOverlayEntriesWriteGuard, PublicOverlay, PublicOverlayBuilder, PublicOverlayEntries, - PublicOverlayEntriesReadGuard, + PrivateOverlay, PrivateOverlayBuilder, PrivateOverlayEntries, PrivateOverlayEntriesEvent, + PrivateOverlayEntriesReadGuard, PrivateOverlayEntriesWriteGuard, PublicOverlay, + PublicOverlayBuilder, PublicOverlayEntries, PublicOverlayEntriesReadGuard, }; pub use self::util::{check_peer_signature, NetworkExt, Routable, Router, RouterBuilder}; pub use dht::{ - xor_distance, DhtClient, DhtClientBuilder, DhtConfig, DhtQueryBuilder, DhtQueryWithDataBuilder, - DhtService, DhtServiceBuilder, FindValueError, OverlayValueMerger, StorageError, + xor_distance, DhtClient, DhtConfig, DhtQueryBuilder, DhtQueryWithDataBuilder, DhtService, + DhtServiceBackgroundTasks, DhtServiceBuilder, FindValueError, OverlayValueMerger, PeerResolver, + PeerResolverBuilder, PeerResolverHandle, StorageError, }; pub use network::{ - ActivePeers, Connection, KnownPeer, KnownPeers, Network, NetworkBuilder, NetworkConfig, Peer, - QuicConfig, RecvStream, SendStream, WeakActivePeers, WeakNetwork, + ActivePeers, Connection, KnownPeerHandle, KnownPeers, KnownPeersError, Network, NetworkBuilder, + NetworkConfig, Peer, PeerBannedError, QuicConfig, RecvStream, SendStream, WeakActivePeers, + WeakKnownPeerHandle, WeakNetwork, }; pub use types::{ service_datagram_fn, service_message_fn, service_query_fn, Address, BoxCloneService, @@ -49,20 +51,16 @@ mod tests { let keypair = everscale_crypto::ed25519::KeyPair::generate(&mut rand::thread_rng()); let peer_id: PeerId = keypair.public_key.into(); - let private_overlay = PrivateOverlay::builder(rand::random()) - .build(service_message_fn(|_| futures_util::future::ready(()))); - - let public_overlay = PublicOverlay::builder(rand::random()) - .build(service_message_fn(|_| futures_util::future::ready(()))); + let (dht_tasks, dht_service) = DhtService::builder(peer_id).build(); let (overlay_tasks, overlay_service) = OverlayService::builder(peer_id) - .with_private_overlay(&private_overlay) - .with_public_overlay(&public_overlay) + .with_dht_service(dht_service.clone()) .build(); - let (dht_client, dht) = DhtService::builder(peer_id).build(); - - let router = Router::builder().route(dht).route(overlay_service).build(); + let router = Router::builder() + .route(dht_service.clone()) + .route(overlay_service.clone()) + .build(); let network = Network::builder() .with_random_private_key() @@ -70,7 +68,19 @@ mod tests { .build((Ipv4Addr::LOCALHOST, 0), router) .unwrap(); - let _dht_client = dht_client.build(network.clone()); - overlay_tasks.spawn(network); + dht_tasks.spawn(&network); + overlay_tasks.spawn(&network); + + let peer_resolver = dht_service.make_peer_resolver().build(&network); + + let private_overlay = PrivateOverlay::builder(rand::random()) + .with_peer_resolver(peer_resolver) + .build(service_message_fn(|_| futures_util::future::ready(()))); + + let public_overlay = PublicOverlay::builder(rand::random()) + .build(service_message_fn(|_| futures_util::future::ready(()))); + + overlay_service.add_private_overlay(&private_overlay); + overlay_service.add_public_overlay(&public_overlay); } } diff --git a/network/src/network/config.rs b/network/src/network/config.rs index 1ef9dbdc4..904e1b91b 100644 --- a/network/src/network/config.rs +++ b/network/src/network/config.rs @@ -50,6 +50,9 @@ pub struct NetworkConfig { /// Default: 1 minute. #[serde(with = "serde_helpers::humantime")] pub shutdown_idle_timeout: Duration, + + /// Default: no. + pub enable_0rtt: bool, } impl Default for NetworkConfig { @@ -66,6 +69,7 @@ impl Default for NetworkConfig { max_concurrent_connections: None, active_peers_event_channel_capacity: 128, shutdown_idle_timeout: Duration::from_secs(60), + enable_0rtt: false, } } } @@ -135,9 +139,9 @@ pub(crate) struct EndpointConfig { pub client_cert: rustls::Certificate, pub pkcs8_der: rustls::PrivateKey, pub quinn_server_config: quinn::ServerConfig, - pub quinn_client_config: quinn::ClientConfig, pub transport_config: Arc, pub quinn_endpoint_config: quinn::EndpointConfig, + pub enable_early_data: bool, } impl EndpointConfig { @@ -148,15 +152,20 @@ impl EndpointConfig { } } - pub fn make_client_config_for_peer_id(&self, peer_id: PeerId) -> Result { - let client_config = rustls::ClientConfig::builder() - .with_safe_defaults() + pub fn make_client_config_for_peer_id(&self, peer_id: &PeerId) -> Result { + let mut client_config = rustls::ClientConfig::builder() + .with_cipher_suites(DEFAULT_CIPHER_SUITES) + .with_kx_groups(DEFAULT_KX_GROUPS) + .with_protocol_versions(DEFAULT_PROTOCOL_VERSIONS) + .unwrap() .with_custom_certificate_verifier(Arc::new(CertVerifierWithPeerId::new( self.service_name.clone(), peer_id, ))) .with_client_auth_cert(vec![self.client_cert.clone()], self.pkcs8_der.clone())?; + client_config.enable_early_data = self.enable_early_data; + let mut client = quinn::ClientConfig::new(Arc::new(client_config)); client.transport_config(self.transport_config.clone()); Ok(client) @@ -170,10 +179,16 @@ pub(crate) struct EndpointConfigBuilder { #[derive(Default)] struct EndpointConfigBuilderFields { + enable_0rtt: bool, transport_config: Option, } impl EndpointConfigBuilder { + pub fn with_0rtt_enabled(mut self, enable_0rtt: bool) -> Self { + self.optional_fields.enable_0rtt = enable_0rtt; + self + } + pub fn with_transport_config(mut self, transport_config: quinn::TransportConfig) -> Self { self.optional_fields.transport_config = Some(transport_config); self @@ -221,12 +236,6 @@ impl EndpointConfigBuilder { generate_cert(&keypair, &service_name).context("Failed to generate a certificate")?; let cert_verifier = Arc::new(CertVerifier::from(service_name.clone())); - let quinn_client_config = make_client_config( - cert.clone(), - pkcs8_der.clone(), - cert_verifier.clone(), - transport_config.clone(), - )?; let quinn_server_config = make_server_config( &service_name, @@ -234,6 +243,7 @@ impl EndpointConfigBuilder { cert.clone(), cert_verifier, transport_config.clone(), + self.optional_fields.enable_0rtt, )?; let peer_id = peer_id_from_certificate(&cert)?; @@ -244,35 +254,20 @@ impl EndpointConfigBuilder { client_cert: cert, pkcs8_der, quinn_server_config, - quinn_client_config, transport_config, quinn_endpoint_config, + enable_early_data: self.optional_fields.enable_0rtt, }) } } -fn make_client_config( - cert: rustls::Certificate, - pkcs8_der: rustls::PrivateKey, - cert_verifier: Arc, - transport_config: Arc, -) -> Result { - let client_config = rustls::ClientConfig::builder() - .with_safe_defaults() - .with_custom_certificate_verifier(cert_verifier) - .with_client_auth_cert(vec![cert], pkcs8_der)?; - - let mut client = quinn::ClientConfig::new(Arc::new(client_config)); - client.transport_config(transport_config); - Ok(client) -} - fn make_server_config( service_name: &str, pkcs8_der: rustls::PrivateKey, cert: rustls::Certificate, cert_verifier: Arc, transport_config: Arc, + enable_0rtt: bool, ) -> Result { let mut server_cert_resolver = rustls::server::ResolvesServerCertUsingSni::new(); @@ -280,11 +275,21 @@ fn make_server_config( let certified_key = rustls::sign::CertifiedKey::new(vec![cert], key); server_cert_resolver.add(service_name, certified_key)?; - let server_crypto = rustls::ServerConfig::builder() - .with_safe_defaults() + let mut server_crypto = rustls::ServerConfig::builder() + .with_cipher_suites(DEFAULT_CIPHER_SUITES) + .with_kx_groups(DEFAULT_KX_GROUPS) + .with_protocol_versions(DEFAULT_PROTOCOL_VERSIONS) + .unwrap() .with_client_cert_verifier(cert_verifier) .with_cert_resolver(Arc::new(server_cert_resolver)); + if enable_0rtt { + server_crypto.max_early_data_size = u32::MAX; + + // TODO: Should we enable this? + // server_crypto.send_half_rtt_data = true; + } + let mut server = quinn::ServerConfig::with_crypto(Arc::new(server_crypto)); server.transport = transport_config; Ok(server) @@ -302,3 +307,11 @@ fn compute_reset_key(private_key: &[u8; 32]) -> ring::hmac::Key { ring::hmac::Key::new(ring::hmac::HMAC_SHA256, &reset_key) } + +static DEFAULT_CIPHER_SUITES: &[rustls::SupportedCipherSuite] = &[ + rustls::cipher_suite::TLS13_AES_256_GCM_SHA384, + rustls::cipher_suite::TLS13_AES_128_GCM_SHA256, + rustls::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, +]; +static DEFAULT_KX_GROUPS: &[&rustls::SupportedKxGroup] = &[&rustls::kx_group::X25519]; +static DEFAULT_PROTOCOL_VERSIONS: &[&rustls::SupportedProtocolVersion] = &[&rustls::version::TLS13]; diff --git a/network/src/network/connection.rs b/network/src/network/connection.rs index d5dd1e34b..cd5354de2 100644 --- a/network/src/network/connection.rs +++ b/network/src/network/connection.rs @@ -19,14 +19,18 @@ pub struct Connection { impl Connection { pub fn new(inner: quinn::Connection, origin: Direction) -> Result { let peer_id = extract_peer_id(&inner)?; - Ok(Self { + Ok(Self::with_peer_id(inner, origin, peer_id)) + } + + pub fn with_peer_id(inner: quinn::Connection, origin: Direction, peer_id: PeerId) -> Self { + Self { request_meta: Arc::new(InboundRequestMeta { peer_id, origin, remote_address: inner.remote_address(), }), inner, - }) + } } pub fn request_meta(&self) -> &Arc { @@ -177,7 +181,7 @@ impl tokio::io::AsyncRead for RecvStream { } } -fn extract_peer_id(connection: &quinn::Connection) -> Result { +pub(crate) fn extract_peer_id(connection: &quinn::Connection) -> Result { let certificate = connection .peer_identity() .and_then(|identity| identity.downcast::>().ok()) @@ -186,3 +190,31 @@ fn extract_peer_id(connection: &quinn::Connection) -> Result { peer_id_from_certificate(&certificate).map_err(Into::into) } + +pub(crate) fn parse_peer_identity(identity: Box) -> Result { + let certificate = identity + .downcast::>() + .ok() + .and_then(|certificates| certificates.into_iter().next()) + .context("No certificate found in the connection")?; + + peer_id_from_certificate(&certificate).map_err(Into::into) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::str::FromStr; + + #[test] + fn parse_cert() { + let peer_id = + PeerId::from_str("7a6e86f44d5bd83093ba658fadccffbeb2878bb2e30db7b92e237e21eef77e07") + .unwrap(); + + let certificate = rustls::Certificate(b"0\x81\xd80\x81\x8b\xa0\x03\x02\x01\x02\x02\x15\0\xa0\xd7\x8e\xa8\xf2\xfe\xd8\x10\xeb3\x90\x19br\x91S`\x01\xe0)0\x05\x06\x03+ep0\00 \x17\r750101000000Z\x18\x0f40960101000000Z0\00*0\x05\x06\x03+ep\x03!\0zn\x86\xf4M[\xd80\x93\xbae\x8f\xad\xcc\xff\xbe\xb2\x87\x8b\xb2\xe3\r\xb7\xb9.#~!\xee\xf7~\x07\xa3\x140\x120\x10\x06\x03U\x1d\x11\x04\t0\x07\x82\x05tycho0\x05\x06\x03+ep\x03A\0\xe3s-\xaf\xbd\xac\x81\xbc\x82\x8a\x83\xf8\xa3\xe3\xcb\x118\xa8g\xef_M\x99*\x7f\xed\x1bQ=\x9f\xf1\xc4%q\xa9g\xfa\x0f\x12R\x84LH\xff\x99\xa7bH\xfc\xbdb\xbcY\xc5C\x11\xc5\x91\x8dn#\xe2\x9b\x05".to_vec()); + let parsed_peer_id = peer_id_from_certificate(&certificate).unwrap(); + + assert_eq!(peer_id, parsed_peer_id); + } +} diff --git a/network/src/network/connection_manager.rs b/network/src/network/connection_manager.rs index cfbe2cd75..ec92eff87 100644 --- a/network/src/network/connection_manager.rs +++ b/network/src/network/connection_manager.rs @@ -1,16 +1,19 @@ +use std::collections::hash_map; +use std::mem::ManuallyDrop; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Weak}; use std::time::{Duration, Instant}; use ahash::HashMap; use anyhow::Result; +use arc_swap::{ArcSwap, AsRaw}; use tokio::sync::{broadcast, mpsc, oneshot}; -use tokio::task::JoinSet; +use tokio::task::{AbortHandle, JoinSet}; use tycho_util::{FastDashMap, FastHashMap}; use crate::network::config::NetworkConfig; use crate::network::connection::Connection; -use crate::network::endpoint::{Connecting, Endpoint}; +use crate::network::endpoint::{Connecting, Endpoint, Into0RttResult}; use crate::network::request_handler::InboundRequestHandler; use crate::network::wire::handshake; use crate::types::{ @@ -20,7 +23,7 @@ use crate::types::{ #[derive(Debug)] pub(crate) enum ConnectionManagerRequest { - Connect(Address, Option, oneshot::Sender>), + Connect(Address, PeerId, CallbackTx), Shutdown(oneshot::Sender<()>), } @@ -30,10 +33,12 @@ pub(crate) struct ConnectionManager { mailbox: mpsc::Receiver, + pending_connection_callbacks: FastHashMap, + pending_partial_connections: JoinSet>, pending_connections: JoinSet, connection_handlers: JoinSet<()>, - pending_dials: FastHashMap>>, + pending_dials: FastHashMap, dial_backoff_states: HashMap, active_peers: ActivePeers, @@ -42,6 +47,9 @@ pub(crate) struct ConnectionManager { service: BoxCloneService, } +type CallbackTx = oneshot::Sender>>; +type CallbackRx = oneshot::Receiver>>; + impl Drop for ConnectionManager { fn drop(&mut self) { self.endpoint.close(); @@ -61,6 +69,8 @@ impl ConnectionManager { config, endpoint, mailbox, + pending_connection_callbacks: Default::default(), + pending_partial_connections: Default::default(), pending_connections: Default::default(), connection_handlers: Default::default(), pending_dials: Default::default(), @@ -92,7 +102,7 @@ impl ConnectionManager { match request { ConnectionManagerRequest::Connect(address, peer_id, callback) => { - self.handle_connect_request(address, peer_id, callback); + self.handle_connect_request(address, &peer_id, callback); } ConnectionManagerRequest::Shutdown(oneshot) => { shutdown_notifier = Some(oneshot); @@ -106,8 +116,21 @@ impl ConnectionManager { } } Some(connecting_output) = self.pending_connections.join_next() => { + match connecting_output { + Ok(connecting) => self.handle_connecting_result(connecting), + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } + continue; + } + } + } + Some(partial_connection) = self.pending_partial_connections.join_next() => { // NOTE: unwrap here is to propagate panic from the spawned future - self.handle_connecting_result(connecting_output.unwrap()); + if let Some(PartialConnection { connection, timeout_at }) = partial_connection.unwrap() { + self.handle_incoming_impl(connection, None, timeout_at); + } } Some(connection_handler_output) = self.connection_handlers.join_next() => { // NOTE: unwrap here is to propagate panic from the spawned future @@ -127,6 +150,7 @@ impl ConnectionManager { async fn shutdown(mut self) { self.endpoint.close(); + self.pending_partial_connections.shutdown().await; self.pending_connections.shutdown().await; while self.connection_handlers.join_next().await.is_some() {} @@ -179,23 +203,25 @@ impl ConnectionManager { .known_peers .0 .iter() - .filter(|item| { - let KnownPeer { - peer_info, - affinity, - } = item.value(); - - *affinity == PeerAffinity::High - && &peer_info.id != self.endpoint.peer_id() + .filter_map(|item| { + let value = match item.value() { + KnownPeerState::Stored(item) => item.upgrade()?, + KnownPeerState::Banned => return None, + }; + let peer_info = value.peer_info.load(); + let affinity = value.compute_affinity(); + + (affinity == PeerAffinity::High + && peer_info.id != self.endpoint.peer_id() && !self.active_peers.contains(&peer_info.id) && !self.pending_dials.contains_key(&peer_info.id) && self .dial_backoff_states .get(&peer_info.id) - .map_or(true, |state| now > state.next_attempt_at) + .map_or(true, |state| now > state.next_attempt_at)) + .then(|| arc_swap::Guard::into_inner(peer_info)) }) .take(outstanding_connections_limit) - .map(|item| item.value().peer_info.clone()) .collect::>(); for peer_info in outstanding_connections { @@ -207,93 +233,234 @@ impl ConnectionManager { .expect("address list must have at least one item"); let (tx, rx) = oneshot::channel(); - self.dial_peer(address, Some(peer_info.id), tx); + self.dial_peer(address, &peer_info.id, tx); self.pending_dials.insert(peer_info.id, rx); } } - fn handle_connect_request( - &mut self, - address: Address, - peer_id: Option, - callback: oneshot::Sender>, - ) { + fn handle_connect_request(&mut self, address: Address, peer_id: &PeerId, callback: CallbackTx) { self.dial_peer(address, peer_id, callback); } fn handle_incoming(&mut self, connecting: Connecting) { - async fn handle_incoming_task( - connecting: Connecting, - config: Arc, - active_peers: ActivePeers, - known_peers: KnownPeers, - ) -> ConnectingOutput { - let fut = async { - let connection = connecting.await?; - - match known_peers.get_affinity(connection.peer_id()) { - Some(PeerAffinity::High | PeerAffinity::Allowed) => {} - Some(PeerAffinity::Never) => { - anyhow::bail!( - "rejecting connection from peer {} due to PeerAffinity::Never", - connection.peer_id(), - ); - } - _ => { - if let Some(limit) = config.max_concurrent_connections { - anyhow::ensure!( - active_peers.len() < limit, - "rejecting connection from peer {} dut too many concurrent connections", - connection.peer_id(), + let remote_addr = connecting.remote_address(); + tracing::trace!( + local_id = %self.endpoint.peer_id(), + %remote_addr, + "received an incoming connection", + ); + + // Split incoming connection into 0.5-RTT and 1-RTT parts. + match connecting.into_0rtt() { + Into0RttResult::Established(connection, accepted) => { + let timeout_at = Instant::now() + self.config.connect_timeout; + self.handle_incoming_impl(connection, Some(accepted), timeout_at); + } + Into0RttResult::WithoutIdentity(partial_connection) => { + tracing::debug!("connection identity is not available yet"); + + let timeout_at = Instant::now() + self.config.connect_timeout; + self.pending_partial_connections.spawn(async move { + match tokio::time::timeout_at(timeout_at.into(), partial_connection).await { + Ok(Ok(connection)) => Some(PartialConnection { + connection, + timeout_at, + }), + Ok(Err(e)) => { + tracing::warn!( + %remote_addr, + "failed to establish an incoming connection: {e:?}", + ); + None + } + Err(_) => { + tracing::warn!( + %remote_addr, + "incoming connection timed out", ); + None } } - } + }); + } + Into0RttResult::InvalidConnection(e) => { + // TODO: Lower log level to trace/debug? + tracing::warn!(%remote_addr, "invalid incoming connection: {e:?}"); + } + Into0RttResult::Unavailable(_) => unreachable!( + "BUG: For incoming connections, a 0.5-RTT connection must \ + always be successfully constructed." + ), + }; + } - handshake(connection).await + fn handle_incoming_impl( + &mut self, + connection: Connection, + accepted: Option, + timeout_at: Instant, + ) { + async fn handle_incoming_task( + seqno: u32, + connection: ConnectionClosedOnDrop, + accepted: Option, + timeout_at: Instant, + ) -> ConnectingOutput { + let target_peer_id = *connection.peer_id(); + let target_address = connection.remote_address().into(); + let fut = async { + if let Some(accepted) = accepted { + // NOTE: `bool` output of this future is meaningless for servers. + accepted.await; + } + handshake(&connection).await }; - let connecting_result = tokio::time::timeout(config.connect_timeout, fut) + let connecting_result = tokio::time::timeout_at(timeout_at.into(), fut) .await .map_err(Into::into) - .and_then(std::convert::identity); + .and_then(std::convert::identity) + .map_err(Arc::new) + .map(|_| connection.disarm()); ConnectingOutput { - connecting_result, - callback: None, - target_address: None, - target_peer_id: None, + seqno, + drop_result: true, + connecting_result: ManuallyDrop::new(connecting_result), + target_address, + target_peer_id, + } + } + + let remote_addr = connection.remote_address(); + + // Check if the peer is allowed before doing anything else. + match self.known_peers.get_affinity(connection.peer_id()) { + Some(PeerAffinity::High | PeerAffinity::Allowed) => {} + Some(PeerAffinity::Never) => { + // TODO: Lower log level to trace/debug? + tracing::warn!( + %remote_addr, + peer_id = %connection.peer_id(), + "rejecting connection due to PeerAffinity::Never", + ); + return; + } + _ => { + if matches!( + self.config.max_concurrent_connections, + Some(limit) if self.active_peers.len() >= limit + ) { + // TODO: Lower log level to trace/debug? + tracing::warn!( + %remote_addr, + peer_id = %connection.peer_id(), + "rejecting connection due too many concurrent connections", + ); + return; + } } } - tracing::trace!("received new incoming connection"); + let entry = match self.pending_connection_callbacks.entry(remote_addr.into()) { + hash_map::Entry::Vacant(entry) => Some(entry.insert(PendingConnectionCallbacks { + last_seqno: 0, + origin: Direction::Inbound, + callbacks: Default::default(), + abort_handle: None, + })), + hash_map::Entry::Occupied(entry) => { + let entry = entry.into_mut(); + + // Check if the incoming connection is a simultaneous dial. + if simultaneous_dial_tie_breaking( + self.endpoint.peer_id(), + connection.peer_id(), + entry.origin, + Direction::Inbound, + ) { + // New connection wins the tie, abort the old one and spawn a new task. + tracing::debug!( + %remote_addr, + peer_id = %connection.peer_id(), + "cancelling old connection to mitigate simultaneous dial", + ); + + entry.origin = Direction::Inbound; + entry.last_seqno += 1; + if let Some(handle) = entry.abort_handle.take() { + handle.abort(); + } + Some(entry) + } else { + // Old connection wins the tie, gracefully close the new one. + tracing::debug!( + %remote_addr, + peer_id = %connection.peer_id(), + "cancelling new connection to mitigate simultaneous dial", + ); + + connection.close(); + None + } + } + }; - self.pending_connections.spawn(handle_incoming_task( - connecting, - self.config.clone(), - self.active_peers.clone(), - self.known_peers.clone(), - )); + if let Some(entry) = entry { + entry.abort_handle = Some(self.pending_connections.spawn(handle_incoming_task( + entry.last_seqno, + ConnectionClosedOnDrop::new(connection), + accepted, + timeout_at, + ))); + } } - fn handle_connecting_result(&mut self, res: ConnectingOutput) { - match res.connecting_result { + fn handle_connecting_result(&mut self, mut res: ConnectingOutput) { + // Check seqno first to drop outdated results. + { + let entry = self + .pending_connection_callbacks + .get(&res.target_address) + .expect("Connection tasks must be tracked"); + + if entry.last_seqno != res.seqno { + tracing::debug!( + target_address = %res.target_address, + target_peer_id = ?res.target_peer_id, + "connection result is outdated" + ); + return; + } + } + + let callbacks = self + .pending_connection_callbacks + .remove(&res.target_address) + .expect("Connection tasks must be tracked") + .callbacks; + + res.drop_result = false; + // SAFETY: `drop_result` is set to `false`. + match unsafe { ManuallyDrop::take(&mut res.connecting_result) } { Ok(connection) => { let peer_id = *connection.peer_id(); tracing::debug!(%peer_id, "new connection"); self.add_peer(connection); - if let Some(callback) = res.callback { + + for callback in callbacks { _ = callback.send(Ok(peer_id)); } } Err(e) => { tracing::debug!( - target_address = ?res.target_address, + target_address = %res.target_address, target_peer_id = ?res.target_peer_id, "connection failed: {e:?}" ); - if let Some(callback) = res.callback { - _ = callback.send(Err(e)); + + for callback in callbacks { + _ = callback.send(Err(e.clone())); } } } @@ -312,57 +479,170 @@ impl ConnectionManager { } #[tracing::instrument(level = "trace", skip_all, fields(peer_id = ?peer_id, address = %address))] - fn dial_peer( - &mut self, - address: Address, - peer_id: Option, - callback: oneshot::Sender>, - ) { + fn dial_peer(&mut self, address: Address, peer_id: &PeerId, callback: CallbackTx) { async fn dial_peer_task( + seqno: u32, connecting: Result, address: Address, - peer_id: Option, - callback: oneshot::Sender>, + peer_id: PeerId, config: Arc, ) -> ConnectingOutput { let fut = async { - let connection = connecting?.await?; - handshake(connection).await + let connection = ConnectionClosedOnDrop::new(connecting?.await?); + handshake(&connection).await?; + Ok(connection) }; let connecting_result = tokio::time::timeout(config.connect_timeout, fut) .await .map_err(Into::into) - .and_then(std::convert::identity); + .and_then(std::convert::identity) + .map_err(Arc::new) + .map(ConnectionClosedOnDrop::disarm); ConnectingOutput { - connecting_result, - callback: Some(callback), - target_address: Some(address), + seqno, + drop_result: true, + connecting_result: ManuallyDrop::new(connecting_result), + target_address: address, target_peer_id: peer_id, } } - let target_address = address.clone(); - let connecting = match peer_id { - None => self.endpoint.connect(address), - Some(peer_id) => self.endpoint.connect_with_expected_id(address, peer_id), + tracing::info!( + local_id = %self.endpoint.peer_id(), + %peer_id, + remote_addr = %address, + "connecting to peer", + ); + + let entry = match self.pending_connection_callbacks.entry(address.clone()) { + hash_map::Entry::Vacant(entry) => Some(entry.insert(PendingConnectionCallbacks { + last_seqno: 0, + origin: Direction::Outbound, + callbacks: vec![callback], + abort_handle: None, + })), + hash_map::Entry::Occupied(entry) => { + let entry = entry.into_mut(); + + // Add the callback to the existing entry. + entry.callbacks.push(callback); + + // Check if the outgoing connection is a simultaneous dial. + if simultaneous_dial_tie_breaking( + self.endpoint.peer_id(), + peer_id, + entry.origin, + Direction::Outbound, + ) { + // New connection wins the tie, abort the old one and spawn a new task. + tracing::debug!( + remote_addr = %address, + %peer_id, + "cancelling old connection to mitigate simultaneous dial", + ); + + entry.origin = Direction::Outbound; + entry.last_seqno += 1; + if let Some(handle) = entry.abort_handle.take() { + handle.abort(); + } + Some(entry) + } else { + // Old connection wins the tie, gracefully close the new one. + tracing::debug!( + remote_addr = %address, + %peer_id, + "cancelling new connection to mitigate simultaneous dial", + ); + None + } + } }; - self.pending_connections.spawn(dial_peer_task( - connecting, - target_address, - peer_id, - callback, - self.config.clone(), - )); + + if let Some(entry) = entry { + let target_address = address.clone(); + let connecting = self + .endpoint + .connect_with_expected_id(address.clone(), peer_id); + self.pending_connections.spawn(dial_peer_task( + entry.last_seqno, + connecting, + target_address, + *peer_id, + self.config.clone(), + )); + } } } +struct PendingConnectionCallbacks { + last_seqno: u32, + origin: Direction, + callbacks: Vec, + abort_handle: Option, +} + +struct PartialConnection { + connection: Connection, + timeout_at: Instant, +} + struct ConnectingOutput { - connecting_result: Result, - callback: Option>>, - target_address: Option
, - target_peer_id: Option, + seqno: u32, + drop_result: bool, + connecting_result: ManuallyDrop>>, + target_address: Address, + target_peer_id: PeerId, +} + +impl Drop for ConnectingOutput { + fn drop(&mut self) { + if self.drop_result { + // SAFETY: `drop_result` is set to `true` only when the result is not used. + unsafe { ManuallyDrop::drop(&mut self.connecting_result) }; + } + } +} + +struct ConnectionClosedOnDrop { + connection: ManuallyDrop, + close_on_drop: bool, +} + +impl ConnectionClosedOnDrop { + fn new(connection: Connection) -> Self { + Self { + connection: ManuallyDrop::new(connection), + close_on_drop: true, + } + } + + fn disarm(mut self) -> Connection { + self.close_on_drop = false; + // SAFETY: `drop` will not be called. + unsafe { ManuallyDrop::take(&mut self.connection) } + } +} + +impl std::ops::Deref for ConnectionClosedOnDrop { + type Target = Connection; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.connection + } +} + +impl Drop for ConnectionClosedOnDrop { + fn drop(&mut self) { + if self.close_on_drop { + // SAFETY: `disarm` was not called. + let connection = unsafe { ManuallyDrop::take(&mut self.connection) }; + connection.close(); + } + } } #[derive(Debug)] @@ -480,21 +760,21 @@ impl ActivePeersInner { fn add(&self, local_id: &PeerId, new_connection: Connection) -> Option { use dashmap::mapref::entry::Entry; - let remote_id = new_connection.peer_id(); - match self.connections.entry(*remote_id) { + let peer_id = new_connection.peer_id(); + match self.connections.entry(*peer_id) { Entry::Occupied(mut entry) => { if simultaneous_dial_tie_breaking( local_id, - remote_id, + peer_id, entry.get().origin(), new_connection.origin(), ) { - tracing::debug!(%remote_id, "closing old connection to mitigate simultaneous dial"); + tracing::debug!(%peer_id, "closing old connection to mitigate simultaneous dial"); let old_connection = entry.insert(new_connection.clone()); old_connection.close(); - self.send_event(PeerEvent::LostPeer(*remote_id, DisconnectReason::Requested)); + self.send_event(PeerEvent::LostPeer(*peer_id, DisconnectReason::Requested)); } else { - tracing::debug!(%remote_id, "closing new connection to mitigate simultaneous dial"); + tracing::debug!(%peer_id, "closing new connection to mitigate simultaneous dial"); new_connection.close(); return None; } @@ -505,7 +785,7 @@ impl ActivePeersInner { } } - self.send_event(PeerEvent::NewPeer(*remote_id)); + self.send_event(PeerEvent::NewPeer(*peer_id)); Some(new_connection) } @@ -547,7 +827,7 @@ impl ActivePeersInner { fn simultaneous_dial_tie_breaking( local_id: &PeerId, - remote_id: &PeerId, + peer_id: &PeerId, old_origin: Direction, new_origin: Direction, ) -> bool { @@ -555,13 +835,14 @@ fn simultaneous_dial_tie_breaking( (Direction::Inbound, Direction::Inbound) | (Direction::Outbound, Direction::Outbound) => { true } - (Direction::Inbound, Direction::Outbound) => remote_id < local_id, - (Direction::Outbound, Direction::Inbound) => local_id < remote_id, + (Direction::Inbound, Direction::Outbound) => peer_id < local_id, + (Direction::Outbound, Direction::Inbound) => local_id < peer_id, } } #[derive(Default, Clone)] -pub struct KnownPeers(Arc>); +#[repr(transparent)] +pub struct KnownPeers(Arc>); impl KnownPeers { pub fn new() -> Self { @@ -572,49 +853,629 @@ impl KnownPeers { self.0.contains_key(peer_id) } - pub fn get(&self, peer_id: &PeerId) -> Option { - self.0.get(peer_id).map(|item| item.value().clone()) + pub fn is_banned(&self, peer_id: &PeerId) -> bool { + self.0 + .get(peer_id) + .and_then(|item| { + Some(match item.value() { + KnownPeerState::Stored(item) => item.upgrade()?.is_banned(), + KnownPeerState::Banned => true, + }) + }) + .unwrap_or_default() + } + + pub fn get(&self, peer_id: &PeerId) -> Option> { + self.0.get(peer_id).and_then(|item| match item.value() { + KnownPeerState::Stored(item) => { + let inner = item.upgrade()?; + Some(inner.peer_info.load_full()) + } + KnownPeerState::Banned => None, + }) } pub fn get_affinity(&self, peer_id: &PeerId) -> Option { - self.0.get(peer_id).map(|item| item.value().affinity) + self.0 + .get(peer_id) + .and_then(|item| item.value().compute_affinity()) + } + + pub fn remove(&self, peer_id: &PeerId) { + self.0.remove(peer_id); } - pub fn insert(&self, peer_info: Arc, affinity: PeerAffinity) -> Option { - match self.0.entry(peer_info.id) { + pub fn ban(&self, peer_id: &PeerId) { + match self.0.entry(*peer_id) { dashmap::mapref::entry::Entry::Vacant(entry) => { - entry.insert(KnownPeer { - peer_info, - affinity, - }); - None + entry.insert(KnownPeerState::Banned); } - dashmap::mapref::entry::Entry::Occupied(entry) => { - if entry.get().peer_info.created_at >= peer_info.created_at { + dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() { + KnownPeerState::Banned => {} + KnownPeerState::Stored(item) => match item.upgrade() { + Some(item) => item.affinity.store(AFFINITY_BANNED, Ordering::Release), + None => *entry.get_mut() = KnownPeerState::Banned, + }, + }, + } + } + + pub fn make_handle(&self, peer_id: &PeerId, with_affinity: bool) -> Option { + let inner = match self.0.get(peer_id)?.value() { + KnownPeerState::Stored(item) => { + let inner = item.upgrade()?; + if with_affinity && !inner.increase_affinity() { return None; } + inner + } + KnownPeerState::Banned => return None, + }; + + Some(KnownPeerHandle::from_inner(inner, with_affinity)) + } + + /// Inserts a new handle only if the provided info is not outdated + /// and the peer is not banned. + pub fn insert( + &self, + peer_info: Arc, + with_affinity: bool, + ) -> Result { + // TODO: add capacity limit for entries without affinity + let inner = match self.0.entry(peer_info.id) { + dashmap::mapref::entry::Entry::Vacant(entry) => { + let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0); + entry.insert(KnownPeerState::Stored(Arc::downgrade(&inner))); + inner + } + dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() { + KnownPeerState::Banned => return Err(KnownPeersError::from(PeerBannedError)), + KnownPeerState::Stored(item) => match item.upgrade() { + Some(inner) => match inner.try_update_peer_info(&peer_info, with_affinity)? { + true => inner, + false => return Err(KnownPeersError::OutdatedInfo), + }, + None => { + let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0); + *item = Arc::downgrade(&inner); + inner + } + }, + }, + }; + + Ok(KnownPeerHandle::from_inner(inner, with_affinity)) + } + + /// Same as [`KnownPeers::insert`], but ignores outdated info. + pub fn insert_allow_outdated( + &self, + peer_info: Arc, + with_affinity: bool, + ) -> Result { + // TODO: add capacity limit for entries without affinity + let inner = match self.0.entry(peer_info.id) { + dashmap::mapref::entry::Entry::Vacant(entry) => { + let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0); + entry.insert(KnownPeerState::Stored(Arc::downgrade(&inner))); + inner + } + dashmap::mapref::entry::Entry::Occupied(mut entry) => match entry.get_mut() { + KnownPeerState::Banned => return Err(PeerBannedError), + KnownPeerState::Stored(item) => match item.upgrade() { + Some(inner) => { + // NOTE: Outdated info is ignored here. + inner.try_update_peer_info(&peer_info, with_affinity)?; + inner + } + None => { + let inner = KnownPeerInner::new(peer_info, with_affinity, &self.0); + *item = Arc::downgrade(&inner); + inner + } + }, + }, + }; + + Ok(KnownPeerHandle::from_inner(inner, with_affinity)) + } +} + +enum KnownPeerState { + Stored(Weak), + Banned, +} + +impl KnownPeerState { + fn compute_affinity(&self) -> Option { + Some(match self { + Self::Stored(weak) => weak.upgrade()?.compute_affinity(), + Self::Banned => PeerAffinity::Never, + }) + } +} + +#[derive(Clone)] +#[repr(transparent)] +pub struct KnownPeerHandle(KnownPeerHandleState); + +impl KnownPeerHandle { + fn from_inner(inner: Arc, with_affinity: bool) -> Self { + KnownPeerHandle(if with_affinity { + KnownPeerHandleState::WithAffinity(ManuallyDrop::new(Arc::new( + KnownPeerHandleWithAffinity { inner }, + ))) + } else { + KnownPeerHandleState::Simple(ManuallyDrop::new(inner)) + }) + } + + pub fn peer_info(&self) -> arc_swap::Guard, arc_swap::DefaultStrategy> { + self.inner().peer_info.load() + } + + pub fn load_peer_info(&self) -> Arc { + arc_swap::Guard::into_inner(self.peer_info()) + } + + pub fn is_banned(&self) -> bool { + self.inner().is_banned() + } + + pub fn max_affinity(&self) -> PeerAffinity { + self.inner().compute_affinity() + } + + pub fn update_peer_info(&self, peer_info: &Arc) -> Result<(), KnownPeersError> { + match self.inner().try_update_peer_info(peer_info, false) { + Ok(true) => Ok(()), + Ok(false) => Err(KnownPeersError::OutdatedInfo), + Err(e) => Err(KnownPeersError::PeerBanned(e)), + } + } + + pub fn ban(&self) -> bool { + let inner = self.inner(); + inner.affinity.swap(AFFINITY_BANNED, Ordering::AcqRel) != AFFINITY_BANNED + } + + pub fn increase_affinity(&mut self) -> bool { + match &mut self.0 { + KnownPeerHandleState::Simple(inner) => { + // NOTE: Handle will be updated even if the peer is banned. + inner.increase_affinity(); + + // SAFETY: Inner value was not dropped. + let inner = unsafe { ManuallyDrop::take(inner) }; + + // Replace the old state with the new one, ensuring that the old state + // is not dropped (because we took the value out of it). + let prev_state = std::mem::replace( + &mut self.0, + KnownPeerHandleState::WithAffinity(ManuallyDrop::new(Arc::new( + KnownPeerHandleWithAffinity { inner }, + ))), + ); - let affinity = match affinity { - PeerAffinity::High | PeerAffinity::Never => affinity, - PeerAffinity::Allowed => entry.get().affinity, + // Forget the old state to avoid dropping it. + #[allow(clippy::mem_forget)] + std::mem::forget(prev_state); + + true + } + KnownPeerHandleState::WithAffinity(_) => false, + } + } + + pub fn decrease_affinity(&mut self) -> bool { + match &mut self.0 { + KnownPeerHandleState::Simple(_) => false, + KnownPeerHandleState::WithAffinity(inner) => { + // NOTE: Handle will be updated even if the peer is banned. + inner.inner.decrease_affinity(); + + // SAFETY: Inner value was not dropped. + let inner = unsafe { ManuallyDrop::take(inner) }; + + // Get `KnownPeerInner` out of the wrapper. + let inner = match Arc::try_unwrap(inner) { + Ok(KnownPeerHandleWithAffinity { inner }) => inner, + Err(inner) => inner.inner.clone(), }; - let (_, old) = entry.replace_entry(KnownPeer { - peer_info, - affinity, - }); - Some(old) + // Replace the old state with the new one, ensuring that the old state + // is not dropped (because we took the value out of it). + let prev_state = std::mem::replace( + &mut self.0, + KnownPeerHandleState::Simple(ManuallyDrop::new(inner)), + ); + + // Forget the old state to avoid dropping it. + #[allow(clippy::mem_forget)] + std::mem::forget(prev_state); + + true + } + } + } + + pub fn downgrade(&self) -> WeakKnownPeerHandle { + WeakKnownPeerHandle(match &self.0 { + KnownPeerHandleState::Simple(data) => { + WeakKnownPeerHandleState::Simple(Arc::downgrade(data)) + } + KnownPeerHandleState::WithAffinity(data) => { + WeakKnownPeerHandleState::WithAffinity(Arc::downgrade(data)) + } + }) + } + + fn inner(&self) -> &KnownPeerInner { + match &self.0 { + KnownPeerHandleState::Simple(data) => data.as_ref(), + KnownPeerHandleState::WithAffinity(data) => data.inner.as_ref(), + } + } +} + +#[derive(Clone)] +enum KnownPeerHandleState { + Simple(ManuallyDrop>), + WithAffinity(ManuallyDrop>), +} + +impl Drop for KnownPeerHandleState { + fn drop(&mut self) { + let inner; + let is_banned; + match self { + KnownPeerHandleState::Simple(data) => { + // SAFETY: inner value is dropped only once + inner = unsafe { ManuallyDrop::take(data) }; + is_banned = inner.is_banned(); + } + KnownPeerHandleState::WithAffinity(data) => { + // SAFETY: inner value is dropped only once + match Arc::into_inner(unsafe { ManuallyDrop::take(data) }) { + Some(data) => { + inner = data.inner; + is_banned = !inner.decrease_affinity() || inner.is_banned(); + } + None => return, + } + } + }; + + if is_banned { + // Don't remove banned peers from the known peers cache + return; + } + + if let Some(inner) = Arc::into_inner(inner) { + // If the last reference is dropped, remove the peer from the known peers cache + if let Some(peers) = inner.weak_known_peers.upgrade() { + peers.remove(&inner.peer_info.load().id); + } + } + } +} + +#[derive(Clone, PartialEq, Eq)] +#[repr(transparent)] +pub struct WeakKnownPeerHandle(WeakKnownPeerHandleState); + +impl WeakKnownPeerHandle { + pub fn upgrade(&self) -> Option { + Some(KnownPeerHandle(match &self.0 { + WeakKnownPeerHandleState::Simple(weak) => { + KnownPeerHandleState::Simple(ManuallyDrop::new(weak.upgrade()?)) + } + WeakKnownPeerHandleState::WithAffinity(weak) => { + KnownPeerHandleState::WithAffinity(ManuallyDrop::new(weak.upgrade()?)) + } + })) + } +} + +#[derive(Clone)] +enum WeakKnownPeerHandleState { + Simple(Weak), + WithAffinity(Weak), +} + +impl Eq for WeakKnownPeerHandleState {} +impl PartialEq for WeakKnownPeerHandleState { + #[inline] + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Simple(left), Self::Simple(right)) => Weak::ptr_eq(left, right), + (Self::WithAffinity(left), Self::WithAffinity(right)) => Weak::ptr_eq(left, right), + _ => false, + } + } +} + +struct KnownPeerHandleWithAffinity { + inner: Arc, +} + +struct KnownPeerInner { + peer_info: ArcSwap, + affinity: AtomicUsize, + weak_known_peers: Weak>, +} + +impl KnownPeerInner { + fn new( + peer_info: Arc, + with_affinity: bool, + known_peers: &Arc>, + ) -> Arc { + Arc::new(Self { + peer_info: ArcSwap::from(peer_info), + affinity: AtomicUsize::new(if with_affinity { 1 } else { 0 }), + weak_known_peers: Arc::downgrade(known_peers), + }) + } + + fn is_banned(&self) -> bool { + self.affinity.load(Ordering::Acquire) == AFFINITY_BANNED + } + + fn compute_affinity(&self) -> PeerAffinity { + match self.affinity.load(Ordering::Acquire) { + 0 => PeerAffinity::Allowed, + AFFINITY_BANNED => PeerAffinity::Never, + _ => PeerAffinity::High, + } + } + + fn increase_affinity(&self) -> bool { + let mut current = self.affinity.load(Ordering::Acquire); + while current != AFFINITY_BANNED { + debug_assert_ne!(current, AFFINITY_BANNED - 1); + match self.affinity.compare_exchange_weak( + current, + current + 1, + Ordering::Release, + Ordering::Acquire, + ) { + Ok(_) => return true, + Err(affinity) => current = affinity, } } + + false + } + + fn decrease_affinity(&self) -> bool { + let mut current = self.affinity.load(Ordering::Acquire); + while current != AFFINITY_BANNED { + debug_assert_ne!(current, 0); + match self.affinity.compare_exchange_weak( + current, + current - 1, + Ordering::Release, + Ordering::Acquire, + ) { + Ok(_) => return true, + Err(affinity) => current = affinity, + } + } + + false } - pub fn remove(&self, peer_id: &PeerId) -> Option { - self.0.remove(peer_id).map(|(_, value)| value) + fn try_update_peer_info( + &self, + peer_info: &Arc, + with_affinity: bool, + ) -> Result { + struct AffinityGuard<'a> { + inner: &'a KnownPeerInner, + decrease_on_drop: bool, + } + + impl AffinityGuard<'_> { + fn increase_affinity_or_check_ban(&mut self, with_affinity: bool) -> bool { + let with_affinity = with_affinity && !self.decrease_on_drop; + let is_banned = if with_affinity { + !self.inner.increase_affinity() + } else { + self.inner.is_banned() + }; + + if !is_banned && with_affinity { + self.decrease_on_drop = true; + } + + is_banned + } + } + + impl Drop for AffinityGuard<'_> { + fn drop(&mut self) { + if self.decrease_on_drop { + self.inner.decrease_affinity(); + } + } + } + + // Create a guard to restore the peer affinity in case of an error + let mut guard = AffinityGuard { + inner: self, + decrease_on_drop: false, + }; + + let mut cur = self.peer_info.load(); + let updated = loop { + if guard.increase_affinity_or_check_ban(with_affinity) { + // Do nothing for banned peers + return Err(PeerBannedError); + } + + match cur.created_at.cmp(&peer_info.created_at) { + // Do nothing for the same creation time + // TODO: is `created_at` equality enough? + std::cmp::Ordering::Equal => break true, + // Try to update peer info + std::cmp::Ordering::Less => { + let prev = self.peer_info.compare_and_swap(&*cur, peer_info.clone()); + if std::ptr::eq(cur.as_raw(), prev.as_raw()) { + break true; + } else { + cur = prev; + } + } + // Allow an outdated data + std::cmp::Ordering::Greater => break false, + } + }; + + guard.decrease_on_drop = false; + Ok(updated) } } -#[derive(Debug, Clone)] -pub struct KnownPeer { - pub peer_info: Arc, - pub affinity: PeerAffinity, +const AFFINITY_BANNED: usize = usize::MAX; + +#[derive(Debug, thiserror::Error)] +pub enum KnownPeersError { + #[error(transparent)] + PeerBanned(#[from] PeerBannedError), + #[error("provided peer info is outdated")] + OutdatedInfo, +} + +#[derive(Debug, Copy, Clone, thiserror::Error)] +#[error("peer is banned")] +pub struct PeerBannedError; + +#[cfg(test)] +mod tests { + use super::*; + use crate::util::make_peer_info_stub; + + #[test] + fn remove_from_cache_on_drop_works() { + let peers = KnownPeers::new(); + + let peer_info = make_peer_info_stub(rand::random()); + let handle = peers.insert(peer_info.clone(), false).unwrap(); + assert!(peers.contains(&peer_info.id)); + assert!(!peers.is_banned(&peer_info.id)); + assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone())); + assert_eq!( + peers.get_affinity(&peer_info.id), + Some(PeerAffinity::Allowed) + ); + + assert_eq!(handle.peer_info().as_ref(), peer_info.as_ref()); + assert_eq!(handle.max_affinity(), PeerAffinity::Allowed); + + let other_handle = peers.insert(peer_info.clone(), false).unwrap(); + assert!(peers.contains(&peer_info.id)); + assert!(!peers.is_banned(&peer_info.id)); + assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone())); + assert_eq!( + peers.get_affinity(&peer_info.id), + Some(PeerAffinity::Allowed) + ); + + assert_eq!(other_handle.peer_info().as_ref(), peer_info.as_ref()); + assert_eq!(other_handle.max_affinity(), PeerAffinity::Allowed); + + drop(other_handle); + assert!(peers.contains(&peer_info.id)); + assert!(!peers.is_banned(&peer_info.id)); + assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone())); + assert_eq!( + peers.get_affinity(&peer_info.id), + Some(PeerAffinity::Allowed) + ); + + drop(handle); + assert!(!peers.contains(&peer_info.id)); + assert!(!peers.is_banned(&peer_info.id)); + assert_eq!(peers.get(&peer_info.id), None); + assert_eq!(peers.get_affinity(&peer_info.id), None); + + peers.insert(peer_info.clone(), false).unwrap(); + } + + #[test] + fn with_affinity_after_simple() { + let peers = KnownPeers::new(); + + let peer_info = make_peer_info_stub(rand::random()); + let handle_simple = peers.insert(peer_info.clone(), false).unwrap(); + assert!(peers.contains(&peer_info.id)); + assert_eq!( + peers.get_affinity(&peer_info.id), + Some(PeerAffinity::Allowed) + ); + assert_eq!(handle_simple.max_affinity(), PeerAffinity::Allowed); + + let handle_with_affinity = peers.insert(peer_info.clone(), true).unwrap(); + assert!(peers.contains(&peer_info.id)); + assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High)); + assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High); + assert_eq!(handle_simple.max_affinity(), PeerAffinity::High); + + drop(handle_with_affinity); + assert!(peers.contains(&peer_info.id)); + assert_eq!(handle_simple.max_affinity(), PeerAffinity::Allowed); + assert_eq!( + peers.get_affinity(&peer_info.id), + Some(PeerAffinity::Allowed) + ); + + drop(handle_simple); + assert!(!peers.contains(&peer_info.id)); + assert_eq!(peers.get_affinity(&peer_info.id), None); + } + + #[test] + fn with_affinity_before_simple() { + let peers = KnownPeers::new(); + + let peer_info = make_peer_info_stub(rand::random()); + let handle_with_affinity = peers.insert(peer_info.clone(), true).unwrap(); + assert!(peers.contains(&peer_info.id)); + assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High)); + assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High); + + let handle_simple = peers.insert(peer_info.clone(), false).unwrap(); + assert!(peers.contains(&peer_info.id)); + assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High)); + assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High); + assert_eq!(handle_simple.max_affinity(), PeerAffinity::High); + + drop(handle_simple); + assert!(peers.contains(&peer_info.id)); + assert_eq!(handle_with_affinity.max_affinity(), PeerAffinity::High); + assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::High)); + + drop(handle_with_affinity); + assert!(!peers.contains(&peer_info.id)); + assert_eq!(peers.get_affinity(&peer_info.id), None); + } + + #[test] + fn ban_while_handle_exists() { + let peers = KnownPeers::new(); + + let peer_info = make_peer_info_stub(rand::random()); + let handle = peers.insert(peer_info.clone(), false).unwrap(); + assert!(peers.contains(&peer_info.id)); + assert_eq!(handle.max_affinity(), PeerAffinity::Allowed); + + peers.ban(&peer_info.id); + assert!(peers.contains(&peer_info.id)); + assert!(peers.is_banned(&peer_info.id)); + assert_eq!(handle.max_affinity(), PeerAffinity::Never); + assert_eq!(peers.get(&peer_info.id), Some(peer_info.clone())); + assert_eq!(peers.get_affinity(&peer_info.id), Some(PeerAffinity::Never)); + } } diff --git a/network/src/network/crypto.rs b/network/src/network/crypto.rs index 602d1d62b..c4409a40d 100644 --- a/network/src/network/crypto.rs +++ b/network/src/network/crypto.rs @@ -50,10 +50,10 @@ pub(crate) struct CertVerifierWithPeerId { } impl CertVerifierWithPeerId { - pub fn new(service_name: String, peer_id: PeerId) -> Self { + pub fn new(service_name: String, peer_id: &PeerId) -> Self { Self { inner: CertVerifier::from(service_name), - peer_id, + peer_id: *peer_id, } } } diff --git a/network/src/network/endpoint.rs b/network/src/network/endpoint.rs index 64ef1f4cc..bee27e0c0 100644 --- a/network/src/network/endpoint.rs +++ b/network/src/network/endpoint.rs @@ -8,7 +8,7 @@ use std::time::Duration; use anyhow::Result; use crate::network::config::EndpointConfig; -use crate::network::connection::Connection; +use crate::network::connection::{parse_peer_identity, Connection}; use crate::types::{Address, Direction, PeerId}; pub(crate) struct Endpoint { @@ -72,16 +72,11 @@ impl Endpoint { } } - /// Connect to a remote endpoint using the endpoint configuration. - pub fn connect(&self, address: Address) -> Result { - self.connect_with_client_config(self.config.quinn_client_config.clone(), address) - } - /// Connect to a remote endpoint expecting it to have the provided peer id. pub fn connect_with_expected_id( &self, address: Address, - peer_id: PeerId, + peer_id: &PeerId, ) -> Result { let config = self.config.make_client_config_for_peer_id(peer_id)?; self.connect_with_client_config(config, address) @@ -152,6 +147,33 @@ impl Connecting { origin: Direction::Outbound, } } + + pub fn remote_address(&self) -> SocketAddr { + self.inner.remote_address() + } + + pub fn into_0rtt(self) -> Into0RttResult { + match self.inner.into_0rtt() { + Ok((c, accepted)) => match c.peer_identity() { + Some(identity) => match parse_peer_identity(identity) { + Ok(peer_id) => Into0RttResult::Established( + Connection::with_peer_id(c, self.origin, peer_id), + accepted, + ), + Err(e) => Into0RttResult::InvalidConnection(e), + }, + None => Into0RttResult::WithoutIdentity(ConnectingFallback { + inner: Some(c), + accepted, + origin: self.origin, + }), + }, + Err(inner) => Into0RttResult::Unavailable(Self { + inner, + origin: self.origin, + }), + } + } } impl Future for Connecting { @@ -165,3 +187,40 @@ impl Future for Connecting { }) } } + +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub(crate) struct ConnectingFallback { + inner: Option, + accepted: quinn::ZeroRttAccepted, + origin: Direction, +} + +impl Drop for ConnectingFallback { + fn drop(&mut self) { + if let Some(inner) = self.inner.take() { + inner.close(0u8.into(), b"cancelled"); + } + } +} + +impl Future for ConnectingFallback { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.accepted).poll(cx).map(|_| { + Connection::new( + self.inner + .take() + .expect("future must not be polled after completion"), + self.origin, + ) + }) + } +} + +pub(crate) enum Into0RttResult { + Established(Connection, quinn::ZeroRttAccepted), + WithoutIdentity(ConnectingFallback), + InvalidConnection(anyhow::Error), + Unavailable(#[allow(unused)] Connecting), +} diff --git a/network/src/network/mod.rs b/network/src/network/mod.rs index f0274cfee..eb89b5af5 100644 --- a/network/src/network/mod.rs +++ b/network/src/network/mod.rs @@ -3,7 +3,6 @@ use std::sync::{Arc, Weak}; use anyhow::Result; use everscale_crypto::ed25519; -use rand::Rng; use tokio::sync::{broadcast, mpsc, oneshot}; use self::config::EndpointConfig; @@ -15,7 +14,10 @@ use crate::types::{ pub use self::config::{NetworkConfig, QuicConfig}; pub use self::connection::{Connection, RecvStream, SendStream}; -pub use self::connection_manager::{ActivePeers, KnownPeer, KnownPeers, WeakActivePeers}; +pub use self::connection_manager::{ + ActivePeers, KnownPeerHandle, KnownPeers, KnownPeersError, PeerBannedError, WeakActivePeers, + WeakKnownPeerHandle, +}; pub use self::peer::Peer; mod config; @@ -64,7 +66,7 @@ impl NetworkBuilder<(T1, ())> { } pub fn with_random_private_key(self) -> NetworkBuilder<(T1, [u8; 32])> { - self.with_private_key(rand::thread_rng().gen()) + self.with_private_key(rand::random()) } } @@ -85,6 +87,7 @@ impl NetworkBuilder { let endpoint_config = EndpointConfig::builder() .with_service_name(service_name) .with_private_key(private_key) + .with_0rtt_enabled(config.enable_0rtt) .with_transport_config(quic_config.make_transport_config()) .build()?; @@ -198,18 +201,11 @@ impl Network { Ok(active_peers.subscribe()) } - pub async fn connect(&self, addr: T) -> Result + pub async fn connect(&self, addr: T, peer_id: &PeerId) -> Result where T: Into
, { - self.0.connect(addr.into(), None).await - } - - pub async fn connect_with_peer_id(&self, addr: T, peer_id: &PeerId) -> Result - where - T: Into
, - { - self.0.connect(addr.into(), Some(peer_id)).await + self.0.connect(addr.into(), peer_id).await } pub fn disconnect(&self, peer_id: &PeerId) -> Result<()> { @@ -259,17 +255,19 @@ impl NetworkInner { &self.known_peers } - async fn connect(&self, addr: Address, peer_id: Option<&PeerId>) -> Result { + async fn connect(&self, addr: Address, peer_id: &PeerId) -> Result { + #[derive(thiserror::Error, Debug)] + #[error(transparent)] + struct ConnectionError(Arc); + let (tx, rx) = oneshot::channel(); self.connection_manager_handle - .send(ConnectionManagerRequest::Connect( - addr, - peer_id.copied(), - tx, - )) + .send(ConnectionManagerRequest::Connect(addr, *peer_id, tx)) .await .map_err(|_e| NetworkShutdownError)?; - rx.await? + + let res = rx.await?; + res.map_err(|e| anyhow::Error::new(ConnectionError(e))) } fn disconnect(&self, peer_id: &PeerId) -> Result<()> { @@ -309,7 +307,8 @@ mod tests { use tracing_test::traced_test; use super::*; - use crate::types::{service_query_fn, BoxCloneService}; + use crate::types::{service_query_fn, BoxCloneService, PeerInfo, Request}; + use crate::util::NetworkExt; fn echo_service() -> BoxCloneService { let handle = |request: ServiceRequest| async move { @@ -323,32 +322,87 @@ mod tests { service_query_fn(handle).boxed_clone() } - #[tokio::test] - #[traced_test] - async fn connection_manager_works() -> anyhow::Result<()> { - let peer1 = Network::builder() + fn make_network(service_name: &str) -> Result { + Network::builder() + .with_config(NetworkConfig { + enable_0rtt: true, + ..Default::default() + }) .with_random_private_key() - .with_service_name("tycho") - .build("127.0.0.1:0", echo_service())?; + .with_service_name(service_name) + .build("127.0.0.1:0", echo_service()) + } - let peer2 = Network::builder() - .with_random_private_key() - .with_service_name("tycho") - .build("127.0.0.1:0", echo_service())?; + fn make_peer_info(network: &Network) -> Arc { + Arc::new(PeerInfo { + id: *network.peer_id(), + address_list: vec![network.local_addr().into()].into_boxed_slice(), + created_at: 0, + expires_at: u32::MAX, + signature: Box::new([0; 64]), + }) + } - let peer3 = Network::builder() - .with_random_private_key() - .with_service_name("not-tycho") - .build("127.0.0.1:0", echo_service())?; + #[traced_test] + #[tokio::test] + async fn connection_manager_works() -> Result<()> { + let peer1 = make_network("tycho")?; + let peer2 = make_network("tycho")?; + let peer3 = make_network("not-tycho")?; + + assert!(peer1 + .connect(peer2.local_addr(), peer2.peer_id()) + .await + .is_ok()); + assert!(peer2 + .connect(peer1.local_addr(), peer1.peer_id()) + .await + .is_ok()); + + assert!(peer1 + .connect(peer3.local_addr(), peer3.peer_id()) + .await + .is_err()); + assert!(peer2 + .connect(peer3.local_addr(), peer3.peer_id()) + .await + .is_err()); + + assert!(peer3 + .connect(peer1.local_addr(), peer1.peer_id()) + .await + .is_err()); + assert!(peer3 + .connect(peer2.local_addr(), peer2.peer_id()) + .await + .is_err()); + + Ok(()) + } + + #[traced_test] + #[tokio::test] + async fn simultaneous_queries() -> Result<()> { + tracing_subscriber::fmt::try_init().ok(); - assert!(peer1.connect(peer2.local_addr()).await.is_ok()); - assert!(peer2.connect(peer1.local_addr()).await.is_ok()); + for _ in 0..10 { + let peer1 = make_network("tycho")?; + let peer2 = make_network("tycho")?; - assert!(peer1.connect(peer3.local_addr()).await.is_err()); - assert!(peer2.connect(peer3.local_addr()).await.is_err()); + let _peer1_peer2_handle = peer1.known_peers().insert(make_peer_info(&peer2), false)?; + let _peer2_peer1_handle = peer2.known_peers().insert(make_peer_info(&peer1), false)?; - assert!(peer3.connect(peer1.local_addr()).await.is_err()); - assert!(peer3.connect(peer2.local_addr()).await.is_err()); + let req = Request { + version: Default::default(), + body: "hello".into(), + }; + let peer1_fut = std::pin::pin!(peer1.query(peer2.peer_id(), req.clone())); + let peer2_fut = std::pin::pin!(peer2.query(peer1.peer_id(), req.clone())); + + let (res1, res2) = futures_util::future::join(peer1_fut, peer2_fut).await; + assert_eq!(res1?.body, req.body); + assert_eq!(res2?.body, req.body); + } Ok(()) } diff --git a/network/src/network/wire.rs b/network/src/network/wire.rs index 3b08eb789..c1953e9fa 100644 --- a/network/src/network/wire.rs +++ b/network/src/network/wire.rs @@ -18,7 +18,7 @@ pub(crate) fn make_codec(config: &NetworkConfig) -> LengthDelimitedCodec { builder.length_field_length(4).big_endian().new_codec() } -pub(crate) async fn handshake(connection: Connection) -> Result { +pub(crate) async fn handshake(connection: &Connection) -> Result<()> { match connection.origin() { Direction::Inbound => { let mut send_stream = connection.open_uni().await?; @@ -30,7 +30,7 @@ pub(crate) async fn handshake(connection: Connection) -> Result { recv_version(&mut recv_stream).await?; } } - Ok(connection) + Ok(()) } pub(crate) async fn send_request( diff --git a/network/src/overlay/config.rs b/network/src/overlay/config.rs index ad5f7b58a..f1094ee32 100644 --- a/network/src/overlay/config.rs +++ b/network/src/overlay/config.rs @@ -6,12 +6,6 @@ use tycho_util::serde_helpers; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(default)] pub struct OverlayConfig { - /// Maximum time to live for public overlay peer entries. - /// - /// Default: 1 hour. - #[serde(with = "serde_helpers::humantime")] - pub max_public_entry_tll: Duration, - /// A period of exchanging public overlay peers. /// /// Default: 3 minutes. @@ -32,7 +26,6 @@ pub struct OverlayConfig { impl Default for OverlayConfig { fn default() -> Self { Self { - max_public_entry_tll: Duration::from_secs(3600), public_overlay_peer_exchange_period: Duration::from_secs(3 * 60), public_overlay_peer_exchange_max_jitter: Duration::from_secs(30), exchange_public_entries_batch: 20, diff --git a/network/src/overlay/mod.rs b/network/src/overlay/mod.rs index 03d09d9f5..78f0ccb5d 100644 --- a/network/src/overlay/mod.rs +++ b/network/src/overlay/mod.rs @@ -1,3 +1,5 @@ +use std::collections::hash_map; +use std::future::Future; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll, Waker}; @@ -7,11 +9,12 @@ use bytes::Buf; use futures_util::{Stream, StreamExt}; use tl_proto::{TlError, TlRead}; use tokio::sync::Notify; -use tokio::task::JoinHandle; +use tokio::task::{AbortHandle, JoinSet}; use tycho_util::futures::BoxFutureOrNoop; use tycho_util::time::{now_sec, shifted_interval}; use tycho_util::{FastDashMap, FastHashMap, FastHashSet}; +use crate::dht::DhtService; use crate::network::{Network, WeakNetwork}; use crate::proto::overlay::{rpc, PublicEntriesResponse, PublicEntry, PublicEntryToSign}; use crate::types::{PeerId, Request, Response, Service, ServiceRequest}; @@ -20,8 +23,8 @@ use crate::util::{NetworkExt, Routable}; pub use self::config::OverlayConfig; pub use self::overlay_id::OverlayId; pub use self::private_overlay::{ - PrivateOverlay, PrivateOverlayBuilder, PrivateOverlayEntries, PrivateOverlayEntriesReadGuard, - PrivateOverlayEntriesWriteGuard, + PrivateOverlay, PrivateOverlayBuilder, PrivateOverlayEntries, PrivateOverlayEntriesEvent, + PrivateOverlayEntriesReadGuard, PrivateOverlayEntriesWriteGuard, }; pub use self::public_overlay::{ PublicOverlay, PublicOverlayBuilder, PublicOverlayEntries, PublicOverlayEntriesReadGuard, @@ -34,20 +37,20 @@ mod public_overlay; pub struct OverlayServiceBackgroundTasks { inner: Arc, + dht: Option, } impl OverlayServiceBackgroundTasks { - pub fn spawn(self, network: Network) { + pub fn spawn(self, network: &Network) { self.inner - .start_background_tasks(Network::downgrade(&network)); + .start_background_tasks(Network::downgrade(network), self.dht); } } pub struct OverlayServiceBuilder { local_id: PeerId, config: Option, - private_overlays: FastDashMap, - public_overlays: FastDashMap, + dht: Option, } impl OverlayServiceBuilder { @@ -56,41 +59,8 @@ impl OverlayServiceBuilder { self } - pub fn with_private_overlay(self, overlay: &PrivateOverlay) -> Self { - assert!( - !self.public_overlays.contains_key(overlay.overlay_id()), - "public overlay with id {} already exists", - overlay.overlay_id() - ); - - let prev = self - .private_overlays - .insert(*overlay.overlay_id(), overlay.clone()); - if let Some(prev) = prev { - panic!( - "private overlay with id {} already exists", - prev.overlay_id() - ); - } - self - } - - pub fn with_public_overlay(self, overlay: &PublicOverlay) -> Self { - assert!( - !self.private_overlays.contains_key(overlay.overlay_id()), - "private overlay with id {} already exists", - overlay.overlay_id() - ); - - let prev = self - .public_overlays - .insert(*overlay.overlay_id(), overlay.clone()); - if let Some(prev) = prev { - panic!( - "public overlay with id {} already exists", - prev.overlay_id() - ); - } + pub fn with_dht_service(mut self, dht: DhtService) -> Self { + self.dht = Some(dht); self } @@ -100,13 +70,15 @@ impl OverlayServiceBuilder { let inner = Arc::new(OverlayServiceInner { local_id: self.local_id, config, - private_overlays: self.private_overlays, - public_overlays: self.public_overlays, + private_overlays: Default::default(), + public_overlays: Default::default(), public_overlays_changed: Arc::new(Notify::new()), + private_overlays_changed: Arc::new(Notify::new()), }); let background_tasks = OverlayServiceBackgroundTasks { inner: inner.clone(), + dht: self.dht, }; (background_tasks, OverlayService(inner)) @@ -121,17 +93,24 @@ impl OverlayService { OverlayServiceBuilder { local_id, config: None, - private_overlays: Default::default(), - public_overlays: Default::default(), + dht: None, } } - pub fn try_add_private_overlay(&self, overlay: &PrivateOverlay) -> bool { - self.0.try_add_private_overlay(overlay) + pub fn add_private_overlay(&self, overlay: &PrivateOverlay) -> bool { + self.0.add_private_overlay(overlay) } - pub fn try_add_public_overlay(&self, overlay: &PublicOverlay) -> bool { - self.0.try_add_public_overlay(overlay) + pub fn remove_private_overlay(&self, overlay_id: &OverlayId) -> bool { + self.0.remove_private_overlay(overlay_id) + } + + pub fn add_public_overlay(&self, overlay: &PublicOverlay) -> bool { + self.0.add_public_overlay(overlay) + } + + pub fn remove_public_overlay(&self, overlay_id: &OverlayId) -> bool { + self.0.remove_public_overlay(overlay_id) } } @@ -273,24 +252,23 @@ struct OverlayServiceInner { public_overlays: FastDashMap, private_overlays: FastDashMap, public_overlays_changed: Arc, + private_overlays_changed: Arc, } impl OverlayServiceInner { - fn start_background_tasks(self: &Arc, network: WeakNetwork) { + fn start_background_tasks(self: &Arc, network: WeakNetwork, _dht: Option) { + // TODO: Store public overlay entries in the DHT. + enum Action<'a> { - UpdatePublicOverlaysList { - exchange_state: &'a mut ExchangeState, - }, - ExchangePublicEntries { - exchange_state: &'a mut ExchangeState, + UpdatePublicOverlaysList(&'a mut PublicOverlaysState), + ExchangePublicOverlayEntries { overlay_id: OverlayId, + exchange: &'a mut OverlayTaskSet, }, } - #[derive(Default)] - struct ExchangeState { - stream: PublicOverlayActionsStream, - futures: FastHashMap>>, + struct PublicOverlaysState { + exchange: OverlayTaskSet, } let public_overlays_notify = self.public_overlays_changed.clone(); @@ -299,31 +277,32 @@ impl OverlayServiceInner { tokio::spawn(async move { tracing::debug!("background overlay loop started"); - let mut exchange_state = None::; let mut public_overlays_changed = Box::pin(public_overlays_notify.notified()); + let mut public_overlays_state = None::; + loop { - let action = match &mut exchange_state { - // Initial update - None => Action::UpdatePublicOverlaysList { - exchange_state: exchange_state.get_or_insert_with(Default::default), - }, + let action = match &mut public_overlays_state { + // Initial update for public overlays list + None => Action::UpdatePublicOverlaysList(public_overlays_state.insert( + PublicOverlaysState { + exchange: OverlayTaskSet::new("exchange public overlay peers"), + }, + )), // Default actions - Some(exchange_state) => { + Some(public_overlays_state) => { tokio::select! { _ = &mut public_overlays_changed => { public_overlays_changed = Box::pin(public_overlays_notify.notified()); - Action::UpdatePublicOverlaysList { - exchange_state - } + Action::UpdatePublicOverlaysList(public_overlays_state) }, - overlay_id = exchange_state.stream.next() => match overlay_id { - Some(id) => Action::ExchangePublicEntries { - exchange_state, - overlay_id: id + overlay_id = public_overlays_state.exchange.next() => match overlay_id { + Some(id) => Action::ExchangePublicOverlayEntries { + overlay_id: id, + exchange: &mut public_overlays_state.exchange, }, None => continue, - } + }, } } }; @@ -333,40 +312,22 @@ impl OverlayServiceInner { }; match action { - Action::UpdatePublicOverlaysList { exchange_state } => exchange_state.stream.rebuild( - this.public_overlays.iter().map(|item| *item.key()), - |_| { + Action::UpdatePublicOverlaysList(PublicOverlaysState { exchange }) => { + let iter = this.public_overlays.iter().map(|item| *item.key()); + exchange.rebuild(iter.clone(), |_| { shifted_interval( this.config.public_overlay_peer_exchange_period, this.config.public_overlay_peer_exchange_max_jitter, ) - }, - |overlay_id| { - if let Some(fut) = exchange_state.futures.remove(overlay_id).flatten() { - tracing::debug!(%overlay_id, "cancelling exchange public entries task"); - fut.abort(); - } - }, - ), - Action::ExchangePublicEntries { exchange_state, overlay_id } => { - let fut_entry = exchange_state.futures.entry(overlay_id).or_default(); - - // Wait for the previous exchange to finish. - if let Some(fut) = fut_entry.take() { - if let Err(e) = fut.await { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } - } - } - - // Spawn a new exchange - *fut_entry = Some(tokio::spawn(async move { - let res = this.exchange_public_entries(&network, &overlay_id).await; - if let Err(e) = res { - tracing::error!(%overlay_id, "failed to exchange public entries: {e:?}"); - }; - })); + }); + } + Action::ExchangePublicOverlayEntries { + exchange: exchange_state, + overlay_id, + } => { + exchange_state.spawn(&overlay_id, move || async move { + this.exchange_public_entries(&network, &overlay_id).await + }); } } } @@ -385,10 +346,15 @@ impl OverlayServiceInner { network: &Network, overlay_id: &OverlayId, ) -> Result<()> { - let Some(overlay) = self.public_overlays.get(overlay_id) else { - anyhow::bail!("overlay not found"); + let overlay = if let Some(overlay) = self.public_overlays.get(overlay_id) { + overlay.value().clone() + } else { + tracing::debug!(%overlay_id, "overlay not found"); + return Ok(()); }; + overlay.remove_invalid_entries(now_sec()); + let n = std::cmp::max(self.config.exchange_public_entries_batch, 1); let mut entries = Vec::with_capacity(n); @@ -408,12 +374,12 @@ impl OverlayServiceInner { // TODO: search for target in known peers. This is a stub which will not work. let peer_id = match iter.next() { - Some(entry) => entry.peer_id, + Some(item) => item.entry.peer_id, None => anyhow::bail!("empty overlay, no peers to exchange entries with"), }; // Add additional random entries to the response - entries.extend(iter.cloned()); + entries.extend(iter.map(|item| item.entry.clone())); // Use this peer id for the request peer_id @@ -439,7 +405,7 @@ impl OverlayServiceInner { count = entries.len(), "received public entries" ); - overlay.add_untrusted_entries(&entries); + overlay.add_untrusted_entries(&entries, now_sec()); } PublicEntriesResponse::OverlayNotFound => { tracing::debug!(%peer_id, "overlay not found"); @@ -468,7 +434,7 @@ impl OverlayServiceInner { } } - pub fn try_add_private_overlay(&self, overlay: &PrivateOverlay) -> bool { + fn add_private_overlay(&self, overlay: &PrivateOverlay) -> bool { use dashmap::mapref::entry::Entry; if self.public_overlays.contains_key(overlay.overlay_id()) { @@ -477,13 +443,22 @@ impl OverlayServiceInner { match self.private_overlays.entry(*overlay.overlay_id()) { Entry::Vacant(entry) => { entry.insert(overlay.clone()); + self.private_overlays_changed.notify_waiters(); true } Entry::Occupied(_) => false, } } - pub fn try_add_public_overlay(&self, overlay: &PublicOverlay) -> bool { + fn remove_private_overlay(&self, overlay_id: &OverlayId) -> bool { + let removed = self.private_overlays.remove(overlay_id).is_some(); + if removed { + self.private_overlays_changed.notify_waiters(); + } + removed + } + + fn add_public_overlay(&self, overlay: &PublicOverlay) -> bool { use dashmap::mapref::entry::Entry; if self.private_overlays.contains_key(overlay.overlay_id()) { @@ -499,6 +474,14 @@ impl OverlayServiceInner { } } + fn remove_public_overlay(&self, overlay_id: &OverlayId) -> bool { + let removed = self.public_overlays.remove(overlay_id).is_some(); + if removed { + self.public_overlays_changed.notify_waiters(); + } + removed + } + fn handle_exchange_public_entries( &self, req: &rpc::ExchangeRandomPublicEntries, @@ -513,7 +496,7 @@ impl OverlayServiceInner { }; // Add proposed entries to the overlay - overlay.add_untrusted_entries(&req.entries); + overlay.add_untrusted_entries(&req.entries, now_sec()); // Collect proposed entries to exclude from the response let requested_ids = req @@ -529,9 +512,9 @@ impl OverlayServiceInner { let n = self.config.exchange_public_entries_batch; entries .choose_multiple(&mut rand::thread_rng(), n + requested_ids.len()) - .filter_map(|entry| { - let is_new = !requested_ids.contains(&entry.peer_id); - is_new.then(|| entry.clone()) + .filter_map(|item| { + let is_new = !requested_ids.contains(&item.entry.peer_id); + is_new.then(|| item.entry.clone()) }) .take(n) .collect::>() @@ -541,13 +524,123 @@ impl OverlayServiceInner { } } +struct OverlayTaskSet { + name: &'static str, + stream: OverlayActionsStream, + handles: FastHashMap, + join_set: JoinSet, +} + +impl OverlayTaskSet { + fn new(name: &'static str) -> Self { + Self { + name, + stream: Default::default(), + handles: Default::default(), + join_set: Default::default(), + } + } + + async fn next(&mut self) -> Option { + use futures_util::future::{select, Either}; + + loop { + // Wait until the next interval or completed task + let res = { + let next = std::pin::pin!(self.stream.next()); + let joined = std::pin::pin!(self.join_set.join_next()); + match select(next, joined).await { + // Handle interval events first + Either::Left((id, _)) => return id, + // Handled task completion otherwise + Either::Right((joined, fut)) => match joined { + Some(res) => res, + None => return fut.await, + }, + } + }; + + // If some task was joined + match res { + // Task was completed successfully + Ok(overlay_id) => { + return if matches!(self.handles.remove(&overlay_id), Some((_, true))) { + // Reset interval and execute task immediately + self.stream.reset_interval(&overlay_id); + Some(overlay_id) + } else { + None + }; + } + // Propagate task panic + Err(e) if e.is_panic() => { + tracing::error!(task = self.name, "task panicked"); + std::panic::resume_unwind(e.into_panic()); + } + // Task cancelled, loop once more with the next task + Err(_) => continue, + } + } + } + + fn rebuild(&mut self, iter: I, f: F) + where + I: Iterator, + for<'a> F: FnMut(&'a OverlayId) -> tokio::time::Interval, + { + self.stream.rebuild(iter, f, |overlay_id| { + if let Some((handle, _)) = self.handles.remove(overlay_id) { + tracing::debug!(task = self.name, %overlay_id, "task cancelled"); + handle.abort(); + } + }); + } + + fn spawn(&mut self, overlay_id: &OverlayId, f: F) + where + F: FnOnce() -> Fut, + Fut: Future> + Send + 'static, + { + match self.handles.entry(*overlay_id) { + hash_map::Entry::Vacant(entry) => { + let fut = { + let fut = f(); + let task = self.name; + let overlay_id = *overlay_id; + async move { + if let Err(e) = fut.await { + tracing::error!(task, %overlay_id, "task failed: {e:?}"); + } + overlay_id + } + }; + entry.insert((self.join_set.spawn(fut), false)); + } + hash_map::Entry::Occupied(mut entry) => { + tracing::warn!( + task = self.name, + %overlay_id, + "task is running longer than expected", + ); + entry.get_mut().1 = true; + } + } + } +} + #[derive(Default)] -struct PublicOverlayActionsStream { +struct OverlayActionsStream { intervals: Vec<(tokio::time::Interval, OverlayId)>, waker: Option, } -impl PublicOverlayActionsStream { +impl OverlayActionsStream { + fn reset_interval(&mut self, overlay_id: &OverlayId) { + if let Some((interval, _)) = self.intervals.iter_mut().find(|(_, id)| id == overlay_id) { + interval.reset(); + } + } + fn rebuild, A, R>( &mut self, iter: I, @@ -576,7 +669,7 @@ impl PublicOverlayActionsStream { } } -impl Stream for PublicOverlayActionsStream { +impl Stream for OverlayActionsStream { type Item = OverlayId; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { diff --git a/network/src/overlay/private_overlay.rs b/network/src/overlay/private_overlay.rs index 25d4040e8..0b3cc84b0 100644 --- a/network/src/overlay/private_overlay.rs +++ b/network/src/overlay/private_overlay.rs @@ -7,9 +7,11 @@ use bytes::{Bytes, BytesMut}; use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use rand::seq::SliceRandom; use rand::Rng; +use tokio::sync::broadcast; use tycho_util::futures::BoxFutureOrNoop; use tycho_util::{FastHashMap, FastHashSet}; +use crate::dht::{PeerResolver, PeerResolverHandle}; use crate::network::Network; use crate::overlay::OverlayId; use crate::proto::overlay::rpc; @@ -19,6 +21,8 @@ use crate::util::NetworkExt; pub struct PrivateOverlayBuilder { overlay_id: OverlayId, entries: FastHashSet, + entry_events_channel_size: usize, + peer_resolver: Option, } impl PrivateOverlayBuilder { @@ -32,6 +36,22 @@ impl PrivateOverlayBuilder { self } + /// The capacity of entries set events. + /// + /// Default: 100. + pub fn with_entry_events_channel_size(mut self, entry_events_channel_size: usize) -> Self { + self.entry_events_channel_size = entry_events_channel_size; + self + } + + /// Whether to resolve peers with the provided resolver. + /// + /// Does not resolve peers by default. + pub fn with_peer_resolver(mut self, peer_resolver: PeerResolver) -> Self { + self.peer_resolver = Some(peer_resolver); + self + } + pub fn build(self, service: S) -> PrivateOverlay where S: Send + Sync + 'static, @@ -44,6 +64,8 @@ impl PrivateOverlayBuilder { let mut entries = PrivateOverlayEntries { peer_id_to_index: Default::default(), data: Default::default(), + events_tx: broadcast::channel(self.entry_events_channel_size).0, + peer_resolver: self.peer_resolver, }; for peer_id in self.entries { entries.insert(&peer_id); @@ -70,6 +92,8 @@ impl PrivateOverlay { PrivateOverlayBuilder { overlay_id, entries: Default::default(), + entry_events_channel_size: 100, + peer_resolver: None, } } @@ -150,14 +174,30 @@ struct Inner { request_prefix: Box<[u8]>, } +// NOTE: `#[derive(Default)]` is missing to prevent construction outside the +// crate. pub struct PrivateOverlayEntries { peer_id_to_index: FastHashMap, - data: Vec, + data: Vec, + events_tx: broadcast::Sender, + peer_resolver: Option, } impl PrivateOverlayEntries { + /// Subscribes to the set updates. + pub fn subscribe(&self) -> broadcast::Receiver { + self.events_tx.subscribe() + } + + /// Returns an iterator over the entry ids. + /// + /// The order is not random, but is not defined. + pub fn iter(&self) -> std::slice::Iter<'_, PrivateOverlayEntryData> { + self.data.iter() + } + /// Returns one random peer, or `None` if set is empty. - pub fn choose(&self, rng: &mut R) -> Option<&PeerId> + pub fn choose(&self, rng: &mut R) -> Option<&PrivateOverlayEntryData> where R: Rng + ?Sized, { @@ -170,18 +210,41 @@ impl PrivateOverlayEntries { &self, rng: &mut R, n: usize, - ) -> rand::seq::SliceChooseIter<'_, [PeerId], PeerId> + ) -> rand::seq::SliceChooseIter<'_, [PrivateOverlayEntryData], PrivateOverlayEntryData> where R: Rng + ?Sized, { self.data.choose_multiple(rng, n) } + /// Clears the set, removing all entries. + pub fn clear(&mut self) { + self.peer_id_to_index.clear(); + self.data.clear(); + } + + /// Returns `true` if the set contains no elements. + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Returns the number of elements in the set, also referred to as its 'length'. + pub fn len(&self) -> usize { + self.data.len() + } + /// Returns true if the set contains the specified peer id. pub fn contains(&self, peer_id: &PeerId) -> bool { self.peer_id_to_index.contains_key(peer_id) } + /// Returns the peer resolver handle for the specified peer id, if it exists. + pub fn get_handle(&self, peer_id: &PeerId) -> Option<&PeerResolverHandle> { + self.peer_id_to_index + .get(peer_id) + .map(|&index| &self.data[index].resolver_handle) + } + /// Adds a peer id to the set. /// /// Returns whether the value was newly inserted. @@ -190,7 +253,20 @@ impl PrivateOverlayEntries { // No entry for the peer_id, insert a new one hash_map::Entry::Vacant(entry) => { entry.insert(self.data.len()); - self.data.push(*peer_id); + + let handle = self.peer_resolver.as_ref().map_or_else( + || PeerResolverHandle::new_noop(peer_id), + |resolver| resolver.insert(peer_id, true), + ); + + self.data.push(PrivateOverlayEntryData { + peer_id: *peer_id, + resolver_handle: handle, + }); + + _ = self + .events_tx + .send(PrivateOverlayEntriesEvent::Added(*peer_id)); true } // Entry for the peer_id exists, do nothing @@ -202,22 +278,36 @@ impl PrivateOverlayEntries { /// /// Returns whether the value was present in the set. pub fn remove(&mut self, peer_id: &PeerId) -> bool { - let Some(index) = self.peer_id_to_index.remove(peer_id) else { + let Some(link) = self.peer_id_to_index.remove(peer_id) else { return false; }; // Remove the entry from the data vector - self.data.swap_remove(index); + self.data.swap_remove(link); + self.fix_data_index(link); - // Update the swapped entry's index - let entry = self - .peer_id_to_index - .get_mut(&self.data[index]) - .expect("inconsistent state"); - *entry = index; + _ = self + .events_tx + .send(PrivateOverlayEntriesEvent::Removed(*peer_id)); true } + + fn fix_data_index(&mut self, index: usize) { + if index < self.data.len() { + let link = self + .peer_id_to_index + .get_mut(&self.data[index].peer_id) + .expect("inconsistent data state"); + *link = index; + } + } +} + +#[derive(Clone)] +pub struct PrivateOverlayEntryData { + pub peer_id: PeerId, + pub resolver_handle: PeerResolverHandle, } pub struct PrivateOverlayEntriesWriteGuard<'a> { @@ -252,3 +342,82 @@ impl std::ops::Deref for PrivateOverlayEntriesReadGuard<'_> { &self.entries } } + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PrivateOverlayEntriesEvent { + /// A new entry was inserted. + Added(PeerId), + /// An existing entry was removed. + Removed(PeerId), +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn entries_container_is_set() { + let mut entries = PrivateOverlayEntries { + peer_id_to_index: Default::default(), + data: Default::default(), + peer_resolver: None, + events_tx: broadcast::channel(100).0, + }; + assert!(entries.is_empty()); + assert_eq!(entries.len(), 0); + + let peer_id = rand::random(); + assert!(entries.insert(&peer_id)); + + assert!(!entries.is_empty()); + assert_eq!(entries.len(), 1); + + assert!(!entries.insert(&peer_id)); + assert_eq!(entries.len(), 1); + + entries.clear(); + assert!(entries.is_empty()); + assert_eq!(entries.len(), 0); + } + + #[test] + fn remove_from_entries_container() { + let (events_tx, mut events_rx) = broadcast::channel(100); + + let mut entries = PrivateOverlayEntries { + peer_id_to_index: Default::default(), + data: Default::default(), + peer_resolver: None, + events_tx, + }; + + let peer_ids = std::array::from_fn::(|_| rand::random()); + for (i, peer_id) in peer_ids.iter().enumerate() { + assert!(entries.insert(peer_id)); + assert_eq!(entries.len(), i + 1); + assert_eq!(entries.data.len(), i + 1); + assert_eq!( + events_rx.try_recv().unwrap(), + PrivateOverlayEntriesEvent::Added(*peer_id) + ); + } + + for peer_id in &peer_ids { + assert!(entries.remove(peer_id)); + assert_eq!( + events_rx.try_recv().unwrap(), + PrivateOverlayEntriesEvent::Removed(*peer_id) + ); + + assert!(entries.data.iter().all(|entry| entry.peer_id != peer_id)); + for (index, entry) in entries.data.iter().enumerate() { + assert_eq!(entries.peer_id_to_index[&entry.peer_id], index); + } + } + + assert!(entries.is_empty()); + + assert!(!entries.remove(&rand::random())); + assert!(events_rx.try_recv().is_err()); + } +} diff --git a/network/src/overlay/public_overlay.rs b/network/src/overlay/public_overlay.rs index 0ead8edd7..5a7de4221 100644 --- a/network/src/overlay/public_overlay.rs +++ b/network/src/overlay/public_overlay.rs @@ -2,6 +2,7 @@ use std::borrow::Borrow; use std::collections::hash_map; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::time::Duration; use anyhow::Result; use bytes::{Bytes, BytesMut}; @@ -11,6 +12,7 @@ use rand::Rng; use tycho_util::futures::BoxFutureOrNoop; use tycho_util::{FastDashSet, FastHashMap}; +use crate::dht::{PeerResolver, PeerResolverHandle}; use crate::network::Network; use crate::overlay::OverlayId; use crate::proto::overlay::{rpc, PublicEntry, PublicEntryToSign}; @@ -20,7 +22,9 @@ use crate::util::NetworkExt; pub struct PublicOverlayBuilder { overlay_id: OverlayId, min_capacity: usize, + entry_ttl: Duration, banned_peer_ids: FastDashSet, + peer_resolver: Option, } impl PublicOverlayBuilder { @@ -34,6 +38,15 @@ impl PublicOverlayBuilder { self } + /// Time-to-live for each entry in the overlay. + /// + /// Default: 1 hour. + pub fn with_entry_ttl(mut self, entry_ttl: Duration) -> Self { + self.entry_ttl = entry_ttl; + self + } + + /// Banned peers that will not be ignored by the overlay. pub fn with_banned_peers(mut self, banned_peers: I) -> Self where I: IntoIterator, @@ -44,6 +57,14 @@ impl PublicOverlayBuilder { self } + /// Whether to resolve peers with the provided resolver. + /// + /// Does not resolve peers by default. + pub fn with_peer_resolver(mut self, peer_resolver: PeerResolver) -> Self { + self.peer_resolver = Some(peer_resolver); + self + } + pub fn build(self, service: S) -> PublicOverlay where S: Send + Sync + 'static, @@ -53,10 +74,13 @@ impl PublicOverlayBuilder { overlay_id: self.overlay_id.as_bytes(), }); + let entry_ttl_sec = self.entry_ttl.as_secs().try_into().unwrap_or(u32::MAX); + PublicOverlay { inner: Arc::new(Inner { overlay_id: self.overlay_id, min_capacity: self.min_capacity, + entry_ttl_sec, entries: RwLock::new(Default::default()), entry_count: AtomicUsize::new(0), banned_peer_ids: self.banned_peer_ids, @@ -78,7 +102,9 @@ impl PublicOverlay { PublicOverlayBuilder { overlay_id, min_capacity: 100, + entry_ttl: Duration::from_secs(3600), banned_peer_ids: Default::default(), + peer_resolver: None, } } @@ -148,7 +174,7 @@ impl PublicOverlay { /// Adds the given entries to the overlay. /// /// NOTE: Will deadlock if called while `PublicOverlayEntriesReadGuard` is held. - pub(crate) fn add_untrusted_entries(&self, entries: &[Arc]) { + pub(crate) fn add_untrusted_entries(&self, entries: &[Arc], now: u32) { if entries.is_empty() { return; } @@ -179,11 +205,19 @@ impl PublicOverlay { // Prepare validation state let mut is_valid = vec![false; entries.len()]; - let mut valid_count = 0; + let mut has_valid = false; // First pass: verify all entries for (entry, is_valid) in std::iter::zip(entries, is_valid.iter_mut()) { + if entry.is_expired(now, this.entry_ttl_sec) + || self.inner.banned_peer_ids.contains(&entry.peer_id) + { + // Skip expired or banned peers early + continue; + } + let Some(pubkey) = entry.peer_id.as_public_key() else { + // Skip entries with invalid public keys continue; }; @@ -195,15 +229,14 @@ impl PublicOverlay { }, &entry.signature, ) { + // Skip entries with invalid signatures continue; } + // NOTE: check all entries, even if we have more than `to_add`. + // We might need them if some are duplicates af known entries. *is_valid = true; - valid_count += 1; - - if valid_count >= to_add { - break; - } + has_valid = true; } // Second pass: insert all valid entries (if any) @@ -211,22 +244,39 @@ impl PublicOverlay { // NOTE: two passes are necessary because public key parsing and // signature verification can be expensive and we want to avoid // holding the lock for too long. - if valid_count > 0 { + let mut added = 0; + if has_valid { let mut stored = this.entries.write(); for (entry, is_valid) in std::iter::zip(entries, is_valid) { - if is_valid { - stored.insert(entry); + if !is_valid { + continue; + } + + added += stored.insert(entry) as usize; + if added >= to_add { + break; } } } // Rollback entries that were not valid and not inserted - if valid_count < to_add { + if added < to_add { this.entry_count - .fetch_sub(to_add - valid_count, Ordering::Release); + .fetch_sub(to_add - added, Ordering::Release); } } + /// Removes all expired and banned entries from the overlay. + pub(crate) fn remove_invalid_entries(&self, now: u32) { + let this = self.inner.as_ref(); + + let mut entries = this.entries.write(); + entries.retain(|item| { + !item.entry.is_expired(now, this.entry_ttl_sec) + && !this.banned_peer_ids.contains(&item.entry.peer_id) + }); + } + fn prepend_prefix_to_body(&self, body: &mut Bytes) { let this = self.inner.as_ref(); @@ -249,6 +299,7 @@ impl std::fmt::Debug for PublicOverlay { struct Inner { overlay_id: OverlayId, min_capacity: usize, + entry_ttl_sec: u32, entries: RwLock, entry_count: AtomicUsize, banned_peer_ids: FastDashSet, @@ -259,13 +310,14 @@ struct Inner { #[derive(Default)] pub struct PublicOverlayEntries { peer_id_to_index: FastHashMap, - data: Vec>, + data: Vec, + peer_resolver: Option, } impl PublicOverlayEntries { /// Returns a reference to one random element of the slice, /// or `None` if the slice is empty. - pub fn choose(&self, rng: &mut R) -> Option<&Arc> + pub fn choose(&self, rng: &mut R) -> Option<&PublicOverlayEntryData> where R: Rng + ?Sized, { @@ -278,7 +330,7 @@ impl PublicOverlayEntries { &self, rng: &mut R, n: usize, - ) -> rand::seq::SliceChooseIter<'_, [Arc], Arc> + ) -> rand::seq::SliceChooseIter<'_, [PublicOverlayEntryData], PublicOverlayEntryData> where R: Rng + ?Sized, { @@ -290,46 +342,57 @@ impl PublicOverlayEntries { // No entry for the peer_id, insert a new one hash_map::Entry::Vacant(entry) => { entry.insert(self.data.len()); - self.data.push(Arc::new(item.clone())); + + let resolver_handle = self.peer_resolver.as_ref().map_or_else( + || PeerResolverHandle::new_noop(&item.peer_id), + |resolver| resolver.insert(&item.peer_id, false), + ); + + self.data.push(PublicOverlayEntryData { + entry: Arc::new(item.clone()), + resolver_handle, + }); + true } // Entry for the peer_id exists, update it if the new item is newer hash_map::Entry::Occupied(entry) => { let index = *entry.get(); let existing = &mut self.data[index]; - if existing.created_at >= item.created_at { + if existing.entry.created_at >= item.created_at { return false; } // Try to reuse the existing Arc if possible - match Arc::get_mut(existing) { + match Arc::get_mut(&mut existing.entry) { Some(existing) => existing.clone_from(item), - None => self.data[index] = Arc::new(item.clone()), + None => self.data[index].entry = Arc::new(item.clone()), } true } } } - fn remove(&mut self, peer_id: &PeerId) -> bool { - let Some(index) = self.peer_id_to_index.remove(peer_id) else { - return false; - }; - - // Remove the entry from the data vector - self.data.swap_remove(index); - - // Update the swapped entry's index - let entry = self - .peer_id_to_index - .get_mut(&self.data[index].peer_id) - .expect("inconsistent state"); - *entry = index; - - true + fn retain(&mut self, mut f: F) + where + F: FnMut(&PublicOverlayEntryData) -> bool, + { + self.data.retain(|item| { + let keep = f(item); + if !keep { + self.peer_id_to_index.remove(&item.entry.peer_id); + } + keep + }); } } +#[derive(Clone)] +pub struct PublicOverlayEntryData { + pub entry: Arc, + pub resolver_handle: PeerResolverHandle, +} + pub struct PublicOverlayEntriesReadGuard<'a> { entries: RwLockReadGuard<'a, PublicOverlayEntries>, } @@ -413,10 +476,10 @@ mod tests { let overlay = make_overlay_with_min_capacity(10); let entries = generate_public_entries(&overlay, now, 10); - overlay.add_untrusted_entries(&entries[..5]); + overlay.add_untrusted_entries(&entries[..5], now); assert_eq!(count_entries(&overlay), 5); - overlay.add_untrusted_entries(&entries[5..]); + overlay.add_untrusted_entries(&entries[5..], now); assert_eq!(count_entries(&overlay), 10); } @@ -424,7 +487,7 @@ mod tests { { let overlay = make_overlay_with_min_capacity(10); let entries = generate_public_entries(&overlay, now, 10); - overlay.add_untrusted_entries(&entries); + overlay.add_untrusted_entries(&entries, now); assert_eq!(count_entries(&overlay), 10); } @@ -432,7 +495,7 @@ mod tests { { let overlay = make_overlay_with_min_capacity(10); let entries = generate_public_entries(&overlay, now, 20); - overlay.add_untrusted_entries(&entries); + overlay.add_untrusted_entries(&entries, now); assert_eq!(count_entries(&overlay), 10); } @@ -440,7 +503,7 @@ mod tests { { let overlay = make_overlay_with_min_capacity(0); let entries = generate_public_entries(&overlay, now, 10); - overlay.add_untrusted_entries(&entries); + overlay.add_untrusted_entries(&entries, now); assert_eq!(count_entries(&overlay), 0); } @@ -450,7 +513,7 @@ mod tests { let entries = (0..10) .map(|_| generate_invalid_public_entry(now)) .collect::>(); - overlay.add_untrusted_entries(&entries); + overlay.add_untrusted_entries(&entries, now); assert_eq!(count_entries(&overlay), 0); } @@ -469,7 +532,7 @@ mod tests { generate_invalid_public_entry(now), generate_public_entry(&overlay, now), ]; - overlay.add_untrusted_entries(&entries); + overlay.add_untrusted_entries(&entries, now); assert_eq!(count_entries(&overlay), 5); } @@ -488,7 +551,7 @@ mod tests { generate_public_entry(&overlay, now), generate_public_entry(&overlay, now), ]; - overlay.add_untrusted_entries(&entries); + overlay.add_untrusted_entries(&entries, now); assert_eq!(count_entries(&overlay), 3); } } @@ -504,7 +567,7 @@ mod tests { for entries in entries.chunks_exact(7 * 3) { s.spawn(|| { for entries in entries.chunks_exact(7) { - overlay.add_untrusted_entries(entries); + overlay.add_untrusted_entries(entries, now); } }); } diff --git a/network/src/proto/overlay.rs b/network/src/proto/overlay.rs index 2ba6a52ab..1cef28a81 100644 --- a/network/src/proto/overlay.rs +++ b/network/src/proto/overlay.rs @@ -30,6 +30,14 @@ pub struct PublicEntry { pub signature: Box<[u8; 64]>, } +impl PublicEntry { + pub fn is_expired(&self, at: u32, ttl_sec: u32) -> bool { + const CLOCK_THRESHOLD: u32 = 1; + + self.created_at > at + CLOCK_THRESHOLD || self.created_at.saturating_add(ttl_sec) < at + } +} + /// A list of public overlay entries. #[derive(Debug, Clone, Hash, PartialEq, Eq, TlRead, TlWrite)] #[tl(boxed, scheme = "proto.tl")] diff --git a/network/src/types/peer_id.rs b/network/src/types/peer_id.rs index 60cb49e09..17147fc0f 100644 --- a/network/src/types/peer_id.rs +++ b/network/src/types/peer_id.rs @@ -107,6 +107,20 @@ impl From for PeerId { } } +impl PartialEq<&PeerId> for PeerId { + #[inline] + fn eq(&self, other: &&PeerId) -> bool { + self == *other + } +} + +impl PartialEq for &PeerId { + #[inline] + fn eq(&self, other: &PeerId) -> bool { + *self == other + } +} + impl std::ops::BitXor for PeerId { type Output = PeerId; diff --git a/network/src/types/request.rs b/network/src/types/request.rs index 8c337b975..ec8ea4712 100644 --- a/network/src/types/request.rs +++ b/network/src/types/request.rs @@ -51,7 +51,7 @@ impl<'de> Deserialize<'de> for Version { } } -#[derive(Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] pub struct Request { pub version: Version, #[serde(with = "serde_body")] diff --git a/network/src/util/mod.rs b/network/src/util/mod.rs index c868a54aa..2d8b8f49d 100644 --- a/network/src/util/mod.rs +++ b/network/src/util/mod.rs @@ -1,11 +1,17 @@ pub use self::router::{Routable, Router, RouterBuilder}; pub use self::traits::NetworkExt; +#[cfg(test)] +pub use self::test::make_peer_info_stub; + use crate::types::PeerId; mod router; mod traits; +#[cfg(test)] +mod test; + pub(crate) mod tl; #[macro_export] diff --git a/network/src/util/test.rs b/network/src/util/test.rs new file mode 100644 index 000000000..2959a47f5 --- /dev/null +++ b/network/src/util/test.rs @@ -0,0 +1,13 @@ +use std::sync::Arc; + +use crate::types::{PeerId, PeerInfo}; + +pub fn make_peer_info_stub(id: PeerId) -> Arc { + Arc::new(PeerInfo { + id, + address_list: Default::default(), + created_at: 0, + expires_at: u32::MAX, + signature: Box::new([0; 64]), + }) +} diff --git a/network/src/util/traits.rs b/network/src/util/traits.rs index fb0a06384..925f1e393 100644 --- a/network/src/util/traits.rs +++ b/network/src/util/traits.rs @@ -2,7 +2,7 @@ use std::future::Future; use anyhow::Result; -use crate::network::{KnownPeer, Network, Peer}; +use crate::network::{Network, Peer}; use crate::types::{PeerEvent, PeerId, Request, Response}; pub trait NetworkExt { @@ -45,7 +45,7 @@ where match network.known_peers().get(peer_id) { // Initiate a connection of it is a known peer - Some(KnownPeer { peer_info, .. }) => { + Some(peer_info) => { // TODO: try multiple addresses let address = peer_info .iter_addresses() @@ -53,7 +53,7 @@ where .cloned() .expect("address list must have at least one item"); - network.connect_with_peer_id(address, peer_id).await?; + network.connect(address, peer_id).await?; } // Error otherwise None => anyhow::bail!("trying to interact with an unknown peer: {peer_id}"), @@ -61,7 +61,7 @@ where loop { match peer_events.recv().await { - Ok(PeerEvent::NewPeer(new_peer_id)) if &new_peer_id == peer_id => { + Ok(PeerEvent::NewPeer(new_peer_id)) if new_peer_id == peer_id => { if let Some(peer) = network.peer(peer_id) { return f.call(&peer, request).await; } diff --git a/network/tests/dht.rs b/network/tests/dht.rs index a3235cfef..01c61fa46 100644 --- a/network/tests/dht.rs +++ b/network/tests/dht.rs @@ -23,9 +23,9 @@ impl Node { fn new(key: &ed25519::SecretKey) -> Self { let keypair = ed25519::KeyPair::from(key); - let (dht_client, dht) = DhtService::builder(keypair.public_key.into()).build(); + let (dht_tasks, dht_service) = DhtService::builder(keypair.public_key.into()).build(); - let router = Router::builder().route(dht).build(); + let router = Router::builder().route(dht_service.clone()).build(); let network = Network::builder() .with_private_key(key.to_bytes()) @@ -33,7 +33,9 @@ impl Node { .build((Ipv4Addr::LOCALHOST, 0), router) .unwrap(); - let dht = dht_client.build(network.clone()); + dht_tasks.spawn(&network); + + let dht = dht_service.make_client(network.clone()); Self { network, dht } } diff --git a/network/tests/overlay.rs b/network/tests/overlay.rs index c376e35ec..f604f5296 100644 --- a/network/tests/overlay.rs +++ b/network/tests/overlay.rs @@ -9,7 +9,7 @@ use std::net::Ipv4Addr; use std::sync::Arc; use tl_proto::{TlRead, TlWrite}; use tycho_network::{ - Address, Network, OverlayId, OverlayService, PeerAffinity, PeerId, PeerInfo, PrivateOverlay, + Address, KnownPeerHandle, Network, OverlayId, OverlayService, PeerId, PeerInfo, PrivateOverlay, Request, Response, Router, Service, ServiceRequest, }; use tycho_util::time::now_sec; @@ -17,6 +17,7 @@ use tycho_util::time::now_sec; struct Node { network: Network, private_overlay: PrivateOverlay, + known_peer_handles: Vec, } impl Node { @@ -24,13 +25,9 @@ impl Node { let keypair = ed25519::KeyPair::from(key); let local_id = PeerId::from(keypair.public_key); - let private_overlay = PrivateOverlay::builder(PRIVATE_OVERLAY_ID).build(PingPongService); - - let (overlay_tasks, overlay_service) = OverlayService::builder(local_id) - .with_private_overlay(&private_overlay) - .build(); + let (overlay_tasks, overlay_service) = OverlayService::builder(local_id).build(); - let router = Router::builder().route(overlay_service).build(); + let router = Router::builder().route(overlay_service.clone()).build(); let network = Network::builder() .with_private_key(key.to_bytes()) @@ -38,11 +35,15 @@ impl Node { .build((Ipv4Addr::LOCALHOST, 0), router) .unwrap(); - overlay_tasks.spawn(network.clone()); + overlay_tasks.spawn(&network); + + let private_overlay = PrivateOverlay::builder(PRIVATE_OVERLAY_ID).build(PingPongService); + overlay_service.add_private_overlay(&private_overlay); Self { network, private_overlay, + known_peer_handles: Vec::new(), } } @@ -80,19 +81,26 @@ fn make_network(node_count: usize) -> Vec { .map(|_| ed25519::SecretKey::generate(&mut rand::thread_rng())) .collect::>(); - let nodes = keys.iter().map(Node::new).collect::>(); + let mut nodes = keys.iter().map(Node::new).collect::>(); let bootstrap_info = std::iter::zip(&keys, &nodes) .map(|(key, node)| Arc::new(Node::make_peer_info(key, node.network.local_addr().into()))) .collect::>(); - for node in &nodes { + for node in &mut nodes { let mut private_overlay_entries = node.private_overlay.write_entries(); for info in &bootstrap_info { - node.network - .known_peers() - .insert(info.clone(), PeerAffinity::Allowed); + if info.id == node.network.peer_id() { + continue; + } + + node.known_peer_handles.push( + node.network + .known_peers() + .insert(info.clone(), false) + .unwrap(), + ); private_overlay_entries.insert(&info.id); } diff --git a/util/Cargo.toml b/util/Cargo.toml index 3e24d5ec7..f2b0a7fa5 100644 --- a/util/Cargo.toml +++ b/util/Cargo.toml @@ -14,9 +14,16 @@ hex = "0.4" humantime = "2" rand = "0.8" serde = { version = "1.0", features = ["derive"] } -tokio = { version = "1", default-features = false, features = ["time"] } +thiserror = "1.0" +tokio = { version = "1", default-features = false, features = ["time", "sync", "rt"] } -# local deps +[dev-dependencies] +tokio = { version = "1", default-features = false, features = [ + "time", + "sync", + "rt-multi-thread", + "macros", +] } [lints] workspace = true diff --git a/util/src/futures.rs b/util/src/futures/box_future_or_noop.rs similarity index 100% rename from util/src/futures.rs rename to util/src/futures/box_future_or_noop.rs diff --git a/util/src/futures/join_task.rs b/util/src/futures/join_task.rs new file mode 100644 index 000000000..c6223375b --- /dev/null +++ b/util/src/futures/join_task.rs @@ -0,0 +1,58 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures_util::{Future, FutureExt}; +use tokio::task::JoinHandle; + +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct JoinTask { + handle: JoinHandle, + completed: bool, +} + +impl JoinTask { + #[inline] + pub fn new(f: F) -> Self + where + F: Future + Send + 'static, + T: Send + 'static, + { + Self { + handle: tokio::spawn(f), + completed: false, + } + } + + pub fn is_finished(&self) -> bool { + self.handle.is_finished() + } +} + +impl Drop for JoinTask { + fn drop(&mut self) { + if !self.completed { + self.handle.abort(); + } + } +} + +impl Future for JoinTask { + type Output = T; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let res = futures_util::ready!(self.handle.poll_unpin(cx)); + match res { + Ok(value) => { + self.completed = true; + Poll::Ready(value) + } + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } + unreachable!() + } + } + } +} diff --git a/util/src/futures/shared.rs b/util/src/futures/shared.rs new file mode 100644 index 000000000..1970d76d0 --- /dev/null +++ b/util/src/futures/shared.rs @@ -0,0 +1,303 @@ +use std::cell::UnsafeCell; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Weak}; +use std::task::{Context, Poll}; + +use futures_util::future::BoxFuture; +use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore, TryAcquireError}; + +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct Shared { + inner: Option>>, + permit_fut: Option>>, + permit: Option, +} + +impl Clone for Shared { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + permit_fut: None, + permit: None, + } + } +} + +impl Shared { + pub fn new(future: Fut) -> Self { + let semaphore = Arc::new(Semaphore::new(1)); + let inner = Arc::new(Inner { + state: AtomicUsize::new(POLLING), + future_or_output: UnsafeCell::new(FutureOrOutput::Future(future)), + semaphore, + }); + + Self { + inner: Some(inner), + permit_fut: None, + permit: None, + } + } + + pub fn downgrade(&self) -> Option> { + self.inner + .as_ref() + .map(|inner| WeakShared(Arc::downgrade(inner))) + } + + /// Drops the future, returning whether it was the last instance. + pub fn consume(mut self) -> bool { + self.inner + .take() + .map(|inner| Arc::into_inner(inner).is_some()) + .unwrap_or_default() + } +} + +impl Future for Shared +where + Fut: Future, + Fut::Output: Clone, +{ + type Output = (Fut::Output, bool); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = &mut *self; + + let inner = this + .inner + .take() + .expect("Shared future polled again after completion"); + + // Fast path for when the wrapped future has already completed + if inner.state.load(Ordering::Acquire) == COMPLETE { + // Safety: We're in the COMPLETE state + return unsafe { Poll::Ready(inner.take_or_clone_output()) }; + } + + if this.permit.is_none() { + this.permit = Some('permit: { + // Poll semaphore future + let permit_fut = if let Some(fut) = this.permit_fut.as_mut() { + fut + } else { + // Avoid allocations completely if we can grab a permit immediately + match Arc::clone(&inner.semaphore).try_acquire_owned() { + Ok(permit) => break 'permit permit, + Err(TryAcquireError::NoPermits) => {} + // NOTE: We don't expect the semaphore to be closed + Err(TryAcquireError::Closed) => unreachable!(), + } + + let next_fut = Arc::clone(&inner.semaphore).acquire_owned(); + this.permit_fut.get_or_insert(Box::pin(next_fut)) + }; + + // Acquire a permit to poll the inner future + match Pin::new(permit_fut).poll(cx) { + Poll::Pending => { + this.inner = Some(inner); + return Poll::Pending; + } + Poll::Ready(Ok(permit)) => { + // Reset the permit future as we don't need it anymore + this.permit_fut = None; + permit + } + // NOTE: We don't expect the semaphore to be closed + Poll::Ready(Err(_e)) => unreachable!(), + } + }); + } + + match inner.state.load(Ordering::Acquire) { + COMPLETE => { + // SAFETY: We're in the COMPLETE state + return unsafe { Poll::Ready(inner.take_or_clone_output()) }; + } + POISONED => panic!("inner future panicked during poll"), + _ => {} + } + + // Create poison guard + struct Reset<'a> { + state: &'a AtomicUsize, + did_not_panic: bool, + } + + impl Drop for Reset<'_> { + fn drop(&mut self) { + if !self.did_not_panic { + self.state.store(POISONED, Ordering::Release); + } + } + } + + let mut reset = Reset { + state: &inner.state, + did_not_panic: false, + }; + + let output = { + // SAFETY: We are now a sole owner of the permit to poll the inner future + let future = unsafe { + match &mut *inner.future_or_output.get() { + FutureOrOutput::Future(fut) => Pin::new_unchecked(fut), + FutureOrOutput::Output(_) => unreachable!(), + } + }; + + let poll_result = future.poll(cx); + reset.did_not_panic = true; + + match poll_result { + Poll::Pending => { + drop(reset); // Make borrow checker happy + this.inner = Some(inner); + return Poll::Pending; + } + Poll::Ready(output) => output, + } + }; + + unsafe { + *inner.future_or_output.get() = FutureOrOutput::Output(output); + } + + inner.state.store(COMPLETE, Ordering::Release); + + drop(reset); // Make borrow checker happy + + // Reset permits + self.permit_fut = None; + self.permit = None; + + // SAFETY: We're in the COMPLETE state + unsafe { Poll::Ready(inner.take_or_clone_output()) } + } +} + +#[repr(transparent)] +pub struct WeakShared(Weak>); + +impl WeakShared { + pub fn upgrade(&self) -> Option> { + self.0.upgrade().map(|inner| Shared { + inner: Some(inner), + permit_fut: None, + permit: None, + }) + } + + pub fn strong_count(&self) -> usize { + self.0.strong_count() + } +} + +struct Inner { + state: AtomicUsize, + future_or_output: UnsafeCell>, + semaphore: Arc, +} + +impl Inner +where + Fut: Future, + Fut::Output: Clone, +{ + /// Safety: callers must first ensure that `inner.state` + /// is `COMPLETE` + unsafe fn take_or_clone_output(self: Arc) -> (Fut::Output, bool) { + match Arc::try_unwrap(self) { + Ok(inner) => match inner.future_or_output.into_inner() { + FutureOrOutput::Output(item) => (item, true), + FutureOrOutput::Future(_) => unreachable!(), + }, + Err(inner) => match &*inner.future_or_output.get() { + FutureOrOutput::Output(item) => (item.clone(), false), + FutureOrOutput::Future(_) => unreachable!(), + }, + } + } +} + +unsafe impl Send for Inner +where + Fut: Future + Send, + Fut::Output: Send + Sync, +{ +} + +unsafe impl Sync for Inner +where + Fut: Future + Send, + Fut::Output: Send + Sync, +{ +} + +enum FutureOrOutput { + Future(Fut), + Output(Fut::Output), +} + +const POLLING: usize = 0; +const COMPLETE: usize = 2; +const POISONED: usize = 3; + +#[cfg(test)] +mod tests { + //! Addresses the original `Shared` futures issue: + //! https://github.com/rust-lang/futures-rs/issues/2706 + + use futures_util::FutureExt; + + use super::*; + + async fn yield_now() { + /// Yield implementation + struct YieldNow { + yielded: bool, + } + + impl Future for YieldNow { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + if self.yielded { + return Poll::Ready(()); + } + + self.yielded = true; + cx.waker().wake_by_ref(); + Poll::Pending + } + } + + YieldNow { yielded: false }.await + } + + #[tokio::test(flavor = "multi_thread")] + async fn must_not_hang_up() { + for _ in 0..200 { + for _ in 0..1000 { + test_fut().await; + } + } + println!(); + } + + async fn test_fut() { + let f1 = Shared::new(yield_now()); + let f2 = f1.clone(); + let x1 = tokio::spawn(async move { + f1.now_or_never(); + }); + let x2 = tokio::spawn(async move { + f2.await; + }); + x1.await.ok(); + x2.await.ok(); + } +} diff --git a/util/src/lib.rs b/util/src/lib.rs index 7106dbf0d..0fb010594 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -1,10 +1,30 @@ use std::collections::HashMap; use std::collections::HashSet; -pub mod futures; pub mod serde_helpers; pub mod time; +pub mod futures { + pub use self::box_future_or_noop::BoxFutureOrNoop; + pub use self::join_task::JoinTask; + pub use self::shared::{Shared, WeakShared}; + + mod box_future_or_noop; + mod join_task; + mod shared; +} + +pub mod sync { + pub use self::priority_semaphore::{AcquireError, PrioritySemaphore, TryAcquireError}; + + mod priority_semaphore; +} + +mod util { + pub(crate) mod linked_list; + pub(crate) mod wake_list; +} + pub type FastDashMap = dashmap::DashMap; pub type FastDashSet = dashmap::DashSet; pub type FastHashMap = HashMap; diff --git a/util/src/sync/priority_semaphore.rs b/util/src/sync/priority_semaphore.rs new file mode 100644 index 000000000..b4722bc17 --- /dev/null +++ b/util/src/sync/priority_semaphore.rs @@ -0,0 +1,631 @@ +//! See . + +use std::cell::UnsafeCell; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::ptr::NonNull; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex, MutexGuard}; +use std::task::{Context, Poll, Waker}; + +use futures_util::Future; + +use crate::util::linked_list::{Link, LinkedList, Pointers}; +use crate::util::wake_list::WakeList; + +pub struct PrioritySemaphore { + waiters: Mutex, + permits: AtomicUsize, +} + +impl PrioritySemaphore { + const MAX_PERMITS: usize = usize::MAX >> 3; + const CLOSED: usize = 1; + const PERMIT_SHIFT: usize = 1; + + pub fn new(permits: usize) -> Self { + assert!( + permits <= Self::MAX_PERMITS, + "a semaphore may not have more than MAX_PERMITS permits ({})", + Self::MAX_PERMITS + ); + + Self { + permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT), + waiters: Mutex::new(Waitlist { + ordinary_queue: LinkedList::new(), + priority_queue: LinkedList::new(), + closed: false, + }), + } + } + + pub const fn const_new(permits: usize) -> Self { + assert!(permits <= Self::MAX_PERMITS); + + Self { + permits: AtomicUsize::new(permits << Self::PERMIT_SHIFT), + waiters: Mutex::new(Waitlist { + ordinary_queue: LinkedList::new(), + priority_queue: LinkedList::new(), + closed: false, + }), + } + } + + pub fn available_permits(&self) -> usize { + self.permits.load(Ordering::Acquire) >> Self::PERMIT_SHIFT + } + + pub fn close(&self) { + fn clear_queue(queue: &mut LinkedList::Target>) { + while let Some(mut waiter) = queue.pop_back() { + let waker = unsafe { (*waiter.as_mut().waker.get()).take() }; + if let Some(waker) = waker { + waker.wake(); + } + } + } + + let mut waiters = self.waiters.lock().unwrap(); + + self.permits.fetch_or(Self::CLOSED, Ordering::Release); + waiters.closed = true; + + clear_queue(&mut waiters.ordinary_queue); + clear_queue(&mut waiters.priority_queue); + } + + pub fn is_closed(&self) -> bool { + self.permits.load(Ordering::Acquire) & Self::CLOSED == Self::CLOSED + } + + pub fn try_acquire(&self) -> Result, TryAcquireError> { + self.try_acquire_impl(1).map(|()| SemaphorePermit { + semaphore: self, + permits: 1, + }) + } + + pub fn try_acquire_owned(self: Arc) -> Result { + self.try_acquire_impl(1).map(|()| OwnedSemaphorePermit { + semaphore: self, + permits: 1, + }) + } + + pub async fn acquire(&self, priority: bool) -> Result, AcquireError> { + match self.acquire_impl(1, priority).await { + Ok(()) => Ok(SemaphorePermit { + semaphore: self, + permits: 1, + }), + Err(e) => Err(e), + } + } + + pub async fn acquire_owned( + self: Arc, + priority: bool, + ) -> Result { + match self.acquire_impl(1, priority).await { + Ok(()) => Ok(OwnedSemaphorePermit { + semaphore: self, + permits: 1, + }), + Err(e) => Err(e), + } + } + + pub fn add_permits(&self, n: usize) { + if n == 0 { + return; + } + + // Assign permits to the wait queue + self.add_permits_locked(n, self.waiters.lock().unwrap()); + } + + fn try_acquire_impl(&self, num_permits: usize) -> Result<(), TryAcquireError> { + assert!( + num_permits <= Self::MAX_PERMITS, + "a semaphore may not have more than MAX_PERMITS permits ({})", + Self::MAX_PERMITS + ); + + let num_permits = num_permits << Self::PERMIT_SHIFT; + let mut curr = self.permits.load(Ordering::Acquire); + loop { + // Has the semaphore closed? + if curr & Self::CLOSED == Self::CLOSED { + return Err(TryAcquireError::Closed); + } + + // Are there enough permits remaining? + if curr < num_permits { + return Err(TryAcquireError::NoPermits); + } + + let next = curr - num_permits; + + match self + .permits + .compare_exchange(curr, next, Ordering::AcqRel, Ordering::Acquire) + { + Ok(_) => return Ok(()), + Err(actual) => curr = actual, + } + } + } + + fn acquire_impl(&self, num_permits: usize, priority: bool) -> Acquire<'_> { + Acquire::new(self, num_permits, priority) + } + + fn add_permits_locked(&self, mut rem: usize, waiters: MutexGuard<'_, Waitlist>) { + let mut wakers = WakeList::new(); + let mut lock = Some(waiters); + let mut is_empty = false; + while rem > 0 { + let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock().unwrap()); + + { + let waiters = &mut *waiters; + 'inner: while wakers.can_push() { + // Was the waiter assigned enough permits to wake it? + let queue = 'queue: { + for queue in [&mut waiters.priority_queue, &mut waiters.ordinary_queue] { + if let Some(waiter) = queue.last() { + if !waiter.assign_permits(&mut rem) { + continue; + } + break 'queue queue; + } + } + + is_empty = true; + // If we assigned permits to all the waiters in the queue, and there are + // still permits left over, assign them back to the semaphore. + break 'inner; + }; + + let mut waiter = queue.pop_back().unwrap(); + if let Some(waker) = unsafe { (*waiter.as_mut().waker.get()).take() } { + wakers.push(waker); + } + } + } + + if rem > 0 && is_empty { + let permits = rem; + assert!( + permits <= Self::MAX_PERMITS, + "cannot add more than MAX_PERMITS permits ({})", + Self::MAX_PERMITS + ); + let prev = self + .permits + .fetch_add(rem << Self::PERMIT_SHIFT, Ordering::Release); + let prev = prev >> Self::PERMIT_SHIFT; + assert!( + prev + permits <= Self::MAX_PERMITS, + "number of added permits ({}) would overflow MAX_PERMITS ({})", + rem, + Self::MAX_PERMITS + ); + + rem = 0; + } + + drop(waiters); // release the lock + + wakers.wake_all(); + } + + assert_eq!(rem, 0); + } + + fn poll_acquire( + &self, + cx: &mut Context<'_>, + num_permits: usize, + node: Pin<&mut Waiter>, + queued: bool, + priority: bool, + ) -> Poll> { + let mut acquired = 0; + + let needed = if queued { + node.state.load(Ordering::Acquire) << Self::PERMIT_SHIFT + } else { + num_permits << Self::PERMIT_SHIFT + }; + + let mut lock = None; + // First, try to take the requested number of permits from the + // semaphore. + let mut curr = self.permits.load(Ordering::Acquire); + let mut waiters = loop { + // Has the semaphore closed? + if curr & Self::CLOSED > 0 { + return Poll::Ready(Err(AcquireError(()))); + } + + let mut remaining = 0; + let total = curr + .checked_add(acquired) + .expect("number of permits must not overflow"); + let (next, acq) = if total >= needed { + let next = curr - (needed - acquired); + (next, needed >> Self::PERMIT_SHIFT) + } else { + remaining = (needed - acquired) - curr; + (0, curr >> Self::PERMIT_SHIFT) + }; + + if remaining > 0 && lock.is_none() { + // No permits were immediately available, so this permit will + // (probably) need to wait. We'll need to acquire a lock on the + // wait queue before continuing. We need to do this _before_ the + // CAS that sets the new value of the semaphore's `permits` + // counter. Otherwise, if we subtract the permits and then + // acquire the lock, we might miss additional permits being + // added while waiting for the lock. + lock = Some(self.waiters.lock().unwrap()); + } + + match self + .permits + .compare_exchange(curr, next, Ordering::AcqRel, Ordering::Acquire) + { + Ok(_) => { + acquired += acq; + if remaining == 0 { + if !queued { + return Poll::Ready(Ok(())); + } else if lock.is_none() { + break self.waiters.lock().unwrap(); + } + } + break lock.expect("lock must be acquired before waiting"); + } + Err(actual) => curr = actual, + } + }; + + if waiters.closed { + return Poll::Ready(Err(AcquireError(()))); + } + + if node.assign_permits(&mut acquired) { + self.add_permits_locked(acquired, waiters); + return Poll::Ready(Ok(())); + } + + assert_eq!(acquired, 0); + let mut old_waker = None; + + // Otherwise, register the waker & enqueue the node. + { + // SAFETY: the wait list is locked, so we may modify the waker. + let waker = unsafe { &mut *node.waker.get() }; + + // Do we need to register the new waker? + if waker + .as_ref() + .map_or(true, |waker| !waker.will_wake(cx.waker())) + { + old_waker = std::mem::replace(waker, Some(cx.waker().clone())); + } + } + + // If the waiter is not already in the wait queue, enqueue it. + if !queued { + let node = unsafe { + let node = Pin::into_inner_unchecked(node) as *mut _; + NonNull::new_unchecked(node) + }; + + waiters.queue_mut(priority).push_front(node); + } + drop(waiters); + drop(old_waker); + + Poll::Pending + } +} + +#[must_use] +#[clippy::has_significant_drop] +pub struct SemaphorePermit<'a> { + semaphore: &'a PrioritySemaphore, + permits: u32, +} + +impl Drop for SemaphorePermit<'_> { + fn drop(&mut self) { + self.semaphore.add_permits(self.permits as usize); + } +} + +#[must_use] +#[clippy::has_significant_drop] +pub struct OwnedSemaphorePermit { + semaphore: Arc, + permits: u32, +} + +impl Drop for OwnedSemaphorePermit { + fn drop(&mut self) { + self.semaphore.add_permits(self.permits as usize); + } +} + +struct Acquire<'a> { + node: Waiter, + semaphore: &'a PrioritySemaphore, + num_permits: usize, + queued: bool, + priority: bool, +} + +impl<'a> Acquire<'a> { + fn new(semaphore: &'a PrioritySemaphore, num_permits: usize, priority: bool) -> Self { + Self { + node: Waiter::new(num_permits), + semaphore, + num_permits, + queued: false, + priority, + } + } + + fn project( + self: Pin<&mut Self>, + ) -> (Pin<&mut Waiter>, &PrioritySemaphore, usize, &mut bool, bool) { + fn is_unpin() {} + unsafe { + // SAFETY: all fields other than `node` are `Unpin` + + is_unpin::<&PrioritySemaphore>(); + is_unpin::<&mut bool>(); + is_unpin::(); + + let this = self.get_unchecked_mut(); + ( + Pin::new_unchecked(&mut this.node), + this.semaphore, + this.num_permits, + &mut this.queued, + this.priority, + ) + } + } +} + +impl Drop for Acquire<'_> { + fn drop(&mut self) { + if !self.queued { + return; + } + + let mut waiters = self.semaphore.waiters.lock().unwrap(); + + let node = NonNull::from(&mut self.node); + // SAFETY: we have locked the wait list. + unsafe { waiters.queue_mut(self.priority).remove(node) }; + + let acquired_permits = self.num_permits - self.node.state.load(Ordering::Acquire); + if acquired_permits > 0 { + self.semaphore.add_permits_locked(acquired_permits, waiters); + } + } +} + +// SAFETY: the `Acquire` future is not `Sync` automatically because it contains +// a `Waiter`, which, in turn, contains an `UnsafeCell`. However, the +// `UnsafeCell` is only accessed when the future is borrowed mutably (either in +// `poll` or in `drop`). Therefore, it is safe (although not particularly +// _useful_) for the future to be borrowed immutably across threads. +unsafe impl Sync for Acquire<'_> {} + +impl Future for Acquire<'_> { + type Output = Result<(), AcquireError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let (node, semaphore, needed, queued, priority) = self.project(); + + match semaphore.poll_acquire(cx, needed, node, *queued, priority) { + Poll::Pending => { + *queued = true; + Poll::Pending + } + Poll::Ready(r) => { + r?; + *queued = false; + Poll::Ready(Ok(())) + } + } + } +} + +#[derive(Debug, thiserror::Error)] +#[error("semaphore closed")] +pub struct AcquireError(()); + +#[derive(Debug, PartialEq, Eq, thiserror::Error)] +pub enum TryAcquireError { + /// The semaphore has been [closed] and cannot issue new permits. + /// + /// [closed]: crate::sync::PrioritySemaphore::close + #[error("semaphore closed")] + Closed, + + /// The semaphore has no available permits. + #[error("no permits available")] + NoPermits, +} + +struct Waitlist { + ordinary_queue: LinkedList::Target>, + priority_queue: LinkedList::Target>, + closed: bool, +} + +impl Waitlist { + fn queue_mut(&mut self, priority: bool) -> &mut LinkedList::Target> { + if priority { + &mut self.priority_queue + } else { + &mut self.ordinary_queue + } + } +} + +struct Waiter { + state: AtomicUsize, + waker: UnsafeCell>, + pointers: Pointers, + _pin: PhantomPinned, +} + +impl Waiter { + fn new(num_permits: usize) -> Self { + Waiter { + state: AtomicUsize::new(num_permits), + waker: UnsafeCell::new(None), + pointers: Pointers::new(), + _pin: PhantomPinned, + } + } + + /// Assign permits to the waiter. + /// + /// Returns `true` if the waiter should be removed from the queue + fn assign_permits(&self, n: &mut usize) -> bool { + let mut curr = self.state.load(Ordering::Acquire); + loop { + let assign = std::cmp::min(curr, *n); + let next = curr - assign; + match self + .state + .compare_exchange(curr, next, Ordering::AcqRel, Ordering::Acquire) + { + Ok(_) => { + *n -= assign; + return next == 0; + } + Err(actual) => curr = actual, + } + } + } + + unsafe fn addr_of_pointers(target: NonNull) -> NonNull> { + let target = target.as_ptr(); + let field = std::ptr::addr_of_mut!((*target).pointers); + NonNull::new_unchecked(field) + } +} + +unsafe impl Link for Waiter { + type Handle = NonNull; + type Target = Self; + + #[inline] + fn as_raw(handle: &Self::Handle) -> NonNull { + *handle + } + + #[inline] + unsafe fn from_raw(ptr: NonNull) -> Self::Handle { + ptr + } + + #[inline] + unsafe fn pointers(target: NonNull) -> NonNull> { + Self::addr_of_pointers(target) + } +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + use std::time::Duration; + + use super::*; + + #[tokio::test(flavor = "multi_thread")] + async fn priority_semaphore_works() { + let permits = Arc::new(PrioritySemaphore::new(1)); + + let flag = Arc::new(AtomicBool::new(false)); + + tokio::spawn({ + let permits = permits.clone(); + async move { + println!("BACKGROUND BEFORE"); + let _guard = permits.acquire(false).await.unwrap(); + println!("BACKGROUND AFTER"); + tokio::time::sleep(Duration::from_millis(100)).await; + println!("BACKGROUND FINISH"); + } + }); + + tokio::time::sleep(Duration::from_micros(10)).await; + + // Spawn an ordinary task that acquires a permit. + let ordinary_task = tokio::spawn({ + let permits = permits.clone(); + let flag = flag.clone(); + async move { + println!("ORDINARY BEFORE"); + let _guard = permits.acquire(false).await.unwrap(); + println!("ORDINARY AFTER"); + // Flag must be fired by the priority task after the permit is acquired. + assert!(flag.load(Ordering::Acquire)); + } + }); + + tokio::time::sleep(Duration::from_micros(10)).await; + + let priority_task = tokio::spawn({ + let permits = permits; + let flag = flag.clone(); + async move { + println!("PRIORITY BEFORE"); + let _guard = permits.acquire(true).await.unwrap(); + println!("PRIORITY"); + flag.store(true, Ordering::Release); + } + }); + + ordinary_task.await.unwrap(); + priority_task.await.unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn priority_semaphore_is_fair() { + let permits = Arc::new(PrioritySemaphore::new(10)); + + let flag = AtomicBool::new(false); + tokio::join!( + non_cooperative_task(permits, &flag), + poor_little_task(&flag), + ); + } + + async fn non_cooperative_task(permits: Arc, flag: &AtomicBool) { + while !flag.load(Ordering::Acquire) { + let _permit = permits.acquire(false).await.unwrap(); + + // NOTE: This yield is necessary to allow the other task to run. + tokio::task::yield_now().await; + } + } + + async fn poor_little_task(flag: &AtomicBool) { + tokio::time::sleep(Duration::from_secs(1)).await; + flag.store(true, Ordering::Release); + } +} diff --git a/util/src/util/linked_list.rs b/util/src/util/linked_list.rs new file mode 100644 index 000000000..db8c2e5b0 --- /dev/null +++ b/util/src/util/linked_list.rs @@ -0,0 +1,206 @@ +//! See . + +use std::cell::UnsafeCell; +use std::marker::{PhantomData, PhantomPinned}; +use std::mem::ManuallyDrop; +use std::ptr::NonNull; + +pub(crate) struct LinkedList { + /// Linked list head + head: Option>, + + /// Linked list tail + tail: Option>, + + /// Node type marker. + _marker: PhantomData<*const L>, +} + +unsafe impl Send for LinkedList where L::Target: Send {} +unsafe impl Sync for LinkedList where L::Target: Sync {} + +impl LinkedList { + pub const fn new() -> LinkedList { + LinkedList { + head: None, + tail: None, + _marker: PhantomData, + } + } +} + +impl LinkedList { + /// Adds an element first in the list. + pub fn push_front(&mut self, val: L::Handle) { + let val = ManuallyDrop::new(val); + let ptr = L::as_raw(&val); + assert_ne!(self.head, Some(ptr)); + unsafe { + L::pointers(ptr).as_mut().set_next(self.head); + L::pointers(ptr).as_mut().set_prev(None); + + if let Some(head) = self.head { + L::pointers(head).as_mut().set_prev(Some(ptr)); + } + + self.head = Some(ptr); + + if self.tail.is_none() { + self.tail = Some(ptr); + } + } + } + + /// Removes the last element from a list and returns it, or None if it is + /// empty. + pub fn pop_back(&mut self) -> Option { + unsafe { + let last = self.tail?; + self.tail = L::pointers(last).as_ref().get_prev(); + + if let Some(prev) = L::pointers(last).as_ref().get_prev() { + L::pointers(prev).as_mut().set_next(None); + } else { + self.head = None; + } + + L::pointers(last).as_mut().set_prev(None); + L::pointers(last).as_mut().set_next(None); + + Some(L::from_raw(last)) + } + } + + /// Removes the specified node from the list + /// + /// # Safety + /// + /// The caller **must** ensure that exactly one of the following is true: + /// - `node` is currently contained by `self`, + /// - `node` is not contained by any list, + /// - `node` is currently contained by some other `GuardedLinkedList` **and** + /// the caller has an exclusive access to that list. This condition is + /// used by the linked list in `sync::Notify`. + pub unsafe fn remove(&mut self, node: NonNull) -> Option { + if let Some(prev) = L::pointers(node).as_ref().get_prev() { + debug_assert_eq!(L::pointers(prev).as_ref().get_next(), Some(node)); + L::pointers(prev) + .as_mut() + .set_next(L::pointers(node).as_ref().get_next()); + } else { + if self.head != Some(node) { + return None; + } + + self.head = L::pointers(node).as_ref().get_next(); + } + + if let Some(next) = L::pointers(node).as_ref().get_next() { + debug_assert_eq!(L::pointers(next).as_ref().get_prev(), Some(node)); + L::pointers(next) + .as_mut() + .set_prev(L::pointers(node).as_ref().get_prev()); + } else { + // This might be the last item in the list + if self.tail != Some(node) { + return None; + } + + self.tail = L::pointers(node).as_ref().get_prev(); + } + + L::pointers(node).as_mut().set_next(None); + L::pointers(node).as_mut().set_prev(None); + + Some(L::from_raw(node)) + } + + pub(crate) fn last(&self) -> Option<&L::Target> { + let tail = self.tail.as_ref()?; + unsafe { Some(&*tail.as_ptr()) } + } +} + +impl Default for LinkedList { + fn default() -> Self { + Self::new() + } +} + +/// # Safety +/// +/// Implementations must guarantee that `Target` types are pinned in memory. +pub(crate) unsafe trait Link { + type Handle; + type Target; + + #[allow(clippy::wrong_self_convention)] + fn as_raw(handle: &Self::Handle) -> NonNull; + + unsafe fn from_raw(ptr: NonNull) -> Self::Handle; + + unsafe fn pointers(target: NonNull) -> NonNull>; +} + +pub(crate) struct Pointers { + inner: UnsafeCell>, +} + +impl Pointers { + /// Create a new set of empty pointers + pub(crate) fn new() -> Pointers { + Pointers { + inner: UnsafeCell::new(PointersInner { + _prev: None, + _next: None, + _pin: PhantomPinned, + }), + } + } + + pub(crate) fn get_prev(&self) -> Option> { + // SAFETY: prev is the first field in PointersInner, which is #[repr(C)]. + unsafe { + let inner = self.inner.get(); + let prev = inner as *const Option>; + std::ptr::read(prev) + } + } + pub(crate) fn get_next(&self) -> Option> { + // SAFETY: next is the second field in PointersInner, which is #[repr(C)]. + unsafe { + let inner = self.inner.get(); + let prev = inner as *const Option>; + let next = prev.add(1); + std::ptr::read(next) + } + } + + fn set_prev(&mut self, value: Option>) { + // SAFETY: prev is the first field in PointersInner, which is #[repr(C)]. + unsafe { + let inner = self.inner.get(); + let prev = inner.cast::>>(); + std::ptr::write(prev, value); + } + } + fn set_next(&mut self, value: Option>) { + // SAFETY: next is the second field in PointersInner, which is #[repr(C)]. + unsafe { + let inner = self.inner.get(); + let prev = inner.cast::>>(); + let next = prev.add(1); + std::ptr::write(next, value); + } + } +} + +#[repr(C)] +struct PointersInner { + _prev: Option>, + _next: Option>, + _pin: PhantomPinned, +} + +unsafe impl Send for Pointers {} +unsafe impl Sync for Pointers {} diff --git a/util/src/util/wake_list.rs b/util/src/util/wake_list.rs new file mode 100644 index 000000000..cbbe331c4 --- /dev/null +++ b/util/src/util/wake_list.rs @@ -0,0 +1,53 @@ +//! See . + +use std::mem::MaybeUninit; +use std::task::Waker; + +pub(crate) struct WakeList { + inner: [MaybeUninit; NUM_WAKERS], + curr: usize, +} + +impl WakeList { + pub fn new() -> Self { + const UNINIT_WAKER: MaybeUninit = MaybeUninit::uninit(); + + Self { + inner: [UNINIT_WAKER; NUM_WAKERS], + curr: 0, + } + } + + pub fn can_push(&self) -> bool { + self.curr < NUM_WAKERS + } + + pub fn push(&mut self, val: Waker) { + debug_assert!(self.can_push()); + + self.inner[self.curr] = MaybeUninit::new(val); + self.curr += 1; + } + + pub fn wake_all(&mut self) { + assert!(self.curr <= NUM_WAKERS); + while self.curr > 0 { + self.curr -= 1; + // SAFETY: The first `curr` elements of `WakeList` are initialized, so by decrementing + // `curr`, we can take ownership of the last item. + let waker = unsafe { std::ptr::read(self.inner[self.curr].as_mut_ptr()) }; + waker.wake(); + } + } +} + +impl Drop for WakeList { + fn drop(&mut self) { + let slice = + std::ptr::slice_from_raw_parts_mut(self.inner.as_mut_ptr().cast::(), self.curr); + // SAFETY: The first `curr` elements are initialized, so we can drop them. + unsafe { std::ptr::drop_in_place(slice) }; + } +} + +const NUM_WAKERS: usize = 32;