diff --git a/.github/workflows/cargo.yml b/.github/workflows/cargo.yml index 75481de04..8502d6eb0 100644 --- a/.github/workflows/cargo.yml +++ b/.github/workflows/cargo.yml @@ -12,7 +12,7 @@ jobs: build: strategy: matrix: - os: [ ubuntu-18.04, macos-10.15 ] + os: [ ubuntu-18.04, macos-11 ] profile: [ release, debug ] name: build-${{ matrix.os }}-${{ matrix.profile }} runs-on: ${{ matrix.os }} @@ -54,7 +54,7 @@ jobs: bench-check: strategy: matrix: - os: [ ubuntu-18.04, macos-10.15 ] + os: [ ubuntu-18.04, macos-11 ] name: build-${{ matrix.os }}-bench runs-on: ${{ matrix.os }} steps: @@ -85,7 +85,7 @@ jobs: check: strategy: matrix: - os: [ ubuntu-20.04, macos-10.15 ] + os: [ ubuntu-20.04, macos-11 ] name: check-${{ matrix.os }} runs-on: ${{ matrix.os }} steps: diff --git a/Cargo.lock b/Cargo.lock index b84ffb056..6cd65cf3d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -18,14 +18,23 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] -name = "ahash" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "796540673305a66d127804eef19ad696f1f204b8c1025aaca4958c17eab32877" +name = "admin" +version = "0.1.0" dependencies = [ - "getrandom", - "once_cell", - "version_check", + "common", + "config", + "crossbeam-channel", + "entrystore", + "libc", + "logger", + "net", + "protocol-admin", + "protocol-common", + "queues", + "rustcommon-metrics 0.1.1 (git+https://github.com/twitter/rustcommon)", + "session", + "slab", + "waker", ] [[package]] @@ -59,9 +68,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.57" +version = "1.0.59" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08f9b8508dccb7687a1d6c4ce66b2b0ecef467c94667de27d8d7fe1f8d2a9cdc" +checksum = "c91f1f46651137be86f3a2b9a8359f9ab421d04d941c62b5982e1ca21113adf9" [[package]] name = "arrayref" @@ -81,12 +90,6 @@ version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" -[[package]] -name = "ascii" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbf56136a5198c7b01a49e3afcbef6cf84597273d298f54432926024107b0109" - [[package]] name = "async-stream" version = "0.3.3" @@ -110,9 +113,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.56" +version = "0.1.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96cf8829f67d2eab0b2dfa42c5d0ef737e0724e4a82b01b3e292456202b19716" +checksum = "76464446b8bc32758d7e88ee1a804d9914cd9b1cb264c029899680b0be29826f" dependencies = [ "proc-macro2", "quote", @@ -138,9 +141,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "backtrace" -version = "0.3.65" +version = "0.3.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11a17d453482a265fd5f8479f2a3f405566e6ca627837aaddb85af8b1ab8ef61" +checksum = "cab84319d616cfb654d03394f38ab7e6f0919e181b1b57e1fd15e7fb4077d9a7" dependencies = [ "addr2line", "cc", @@ -273,12 +276,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8" -[[package]] -name = "cache-padded" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1db59621ec70f09c5e9b597b220c7a2b43611f4710dc03ceb8748637775692c" - [[package]] name = "cast" version = "0.2.7" @@ -329,12 +326,6 @@ dependencies = [ "winapi 0.3.9", ] -[[package]] -name = "chunked_transfer" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fff857943da45f546682664a79488be82e69e43c1a7a2307679ab9afb3a66d2e" - [[package]] name = "clang-sys" version = "1.3.3" @@ -376,6 +367,7 @@ version = "0.1.0" dependencies = [ "boring", "macros", + "net", "rustcommon-logger", "rustcommon-metrics 0.1.1 (git+https://github.com/twitter/rustcommon?rev=fc9c565)", "rustcommon-time 0.0.12 (git+https://github.com/twitter/rustcommon)", @@ -384,7 +376,7 @@ dependencies = [ [[package]] name = "config" -version = "0.1.0" +version = "0.1.1" dependencies = [ "common", "log", @@ -579,27 +571,6 @@ dependencies = [ "seg", ] -[[package]] -name = "errno" -version = "0.2.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f639046355ee4f37944e44f60642c6f3a7efa3cf6b78c78a0d989a8ce6c396a1" -dependencies = [ - "errno-dragonfly", - "libc", - "winapi 0.3.9", -] - -[[package]] -name = "errno-dragonfly" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" -dependencies = [ - "cc", - "libc", -] - [[package]] name = "fastrand" version = "1.7.0" @@ -611,9 +582,9 @@ dependencies = [ [[package]] name = "fixedbitset" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "279fb028e20b3c4c320317955b77c5e0c9701f05a1d309905d6fc702cdc5053e" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "fnv" @@ -648,16 +619,6 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b" -[[package]] -name = "form_urlencoded" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fc25a87fa4fd2094bffb06925852034d90a17f0d1e05197d4956d3555752191" -dependencies = [ - "matches", - "percent-encoding", -] - [[package]] name = "fuchsia-zircon" version = "0.3.3" @@ -736,9 +697,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.26.1" +version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78cc372d058dcf6d5ecd98510e7fbc9e5aec4d21de70f65fea8fecebcd881bd4" +checksum = "22030e2c5a68ec659fde1e949a745124b48e6fa8b045b7ed5bd1fe4ccc5c4e5d" [[package]] name = "glob" @@ -773,9 +734,9 @@ checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" [[package]] name = "hashbrown" -version = "0.11.2" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "heck" @@ -831,9 +792,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" [[package]] name = "hyper" -version = "0.14.19" +version = "0.14.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42dc3c131584288d375f2d07f822b0cb012d8c6fb899a5b9fdb3cb7eb9b6004f" +checksum = "02c929dc5c39e335a03c405292728118860721b10190d98c2a0f0efd5baafbac" dependencies = [ "bytes 1.1.0", "futures-channel", @@ -865,22 +826,11 @@ dependencies = [ "tokio-io-timeout", ] -[[package]] -name = "idna" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "418a0a6fab821475f634efe3ccc45c013f742efe03d853e8d3355d5cb850ecf8" -dependencies = [ - "matches", - "unicode-bidi", - "unicode-normalization", -] - [[package]] name = "indexmap" -version = "1.8.2" +version = "1.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6012d540c5baa3589337a98ce73408de9b5a25ec9fc2c6fd6be8f0d39e0ca5a" +checksum = "10a35a97730320ffe8e2d410b5d3b69279b98d2c14bdb8b70ea89ecf7888d41e" dependencies = [ "autocfg", "hashbrown", @@ -936,9 +886,9 @@ dependencies = [ [[package]] name = "jsonwebtoken" -version = "8.1.0" +version = "8.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc9051c17f81bae79440afa041b3a278e1de71bfb96d32454b477fd4703ccb6f" +checksum = "1aa4b4af834c6cfd35d8763d359661b90f2e45d8f750a0849156c7f4671af09c" dependencies = [ "base64", "pem", @@ -985,9 +935,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.126" +version = "0.2.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "349d5a591cd28b49e1d1037471617a32ddcda5731b99419008085f72d5a53836" +checksum = "8371e4e5341c3a96db127eb2465ac681ced4c433e01dd0e938adbef26ba93ba5" [[package]] name = "libloading" @@ -1065,12 +1015,6 @@ dependencies = [ "syn", ] -[[package]] -name = "matches" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f" - [[package]] name = "memchr" version = "2.5.0" @@ -1140,9 +1084,9 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "713d550d9b44d89174e066b7a6217ae06234c10cb47819a88290d2b353c31799" +checksum = "57ee1c23c7c63b0c9250c339ffdc69255f110b298b901b9f6c82547b7b87caaf" dependencies = [ "libc", "log", @@ -1205,6 +1149,7 @@ dependencies = [ "libc", "logger", "momento", + "net", "protocol-admin", "protocol-memcache", "protocol-resp", @@ -1226,6 +1171,18 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" +[[package]] +name = "net" +version = "0.1.0" +dependencies = [ + "boring", + "boring-sys", + "foreign-types-shared", + "libc", + "mio 0.8.4", + "rustcommon-metrics 0.1.1 (git+https://github.com/twitter/rustcommon)", +] + [[package]] name = "net2" version = "0.2.37" @@ -1309,9 +1266,9 @@ dependencies = [ [[package]] name = "object" -version = "0.28.4" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e42c982f2d955fac81dd7e1d0e1426a7d702acd9c98d19ab01083a6a0328c424" +checksum = "21158b2c33aa6d4561f1c0a6ea283ca92bc54802a93b263e910746d679a7eb53" dependencies = [ "memchr", ] @@ -1365,9 +1322,9 @@ checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" [[package]] name = "pem" -version = "1.0.2" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9a3b09a20e374558580a4914d3b7d89bd61b954a5a5e1dcbea98753addb1947" +checksum = "03c64931a1a212348ec4f3b4362585eca7159d0d09cbdf4a7f74f02173596fd4" dependencies = [ "base64", ] @@ -1390,18 +1347,18 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58ad3879ad3baf4e44784bc6a718a8698867bb991f8ce24d1bcbe2cfb4c3a75e" +checksum = "78203e83c48cffbe01e4a2d35d566ca4de445d79a85372fc64e378bfc812a260" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "744b6f092ba29c3650faf274db506afd39944f48420f6c86b17cfe0ee1cb36bb" +checksum = "710faf75e1b33345361201d36d04e98ac1ed8909151a017ed384700836104c74" dependencies = [ "proc-macro2", "quote", @@ -1436,7 +1393,7 @@ dependencies = [ [[package]] name = "pingserver" -version = "0.1.0" +version = "0.2.0" dependencies = [ "backtrace", "clap", @@ -1565,7 +1522,7 @@ dependencies = [ "criterion", "logger", "protocol-common", - "session", + "rustcommon-metrics 0.1.1 (git+https://github.com/twitter/rustcommon)", "storage-types", ] @@ -1573,17 +1530,17 @@ dependencies = [ name = "protocol-common" version = "0.0.1" dependencies = [ + "bytes 1.1.0", "common", "config", "criterion", "logger", - "session", "storage-types", ] [[package]] name = "protocol-memcache" -version = "0.1.0" +version = "0.2.0" dependencies = [ "common", "criterion", @@ -1591,12 +1548,11 @@ dependencies = [ "nom 5.1.2", "protocol-common", "rustcommon-metrics 0.1.1 (git+https://github.com/twitter/rustcommon)", - "session", ] [[package]] name = "protocol-ping" -version = "0.0.1" +version = "0.0.2" dependencies = [ "common", "config", @@ -1604,58 +1560,58 @@ dependencies = [ "logger", "protocol-common", "rustcommon-metrics 0.1.1 (git+https://github.com/twitter/rustcommon)", - "session", "storage-types", ] [[package]] name = "protocol-resp" -version = "0.1.0" +version = "0.2.0" dependencies = [ + "common", "nom 5.1.2", "protocol-common", - "session", + "rustcommon-metrics 0.1.1 (git+https://github.com/twitter/rustcommon)", ] [[package]] name = "protocol-thrift" -version = "0.0.1" +version = "0.0.2" dependencies = [ + "common", "logger", "protocol-common", "rustcommon-metrics 0.1.1 (git+https://github.com/twitter/rustcommon)", - "session", ] [[package]] name = "proxy" -version = "0.0.1" +version = "0.0.2" dependencies = [ + "admin", "common", "config", "crossbeam-channel", - "libc", + "entrystore", "logger", - "mio 0.8.3", - "mpmc", + "net", "protocol-admin", "protocol-common", "queues", "rustcommon-metrics 0.1.1 (git+https://github.com/twitter/rustcommon)", - "rustcommon-time 0.0.12 (git+https://github.com/twitter/rustcommon)", "session", "slab", - "tiny_http", + "waker", ] [[package]] name = "queues" -version = "0.2.0" +version = "0.3.0" dependencies = [ "crossbeam-queue", - "mio 0.8.3", + "net", "rand", "rand_chacha", + "waker", ] [[package]] @@ -1786,15 +1742,6 @@ dependencies = [ "winapi 0.3.9", ] -[[package]] -name = "rtrb" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "318256ac02f7e11a48a10339ba5dca8bd7eb17496abf384e8ea909bb2ae5275f" -dependencies = [ - "cache-padded", -] - [[package]] name = "rustc-demangle" version = "0.1.21" @@ -1877,7 +1824,7 @@ name = "rustcommon-logger" version = "0.1.0" source = "git+https://github.com/twitter/rustcommon#eb44fc27260fbd64fa41b8744f1831be7d1db945" dependencies = [ - "ahash 0.7.6", + "ahash", "log", "mpmc", "rustcommon-metrics 0.1.1 (git+https://github.com/twitter/rustcommon)", @@ -2051,7 +1998,7 @@ dependencies = [ name = "seg" version = "0.1.0" dependencies = [ - "ahash 0.7.6", + "ahash", "common", "criterion", "datapool", @@ -2067,7 +2014,7 @@ dependencies = [ [[package]] name = "segcache" -version = "0.1.0" +version = "0.2.0" dependencies = [ "backtrace", "clap", @@ -2130,54 +2077,34 @@ dependencies = [ [[package]] name = "server" -version = "0.1.0" +version = "0.2.0" dependencies = [ - "ahash 0.6.3", - "backtrace", - "bytes 1.1.0", + "admin", "common", "config", - "criterion", "crossbeam-channel", "entrystore", - "libc", "logger", - "mio 0.8.3", + "net", "protocol-admin", "protocol-common", "queues", - "rand", - "rtrb", "rustcommon-metrics 0.1.1 (git+https://github.com/twitter/rustcommon)", - "serde", - "serde_json", "session", "slab", - "strum", - "strum_macros", - "sysconf", - "thiserror", - "tiny_http", + "waker", ] [[package]] name = "session" -version = "0.0.2" +version = "0.1.0" dependencies = [ - "common", - "config", - "logger", - "mio 0.8.3", - "rand", - "rtrb", + "bytes 1.1.0", + "log", + "net", + "protocol-common", "rustcommon-metrics 0.1.1 (git+https://github.com/twitter/rustcommon)", - "serde", - "serde_json", - "slab", - "strum", - "strum_macros", - "sysconf", - "thiserror", + "rustcommon-time 0.0.12 (git+https://github.com/twitter/rustcommon)", ] [[package]] @@ -2257,24 +2184,6 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" -[[package]] -name = "strum" -version = "0.20.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7318c509b5ba57f18533982607f24070a55d353e90d4cae30c467cdb2ad5ac5c" - -[[package]] -name = "strum_macros" -version = "0.20.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee8bc6b87a5112aeeab1f4a9f7ab634fe6cbefc4850006df31267f4cfb9e3149" -dependencies = [ - "heck", - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "subtle" version = "2.4.1" @@ -2292,18 +2201,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "sysconf" -version = "0.3.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59e93f5d45535f49b6a05ef7ac2f0f795d28de494cf53a512751602c9849bea3" -dependencies = [ - "errno", - "kernel32-sys", - "libc", - "winapi 0.2.8", -] - [[package]] name = "tempfile" version = "3.3.0" @@ -2390,19 +2287,6 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42657b1a6f4d817cda8e7a0ace261fe0cc946cf3a80314390b22cc61ae080792" -[[package]] -name = "tiny_http" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0d6ef4e10d23c1efb862eecad25c5054429a71958b4eeef85eb5e7170b477ca" -dependencies = [ - "ascii", - "chunked_transfer", - "log", - "time 0.3.9", - "url", -] - [[package]] name = "tinytemplate" version = "1.2.1" @@ -2413,31 +2297,17 @@ dependencies = [ "serde_json", ] -[[package]] -name = "tinyvec" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" -dependencies = [ - "tinyvec_macros", -] - -[[package]] -name = "tinyvec_macros" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" - [[package]] name = "tokio" -version = "1.19.2" +version = "1.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c51a52ed6686dd62c320f9b89299e9dfb46f730c7a48e635c19f21d116cb1439" +checksum = "7a8325f63a7d4774dd041e363b2409ed1c5cbbd0f867795e661df066b2b0a581" dependencies = [ + "autocfg", "bytes 1.1.0", "libc", "memchr", - "mio 0.8.3", + "mio 0.8.4", "num_cpus", "once_cell", "parking_lot", @@ -2576,9 +2446,9 @@ dependencies = [ [[package]] name = "tower" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a89fd63ad6adf737582df5db40d286574513c69a11dac5214dc3b5603d6713e" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" dependencies = [ "futures-core", "futures-util", @@ -2602,9 +2472,9 @@ checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62" [[package]] name = "tower-service" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "360dfd1d6d30e05fda32ace2c8c70e9c0a9da713275777f5a4dbb8a1893930c6" +checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" [[package]] name = "tracing" @@ -2621,9 +2491,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.21" +version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc6b8ad3567499f98a1db7a752b07a7c8c7c7c34c332ec00effb2b0027974b7c" +checksum = "11c75893af559bc8e10716548bdef5cb2b983f8e637db9d0e15126b61b484ee2" dependencies = [ "proc-macro2", "quote", @@ -2632,9 +2502,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.27" +version = "0.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7709595b8878a4965ce5e87ebf880a7d39c9afc6837721b21a5a816a8117d921" +checksum = "7b7358be39f2f274f322d2aaed611acc57f382e8eb1e5b48cb9ae30933495ce7" dependencies = [ "once_cell", ] @@ -2661,27 +2531,12 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" -[[package]] -name = "unicode-bidi" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "099b7128301d285f79ddd55b9a83d5e6b9e97c92e0ea0daebee7263e932de992" - [[package]] name = "unicode-ident" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d22af068fba1eb5edcb4aea19d382b2a3deb4c8f9d475c589b6ada9e0fd493ee" -[[package]] -name = "unicode-normalization" -version = "0.1.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d54590932941a9e9266f0832deed84ebe1bf2e4c9e4a3554d393d18f5e854bf9" -dependencies = [ - "tinyvec", -] - [[package]] name = "unicode-segmentation" version = "1.9.0" @@ -2700,18 +2555,6 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" -[[package]] -name = "url" -version = "2.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507c383b2d33b5fc35d1861e77e6b383d158b2da5e14fe51b83dfedf6fd578c" -dependencies = [ - "form_urlencoded", - "idna", - "matches", - "percent-encoding", -] - [[package]] name = "vec_map" version = "0.8.2" @@ -2724,6 +2567,14 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "waker" +version = "0.1.0" +dependencies = [ + "libc", + "mio 0.8.4", +] + [[package]] name = "walkdir" version = "2.3.2" diff --git a/Cargo.toml b/Cargo.toml index 730230ca4..055a7d7d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,11 +2,14 @@ members = [ "src/common", "src/config", + "src/core/admin", "src/core/proxy", "src/core/server", + "src/core/waker", "src/entrystore", "src/logger", "src/macros", + "src/net", "src/protocol/admin", "src/protocol/common", "src/protocol/memcache", diff --git a/config/pingserver.toml b/config/pingserver.toml index a9a0873dc..339e4e2d4 100644 --- a/config/pingserver.toml +++ b/config/pingserver.toml @@ -18,6 +18,8 @@ nevent = 1024 timeout = 100 # epoll max events returned nevent = 1024 +# the number of worker threads to use +threads = 1 # NOTE: not currently implemented [time] diff --git a/src/common/Cargo.toml b/src/common/Cargo.toml index d18dbabd2..a5b3a4f73 100644 --- a/src/common/Cargo.toml +++ b/src/common/Cargo.toml @@ -13,6 +13,7 @@ license = "Apache-2.0" [dependencies] boring = "2.0.0" serde = { version = "1.0.117", features = ["derive"] } +net = { path = "../net" } macros = { path = "../macros" } [dependencies.rustcommon-metrics] diff --git a/src/common/src/ssl.rs b/src/common/src/ssl.rs index 3205b9fc1..4f1054720 100644 --- a/src/common/src/ssl.rs +++ b/src/common/src/ssl.rs @@ -3,7 +3,8 @@ // http://www.apache.org/licenses/LICENSE-2.0 pub use boring::ssl::*; -use boring::x509::X509; + +use net::TlsTcpAcceptor; use std::io::{Error, ErrorKind}; pub trait TlsConfig { @@ -16,12 +17,12 @@ pub trait TlsConfig { fn ca_file(&self) -> Option; } -/// Create an `SslContext` from the given `TlsConfig`. Returns an error if there -/// was any issues during initialization. Otherwise, returns a `SslContext` -/// wrapped in an option, where the `None` variant indicates that TLS should not -/// be used. -pub fn ssl_context(config: &dyn TlsConfig) -> Result, std::io::Error> { - let mut builder = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls_server())?; +/// Create an `TlsTcpAcceptor` from the given `TlsConfig`. Returns an error if +/// there were any issues during initialization. Otherwise, returns a +/// `TlsTcpAcceptor` wrapped in an option, where the `None` variant indicates +/// that TLS should not be used. +pub fn tls_acceptor(config: &dyn TlsConfig) -> Result, std::io::Error> { + let mut builder = TlsTcpAcceptor::mozilla_intermediate_v5()?; // we use xor here to check if we have an under-specified tls configuration if config.private_key().is_some() @@ -34,9 +35,7 @@ pub fn ssl_context(config: &dyn TlsConfig) -> Result, std::io // // NOTE: this is required, so we return `Ok(None)` if it is not specified if let Some(f) = config.private_key() { - builder - .set_private_key_file(f, SslFiletype::PEM) - .map_err(|_| Error::new(ErrorKind::Other, "bad private key"))?; + builder = builder.private_key_file(f); } else { return Ok(None); } @@ -46,53 +45,16 @@ pub fn ssl_context(config: &dyn TlsConfig) -> Result, std::io // NOTE: this is optional, so we do not return `Ok(None)` when it has not // been specified if let Some(f) = config.ca_file() { - builder - .set_ca_file(f) - .map_err(|_| Error::new(ErrorKind::Other, "bad ca file"))?; + builder = builder.ca_file(f); } - match (config.certificate_chain(), config.certificate()) { - (Some(chain), Some(cert)) => { - // assume we have the leaf in a standalone file, and the - // intermediates + root in another file - - // first load the leaf - builder - .set_certificate_file(cert, SslFiletype::PEM) - .map_err(|_| Error::new(ErrorKind::Other, "bad certificate file"))?; - - // append the rest of the chain - let pem = std::fs::read(chain) - .map_err(|_| Error::new(ErrorKind::Other, "failed to read certificate chain"))?; - let chain = X509::stack_from_pem(&pem) - .map_err(|_| Error::new(ErrorKind::Other, "bad certificate chain"))?; - for cert in chain { - builder - .add_extra_chain_cert(cert) - .map_err(|_| Error::new(ErrorKind::Other, "bad certificate in chain"))?; - } - } - (Some(chain), None) => { - // assume we have a complete chain: leaf + intermediates + root in - // one file + if let Some(f) = config.certificate() { + builder = builder.certificate_file(f); + } - // load the entire chain - builder - .set_certificate_chain_file(chain) - .map_err(|_| Error::new(ErrorKind::Other, "bad certificate chain"))?; - } - (None, Some(cert)) => { - // this will just load the leaf certificate from the file - builder - .set_certificate_file(cert, SslFiletype::PEM) - .map_err(|_| Error::new(ErrorKind::Other, "bad certificate file"))?; - } - (None, None) => { - // if we have neither a chain nor a leaf cert to load, we return no - // ssl context - return Ok(None); - } + if let Some(f) = config.certificate_chain() { + builder = builder.certificate_chain_file(f); } - Ok(Some(builder.build().into_context())) + Ok(Some(builder.build()?)) } diff --git a/src/config/Cargo.toml b/src/config/Cargo.toml index 30e0b66c3..39372789f 100644 --- a/src/config/Cargo.toml +++ b/src/config/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "config" -version = "0.1.0" +version = "0.1.1" authors = ["Brian Martin "] edition = "2018" description = "component configurations for Pelikan" diff --git a/src/config/src/momento_proxy.rs b/src/config/src/momento_proxy.rs index 12593f121..f1004ba24 100644 --- a/src/config/src/momento_proxy.rs +++ b/src/config/src/momento_proxy.rs @@ -21,7 +21,7 @@ impl Default for Protocol { } // struct definitions -#[derive(Clone, Serialize, Deserialize, Debug)] +#[derive(Clone, Serialize, Default, Deserialize, Debug)] pub struct MomentoProxyConfig { // application modules #[serde(default)] @@ -128,16 +128,3 @@ impl KlogConfig for MomentoProxyConfig { &self.klog } } - -// trait implementations -impl Default for MomentoProxyConfig { - fn default() -> Self { - Self { - admin: Default::default(), - proxy: Default::default(), - cache: Default::default(), - debug: Default::default(), - klog: Default::default(), - } - } -} diff --git a/src/config/src/proxy.rs b/src/config/src/proxy.rs index 8a3cdf96c..9b6c61116 100644 --- a/src/config/src/proxy.rs +++ b/src/config/src/proxy.rs @@ -167,18 +167,18 @@ impl Backend { self.zk_endpoint.as_ref(), ) { let mut ret = Vec::new(); - if let Ok(server) = ZooKeeper::connect(&server, Duration::from_secs(15), ExitWatcher) { - if let Ok(children) = server.get_children(&path, true) { + if let Ok(server) = ZooKeeper::connect(server, Duration::from_secs(15), ExitWatcher) { + if let Ok(children) = server.get_children(path, true) { for child in children { let data = server .get_data(&format!("{}/{}", path, child), true) .map(|v| { std::str::from_utf8(&v.0) .map_err(|_| { - return std::io::Error::new( + std::io::Error::new( std::io::ErrorKind::Other, "bad data in zknode", - ); + ) }) .unwrap() .to_owned() diff --git a/src/core/admin/Cargo.toml b/src/core/admin/Cargo.toml new file mode 100644 index 000000000..45a28e9d6 --- /dev/null +++ b/src/core/admin/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "admin" +version = "0.1.0" +edition = "2021" +authors = ["Brian Martin "] +homepage = "https://pelikan.io" +repository = "https://github.com/twitter/pelikan" +license = "Apache-2.0" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +common = { path = "../../common" } +config = { path = "../../config" } +crossbeam-channel = "0.5.0" +entrystore = { path = "../../entrystore" } +libc = "0.2.132" +logger = { path = "../../logger" } +net = { path = "../../net" } +protocol-admin = { path = "../../protocol/admin" } +protocol-common = { path = "../../protocol/common" } +queues = { path = "../../queues" } +rustcommon-metrics = { git = "https://github.com/twitter/rustcommon" } +session = { path = "../../session" } +slab = "0.4.2" +waker = { path = "../waker" } diff --git a/src/core/admin/src/lib.rs b/src/core/admin/src/lib.rs new file mode 100644 index 000000000..d2d8c0508 --- /dev/null +++ b/src/core/admin/src/lib.rs @@ -0,0 +1,494 @@ +// Copyright 2021 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + +use ::net::event::{Event, Source}; +use ::net::*; +use common::signal::Signal; +use common::ssl::tls_acceptor; +use config::{AdminConfig, TlsConfig}; +use crossbeam_channel::Receiver; +use logger::*; +use protocol_admin::*; +use queues::Queues; +use rustcommon_metrics::*; +use session::{Buf, ServerSession, Session}; +use slab::Slab; +use std::collections::VecDeque; +use std::io::{Error, ErrorKind, Result}; +use std::sync::Arc; +use std::time::Duration; +use waker::Waker; + +counter!(ADMIN_REQUEST_PARSE); +counter!(ADMIN_RESPONSE_COMPOSE); +counter!(ADMIN_EVENT_ERROR); +counter!(ADMIN_EVENT_WRITE); +counter!(ADMIN_EVENT_READ); +counter!(ADMIN_EVENT_LOOP); +counter!(ADMIN_EVENT_TOTAL); + +counter!(RU_UTIME); +counter!(RU_STIME); +gauge!(RU_MAXRSS); +gauge!(RU_IXRSS); +gauge!(RU_IDRSS); +gauge!(RU_ISRSS); +counter!(RU_MINFLT); +counter!(RU_MAJFLT); +counter!(RU_NSWAP); +counter!(RU_INBLOCK); +counter!(RU_OUBLOCK); +counter!(RU_MSGSND); +counter!(RU_MSGRCV); +counter!(RU_NSIGNALS); +counter!(RU_NVCSW); +counter!(RU_NIVCSW); + +counter!( + ADMIN_SESSION_ACCEPT, + "total number of attempts to accept a session" +); +counter!( + ADMIN_SESSION_ACCEPT_EX, + "number of times accept resulted in an exception, ignoring attempts that would block" +); +counter!( + ADMIN_SESSION_ACCEPT_OK, + "number of times a session was accepted successfully" +); + +counter!( + ADMIN_SESSION_CLOSE, + "total number of times a session was closed" +); + +gauge!(ADMIN_SESSION_CURR, "current number of admin sessions"); + +// consts + +const LISTENER_TOKEN: Token = Token(usize::MAX - 1); +const WAKER_TOKEN: Token = Token(usize::MAX); + +const KB: u64 = 1024; // one kilobyte in bytes +const S: u64 = 1_000_000_000; // one second in nanoseconds +const US: u64 = 1_000; // one microsecond in nanoseconds + +// helper functions + +fn map_err(e: std::io::Error) -> Result<()> { + match e.kind() { + ErrorKind::WouldBlock => Ok(()), + _ => Err(e), + } +} + +pub struct Admin { + /// A backlog of tokens that need to be handled + backlog: VecDeque, + /// The actual network listener for the ASCII Admin Endpoint + listener: ::net::Listener, + /// The drain handle for the logger + log_drain: Box, + /// The maximum number of events to process per call to poll + nevent: usize, + /// The actual poll instantance + poll: Poll, + /// The sessions which have been opened + sessions: Slab>, + /// A queue for receiving signals from the parent thread + signal_queue_rx: Receiver, + /// A set of queues for sending signals to sibling threads + signal_queue_tx: Queues, + /// The timeout for each call to poll + timeout: Duration, + /// The version of the service + version: String, + /// The waker for this thread + waker: Arc, +} + +pub struct AdminBuilder { + backlog: VecDeque, + listener: ::net::Listener, + nevent: usize, + poll: Poll, + sessions: Slab>, + timeout: Duration, + version: String, + waker: Arc, +} + +impl AdminBuilder { + pub fn new(config: &T) -> Result { + let tls_config = config.tls(); + let config = config.admin(); + + let addr = config.socket_addr().map_err(|e| { + error!("{}", e); + std::io::Error::new(std::io::ErrorKind::Other, "Bad listen address") + })?; + + let tcp_listener = TcpListener::bind(addr)?; + + let mut listener = match (config.use_tls(), tls_acceptor(tls_config)?) { + (true, Some(tls_acceptor)) => ::net::Listener::from((tcp_listener, tls_acceptor)), + _ => ::net::Listener::from(tcp_listener), + }; + + let poll = Poll::new()?; + listener.register(poll.registry(), LISTENER_TOKEN, Interest::READABLE)?; + + let waker = Arc::new(Waker::from( + ::net::Waker::new(poll.registry(), WAKER_TOKEN).unwrap(), + )); + + let nevent = config.nevent(); + let timeout = Duration::from_millis(config.timeout() as u64); + + let sessions = Slab::new(); + + let version = "unknown".to_string(); + + let backlog = VecDeque::new(); + + Ok(Self { + backlog, + listener, + nevent, + poll, + sessions, + timeout, + version, + waker, + }) + } + + pub fn version(&mut self, version: &str) { + self.version = version.to_string(); + } + + pub fn waker(&self) -> Arc { + self.waker.clone() + } + + pub fn build( + self, + log_drain: Box, + signal_queue_rx: Receiver, + signal_queue_tx: Queues, + ) -> Admin { + Admin { + backlog: self.backlog, + listener: self.listener, + log_drain, + nevent: self.nevent, + poll: self.poll, + sessions: self.sessions, + signal_queue_rx, + signal_queue_tx, + timeout: self.timeout, + version: self.version, + waker: self.waker, + } + } +} + +fn get_rusage() { + let mut rusage = libc::rusage { + ru_utime: libc::timeval { + tv_sec: 0, + tv_usec: 0, + }, + ru_stime: libc::timeval { + tv_sec: 0, + tv_usec: 0, + }, + ru_maxrss: 0, + ru_ixrss: 0, + ru_idrss: 0, + ru_isrss: 0, + ru_minflt: 0, + ru_majflt: 0, + ru_nswap: 0, + ru_inblock: 0, + ru_oublock: 0, + ru_msgsnd: 0, + ru_msgrcv: 0, + ru_nsignals: 0, + ru_nvcsw: 0, + ru_nivcsw: 0, + }; + + if unsafe { libc::getrusage(libc::RUSAGE_SELF, &mut rusage) } == 0 { + RU_UTIME.set(rusage.ru_utime.tv_sec as u64 * S + rusage.ru_utime.tv_usec as u64 * US); + RU_STIME.set(rusage.ru_stime.tv_sec as u64 * S + rusage.ru_stime.tv_usec as u64 * US); + RU_MAXRSS.set(rusage.ru_maxrss * KB as i64); + RU_IXRSS.set(rusage.ru_ixrss * KB as i64); + RU_IDRSS.set(rusage.ru_idrss * KB as i64); + RU_ISRSS.set(rusage.ru_isrss * KB as i64); + RU_MINFLT.set(rusage.ru_minflt as u64); + RU_MAJFLT.set(rusage.ru_majflt as u64); + RU_NSWAP.set(rusage.ru_nswap as u64); + RU_INBLOCK.set(rusage.ru_inblock as u64); + RU_OUBLOCK.set(rusage.ru_oublock as u64); + RU_MSGSND.set(rusage.ru_msgsnd as u64); + RU_MSGRCV.set(rusage.ru_msgrcv as u64); + RU_NSIGNALS.set(rusage.ru_nsignals as u64); + RU_NVCSW.set(rusage.ru_nvcsw as u64); + RU_NIVCSW.set(rusage.ru_nivcsw as u64); + } +} + +impl Admin { + /// Call accept one time + fn accept(&mut self) { + ADMIN_SESSION_ACCEPT.increment(); + + match self + .listener + .accept() + .map(|v| ServerSession::new(Session::from(v), AdminRequestParser::default())) + { + Ok(mut session) => { + let s = self.sessions.vacant_entry(); + let interest = session.interest(); + if session + .register(self.poll.registry(), Token(s.key()), interest) + .is_ok() + { + ADMIN_SESSION_ACCEPT_OK.increment(); + ADMIN_SESSION_CURR.increment(); + + s.insert(session); + } else { + // failed to register + ADMIN_SESSION_ACCEPT_EX.increment(); + } + + self.backlog.push_back(LISTENER_TOKEN); + let _ = self.waker.wake(); + } + Err(e) => { + if e.kind() != ErrorKind::WouldBlock { + ADMIN_SESSION_ACCEPT_EX.increment(); + self.backlog.push_back(LISTENER_TOKEN); + let _ = self.waker.wake(); + } + } + } + } + + fn read(&mut self, token: Token) -> Result<()> { + let session = self + .sessions + .get_mut(token.0) + .ok_or_else(|| Error::new(ErrorKind::Other, "non-existant session"))?; + + // fill the session + match session.fill() { + Ok(0) => Err(Error::new(ErrorKind::Other, "client hangup")), + r => r, + }?; + + match session.receive() { + Ok(request) => { + ADMIN_REQUEST_PARSE.increment(); + + // do some request handling + match request { + AdminRequest::FlushAll => { + let _ = self.signal_queue_tx.try_send_all(Signal::FlushAll); + session.send(AdminResponse::Ok)?; + } + AdminRequest::Quit => { + return Err(Error::new(ErrorKind::Other, "should hangup")); + } + AdminRequest::Stats => { + session.send(AdminResponse::Stats)?; + } + AdminRequest::Version => { + session.send(AdminResponse::version(self.version.clone()))?; + } + } + + ADMIN_RESPONSE_COMPOSE.increment(); + + match session.flush() { + Ok(_) => Ok(()), + Err(e) => map_err(e), + }?; + + if session.write_pending() > 0 || session.remaining() > 0 { + let interest = session.interest(); + if session + .reregister(self.poll.registry(), token, interest) + .is_err() + { + return Err(Error::new(ErrorKind::Other, "failed to reregister")); + } + } + Ok(()) + } + Err(e) => match e.kind() { + ErrorKind::WouldBlock => Ok(()), + _ => Err(e), + }, + } + } + + fn write(&mut self, token: Token) -> Result<()> { + let session = self + .sessions + .get_mut(token.0) + .ok_or_else(|| Error::new(ErrorKind::Other, "non-existant session"))?; + + match session.flush() { + Ok(_) => Ok(()), + Err(e) => match e.kind() { + ErrorKind::WouldBlock => Ok(()), + _ => Err(e), + }, + } + } + + /// Closes the session with the given token + fn close(&mut self, token: Token) { + if self.sessions.contains(token.0) { + ADMIN_SESSION_CLOSE.increment(); + ADMIN_SESSION_CURR.decrement(); + + let mut session = self.sessions.remove(token.0); + let _ = session.flush(); + } + } + + fn handshake(&mut self, token: Token) -> Result<()> { + let session = self + .sessions + .get_mut(token.0) + .ok_or_else(|| Error::new(ErrorKind::Other, "non-existant session"))?; + + match session.do_handshake() { + Ok(()) => { + if session.remaining() > 0 { + let interest = session.interest(); + session.reregister(self.poll.registry(), token, interest)?; + Ok(()) + } else { + Ok(()) + } + } + Err(e) => Err(e), + } + } + + /// handle a single session event + fn session_event(&mut self, event: &Event) { + let token = event.token(); + + if event.is_error() { + ADMIN_EVENT_ERROR.increment(); + + self.close(token); + return; + } + + if event.is_writable() { + ADMIN_EVENT_WRITE.increment(); + + if self.write(token).is_err() { + self.close(token); + return; + } + } + + if event.is_readable() { + ADMIN_EVENT_READ.increment(); + + if self.read(token).is_err() { + self.close(token); + return; + } + } + + match self.handshake(token) { + Ok(_) => {} + Err(e) => match e.kind() { + ErrorKind::WouldBlock => {} + _ => { + self.close(token); + } + }, + } + } + + pub fn run(&mut self) { + info!( + "running admin on: {}", + self.listener + .local_addr() + .map(|v| format!("{v}")) + .unwrap_or_else(|_| "unknown address".to_string()) + ); + + let mut events = Events::with_capacity(self.nevent); + + loop { + ADMIN_EVENT_LOOP.increment(); + + get_rusage(); + + if self.poll.poll(&mut events, Some(self.timeout)).is_err() { + error!("Error polling"); + } + + ADMIN_EVENT_TOTAL.add(events.iter().count() as _); + + // handle all events + for event in events.iter() { + match event.token() { + LISTENER_TOKEN => { + self.accept(); + } + WAKER_TOKEN => { + self.waker.reset(); + let tokens: Vec = self.backlog.drain(..).collect(); + for token in tokens { + if token == LISTENER_TOKEN { + self.accept(); + } + } + } + _ => { + self.session_event(event); + } + } + } + + // handle all signals + while let Ok(signal) = self.signal_queue_rx.try_recv() { + match signal { + Signal::FlushAll => {} + Signal::Shutdown => { + // if a shutdown is received from any + // thread, we will broadcast it to all + // sibling threads and stop our event loop + info!("shutting down"); + let _ = self.signal_queue_tx.try_send_all(Signal::Shutdown); + if self.signal_queue_tx.wake().is_err() { + fatal!("error waking threads for shutdown"); + } + let _ = self.log_drain.flush(); + return; + } + } + } + + // flush pending log entries to log destinations + let _ = self.log_drain.flush(); + } + } +} + +common::metrics::test_no_duplicates!(); diff --git a/src/core/proxy/Cargo.toml b/src/core/proxy/Cargo.toml index dcf22fb2a..59aed859b 100644 --- a/src/core/proxy/Cargo.toml +++ b/src/core/proxy/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "proxy" -version = "0.0.1" +version = "0.0.2" edition = "2021" authors = ["Brian Martin "] homepage = "https://pelikan.io" @@ -10,19 +10,17 @@ license = "Apache-2.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +admin = { path = "../admin" } common = { path = "../../common" } config = { path = "../../config" } crossbeam-channel = "0.5.0" -libc = "0.2.83" +entrystore = { path = "../../entrystore" } logger = { path = "../../logger" } -mio = { version = "0.8.0", features = ["os-poll", "net"] } -mpmc = "*" +net = { path = "../../net" } protocol-admin = { path = "../../protocol/admin" } protocol-common = { path = "../../protocol/common" } -rustcommon-metrics = { git = "https://github.com/twitter/rustcommon" } -rustcommon-time = { git = "https://github.com/twitter/rustcommon" } queues = { path = "../../queues" } +rustcommon-metrics = { git = "https://github.com/twitter/rustcommon" } session = { path = "../../session" } -slab = "*" -tiny_http = "0.11.0" - +slab = "0.4.2" +waker = { path = "../waker" } diff --git a/src/core/proxy/src/admin.rs b/src/core/proxy/src/admin.rs deleted file mode 100644 index 59e8c1862..000000000 --- a/src/core/proxy/src/admin.rs +++ /dev/null @@ -1,758 +0,0 @@ -// Copyright 2021 Twitter, Inc. -// Licensed under the Apache License, Version 2.0 -// http://www.apache.org/licenses/LICENSE-2.0 - -//! The admin thread, which handles admin requests to return stats, get version -//! info, etc. - -use crate::event_loop::EventLoop; -use crate::poll::{Poll, LISTENER_TOKEN, WAKER_TOKEN}; -use crate::QUEUE_RETRIES; -use crate::TCP_ACCEPT_EX; -use crate::*; -use common::signal::Signal; -use common::ssl::{HandshakeError, MidHandshakeSslStream, Ssl, SslContext, SslStream}; -use config::*; -use core::time::Duration; -use crossbeam_channel::Receiver; -use logger::Drain; -use mio::event::Event; -use mio::{Events, Token, Waker}; -use protocol_admin::*; -use queues::Queues; -use rustcommon_metrics::*; -use session::{Session, TcpStream}; -use std::io::{BufRead, Error, ErrorKind, Write}; -use std::net::SocketAddr; -use std::sync::Arc; -use tiny_http::{Method, Request, Response}; - -counter!(ADMIN_REQUEST_PARSE); -counter!(ADMIN_RESPONSE_COMPOSE); -counter!(ADMIN_EVENT_ERROR); -counter!(ADMIN_EVENT_WRITE); -counter!(ADMIN_EVENT_READ); -counter!(ADMIN_EVENT_LOOP); -counter!(ADMIN_EVENT_TOTAL); -counter!(RU_UTIME); -counter!(RU_STIME); -gauge!(RU_MAXRSS); -gauge!(RU_IXRSS); -gauge!(RU_IDRSS); -gauge!(RU_ISRSS); -counter!(RU_MINFLT); -counter!(RU_MAJFLT); -counter!(RU_NSWAP); -counter!(RU_INBLOCK); -counter!(RU_OUBLOCK); -counter!(RU_MSGSND); -counter!(RU_MSGRCV); -counter!(RU_NSIGNALS); -counter!(RU_NVCSW); -counter!(RU_NIVCSW); - -const KB: u64 = 1024; // one kilobyte in bytes -const S: u64 = 1_000_000_000; // one second in nanoseconds -const US: u64 = 1_000; // one microsecond in nanoseconds - -pub static PERCENTILES: &[(&str, f64)] = &[ - ("p25", 25.0), - ("p50", 50.0), - ("p75", 75.0), - ("p90", 90.0), - ("p99", 99.0), - ("p999", 99.9), - ("p9999", 99.99), -]; - -pub struct AdminBuilder { - addr: SocketAddr, - nevent: usize, - poll: Poll, - timeout: Duration, - parser: AdminRequestParser, - log_drain: Box, - http_server: Option, - version: String, -} - -impl AdminBuilder { - /// Creates a new `Admin` event loop. - pub fn new( - config: &T, - mut log_drain: Box, - ) -> Result { - let config = config.admin(); - - let addr = config.socket_addr().map_err(|e| { - error!("{}", e); - error!("bad admin listen address"); - let _ = log_drain.flush(); - Error::new(ErrorKind::Other, "bad listen address") - })?; - let mut poll = Poll::new().map_err(|e| { - error!("{}", e); - error!("failed to create epoll instance"); - let _ = log_drain.flush(); - Error::new(ErrorKind::Other, "failed to create epoll instance") - })?; - poll.bind(addr, &Tls::default()).map_err(|e| { - error!("{}", e); - error!("failed to bind admin tcp listener"); - let _ = log_drain.flush(); - Error::new( - ErrorKind::Other, - format!( - "failed to bind listener on: {}:{}", - config.host(), - config.port() - ), - ) - })?; - - let timeout = std::time::Duration::from_millis(config.timeout() as u64); - - let nevent = config.nevent(); - - let http_server = if config.http_enabled() { - let addr = config.http_socket_addr().map_err(|e| { - error!("{}", e); - error!("bad admin http listen address"); - let _ = log_drain.flush(); - Error::new(ErrorKind::Other, "bad listen address") - })?; - let server = tiny_http::Server::http(addr).map_err(|e| { - error!("{}", e); - error!("could not start admin http server"); - let _ = log_drain.flush(); - Error::new(ErrorKind::Other, "failed to create http server") - })?; - Some(server) - } else { - None - }; - - Ok(Self { - addr, - timeout, - nevent, - poll, - parser: AdminRequestParser::new(), - log_drain, - http_server, - version: "unknown".to_string(), - }) - } - - pub fn waker(&self) -> Arc { - self.poll.waker() - } - - /// Triggers a flush of the log - pub fn log_flush(&mut self) -> Result<()> { - self.log_drain.flush() - } - - /// Set the reported version number - pub fn version(&mut self, version: &str) { - self.version = version.to_string(); - } - - pub fn build( - self, - signal_queue_tx: Queues, - signal_queue_rx: Receiver, - ) -> Admin { - Admin { - addr: self.addr, - nevent: self.nevent, - poll: self.poll, - timeout: self.timeout, - parser: self.parser, - log_drain: self.log_drain, - http_server: self.http_server, - signal_queue_tx, - signal_queue_rx, - version: self.version, - } - } -} - -pub struct Admin { - addr: SocketAddr, - nevent: usize, - poll: Poll, - timeout: Duration, - parser: AdminRequestParser, - log_drain: Box, - /// optional http server - http_server: Option, - /// used to send signals to all sibling threads - signal_queue_tx: Queues, - /// used to receive signals from the parent thread - signal_queue_rx: Receiver, - /// version number to report - version: String, -} - -impl Drop for Admin { - fn drop(&mut self) { - let _ = self.log_drain.flush(); - } -} - -impl Admin { - /// Adds a new fully established TLS session - fn add_established_tls_session(&mut self, stream: SslStream) { - let session = Session::tls_with_capacity( - stream, - crate::DEFAULT_BUFFER_SIZE, - crate::ADMIN_MAX_BUFFER_SIZE, - ); - if self.poll.add_session(session).is_err() { - TCP_ACCEPT_EX.increment(); - } - } - - /// Adds a new TLS session that requires further handshaking - fn add_handshaking_tls_session(&mut self, stream: MidHandshakeSslStream) { - let session = Session::handshaking_with_capacity( - stream, - crate::DEFAULT_BUFFER_SIZE, - crate::ADMIN_MAX_BUFFER_SIZE, - ); - trace!("accepted new session: {:?}", session.peer_addr()); - if self.poll.add_session(session).is_err() { - TCP_ACCEPT_EX.increment(); - } - } - - /// Adds a new plain (non-TLS) session - fn add_plain_session(&mut self, stream: TcpStream) { - let session = Session::plain_with_capacity( - stream, - crate::DEFAULT_BUFFER_SIZE, - crate::ADMIN_MAX_BUFFER_SIZE, - ); - trace!("accepted new session: {:?}", session.peer_addr()); - if self.poll.add_session(session).is_err() { - TCP_ACCEPT_EX.increment(); - } - } - - /// Repeatedly call accept on the listener - fn do_accept(&mut self) { - while self.poll.accept().is_ok() {} - } - - /// This is a handler for the stats commands on the legacy admin port. It - /// responses using the Memcached `stats` command response format, each stat - /// appears on its own line with a CR+LF used as end of line symbol. The - /// stats appear in sorted order. - /// - /// ```text - /// STAT get 0 - /// STAT get_cardinality_p25 0 - /// STAT get_cardinality_p50 0 - /// STAT get_cardinality_p75 0 - /// STAT get_cardinality_p90 0 - /// STAT get_cardinality_p99 0 - /// STAT get_cardinality_p999 0 - /// STAT get_cardinality_p9999 0 - /// STAT get_ex 0 - /// STAT get_key 0 - /// STAT get_key_hit 0 - /// STAT get_key_miss 0 - /// ``` - fn handle_stats_request(session: &mut Session) { - ADMIN_REQUEST_PARSE.increment(); - let mut data = Vec::new(); - for metric in &rustcommon_metrics::metrics() { - let any = match metric.as_any() { - Some(any) => any, - None => { - continue; - } - }; - - if let Some(counter) = any.downcast_ref::() { - data.push(format!("STAT {} {}\r\n", metric.name(), counter.value())); - } else if let Some(gauge) = any.downcast_ref::() { - data.push(format!("STAT {} {}\r\n", metric.name(), gauge.value())); - } else if let Some(heatmap) = any.downcast_ref::() { - for (label, value) in PERCENTILES { - let percentile = heatmap.percentile(*value).unwrap_or(0); - data.push(format!( - "STAT {}_{} {}\r\n", - metric.name(), - label, - percentile - )); - } - } - } - - data.sort(); - for line in data { - let _ = session.write(line.as_bytes()); - } - let _ = session.write(b"END\r\n"); - session.finalize_response(); - ADMIN_RESPONSE_COMPOSE.increment(); - } - - fn handle_version_request(session: &mut Session, version: &str) { - let _ = session.write(format!("VERSION {}\r\n", version).as_bytes()); - session.finalize_response(); - ADMIN_RESPONSE_COMPOSE.increment(); - } - - /// Handle an event on an existing session - fn handle_session_event(&mut self, event: &Event) { - let token = event.token(); - trace!("got event for admin session: {}", token.0); - - // handle error events first - if event.is_error() { - ADMIN_EVENT_ERROR.increment(); - self.handle_error(token); - } - - // handle handshaking - if let Ok(session) = self.poll.get_mut_session(token) { - if session.session.is_handshaking() { - if let Err(e) = session.session.do_handshake() { - if e.kind() == ErrorKind::WouldBlock { - // the session is still handshaking - return; - } else { - // some error occured while handshaking - let _ = self.poll.close_session(token); - } - } - } - } - - // handle write events before read events to reduce write - // buffer growth if there is also a readable event - if event.is_writable() { - ADMIN_EVENT_WRITE.increment(); - self.do_write(token); - } - - // read events are handled last - if event.is_readable() { - ADMIN_EVENT_READ.increment(); - let _ = self.do_read(token); - }; - } - - /// A "human-readable" exposition format which outputs one stat per line, - /// with a LF used as the end of line symbol. - /// - /// ```text - /// get: 0 - /// get_cardinality_p25: 0 - /// get_cardinality_p50: 0 - /// get_cardinality_p75: 0 - /// get_cardinality_p90: 0 - /// get_cardinality_p9999: 0 - /// get_cardinality_p999: 0 - /// get_cardinality_p99: 0 - /// get_ex: 0 - /// get_key: 0 - /// get_key_hit: 0 - /// get_key_miss: 0 - /// ``` - fn human_stats(&self) -> String { - let mut data = Vec::new(); - - for metric in &rustcommon_metrics::metrics() { - let any = match metric.as_any() { - Some(any) => any, - None => { - continue; - } - }; - - if let Some(counter) = any.downcast_ref::() { - data.push(format!("{}: {}", metric.name(), counter.value())); - } else if let Some(gauge) = any.downcast_ref::() { - data.push(format!("{}: {}", metric.name(), gauge.value())); - } else if let Some(heatmap) = any.downcast_ref::() { - for (label, value) in PERCENTILES { - let percentile = heatmap.percentile(*value).unwrap_or(0); - data.push(format!("{}_{}: {}", metric.name(), label, percentile)); - } - } - } - - data.sort(); - data.join("\n") + "\n" - } - - /// JSON stats output which follows the conventions found in Finagle and - /// TwitterServer libraries. Percentiles are appended to the metric name, - /// eg: `request_latency_p999` for the 99.9th percentile. For more details - /// about the Finagle / TwitterServer format see: - /// https://twitter.github.io/twitter-server/Features.html#metrics - /// - /// ```text - /// {"get": 0,"get_cardinality_p25": 0,"get_cardinality_p50": 0, ... } - /// ``` - fn json_stats(&self) -> String { - let head = "{".to_owned(); - - let mut data = Vec::new(); - - for metric in &rustcommon_metrics::metrics() { - let any = match metric.as_any() { - Some(any) => any, - None => { - continue; - } - }; - - if let Some(counter) = any.downcast_ref::() { - data.push(format!("\"{}\": {}", metric.name(), counter.value())); - } else if let Some(gauge) = any.downcast_ref::() { - data.push(format!("\"{}\": {}", metric.name(), gauge.value())); - } else if let Some(heatmap) = any.downcast_ref::() { - for (label, value) in PERCENTILES { - let percentile = heatmap.percentile(*value).unwrap_or(0); - data.push(format!("\"{}_{}\": {}", metric.name(), label, percentile)); - } - } - } - - data.sort(); - let body = data.join(","); - let mut content = head; - content += &body; - content += "}"; - content - } - - /// Prometheus / OpenTelemetry compatible stats output. Each stat is - /// annotated with a type. Percentiles use the label 'percentile' to - /// indicate which percentile corresponds to the value: - /// - /// ```text - /// # TYPE get counter - /// get 0 - /// # TYPE get_cardinality gauge - /// get_cardinality{percentile="p25"} 0 - /// # TYPE get_cardinality gauge - /// get_cardinality{percentile="p50"} 0 - /// # TYPE get_cardinality gauge - /// get_cardinality{percentile="p75"} 0 - /// # TYPE get_cardinality gauge - /// get_cardinality{percentile="p90"} 0 - /// # TYPE get_cardinality gauge - /// get_cardinality{percentile="p99"} 0 - /// # TYPE get_cardinality gauge - /// get_cardinality{percentile="p999"} 0 - /// # TYPE get_cardinality gauge - /// get_cardinality{percentile="p9999"} 0 - /// # TYPE get_ex counter - /// get_ex 0 - /// # TYPE get_key counter - /// get_key 0 - /// # TYPE get_key_hit counter - /// get_key_hit 0 - /// # TYPE get_key_miss counter - /// get_key_miss 0 - /// ``` - fn prometheus_stats(&self) -> String { - let mut data = Vec::new(); - - for metric in &rustcommon_metrics::metrics() { - let any = match metric.as_any() { - Some(any) => any, - None => { - continue; - } - }; - - if let Some(counter) = any.downcast_ref::() { - data.push(format!( - "# TYPE {} counter\n{} {}", - metric.name(), - metric.name(), - counter.value() - )); - } else if let Some(gauge) = any.downcast_ref::() { - data.push(format!( - "# TYPE {} gauge\n{} {}", - metric.name(), - metric.name(), - gauge.value() - )); - } else if let Some(heatmap) = any.downcast_ref::() { - for (label, value) in PERCENTILES { - let percentile = heatmap.percentile(*value).unwrap_or(0); - data.push(format!( - "# TYPE {} gauge\n{}{{percentile=\"{}\"}} {}", - metric.name(), - metric.name(), - label, - percentile - )); - } - } - } - data.sort(); - let mut content = data.join("\n"); - content += "\n"; - let parts: Vec<&str> = content.split('/').collect(); - parts.join("_") - } - - /// Handle a HTTP request - fn handle_http_request(&self, request: Request) { - let url = request.url(); - let parts: Vec<&str> = url.split('?').collect(); - let url = parts[0]; - match url { - // Prometheus/OpenTelemetry expect the `/metrics` URI will return - // stats in the Prometheus format - "/metrics" => match request.method() { - Method::Get => { - let _ = request.respond(Response::from_string(self.prometheus_stats())); - } - _ => { - let _ = request.respond(Response::empty(400)); - } - }, - // we export Finagle/TwitterServer format stats on a few endpoints - // for maximum compatibility with various internal conventions - "/metrics.json" | "/vars.json" | "/admin/metrics.json" => match request.method() { - Method::Get => { - let _ = request.respond(Response::from_string(self.json_stats())); - } - _ => { - let _ = request.respond(Response::empty(400)); - } - }, - // human-readable stats are exported on the `/vars` endpoint based - // on internal conventions - "/vars" => match request.method() { - Method::Get => { - let _ = request.respond(Response::from_string(self.human_stats())); - } - _ => { - let _ = request.respond(Response::empty(400)); - } - }, - _ => { - let _ = request.respond(Response::empty(404)); - } - } - } - - /// Runs the `Admin` in a loop, accepting new sessions for the admin - /// listener and handling events on existing sessions. - pub fn run(&mut self) { - info!("running admin on: {}", self.addr); - - let mut events = Events::with_capacity(self.nevent); - - // run in a loop, accepting new sessions and events on existing sessions - loop { - ADMIN_EVENT_LOOP.increment(); - - if self.poll.poll(&mut events, self.timeout).is_err() { - error!("Error polling"); - } - - ADMIN_EVENT_TOTAL.add(events.iter().count() as _); - - // handle all events - for event in events.iter() { - match event.token() { - LISTENER_TOKEN => { - self.do_accept(); - } - WAKER_TOKEN => { - // check if we have received signals from any sibling - // thread - while let Ok(signal) = self.signal_queue_rx.try_recv() { - match signal { - Signal::FlushAll => {} - Signal::Shutdown => { - // if a shutdown is received from any - // thread, we will broadcast it to all - // sibling threads and stop our event loop - info!("shutting down"); - let _ = self.signal_queue_tx.try_send_all(Signal::Shutdown); - if self.signal_queue_tx.wake().is_err() { - fatal!("error waking threads for shutdown"); - } - let _ = self.log_drain.flush(); - return; - } - } - } - } - _ => { - self.handle_session_event(event); - } - } - } - - // handle all http requests if the http server is enabled - if let Some(ref server) = self.http_server { - while let Ok(Some(request)) = server.try_recv() { - self.handle_http_request(request); - } - } - - // handle all signals - while let Ok(signal) = self.signal_queue_rx.try_recv() { - match signal { - Signal::FlushAll => {} - Signal::Shutdown => { - // if a shutdown is received from any - // thread, we will broadcast it to all - // sibling threads and stop our event loop - info!("shutting down"); - let _ = self.signal_queue_tx.try_send_all(Signal::Shutdown); - if self.signal_queue_tx.wake().is_err() { - fatal!("error waking threads for shutdown"); - } - let _ = self.log_drain.flush(); - return; - } - } - } - - // get updated usage - self.get_rusage(); - - // flush pending log entries to log destinations - let _ = self.log_drain.flush(); - } - } - - // TODO(bmartin): move this into a common module, should be shared with - // other backends - pub fn get_rusage(&self) { - let mut rusage = libc::rusage { - ru_utime: libc::timeval { - tv_sec: 0, - tv_usec: 0, - }, - ru_stime: libc::timeval { - tv_sec: 0, - tv_usec: 0, - }, - ru_maxrss: 0, - ru_ixrss: 0, - ru_idrss: 0, - ru_isrss: 0, - ru_minflt: 0, - ru_majflt: 0, - ru_nswap: 0, - ru_inblock: 0, - ru_oublock: 0, - ru_msgsnd: 0, - ru_msgrcv: 0, - ru_nsignals: 0, - ru_nvcsw: 0, - ru_nivcsw: 0, - }; - - if unsafe { libc::getrusage(libc::RUSAGE_SELF, &mut rusage) } == 0 { - RU_UTIME.set(rusage.ru_utime.tv_sec as u64 * S + rusage.ru_utime.tv_usec as u64 * US); - RU_STIME.set(rusage.ru_stime.tv_sec as u64 * S + rusage.ru_stime.tv_usec as u64 * US); - RU_MAXRSS.set(rusage.ru_maxrss * KB as i64); - RU_IXRSS.set(rusage.ru_ixrss * KB as i64); - RU_IDRSS.set(rusage.ru_idrss * KB as i64); - RU_ISRSS.set(rusage.ru_isrss * KB as i64); - RU_MINFLT.set(rusage.ru_minflt as u64); - RU_MAJFLT.set(rusage.ru_majflt as u64); - RU_NSWAP.set(rusage.ru_nswap as u64); - RU_INBLOCK.set(rusage.ru_inblock as u64); - RU_OUBLOCK.set(rusage.ru_oublock as u64); - RU_MSGSND.set(rusage.ru_msgsnd as u64); - RU_MSGRCV.set(rusage.ru_msgrcv as u64); - RU_NSIGNALS.set(rusage.ru_nsignals as u64); - RU_NVCSW.set(rusage.ru_nvcsw as u64); - RU_NIVCSW.set(rusage.ru_nivcsw as u64); - } - } -} - -impl EventLoop for Admin { - fn handle_data(&mut self, token: Token) -> Result<()> { - trace!("handling request for admin session: {}", token.0); - if let Ok(session) = self.poll.get_mut_session(token) { - loop { - if session.session.write_capacity() == 0 { - // if the write buffer is over-full, skip processing - break; - } - match self.parser.parse(session.session.buffer()) { - Ok(parsed_request) => { - let consumed = parsed_request.consumed(); - let request = parsed_request.into_inner(); - session.session.consume(consumed); - - match request { - AdminRequest::FlushAll => { - for _ in 0..QUEUE_RETRIES { - if self.signal_queue_tx.try_send_all(Signal::FlushAll).is_ok() { - warn!("sending flush_all signal"); - break; - } - } - for _ in 0..QUEUE_RETRIES { - if self.signal_queue_tx.wake().is_ok() { - break; - } - } - - let _ = session.session.write(b"OK\r\n"); - session.session.finalize_response(); - ADMIN_RESPONSE_COMPOSE.increment(); - } - AdminRequest::Stats => { - Self::handle_stats_request(&mut session.session); - } - AdminRequest::Quit => { - let _ = self.poll.close_session(token); - return Ok(()); - } - AdminRequest::Version => { - Self::handle_version_request(&mut session.session, &self.version); - } - } - } - Err(ParseError::Incomplete) => { - break; - } - Err(_) => { - self.handle_error(token); - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "bad request", - )); - } - } - } - } else { - // no session for the token - trace!( - "attempted to handle data for non-existent session: {}", - token.0 - ); - return Ok(()); - } - self.poll.reregister(token); - Ok(()) - } - - fn poll(&mut self) -> &mut Poll { - &mut self.poll - } -} diff --git a/src/core/proxy/src/backend.rs b/src/core/proxy/src/backend.rs index d54c89011..2972b4e76 100644 --- a/src/core/proxy/src/backend.rs +++ b/src/core/proxy/src/backend.rs @@ -2,305 +2,307 @@ // Licensed under the Apache License, Version 2.0 // http://www.apache.org/licenses/LICENSE-2.0 +use super::map_result; use crate::*; -use common::signal::Signal; -use config::proxy::BackendConfig; -use core::marker::PhantomData; -use core::time::Duration; -use mio::Waker; -use poll::*; -use protocol_common::*; -use queues::Queues; -use queues::TrackedItem; -use session::Session; -use std::sync::Arc; - -use rustcommon_metrics::*; - -const KB: usize = 1024; - -const SESSION_BUFFER_MIN: usize = 16 * KB; -const SESSION_BUFFER_MAX: usize = 1024 * KB; - -counter!(BACKEND_EVENT_ERROR); -counter!(BACKEND_EVENT_READ); -counter!(BACKEND_EVENT_WRITE); +use session::ClientSession; +use std::collections::HashMap; +use std::collections::VecDeque; + +heatmap!( + BACKEND_EVENT_DEPTH, + 100_000, + "distribution of the number of events received per iteration of the event loop" +); +counter!(BACKEND_EVENT_ERROR, "the number of error events received"); +counter!( + BACKEND_EVENT_LOOP, + "the number of times the event loop has run" +); counter!( BACKEND_EVENT_MAX_REACHED, "the number of times the maximum number of events was returned" ); -heatmap!(BACKEND_EVENT_DEPTH, 100_000); - -pub const QUEUE_RETRIES: usize = 3; +counter!(BACKEND_EVENT_READ, "the number of read events received"); +counter!(BACKEND_EVENT_TOTAL, "the total number of events received"); +counter!(BACKEND_EVENT_WRITE, "the number of write events received"); pub struct BackendWorkerBuilder { - poll: Poll, - parser: Parser, free_queue: VecDeque, nevent: usize, + parser: Parser, + poll: Poll, + sessions: Slab>, timeout: Duration, - _request: PhantomData, - _response: PhantomData, + waker: Arc, } -impl BackendWorkerBuilder { +impl BackendWorkerBuilder +where + Parser: Clone + Parse, + Request: Compose, +{ pub fn new(config: &T, parser: Parser) -> Result { let config = config.backend(); - let mut poll = Poll::new()?; - - let server_endpoints = config.socket_addrs()?; - - let mut free_queue = VecDeque::with_capacity(server_endpoints.len() * config.poolsize()); - - for addr in server_endpoints { - for _ in 0..config.poolsize() { - let connection = std::net::TcpStream::connect(addr).expect("failed to connect"); - connection - .set_nonblocking(true) - .expect("failed to set non-blocking"); - let connection = TcpStream::from_std(connection); - let session = Session::plain_with_capacity( - session::TcpStream::try_from(connection).expect("failed to convert"), - SESSION_BUFFER_MIN, - SESSION_BUFFER_MAX, - ); - if let Ok(token) = poll.add_session(session) { - println!("new backend connection with token: {}", token.0); - free_queue.push_back(token); - } - } + let poll = Poll::new()?; + + let waker = Arc::new(Waker::from( + ::net::Waker::new(poll.registry(), WAKER_TOKEN).unwrap(), + )); + + let nevent = config.nevent(); + let timeout = Duration::from_millis(config.timeout() as u64); + + let mut sessions = Slab::new(); + let mut free_queue = VecDeque::new(); + + for endpoint in config.socket_addrs()? { + let stream = TcpStream::connect(endpoint)?; + let mut session = ClientSession::new(Session::from(stream), parser.clone()); + let s = sessions.vacant_entry(); + let interest = session.interest(); + session + .register(poll.registry(), Token(s.key()), interest) + .expect("failed to register"); + free_queue.push_back(Token(s.key())); + s.insert(session); } Ok(Self { - poll, free_queue, + nevent, parser, - nevent: config.nevent(), - timeout: Duration::from_millis(config.timeout() as u64), - _request: PhantomData, - _response: PhantomData, + poll, + sessions, + timeout, + waker, }) } + pub fn waker(&self) -> Arc { - self.poll.waker() + self.waker.clone() } pub fn build( self, + data_queue: Queues<(Request, Response, Token), (Request, Token)>, signal_queue: Queues<(), Signal>, - queues: Queues, TokenWrapper>, ) -> BackendWorker { BackendWorker { - poll: self.poll, + backlog: VecDeque::new(), + data_queue, free_queue: self.free_queue, - signal_queue, - queues, - parser: self.parser, nevent: self.nevent, + parser: self.parser, + pending: HashMap::new(), + poll: self.poll, + sessions: self.sessions, + signal_queue, timeout: self.timeout, + waker: self.waker, } } } pub struct BackendWorker { - poll: Poll, - queues: Queues, TokenWrapper>, + backlog: VecDeque<(Request, Token)>, + data_queue: Queues<(Request, Response, Token), (Request, Token)>, free_queue: VecDeque, - signal_queue: Queues<(), Signal>, - parser: Parser, nevent: usize, + parser: Parser, + pending: HashMap, + poll: Poll, + sessions: Slab>, + signal_queue: Queues<(), Signal>, timeout: Duration, + waker: Arc, } impl BackendWorker where + Parser: Parse + Clone, Request: Compose, - Parser: Parse, { - #[allow(clippy::match_single_binding)] - pub fn run(mut self) { - let mut events = Events::with_capacity(self.nevent); - let mut requests = Vec::with_capacity(self.nevent); - loop { - let _ = self.poll.poll(&mut events, self.timeout); - for event in &events { - match event.token() { - WAKER_TOKEN => { - self.handle_waker(&mut requests); - if !requests.is_empty() { - let _ = self.poll.waker().wake(); - } - } - _ => { - self.handle_event(event); - } + /// Return the `Session` to the `Listener` to handle flush/close + fn close(&mut self, token: Token) { + if self.sessions.contains(token.0) { + let mut session = self.sessions.remove(token.0); + let _ = session.flush(); + } + } + + /// Handle up to one response for a session + fn read(&mut self, token: Token) -> Result<()> { + let session = self + .sessions + .get_mut(token.0) + .ok_or_else(|| Error::new(ErrorKind::Other, "non-existant session"))?; + + // fill the session + map_result(session.fill())?; + + // process up to one request + match session.receive() { + Ok((request, response)) => { + if let Some(fe_token) = self.pending.remove(&token) { + self.free_queue.push_back(token); + self.data_queue + .try_send_to(0, (request, response, fe_token)) + .map_err(|_| Error::new(ErrorKind::Other, "data queue is full")) + } else { + panic!("corrupted state"); } } - let count = events.iter().count(); - if count == self.nevent { - BACKEND_EVENT_MAX_REACHED.increment(); - } else { - BACKEND_EVENT_DEPTH.increment( - common::time::Instant::>::now(), - count as _, - 1, - ); - } - let _ = self.queues.wake(); + Err(e) => map_err(e), } } - fn handle_event(&mut self, event: &Event) { - let token = event.token(); + /// Handle write by flushing the session + fn write(&mut self, token: Token) -> Result<()> { + let session = self + .sessions + .get_mut(token.0) + .ok_or_else(|| Error::new(ErrorKind::Other, "non-existant session"))?; - // handle error events first - if event.is_error() { - BACKEND_EVENT_ERROR.increment(); - self.handle_error(token); + match session.flush() { + Ok(_) => Ok(()), + Err(e) => map_err(e), } + } - // handle write events before read events to reduce write buffer - // growth if there is also a readable event - if event.is_writable() { - BACKEND_EVENT_WRITE.increment(); - self.do_write(token); - } + /// Run the worker in a loop, handling new events. + pub fn run(&mut self) { + // these are buffers which are re-used in each loop iteration to receive + // events and queue messages + let mut events = Events::with_capacity(self.nevent); + let mut messages = Vec::with_capacity(QUEUE_CAPACITY); + // let mut sessions = Vec::with_capacity(QUEUE_CAPACITY); - // read events are handled last - if event.is_readable() { - BACKEND_EVENT_READ.increment(); - if let Ok(session) = self.poll.get_mut_session(token) { - session.session.set_timestamp(rustcommon_time::Instant::< - rustcommon_time::Nanoseconds, - >::recent()); - } - let _ = self.do_read(token); - } + loop { + BACKEND_EVENT_LOOP.increment(); - if let Ok(session) = self.poll.get_mut_session(token) { - if session.session.read_pending() > 0 { - trace!( - "session: {:?} has {} bytes pending in read buffer", - session.session, - session.session.read_pending() - ); + // get events with timeout + if self.poll.poll(&mut events, Some(self.timeout)).is_err() { + error!("Error polling"); } - if session.session.write_pending() > 0 { - trace!( - "session: {:?} has {} bytes pending in write buffer", - session.session, - session.session.read_pending() - ); - } - } - } - pub fn handle_waker(&mut self, requests: &mut Vec>>) { - // try to get requests from the queue if we don't already - // have a backlog - if requests.is_empty() { - self.queues.try_recv_all(requests); - } + let timestamp = Instant::now(); - // as long as we have free backend connections and we - // have requests from the most recent read of the queue - // we can dispatch requests - while !self.free_queue.is_empty() && !requests.is_empty() { - let backend_token = self.free_queue.pop_front().unwrap(); - let request = requests.remove(0); - - // check if this token is still a valid connection - if let Ok(session) = self.poll.get_mut_session(backend_token) { - if session.token.is_none() && session.sender.is_none() { - let sender = request.sender(); - let request = request.into_inner(); - let token = request.token(); - let request = request.into_inner(); - - session.sender = Some(sender); - session.token = Some(token); - request.compose(&mut session.session); - session.session.finalize_response(); - - if session.session.write_pending() > 0 { - let _ = session.session.flush(); - if session.session.write_pending() > 0 { - self.poll.reregister(token); - } - } - } + let count = events.iter().count(); + BACKEND_EVENT_TOTAL.add(count as _); + if count == self.nevent { + BACKEND_EVENT_MAX_REACHED.increment(); + } else { + BACKEND_EVENT_DEPTH.increment(timestamp, count as _, 1); } - self.poll.reregister(backend_token); - } - } + // process all events + for event in events.iter() { + let token = event.token(); + match token { + WAKER_TOKEN => { + self.waker.reset(); + // handle all pending messages on the data queue + self.data_queue.try_recv_all(&mut messages); + for (request, fe_token) in messages.drain(..).map(|v| v.into_inner()) { + if let Some(be_token) = self.free_queue.pop_front() { + let session = &mut self.sessions[be_token.0]; + if session.send(request).is_err() { + panic!("we don't handle this right now"); + } else { + self.pending.insert(be_token, fe_token); + } + } else { + self.backlog.push_back((request, token)); + } + } - fn handle_session_read(&mut self, token: Token) -> Result<()> { - let s = self.poll.get_mut_session(token)?; - let session = &mut s.session; - match self.parser.parse(session.buffer()) { - Ok(response) => { - let consumed = response.consumed(); - let response = response.into_inner(); - session.consume(consumed); - - let fe_worker = s.sender.take().unwrap(); - let client_token = s.token.take().unwrap(); - - let mut message = TokenWrapper::new(response, client_token); - - for retry in 0..QUEUE_RETRIES { - if let Err(m) = self.queues.try_send_to(fe_worker, message) { - if (retry + 1) == QUEUE_RETRIES { - error!("queue full trying to send response to frontend"); - let _ = self.poll.close_session(token); + // check if we received any signals from the admin thread + while let Some(signal) = + self.signal_queue.try_recv().map(|v| v.into_inner()) + { + match signal { + Signal::FlushAll => {} + Signal::Shutdown => { + // if we received a shutdown, we can return + // and stop processing events + return; + } + } } - // try to wake frontend thread - let _ = self.queues.wake(); - message = m; - } else { - break; } - } + _ => { + if event.is_error() { + BACKEND_EVENT_ERROR.increment(); - self.free_queue.push_back(token); + self.close(token); + continue; + } - let _ = self.queues.wake(); + if event.is_writable() { + BACKEND_EVENT_WRITE.increment(); - Ok(()) - } - Err(ParseError::Incomplete) => { - trace!("incomplete response for session: {:?}", session); - Err(std::io::Error::new( - std::io::ErrorKind::WouldBlock, - "incomplete response", - )) - } - Err(_) => { - debug!("bad response for session: {:?}", session); - trace!("session: {:?} read buffer: {:?}", session, session.buffer()); - let _ = self.poll.close_session(token); - Err(std::io::Error::new( - std::io::ErrorKind::Other, - "bad response", - )) + if self.write(token).is_err() { + self.close(token); + continue; + } + } + + if event.is_readable() { + BACKEND_EVENT_READ.increment(); + + if self.read(token).is_err() { + self.close(token); + continue; + } + } + } + } } + + // wakes the storage thread if necessary + let _ = self.data_queue.wake(); } } } -impl EventLoop for BackendWorker +pub struct BackendBuilder { + builders: Vec>, +} + +impl + BackendBuilder where - Request: Compose, - Parser: Parse, + BackendParser: Parse + Clone, + BackendRequest: Compose, { - fn handle_data(&mut self, token: Token) -> Result<()> { - let _ = self.handle_session_read(token); - Ok(()) + pub fn new( + config: &T, + parser: BackendParser, + threads: usize, + ) -> Result { + let mut builders = Vec::new(); + for _ in 0..threads { + builders.push(BackendWorkerBuilder::new(config, parser.clone())?); + } + Ok(Self { builders }) } - fn poll(&mut self) -> &mut poll::Poll { - &mut self.poll + pub fn wakers(&self) -> Vec> { + self.builders.iter().map(|b| b.waker()).collect() + } + + #[allow(clippy::type_complexity)] + pub fn build( + mut self, + mut data_queues: Vec< + Queues<(BackendRequest, BackendResponse, Token), (BackendRequest, Token)>, + >, + mut signal_queues: Vec>, + ) -> Vec> { + self.builders + .drain(..) + .map(|b| b.build(data_queues.pop().unwrap(), signal_queues.pop().unwrap())) + .collect() } } diff --git a/src/core/proxy/src/event_loop.rs b/src/core/proxy/src/event_loop.rs deleted file mode 100644 index 487ed34b0..000000000 --- a/src/core/proxy/src/event_loop.rs +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2021 Twitter, Inc. -// Licensed under the Apache License, Version 2.0 -// http://www.apache.org/licenses/LICENSE-2.0 - -//! A trait defining common functions for event-based threads which operate on -//! sessions. - -use std::io::{BufRead, ErrorKind, Write}; - -use mio::Token; - -use crate::poll::Poll; - -/// An `EventLoop` describes the functions which must be implemented for a basic -/// event loop and provides some default implementations and helper functions. -pub trait EventLoop { - // the following functions must be implemented - - /// Provides access to the `Poll` structure which allows polling for new - /// readiness events and managing registration for event sources. - fn poll(&mut self) -> &mut Poll; - - /// Handle new data received for the `Session` with the provided `Token`. - /// This will include parsing the incoming data and composing a response. - fn handle_data(&mut self, token: Token) -> Result<(), std::io::Error>; - - /// Handle a read event for the `Session` with the `Token`. - fn do_read(&mut self, token: Token) -> Result<(), ()> { - if let Ok(session) = self.poll().get_mut_session(token) { - // read from session to buffer - match session.session.fill_buf().map(|b| b.len()) { - Ok(0) => { - trace!("hangup for session: {:?}", session.session); - let _ = self.poll().close_session(token); - Err(()) - } - Ok(bytes) => { - trace!("read {} bytes for session: {:?}", bytes, session.session); - if self.handle_data(token).is_err() { - self.handle_error(token); - Err(()) - } else { - Ok(()) - } - } - Err(e) => { - match e.kind() { - ErrorKind::WouldBlock => { - trace!("would block"); - // spurious read - self.poll().reregister(token); - Ok(()) - } - ErrorKind::Interrupted => { - trace!("interrupted"); - self.do_read(token) - } - _ => { - trace!("error reading for session: {:?} {:?}", session.session, e); - // some read error - self.handle_error(token); - Err(()) - } - } - } - } - } else { - warn!("attempted to read from non-existent session: {}", token.0); - Err(()) - } - } - - /// Handle a write event for a `Session` with the `Token`. - fn do_write(&mut self, token: Token) { - if let Ok(session) = self.poll().get_mut_session(token) { - trace!("write for session: {:?}", session.session); - match session.session.flush() { - Ok(_) => { - self.poll().reregister(token); - } - Err(e) => match e.kind() { - ErrorKind::WouldBlock => {} - ErrorKind::Interrupted => self.do_write(token), - _ => { - self.handle_error(token); - } - }, - } - } else { - trace!("attempted to write to non-existent session: {}", token.0) - } - } - - /// Handle errors for the `Session` with the `Token` by logging a message - /// and closing the session. - fn handle_error(&mut self, token: Token) { - if let Ok(session) = self.poll().get_mut_session(token) { - trace!("handling error for session: {:?}", session.session); - let _ = session.session.flush(); - let _ = self.poll().close_session(token); - } else { - trace!( - "attempted to handle error for non-existent session: {}", - token.0 - ) - } - } -} diff --git a/src/core/proxy/src/frontend.rs b/src/core/proxy/src/frontend.rs index a55656e02..fa9be36e5 100644 --- a/src/core/proxy/src/frontend.rs +++ b/src/core/proxy/src/frontend.rs @@ -2,255 +2,388 @@ // Licensed under the Apache License, Version 2.0 // http://www.apache.org/licenses/LICENSE-2.0 +use super::map_result; use crate::*; -use common::signal::Signal; -use common::time::Instant; -use config::proxy::FrontendConfig; -use core::marker::PhantomData; -use core::time::Duration; -use mio::Waker; -use poll::*; -use protocol_common::*; -use queues::Queues; -use session::Session; -use std::sync::Arc; - -use rustcommon_metrics::*; - -counter!(FRONTEND_EVENT_ERROR); -counter!(FRONTEND_EVENT_READ); -counter!(FRONTEND_EVENT_WRITE); + +heatmap!( + FRONTEND_EVENT_DEPTH, + 100_000, + "distribution of the number of events received per iteration of the event loop" +); +counter!(FRONTEND_EVENT_ERROR, "the number of error events received"); +counter!( + FRONTEND_EVENT_LOOP, + "the number of times the event loop has run" +); counter!( FRONTEND_EVENT_MAX_REACHED, "the number of times the maximum number of events was returned" ); -heatmap!(FRONTEND_EVENT_DEPTH, 100_000); - -pub const QUEUE_RETRIES: usize = 3; +counter!(FRONTEND_EVENT_READ, "the number of read events received"); +counter!(FRONTEND_EVENT_TOTAL, "the total number of events received"); +counter!(FRONTEND_EVENT_WRITE, "the number of write events received"); -pub struct FrontendWorkerBuilder { - poll: Poll, - parser: Parser, +pub struct FrontendWorkerBuilder< + FrontendParser, + FrontendRequest, + FrontendResponse, + BackendRequest, + BackendResponse, +> { nevent: usize, + parser: FrontendParser, + poll: Poll, + sessions: Slab>, timeout: Duration, - _request: PhantomData, - _response: PhantomData, + waker: Arc, + _backend_request: PhantomData, + _backend_response: PhantomData, } -impl FrontendWorkerBuilder { - pub fn new(config: &T, parser: Parser) -> Result { +impl + FrontendWorkerBuilder< + FrontendParser, + FrontendRequest, + FrontendResponse, + BackendRequest, + BackendResponse, + > +{ + pub fn new(config: &T, parser: FrontendParser) -> Result { let config = config.frontend(); + let poll = Poll::new()?; + + let waker = Arc::new(Waker::from( + ::net::Waker::new(poll.registry(), WAKER_TOKEN).unwrap(), + )); + + let nevent = config.nevent(); + let timeout = Duration::from_millis(config.timeout() as u64); + Ok(Self { - poll: Poll::new()?, + nevent, parser, - nevent: config.nevent(), - timeout: Duration::from_millis(config.timeout() as u64), - _request: PhantomData, - _response: PhantomData, + poll, + sessions: Slab::new(), + timeout, + waker, + _backend_request: PhantomData, + _backend_response: PhantomData, }) } pub fn waker(&self) -> Arc { - self.poll.waker() + self.waker.clone() } pub fn build( self, + data_queue: Queues<(BackendRequest, Token), (BackendRequest, BackendResponse, Token)>, + session_queue: Queues, signal_queue: Queues<(), Signal>, - connection_queues: Queues<(), Session>, - data_queues: Queues, TokenWrapper>, - ) -> FrontendWorker { + ) -> FrontendWorker< + FrontendParser, + FrontendRequest, + FrontendResponse, + BackendRequest, + BackendResponse, + > { FrontendWorker { - poll: self.poll, - parser: self.parser, + data_queue, nevent: self.nevent, - timeout: self.timeout, + parser: self.parser, + poll: self.poll, + session_queue, + sessions: self.sessions, signal_queue, - connection_queues, - data_queues, + timeout: self.timeout, + waker: self.waker, } } } -pub struct FrontendWorker { - poll: Poll, - parser: Parser, +pub struct FrontendWorker< + FrontendParser, + FrontendRequest, + FrontendResponse, + BackendRequest, + BackendResponse, +> { + data_queue: Queues<(BackendRequest, Token), (BackendRequest, BackendResponse, Token)>, nevent: usize, - timeout: Duration, + parser: FrontendParser, + poll: Poll, + session_queue: Queues, + sessions: Slab>, signal_queue: Queues<(), Signal>, - connection_queues: Queues<(), Session>, - data_queues: Queues, TokenWrapper>, + timeout: Duration, + waker: Arc, } -impl FrontendWorker +impl + FrontendWorker< + FrontendParser, + FrontendRequest, + FrontendResponse, + BackendRequest, + BackendResponse, + > where - Parser: Parse, - Response: Compose, + FrontendParser: Parse + Clone, + FrontendResponse: Compose, + FrontendResponse: From, + BackendRequest: From, + BackendRequest: Compose, + BackendResponse: Compose, { - #[allow(clippy::match_single_binding)] - pub fn run(mut self) { + /// Return the `Session` to the `Listener` to handle flush/close + fn close(&mut self, token: Token) { + if self.sessions.contains(token.0) { + let mut session = self.sessions.remove(token.0).into_inner(); + let _ = session.deregister(self.poll.registry()); + let _ = self.session_queue.try_send_any(session); + let _ = self.session_queue.wake(); + } + } + + /// Handle up to one request for a session + fn read(&mut self, token: Token) -> Result<()> { + let session = self + .sessions + .get_mut(token.0) + .ok_or_else(|| Error::new(ErrorKind::Other, "non-existant session"))?; + + // fill the session + map_result(session.fill())?; + + // process up to one request + match session.receive() { + Ok(request) => self + .data_queue + .try_send_to(0, (BackendRequest::from(request), token)) + .map_err(|_| Error::new(ErrorKind::Other, "data queue is full")), + Err(e) => map_err(e), + } + } + + /// Handle write by flushing the session + fn write(&mut self, token: Token) -> Result<()> { + let session = self + .sessions + .get_mut(token.0) + .ok_or_else(|| Error::new(ErrorKind::Other, "non-existant session"))?; + + match session.flush() { + Ok(_) => Ok(()), + Err(e) => map_err(e), + } + } + + /// Run the worker in a loop, handling new events. + pub fn run(&mut self) { + // these are buffers which are re-used in each loop iteration to receive + // events and queue messages let mut events = Events::with_capacity(self.nevent); - let mut sessions = Vec::with_capacity(self.nevent); - let mut responses = Vec::with_capacity(self.nevent); + let mut messages = Vec::with_capacity(QUEUE_CAPACITY); + loop { - let _ = self.poll.poll(&mut events, self.timeout); - for event in &events { - match event.token() { + FRONTEND_EVENT_LOOP.increment(); + + // get events with timeout + if self.poll.poll(&mut events, Some(self.timeout)).is_err() { + error!("Error polling"); + } + + let timestamp = Instant::now(); + + let count = events.iter().count(); + FRONTEND_EVENT_TOTAL.add(count as _); + if count == self.nevent { + FRONTEND_EVENT_MAX_REACHED.increment(); + } else { + FRONTEND_EVENT_DEPTH.increment(timestamp, count as _, 1); + } + + // process all events + for event in events.iter() { + let token = event.token(); + match token { WAKER_TOKEN => { - self.connection_queues.try_recv_all(&mut sessions); - for session in sessions.drain(..).map(|v| v.into_inner()) { - if self.poll.add_session(session).is_ok() { - trace!("frontend registered new session"); + self.waker.reset(); + // handle up to one new session + if let Some(mut session) = + self.session_queue.try_recv().map(|v| v.into_inner()) + { + let s = self.sessions.vacant_entry(); + let interest = session.interest(); + if session + .register(self.poll.registry(), Token(s.key()), interest) + .is_ok() + { + s.insert(ServerSession::new(session, self.parser.clone())); } else { - warn!("frontend failed to register new session"); + let _ = self.session_queue.try_send_any(session); } + + // trigger a wake-up in case there are more sessions + let _ = self.waker.wake(); } - self.data_queues.try_recv_all(&mut responses); - for response in responses.drain(..).map(|v| v.into_inner()) { - let token = response.token(); - let response = response.into_inner(); - if let Ok(session) = self.poll.get_mut_session(token) { - response.compose(&mut session.session); - session.session.finalize_response(); - - // if we have pending writes, we should attempt to flush the session - // now. if we still have pending bytes, we should re-register to - // remove the read interest. - if session.session.write_pending() > 0 { - let _ = session.session.flush(); - if session.session.write_pending() > 0 { - self.poll.reregister(token); + + // handle all pending messages on the data queue + self.data_queue.try_recv_all(&mut messages); + for (_request, response, token) in + messages.drain(..).map(|v| v.into_inner()) + { + if let Some(session) = self.sessions.get_mut(token.0) { + if response.should_hangup() { + let _ = session.send(FrontendResponse::from(response)); + self.close(token); + continue; + } else if session.send(FrontendResponse::from(response)).is_err() { + self.close(token); + continue; + } else if session.write_pending() > 0 { + let interest = session.interest(); + if session + .reregister(self.poll.registry(), token, interest) + .is_err() + { + self.close(token); + continue; } } + if session.remaining() > 0 && self.read(token).is_err() { + self.close(token); + continue; + } + } + } + + // check if we received any signals from the admin thread + while let Some(signal) = + self.signal_queue.try_recv().map(|v| v.into_inner()) + { + match signal { + Signal::FlushAll => {} + Signal::Shutdown => { + // if we received a shutdown, we can return + // and stop processing events + return; + } } } } _ => { - self.handle_event(event); - } - } - } - let count = events.iter().count(); - if count == self.nevent { - FRONTEND_EVENT_MAX_REACHED.increment(); - } else { - FRONTEND_EVENT_DEPTH.increment( - common::time::Instant::>::now(), - count as _, - 1, - ); - } - let _ = self.data_queues.wake(); - } - } + if event.is_error() { + FRONTEND_EVENT_ERROR.increment(); - fn handle_event(&mut self, event: &Event) { - let token = event.token(); - - // handle error events first - if event.is_error() { - FRONTEND_EVENT_ERROR.increment(); - self.handle_error(token); - } + self.close(token); + continue; + } - // handle write events before read events to reduce write buffer - // growth if there is also a readable event - if event.is_writable() { - FRONTEND_EVENT_WRITE.increment(); - self.do_write(token); - } + if event.is_writable() { + FRONTEND_EVENT_WRITE.increment(); - // read events are handled last - if event.is_readable() { - FRONTEND_EVENT_READ.increment(); - if let Ok(session) = self.poll.get_mut_session(token) { - session.session.set_timestamp(rustcommon_time::Instant::< - rustcommon_time::Nanoseconds, - >::recent()); - } - let _ = self.do_read(token); - } + if self.write(token).is_err() { + self.close(token); + continue; + } + } - if let Ok(session) = self.poll.get_mut_session(token) { - if session.session.read_pending() > 0 { - trace!( - "session: {:?} has {} bytes pending in read buffer", - session.session, - session.session.read_pending() - ); - } - if session.session.write_pending() > 0 { - trace!( - "session: {:?} has {} bytes pending in write buffer", - session.session, - session.session.read_pending() - ); - } - } - } + if event.is_readable() { + FRONTEND_EVENT_READ.increment(); - fn handle_session_read(&mut self, token: Token) -> Result<()> { - let s = self.poll.get_mut_session(token)?; - let session = &mut s.session; - match self.parser.parse(session.buffer()) { - Ok(request) => { - let consumed = request.consumed(); - let request = request.into_inner(); - trace!("parsed request for sesion: {:?}", session); - session.consume(consumed); - let mut message = TokenWrapper::new(request, token); - - for retry in 0..QUEUE_RETRIES { - if let Err(m) = self.data_queues.try_send_any(message) { - if (retry + 1) == QUEUE_RETRIES { - warn!("queue full trying to send message to backend thread"); - let _ = self.poll.close_session(token); + if self.read(token).is_err() { + self.close(token); + continue; + } } - // try to wake backend thread - let _ = self.data_queues.wake(); - message = m; - } else { - break; } } - Ok(()) - } - Err(ParseError::Incomplete) => { - trace!("incomplete request for session: {:?}", session); - Err(std::io::Error::new( - std::io::ErrorKind::WouldBlock, - "incomplete request", - )) - } - Err(_) => { - debug!("bad request for session: {:?}", session); - trace!("session: {:?} read buffer: {:?}", session, session.buffer()); - let _ = self.poll.close_session(token); - Err(std::io::Error::new( - std::io::ErrorKind::Other, - "bad request", - )) } + + // wakes the storage thread if necessary + let _ = self.data_queue.wake(); } } +} - pub fn try_close(&mut self, token: Token) { - let _ = self.poll.remove_session(token); - } +pub struct FrontendBuilder< + FrontendParser, + FrontendRequest, + FrontendResponse, + BackendRequest, + BackendResponse, +> { + builders: Vec< + FrontendWorkerBuilder< + FrontendParser, + FrontendRequest, + FrontendResponse, + BackendRequest, + BackendResponse, + >, + >, } -impl EventLoop for FrontendWorker +impl + FrontendBuilder< + FrontendParser, + FrontendRequest, + FrontendResponse, + BackendRequest, + BackendResponse, + > where - Parser: Parse, - Response: Compose, + FrontendParser: Parse + Clone, + FrontendResponse: Compose, + FrontendResponse: From, + BackendRequest: From, + BackendRequest: Compose, { - fn handle_data(&mut self, token: Token) -> Result<()> { - let _ = self.handle_session_read(token); - Ok(()) + pub fn new( + config: &T, + parser: FrontendParser, + threads: usize, + ) -> Result { + let mut builders = Vec::new(); + for _ in 0..threads { + builders.push(FrontendWorkerBuilder::new(config, parser.clone())?); + } + Ok(Self { builders }) } - fn poll(&mut self) -> &mut poll::Poll { - &mut self.poll + pub fn wakers(&self) -> Vec> { + self.builders.iter().map(|b| b.waker()).collect() + } + + #[allow(clippy::type_complexity)] + pub fn build( + mut self, + mut data_queues: Vec< + Queues<(BackendRequest, Token), (BackendRequest, BackendResponse, Token)>, + >, + mut session_queues: Vec>, + mut signal_queues: Vec>, + ) -> Vec< + FrontendWorker< + FrontendParser, + FrontendRequest, + FrontendResponse, + BackendRequest, + BackendResponse, + >, + > { + self.builders + .drain(..) + .map(|b| { + b.build( + data_queues.pop().unwrap(), + session_queues.pop().unwrap(), + signal_queues.pop().unwrap(), + ) + }) + .collect() } } diff --git a/src/core/proxy/src/lib.rs b/src/core/proxy/src/lib.rs index 7d1dd0c1b..e585b51b1 100644 --- a/src/core/proxy/src/lib.rs +++ b/src/core/proxy/src/lib.rs @@ -8,105 +8,81 @@ #[macro_use] extern crate logger; -use mio::event::Event; -use mio::net::{TcpListener, TcpStream}; -use mio::{Events, Interest, Token}; -use mpmc::Queue; -use poll::Poll; +#[macro_use] +extern crate rustcommon_metrics; + +use ::net::event::{Event, Source}; +use ::net::*; +use admin::AdminBuilder; +use common::signal::Signal; +use common::ssl::tls_acceptor; +use config::proxy::*; +use config::*; +use core::marker::PhantomData; +use core::time::Duration; +use crossbeam_channel::{bounded, Receiver, Sender}; +use entrystore::EntryStore; +use logger::Drain; +use protocol_common::{Compose, Execute, Parse}; +use queues::Queues; +use rustcommon_metrics::*; +use session::{Buf, ServerSession, Session}; use slab::Slab; -use std::collections::VecDeque; -use std::io::*; -use std::net::SocketAddr; +use std::io::{Error, ErrorKind, Result}; +use std::sync::Arc; +use waker::Waker; + +type Instant = rustcommon_metrics::Instant>; -mod admin; mod backend; -mod event_loop; mod frontend; mod listener; -mod poll; mod process; -pub use admin::PERCENTILES; -use backend::BackendWorker; -use event_loop::EventLoop; -use frontend::FrontendWorker; -use listener::Listener; -pub use process::{Process, ProcessBuilder}; - -type Result = std::result::Result; +use backend::BackendBuilder; +use frontend::FrontendBuilder; +use listener::ListenerBuilder; -use rustcommon_metrics::*; - -counter!(TCP_ACCEPT_EX); - -// The default buffer size is matched to the upper-bound on TLS fragment size as -// per RFC 5246 https://datatracker.ietf.org/doc/html/rfc5246#section-6.2.1 -pub const DEFAULT_BUFFER_SIZE: usize = 16 * 1024; // 16KB - -// The admin thread (control plane) sessions use a fixed upper-bound on the -// session buffer size. The max buffer size for data plane sessions are to be -// specified during `Listener` initialization. This allows protocol and config -// specific upper bounds. -const ADMIN_MAX_BUFFER_SIZE: usize = 2 * 1024 * 1024; // 1MB +pub use process::{Process, ProcessBuilder}; // TODO(bmartin): this *should* be plenty safe, the queue should rarely ever be // full, and a single wakeup should drain at least one message and make room for // the response. A stat to prove that this is sufficient would be good. const QUEUE_RETRIES: usize = 3; -const THREAD_PREFIX: &str = "pelikan"; const QUEUE_CAPACITY: usize = 64 * 1024; -#[derive(PartialEq, Copy, Clone, Eq, Debug)] -pub enum ConnectionState { - Open, - HalfClosed, -} +// determines the max number of calls to accept when the listener is ready +const ACCEPT_BATCH: usize = 8; -pub struct ClientConnection { - addr: SocketAddr, - stream: TcpStream, - r_buf: Box<[u8]>, - state: ConnectionState, - pipeline_depth: usize, -} +const LISTENER_TOKEN: Token = Token(usize::MAX - 1); +const WAKER_TOKEN: Token = Token(usize::MAX); -impl ClientConnection { - #[allow(clippy::slow_vector_initialization)] - pub fn new(addr: SocketAddr, stream: TcpStream) -> Self { - let mut r_buf = Vec::with_capacity(16384); - r_buf.resize(16384, 0); - let r_buf = r_buf.into_boxed_slice(); - - Self { - addr, - stream, - r_buf, - state: ConnectionState::Open, - pipeline_depth: 0, - } - } +const THREAD_PREFIX: &str = "pelikan"; - pub fn do_read(&mut self) -> Result { - self.stream.read(&mut self.r_buf) +pub static PERCENTILES: &[(&str, f64)] = &[ + ("p25", 25.0), + ("p50", 50.0), + ("p75", 75.0), + ("p90", 90.0), + ("p99", 99.0), + ("p999", 99.9), + ("p9999", 99.99), +]; + +fn map_err(e: std::io::Error) -> Result<()> { + match e.kind() { + ErrorKind::WouldBlock => Ok(()), + _ => Err(e), } } -pub struct TokenWrapper { - inner: T, - token: Token, -} - -impl TokenWrapper { - pub fn new(inner: T, token: Token) -> Self { - Self { inner, token } - } - - pub fn token(&self) -> Token { - self.token - } - - pub fn into_inner(self) -> T { - self.inner +fn map_result(result: Result) -> Result<()> { + match result { + Ok(0) => Err(Error::new(ErrorKind::Other, "client hangup")), + Ok(_) => Ok(()), + Err(e) => map_err(e), } } + +common::metrics::test_no_duplicates!(); diff --git a/src/core/proxy/src/listener.rs b/src/core/proxy/src/listener.rs index 43feedde2..2369b5af8 100644 --- a/src/core/proxy/src/listener.rs +++ b/src/core/proxy/src/listener.rs @@ -3,31 +3,30 @@ // http://www.apache.org/licenses/LICENSE-2.0 use crate::*; -use config::proxy::ListenerConfig; -use config::TlsConfig; -use core::time::Duration; -use mio::Waker; -use poll::*; -use queues::Queues; -use session::Session; -use std::sync::Arc; - use rustcommon_metrics::*; +use std::time::Duration; -const KB: usize = 1024; - -const SESSION_BUFFER_MIN: usize = 16 * KB; -const SESSION_BUFFER_MAX: usize = 1024 * KB; +counter!(LISTENER_EVENT_ERROR, "the number of error events received"); +counter!( + LISTENER_EVENT_LOOP, + "the number of times the event loop has run" +); +counter!(LISTENER_EVENT_READ, "the number of read events received"); +counter!(LISTENER_EVENT_TOTAL, "the total number of events received"); +counter!(LISTENER_EVENT_WRITE, "the number of write events received"); -counter!(LISTENER_EVENT_ERROR); -counter!(LISTENER_EVENT_READ); -counter!(LISTENER_EVENT_WRITE); +counter!( + LISTENER_SESSION_DISCARD, + "the number of sessions discarded by the listener" +); pub struct ListenerBuilder { - addr: SocketAddr, + listener: ::net::Listener, nevent: usize, poll: Poll, + sessions: Slab, timeout: Duration, + waker: Arc, } impl ListenerBuilder { @@ -35,149 +34,283 @@ impl ListenerBuilder { let tls_config = config.tls(); let config = config.listener(); - let addr = config - .socket_addr() - .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "bad listen address"))?; + let addr = config.socket_addr().map_err(|e| { + error!("{}", e); + std::io::Error::new(std::io::ErrorKind::Other, "Bad listen address") + })?; + + let tcp_listener = TcpListener::bind(addr)?; + + let mut listener = if let Some(tls_acceptor) = tls_acceptor(tls_config)? { + ::net::Listener::from((tcp_listener, tls_acceptor)) + } else { + ::net::Listener::from(tcp_listener) + }; + + let poll = Poll::new()?; + listener.register(poll.registry(), LISTENER_TOKEN, Interest::READABLE)?; + + let waker = Arc::new(Waker::from( + ::net::Waker::new(poll.registry(), WAKER_TOKEN).unwrap(), + )); + let nevent = config.nevent(); let timeout = Duration::from_millis(config.timeout() as u64); - let mut poll = Poll::new()?; - poll.bind(addr, tls_config)?; + let sessions = Slab::new(); Ok(Self { - addr, + listener, nevent, poll, + sessions, timeout, + waker, }) } pub fn waker(&self) -> Arc { - self.poll.waker() + self.waker.clone() } - pub fn build(self, connection_queues: Queues) -> Listener { + pub fn build( + self, + signal_queue: Queues<(), Signal>, + session_queue: Queues, + ) -> Listener { Listener { - addr: self.addr, - connection_queues, + listener: self.listener, nevent: self.nevent, poll: self.poll, + sessions: self.sessions, + session_queue, + signal_queue, timeout: self.timeout, + waker: self.waker, } } } pub struct Listener { - addr: SocketAddr, - connection_queues: Queues, + /// The actual network listener server + listener: ::net::Listener, + /// The maximum number of events to process per call to poll nevent: usize, + /// The actual poll instantance poll: Poll, + /// Sessions which have been opened, but are not fully established + sessions: Slab, + /// Queues for sending established sessions to the worker thread(s) and to + /// receive sessions which should be closed + session_queue: Queues, + /// Queue for receieving signals from the admin thread + signal_queue: Queues<(), Signal>, + /// The timeout for each call to poll timeout: Duration, + /// The waker handle for this thread + waker: Arc, } impl Listener { - /// Handle an event on an existing session - fn handle_session_event(&mut self, event: &Event) { - let token = event.token(); - - // handle error events first - if event.is_error() { - LISTENER_EVENT_ERROR.increment(); - self.handle_error(token); + /// Accept new sessions + fn accept(&mut self) { + for _ in 0..ACCEPT_BATCH { + if let Ok(mut session) = self.listener.accept().map(Session::from) { + if session.is_handshaking() { + let s = self.sessions.vacant_entry(); + let interest = session.interest(); + if session + .register(self.poll.registry(), Token(s.key()), interest) + .is_ok() + { + s.insert(session); + } else { + // failed to register + } + } else { + for attempt in 1..=QUEUE_RETRIES { + if let Err(s) = self.session_queue.try_send_any(session) { + if attempt == QUEUE_RETRIES { + LISTENER_SESSION_DISCARD.increment(); + } else { + let _ = self.session_queue.wake(); + } + session = s; + } else { + break; + } + } + // if pushing to the session queues fails, the session will be + // closed on drop here + } + } else { + return; + } } - // handle write events before read events to reduce write - // buffer growth if there is also a readable event - if event.is_writable() { - LISTENER_EVENT_WRITE.increment(); - self.do_write(token); + // reregister is needed here so we will call accept if there is a backlog + if self + .listener + .reregister(self.poll.registry(), LISTENER_TOKEN, Interest::READABLE) + .is_err() + { + // failed to reregister listener? how do we handle this? } + } - // read events are handled last - if event.is_readable() { - LISTENER_EVENT_READ.increment(); - let _ = self.do_read(token); - } + /// Handle a read event for the `Session` with the `Token`. This primarily + /// just checks that there wasn't a hangup, as indicated by a zero-sized + /// return from `read()`. + fn read(&mut self, token: Token) -> Result<()> { + let session = self + .sessions + .get_mut(token.0) + .ok_or_else(|| Error::new(ErrorKind::Other, "non-existant session"))?; - if let Ok(session) = self.poll.get_mut_session(token) { - if session.session.do_handshake().is_ok() { - trace!("handshake complete for session: {:?}", session.session); - if let Ok(session) = self.poll.remove_session(token) { - if self - .connection_queues - .try_send_any(session.session) - .is_err() - { - error!("error sending session to worker"); - TCP_ACCEPT_EX.increment(); + // read from session to buffer + match session.fill() { + Ok(0) => { + // zero-length reads indicate remote side has closed connection + trace!("hangup for session: {:?}", session); + Err(Error::new(ErrorKind::Other, "client hangup")) + } + Ok(bytes) => { + trace!("read {} bytes for session: {:?}", bytes, session); + Ok(()) + } + Err(e) => { + match e.kind() { + ErrorKind::WouldBlock => { + // spurious read, ignore + Ok(()) } - } else { - error!("error removing session from poller"); - TCP_ACCEPT_EX.increment(); + _ => Err(e), } - } else { - trace!("handshake incomplete for session: {:?}", session.session); } } } - pub fn do_accept(&mut self) { - if let Ok(token) = self.poll.accept() { - match self - .poll - .get_mut_session(token) - .map(|v| v.session.is_handshaking()) - { - Ok(false) => { - if let Ok(session) = self.poll.remove_session(token) { - if self - .connection_queues - .try_send_any(session.session) - .is_err() - { - warn!("rejecting connection, client connection queue is too full"); + /// Closes the session with the given token + fn close(&mut self, token: Token) { + if self.sessions.contains(token.0) { + let mut session = self.sessions.remove(token.0); + let _ = session.flush(); + } + } + + fn handshake(&mut self, token: Token) -> Result<()> { + let session = self + .sessions + .get_mut(token.0) + .ok_or_else(|| Error::new(ErrorKind::Other, "non-existant session"))?; + + session.do_handshake() + } + + /// handle a single session event + fn session_event(&mut self, event: &Event) { + let token = event.token(); + + if event.is_error() { + LISTENER_EVENT_ERROR.increment(); + self.close(token); + return; + } + + if event.is_readable() { + LISTENER_EVENT_READ.increment(); + if self.read(token).is_err() { + self.close(token); + return; + } + } + + match self.handshake(token) { + Ok(_) => { + // handshake is complete, send the session to a worker thread + let mut session = self.sessions.remove(token.0); + for attempt in 1..=QUEUE_RETRIES { + if let Err(s) = self.session_queue.try_send_any(session) { + if attempt == QUEUE_RETRIES { + LISTENER_SESSION_DISCARD.increment(); } else { - trace!("sending new connection to worker threads"); + let _ = self.session_queue.wake(); } + session = s; + } else { + break; } } - Ok(true) => {} - Err(e) => { - warn!("error checking if new session is handshaking: {}", e); - } + // if pushing to the session queues fails, the session will be + // closed on drop here } + Err(e) => match e.kind() { + ErrorKind::WouldBlock => {} + _ => { + self.close(token); + } + }, } - self.poll.reregister(LISTENER_TOKEN); - let _ = self.connection_queues.wake(); } - pub fn run(mut self) { - info!("running listener on: {}", self.addr); + pub fn run(&mut self) { + info!( + "running server on: {}", + self.listener + .local_addr() + .map(|v| format!("{v}")) + .unwrap_or_else(|_| "unknown address".to_string()) + ); let mut events = Events::with_capacity(self.nevent); + + // repeatedly run accepting new connections and moving them to the worker loop { - let _ = self.poll.poll(&mut events, self.timeout); - for event in &events { + LISTENER_EVENT_LOOP.increment(); + if self.poll.poll(&mut events, Some(self.timeout)).is_err() { + error!("Error polling server"); + } + LISTENER_EVENT_TOTAL.add(events.iter().count() as _); + + // handle all events + for event in events.iter() { match event.token() { LISTENER_TOKEN => { - self.do_accept(); + self.accept(); + } + WAKER_TOKEN => { + self.waker.reset(); + // handle any closing sessions + if let Some(mut session) = + self.session_queue.try_recv().map(|v| v.into_inner()) + { + let _ = session.flush(); + + // wakeup to handle the possibility of more sessions + let _ = self.waker.wake(); + } + + // check if we received any signals from the admin thread + while let Some(signal) = + self.signal_queue.try_recv().map(|v| v.into_inner()) + { + match signal { + Signal::FlushAll => {} + Signal::Shutdown => { + // if we received a shutdown, we can return + // and stop processing events + return; + } + } + } } - WAKER_TOKEN => {} _ => { - self.handle_session_event(event); + self.session_event(event); } } } - } - } -} - -impl EventLoop for Listener { - fn handle_data(&mut self, _token: Token) -> Result<()> { - Ok(()) - } - fn poll(&mut self) -> &mut Poll { - &mut self.poll + let _ = self.session_queue.wake(); + } } } diff --git a/src/core/proxy/src/poll.rs b/src/core/proxy/src/poll.rs deleted file mode 100644 index 26f3d8b13..000000000 --- a/src/core/proxy/src/poll.rs +++ /dev/null @@ -1,263 +0,0 @@ -// Copyright 2021 Twitter, Inc. -// Licensed under the Apache License, Version 2.0 -// http://www.apache.org/licenses/LICENSE-2.0 - -//! This module provides common functionality for threads which are based on an -//! event loop. - -use crate::TCP_ACCEPT_EX; -use common::ssl::*; -use mio::event::Source; -use mio::{Events, Interest, Token, Waker}; -use session::{Session, TcpStream}; -use slab::Slab; -use std::convert::TryFrom; -use std::net::SocketAddr; -use std::sync::Arc; -use std::time::Duration; - -pub const LISTENER_TOKEN: Token = Token(usize::MAX - 1); -pub const WAKER_TOKEN: Token = Token(usize::MAX); - -const KB: usize = 1024; - -const SESSION_BUFFER_MIN: usize = 16 * KB; -const SESSION_BUFFER_MAX: usize = 1024 * KB; - -struct TcpListener { - inner: mio::net::TcpListener, - ssl_context: Option, -} - -impl TcpListener { - pub fn bind(addr: SocketAddr, tls_config: &dyn TlsConfig) -> Result { - let listener = mio::net::TcpListener::bind(addr).map_err(|e| { - error!("{}", e); - std::io::Error::new(std::io::ErrorKind::Other, "failed to start tcp listener") - })?; - - let ssl_context = common::ssl::ssl_context(tls_config)?; - - Ok(Self { - inner: listener, - ssl_context, - }) - } -} - -pub struct Poll { - listener: Option, - poll: mio::Poll, - sessions: Slab, - waker: Arc, -} - -pub struct TrackedSession { - pub session: Session, - pub sender: Option, - pub token: Option, -} - -impl Poll { - /// Create a new `Poll` instance. - pub fn new() -> Result { - let poll = mio::Poll::new().map_err(|e| { - error!("{}", e); - std::io::Error::new(std::io::ErrorKind::Other, "failed to create poll instance") - })?; - - let waker = Arc::new(Waker::new(poll.registry(), WAKER_TOKEN).unwrap()); - - let sessions = Slab::::new(); - - Ok(Self { - listener: None, - poll, - sessions, - waker, - }) - } - - /// Bind and begin listening on the provided address. - pub fn bind( - &mut self, - addr: SocketAddr, - tls_config: &dyn TlsConfig, - ) -> Result<(), std::io::Error> { - let mut listener = TcpListener::bind(addr, tls_config).map_err(|e| { - error!("{}", e); - std::io::Error::new(std::io::ErrorKind::Other, "failed to start tcp listener") - })?; - - // register listener to event loop - self.poll - .registry() - .register(&mut listener.inner, LISTENER_TOKEN, Interest::READABLE) - .map_err(|e| { - error!("{}", e); - std::io::Error::new( - std::io::ErrorKind::Other, - "failed to register listener with epoll", - ) - })?; - - self.listener = Some(listener); - - Ok(()) - } - - /// Get a copy of the `Waker` for this `Poll` instance - pub fn waker(&self) -> Arc { - self.waker.clone() - } - - pub fn poll(&mut self, events: &mut Events, timeout: Duration) -> Result<(), std::io::Error> { - self.poll.poll(events, Some(timeout)) - } - - pub fn accept(&mut self) -> Result { - if let Some(ref mut listener) = self.listener { - let (stream, _addr) = listener.inner.accept()?; - - // disable Nagle's algorithm - let _ = stream.set_nodelay(true); - - let stream = TcpStream::try_from(stream)?; - - let session = if let Some(ssl_context) = &listener.ssl_context { - match Ssl::new(ssl_context).map(|v| v.accept(stream)) { - // handle case where we have a fully-negotiated - // TLS stream on accept() - Ok(Ok(stream)) => { - Session::tls_with_capacity(stream, SESSION_BUFFER_MIN, SESSION_BUFFER_MAX) - } - // handle case where further negotiation is - // needed - Ok(Err(HandshakeError::WouldBlock(stream))) => { - Session::handshaking_with_capacity( - stream, - SESSION_BUFFER_MIN, - SESSION_BUFFER_MAX, - ) - } - // some other error has occurred and we drop the - // stream - Ok(Err(e)) => { - error!("accept failed: {}", e); - TCP_ACCEPT_EX.increment(); - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "accept failed", - )); - } - Err(e) => { - error!("accept failed: {}", e); - TCP_ACCEPT_EX.increment(); - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "accept failed", - )); - } - } - } else { - Session::plain_with_capacity(stream, SESSION_BUFFER_MIN, SESSION_BUFFER_MAX) - }; - - self.add_session(session) - } else { - Err(std::io::Error::new( - std::io::ErrorKind::Other, - "not listening", - )) - } - } - - // Session methods - - /// Add a new session - pub fn add_session(&mut self, session: Session) -> Result { - let s = self.sessions.vacant_entry(); - let token = Token(s.key()); - let mut session = TrackedSession { - session, - sender: None, - token: None, - }; - session.session.set_token(token); - session.session.register(&self.poll)?; - s.insert(session); - Ok(token) - } - - /// Close an existing session - pub fn close_session(&mut self, token: Token) -> Result<(), std::io::Error> { - let mut session = self.remove_session(token)?; - trace!("closing session: {:?}", session.session); - session.session.close(); - Ok(()) - } - - /// Remove a session from the poller and return it to the caller - pub fn remove_session(&mut self, token: Token) -> Result { - let mut session = self.take_session(token)?; - trace!("removing session: {:?}", session.session); - session.session.deregister(&self.poll)?; - Ok(session) - } - - pub fn get_mut_session(&mut self, token: Token) -> Result<&mut TrackedSession, std::io::Error> { - self.sessions - .get_mut(token.0) - .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "no such session")) - } - - fn take_session(&mut self, token: Token) -> Result { - if self.sessions.contains(token.0) { - let session = self.sessions.remove(token.0); - Ok(session) - } else { - Err(std::io::Error::new( - std::io::ErrorKind::Other, - "no such session", - )) - } - } - - pub fn reregister(&mut self, token: Token) { - match token { - LISTENER_TOKEN => { - if let Some(ref mut listener) = self.listener { - if listener - .inner - .reregister(self.poll.registry(), LISTENER_TOKEN, Interest::READABLE) - .is_err() - { - warn!("reregister of listener failed, attempting to recover"); - let _ = listener.inner.deregister(self.poll.registry()); - if listener - .inner - .register(self.poll.registry(), LISTENER_TOKEN, Interest::READABLE) - .is_err() - { - panic!("reregister of listener failed and was unrecoverable"); - } - } - } - } - WAKER_TOKEN => { - trace!("reregister of waker token is not supported"); - } - _ => { - if let Some(session) = self.sessions.get_mut(token.0) { - trace!("reregistering session: {:?}", session.session); - if session.session.reregister(&self.poll).is_err() { - error!("failed to reregister session"); - let _ = self.close_session(token); - } - } else { - trace!("attempted to reregister non-existent session: {}", token.0); - } - } - } - } -} diff --git a/src/core/proxy/src/process.rs b/src/core/proxy/src/process.rs index b14e8b1a6..328fd2db5 100644 --- a/src/core/proxy/src/process.rs +++ b/src/core/proxy/src/process.rs @@ -2,171 +2,173 @@ // Licensed under the Apache License, Version 2.0 // http://www.apache.org/licenses/LICENSE-2.0 -use crate::admin::Admin; -use crate::admin::AdminBuilder; -use crate::backend::BackendWorkerBuilder; -use crate::frontend::FrontendWorkerBuilder; -use crate::listener::ListenerBuilder; use crate::*; -use common::signal::Signal; -use config::proxy::{BackendConfig, FrontendConfig, ListenerConfig}; -use config::AdminConfig; -use config::ServerConfig; -use config::TlsConfig; -use crossbeam_channel::bounded; -use crossbeam_channel::Sender; -use logger::Drain; -use mio::Waker; -use protocol_common::*; -use queues::Queues; -use std::sync::Arc; +use config::proxy::BackendConfig; +use config::proxy::FrontendConfig; +use config::proxy::ListenerConfig; use std::thread::JoinHandle; -pub const FRONTEND_THREADS: usize = 1; -pub const BACKEND_THREADS: usize = 1; -pub const BACKEND_POOLSIZE: usize = 1; - -pub struct ProcessBuilder { - admin: Admin, - listener: Listener, - frontends: Vec>, - backends: Vec>, - signal_tx: Sender, +pub struct ProcessBuilder< + BackendParser, + BackendRequest, + BackendResponse, + FrontendParser, + FrontendRequest, + FrontendResponse, +> { + admin: AdminBuilder, + backend: BackendBuilder, + frontend: FrontendBuilder< + FrontendParser, + FrontendRequest, + FrontendResponse, + BackendRequest, + BackendResponse, + >, + listener: ListenerBuilder, + log_drain: Box, } -impl - ProcessBuilder +impl< + BackendParser, + BackendRequest, + BackendResponse, + FrontendParser, + FrontendRequest, + FrontendResponse, + > + ProcessBuilder< + BackendParser, + BackendRequest, + BackendResponse, + FrontendParser, + FrontendRequest, + FrontendResponse, + > where - RequestParser: 'static + Clone + Send + Parse, - Request: 'static + Send + Compose, - ResponseParser: 'static + Clone + Send + Parse, - Response: 'static + Send + Compose, + BackendParser: 'static + Parse + Clone + Send, + BackendRequest: 'static + Send + Compose + From + Compose, + BackendResponse: 'static + Compose + Send, + FrontendParser: 'static + Parse + Clone + Send, + FrontendRequest: 'static + Send, + FrontendResponse: 'static + Compose + Send, + FrontendResponse: From + Compose, { - pub fn new( - config: T, - request_parser: RequestParser, - response_parser: ResponseParser, + pub fn new( + config: &T, log_drain: Box, + backend_parser: BackendParser, + frontend_parser: FrontendParser, ) -> Result { - // initialize the clock - common::time::refresh_clock(); - - let admin_builder = AdminBuilder::new(&config, log_drain).unwrap_or_else(|e| { - error!("failed to initialize admin: {}", e); - std::process::exit(1); - }); - let admin_waker = admin_builder.waker(); + let admin = AdminBuilder::new(config)?; + let backend = BackendBuilder::new(config, backend_parser, 1)?; + let frontend = FrontendBuilder::new(config, frontend_parser, 1)?; + let listener = ListenerBuilder::new(config)?; - let listener_builder = ListenerBuilder::new(&config)?; - let listener_waker = listener_builder.waker(); - - let mut frontend_builders = Vec::new(); - for _ in 0..config.frontend().threads() { - frontend_builders.push(FrontendWorkerBuilder::new(&config, request_parser.clone())?); - } - let frontend_wakers: Vec> = - frontend_builders.iter().map(|v| v.waker()).collect(); + Ok(Self { + admin, + backend, + frontend, + listener, + log_drain, + }) + } - let mut backend_builders = Vec::new(); - for _ in 0..config.backend().threads() { - backend_builders.push(BackendWorkerBuilder::new(&config, response_parser.clone())?); - } - let backend_wakers: Vec> = backend_builders.iter().map(|v| v.waker()).collect(); + pub fn version(mut self, version: &str) -> Self { + self.admin.version(version); + self + } - let mut thread_wakers = vec![listener_waker.clone()]; - thread_wakers.extend_from_slice(&backend_wakers); - thread_wakers.extend_from_slice(&frontend_wakers); + pub fn spawn(self) -> Process { + let mut thread_wakers = vec![self.listener.waker()]; + thread_wakers.extend_from_slice(&self.backend.wakers()); + thread_wakers.extend_from_slice(&self.frontend.wakers()); // channel for the parent `Process` to send `Signal`s to the admin thread let (signal_tx, signal_rx) = bounded(QUEUE_CAPACITY); // queues for the `Admin` to send `Signal`s to all sibling threads let (mut signal_queue_tx, mut signal_queue_rx) = - Queues::new(vec![admin_waker], thread_wakers, QUEUE_CAPACITY); + Queues::new(vec![self.admin.waker()], thread_wakers, QUEUE_CAPACITY); - let (mut queues_listener_session, mut queues_worker_session) = Queues::new( - vec![listener_waker], - frontend_wakers.clone(), + // queues for the `Listener` to send `Session`s to the worker threads + let (mut listener_session_queues, worker_session_queues) = Queues::new( + vec![self.listener.waker()], + self.frontend.wakers(), QUEUE_CAPACITY, ); - let (mut queues_frontend_data, mut queues_backend_data) = - Queues::new(frontend_wakers, backend_wakers, QUEUE_CAPACITY); - let backends: Vec> = backend_builders - .drain(..) - .map(|v| v.build(signal_queue_rx.remove(0), queues_backend_data.remove(0))) - .collect(); + let (fe_data_queues, be_data_queues) = Queues::new( + self.frontend.wakers(), + self.backend.wakers(), + QUEUE_CAPACITY, + ); - let frontends: Vec> = frontend_builders - .drain(..) - .map(|v| { - v.build( - signal_queue_rx.remove(0), - queues_worker_session.remove(0), - queues_frontend_data.remove(0), - ) - }) - .collect(); - let listener = listener_builder.build(queues_listener_session.remove(0)); + let mut admin = self + .admin + .build(self.log_drain, signal_rx, signal_queue_tx.remove(0)); - let admin = admin_builder.build(signal_queue_tx.remove(0), signal_rx); + let mut listener = self + .listener + .build(signal_queue_rx.remove(0), listener_session_queues.remove(0)); - Ok(Self { - admin, - listener, - frontends, - backends, - signal_tx, - }) - } + let be_threads = be_data_queues.len(); + + let mut backend_workers = self.backend.build( + be_data_queues, + signal_queue_rx.drain(0..be_threads).collect(), + ); + let mut frontend_workers = + self.frontend + .build(fe_data_queues, worker_session_queues, signal_queue_rx); - #[allow(clippy::vec_init_then_push)] - pub fn spawn(mut self) -> Process { let admin = std::thread::Builder::new() - .name("pelikan_admin".to_string()) - .spawn(move || self.admin.run()) + .name(format!("{}_admin", THREAD_PREFIX)) + .spawn(move || admin.run()) .unwrap(); let listener = std::thread::Builder::new() - .name("pelikan_listener".to_string()) - .spawn(move || self.listener.run()) + .name(format!("{}_listener", THREAD_PREFIX)) + .spawn(move || listener.run()) .unwrap(); - let mut frontend = Vec::new(); - for (id, fe) in self.frontends.drain(..).enumerate() { - frontend.push( + let backend = backend_workers + .drain(..) + .enumerate() + .map(|(i, mut b)| { std::thread::Builder::new() - .name(format!("pelikan_fe_{}", id)) - .spawn(move || fe.run()) - .unwrap(), - ) - } + .name(format!("{}_be_{}", THREAD_PREFIX, i)) + .spawn(move || b.run()) + .unwrap() + }) + .collect(); - let mut backend = Vec::new(); - for (id, be) in self.backends.drain(..).enumerate() { - backend.push( + let frontend = frontend_workers + .drain(..) + .enumerate() + .map(|(i, mut f)| { std::thread::Builder::new() - .name(format!("pelikan_be_{}", id)) - .spawn(move || be.run()) - .unwrap(), - ) - } + .name(format!("{}_fe_{}", THREAD_PREFIX, i)) + .spawn(move || f.run()) + .unwrap() + }) + .collect(); Process { admin, - listener, - frontend, backend, - signal_tx: self.signal_tx, + frontend, + listener, + signal_tx, } } } pub struct Process { admin: JoinHandle<()>, - listener: JoinHandle<()>, backend: Vec>, frontend: Vec>, + listener: JoinHandle<()>, signal_tx: Sender, } diff --git a/src/core/server/Cargo.toml b/src/core/server/Cargo.toml index f4a2c9ac0..98bac19c3 100644 --- a/src/core/server/Cargo.toml +++ b/src/core/server/Cargo.toml @@ -1,42 +1,25 @@ [package] name = "server" -version = "0.1.0" +version = "0.2.0" +edition = "2021" authors = ["Brian Martin "] -edition = "2018" description = "core server event loops and threads for Pelikan servers" homepage = "https://pelikan.io" repository = "https://github.com/twitter/pelikan" license = "Apache-2.0" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] -ahash = "0.6.2" -backtrace = "0.3.56" -bytes = "1.0.1" +admin = { path = "../admin" } common = { path = "../../common" } config = { path = "../../config" } crossbeam-channel = "0.5.0" -libc = "0.2.83" +entrystore = { path = "../../entrystore" } logger = { path = "../../logger" } -mio = { version = "0.8.0", features = ["os-poll", "net"] } +net = { path = "../../net" } protocol-admin = { path = "../../protocol/admin" } protocol-common = { path = "../../protocol/common" } queues = { path = "../../queues" } -rand = "0.8.0" -rtrb = "0.1.3" -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0.64" +rustcommon-metrics = { git = "https://github.com/twitter/rustcommon" } session = { path = "../../session" } -entrystore = { path = "../../entrystore" } slab = "0.4.2" -strum = "0.20.0" -strum_macros = "0.20.1" -sysconf = "0.3.4" -thiserror = "1.0.23" -tiny_http = "0.11.0" -rustcommon-metrics = { git = "https://github.com/twitter/rustcommon" } - -[dev-dependencies] -criterion = "0.3" - +waker = { path = "../waker" } diff --git a/src/core/server/src/lib.rs b/src/core/server/src/lib.rs index 81fa79290..a2135af83 100644 --- a/src/core/server/src/lib.rs +++ b/src/core/server/src/lib.rs @@ -91,33 +91,70 @@ #[macro_use] extern crate logger; -mod poll; -mod process; -mod threads; - -pub use process::{Process, ProcessBuilder}; -pub use threads::PERCENTILES; - +use ::net::event::{Event, Source}; +use ::net::*; +use admin::AdminBuilder; +use common::signal::Signal; +use common::ssl::tls_acceptor; +use config::*; +use core::marker::PhantomData; +use core::time::Duration; +use crossbeam_channel::{bounded, Sender}; +use entrystore::EntryStore; +use logger::{Drain, Klog}; +use protocol_common::{Compose, Execute, Parse}; +use queues::Queues; use rustcommon_metrics::*; +use session::{Buf, ServerSession, Session}; +use slab::Slab; +use std::io::{Error, ErrorKind, Result}; +use std::sync::Arc; +use waker::Waker; -counter!(TCP_ACCEPT_EX); +mod listener; +mod process; +mod workers; + +use listener::ListenerBuilder; +use workers::WorkersBuilder; -// The default buffer size is matched to the upper-bound on TLS fragment size as -// per RFC 5246 https://datatracker.ietf.org/doc/html/rfc5246#section-6.2.1 -pub const DEFAULT_BUFFER_SIZE: usize = 16 * 1024; // 16KB +pub use process::{Process, ProcessBuilder}; -// The admin thread (control plane) sessions use a fixed upper-bound on the -// session buffer size. The max buffer size for data plane sessions are to be -// specified during `Listener` initialization. This allows protocol and config -// specific upper bounds. -const ADMIN_MAX_BUFFER_SIZE: usize = 2 * 1024 * 1024; // 1MB +type Instant = rustcommon_metrics::Instant>; // TODO(bmartin): this *should* be plenty safe, the queue should rarely ever be // full, and a single wakeup should drain at least one message and make room for // the response. A stat to prove that this is sufficient would be good. const QUEUE_RETRIES: usize = 3; -const THREAD_PREFIX: &str = "pelikan"; const QUEUE_CAPACITY: usize = 64 * 1024; +// determines the max number of calls to accept when the listener is ready +const ACCEPT_BATCH: usize = 8; + +const LISTENER_TOKEN: Token = Token(usize::MAX - 1); +const WAKER_TOKEN: Token = Token(usize::MAX); + +const THREAD_PREFIX: &str = "pelikan"; + +pub static PERCENTILES: &[(&str, f64)] = &[ + ("p25", 25.0), + ("p50", 50.0), + ("p75", 75.0), + ("p90", 90.0), + ("p99", 99.0), + ("p999", 99.9), + ("p9999", 99.99), +]; + +// stats +counter!(PROCESS_REQ); + +fn map_err(e: std::io::Error) -> Result<()> { + match e.kind() { + ErrorKind::WouldBlock => Ok(()), + _ => Err(e), + } +} + common::metrics::test_no_duplicates!(); diff --git a/src/core/server/src/listener.rs b/src/core/server/src/listener.rs new file mode 100644 index 000000000..4aaff0098 --- /dev/null +++ b/src/core/server/src/listener.rs @@ -0,0 +1,315 @@ +// Copyright 2021 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + +use crate::*; +use std::time::Duration; + +counter!(LISTENER_EVENT_ERROR, "the number of error events received"); +counter!( + LISTENER_EVENT_LOOP, + "the number of times the event loop has run" +); +counter!(LISTENER_EVENT_READ, "the number of read events received"); +counter!(LISTENER_EVENT_TOTAL, "the total number of events received"); +counter!(LISTENER_EVENT_WRITE, "the number of write events received"); + +counter!( + LISTENER_SESSION_DISCARD, + "the number of sessions discarded by the listener" +); + +pub struct Listener { + /// The actual network listener server + listener: ::net::Listener, + /// The maximum number of events to process per call to poll + nevent: usize, + /// The actual poll instantance + poll: Poll, + /// Sessions which have been opened, but are not fully established + sessions: Slab, + /// Queues for sending established sessions to the worker thread(s) and to + /// receive sessions which should be closed + session_queue: Queues, + /// Queue for receieving signals from the admin thread + signal_queue: Queues<(), Signal>, + /// The timeout for each call to poll + timeout: Duration, + /// The waker handle for this thread + waker: Arc, +} + +pub struct ListenerBuilder { + listener: ::net::Listener, + nevent: usize, + poll: Poll, + sessions: Slab, + timeout: Duration, + waker: Arc, +} + +impl ListenerBuilder { + pub fn new(config: &T) -> Result { + let tls_config = config.tls(); + let config = config.server(); + + let addr = config.socket_addr().map_err(|e| { + error!("{}", e); + std::io::Error::new(std::io::ErrorKind::Other, "Bad listen address") + })?; + + let tcp_listener = TcpListener::bind(addr)?; + + let mut listener = if let Some(tls_acceptor) = tls_acceptor(tls_config)? { + ::net::Listener::from((tcp_listener, tls_acceptor)) + } else { + ::net::Listener::from(tcp_listener) + }; + + let poll = Poll::new()?; + listener.register(poll.registry(), LISTENER_TOKEN, Interest::READABLE)?; + + let waker = Arc::new(Waker::from( + ::net::Waker::new(poll.registry(), WAKER_TOKEN).unwrap(), + )); + + let nevent = config.nevent(); + let timeout = Duration::from_millis(config.timeout() as u64); + + let sessions = Slab::new(); + + Ok(Self { + listener, + nevent, + poll, + sessions, + timeout, + waker, + }) + } + + pub fn waker(&self) -> Arc { + self.waker.clone() + } + + pub fn build( + self, + signal_queue: Queues<(), Signal>, + session_queue: Queues, + ) -> Listener { + Listener { + listener: self.listener, + nevent: self.nevent, + poll: self.poll, + sessions: self.sessions, + session_queue, + signal_queue, + timeout: self.timeout, + waker: self.waker, + } + } +} + +impl Listener { + /// Accept new sessions + fn accept(&mut self) { + for _ in 0..ACCEPT_BATCH { + if let Ok(mut session) = self.listener.accept().map(Session::from) { + if session.is_handshaking() { + let s = self.sessions.vacant_entry(); + let interest = session.interest(); + if session + .register(self.poll.registry(), Token(s.key()), interest) + .is_ok() + { + s.insert(session); + } else { + // failed to register + } + } else { + for attempt in 1..=QUEUE_RETRIES { + if let Err(s) = self.session_queue.try_send_any(session) { + if attempt == QUEUE_RETRIES { + LISTENER_SESSION_DISCARD.increment(); + } else { + let _ = self.session_queue.wake(); + } + session = s; + } else { + break; + } + } + // if pushing to the session queues fails, the session will be + // closed on drop here + } + } else { + return; + } + } + + // reregister is needed here so we will call accept if there is a backlog + if self + .listener + .reregister(self.poll.registry(), LISTENER_TOKEN, Interest::READABLE) + .is_err() + { + // failed to reregister listener? how do we handle this? + } + } + + /// Handle a read event for the `Session` with the `Token`. This primarily + /// just checks that there wasn't a hangup, as indicated by a zero-sized + /// return from `read()`. + fn read(&mut self, token: Token) -> Result<()> { + let session = self + .sessions + .get_mut(token.0) + .ok_or_else(|| Error::new(ErrorKind::Other, "non-existant session"))?; + + // read from session to buffer + match session.fill() { + Ok(0) => { + // zero-length reads indicate remote side has closed connection + trace!("hangup for session: {:?}", session); + Err(Error::new(ErrorKind::Other, "client hangup")) + } + Ok(bytes) => { + trace!("read {} bytes for session: {:?}", bytes, session); + Ok(()) + } + Err(e) => { + match e.kind() { + ErrorKind::WouldBlock => { + // spurious read, ignore + Ok(()) + } + _ => Err(e), + } + } + } + } + + /// Closes the session with the given token + fn close(&mut self, token: Token) { + if self.sessions.contains(token.0) { + let mut session = self.sessions.remove(token.0); + let _ = session.flush(); + } + } + + fn handshake(&mut self, token: Token) -> Result<()> { + let session = self + .sessions + .get_mut(token.0) + .ok_or_else(|| Error::new(ErrorKind::Other, "non-existant session"))?; + + session.do_handshake() + } + + /// handle a single session event + fn session_event(&mut self, event: &Event) { + let token = event.token(); + + if event.is_error() { + LISTENER_EVENT_ERROR.increment(); + self.close(token); + return; + } + + if event.is_readable() { + LISTENER_EVENT_READ.increment(); + if self.read(token).is_err() { + self.close(token); + return; + } + } + + match self.handshake(token) { + Ok(_) => { + // handshake is complete, send the session to a worker thread + let mut session = self.sessions.remove(token.0); + for attempt in 1..=QUEUE_RETRIES { + if let Err(s) = self.session_queue.try_send_any(session) { + if attempt == QUEUE_RETRIES { + LISTENER_SESSION_DISCARD.increment(); + } else { + let _ = self.session_queue.wake(); + } + session = s; + } else { + break; + } + } + // if pushing to the session queues fails, the session will be + // closed on drop here + } + Err(e) => match e.kind() { + ErrorKind::WouldBlock => {} + _ => { + self.close(token); + } + }, + } + } + + pub fn run(&mut self) { + info!( + "running server on: {}", + self.listener + .local_addr() + .map(|v| format!("{v}")) + .unwrap_or_else(|_| "unknown address".to_string()) + ); + + let mut events = Events::with_capacity(self.nevent); + + // repeatedly run accepting new connections and moving them to the worker + loop { + LISTENER_EVENT_LOOP.increment(); + if self.poll.poll(&mut events, Some(self.timeout)).is_err() { + error!("Error polling server"); + } + LISTENER_EVENT_TOTAL.add(events.iter().count() as _); + + // handle all events + for event in events.iter() { + match event.token() { + LISTENER_TOKEN => { + self.accept(); + } + WAKER_TOKEN => { + self.waker.reset(); + // handle any closing sessions + if let Some(mut session) = + self.session_queue.try_recv().map(|v| v.into_inner()) + { + let _ = session.flush(); + + // wakeup to handle the possibility of more sessions + let _ = self.waker.wake(); + } + + // check if we received any signals from the admin thread + while let Some(signal) = + self.signal_queue.try_recv().map(|v| v.into_inner()) + { + match signal { + Signal::FlushAll => {} + Signal::Shutdown => { + // if we received a shutdown, we can return + // and stop processing events + return; + } + } + } + } + _ => { + self.session_event(event); + } + } + } + + let _ = self.session_queue.wake(); + } + } +} diff --git a/src/core/server/src/poll/mod.rs b/src/core/server/src/poll/mod.rs deleted file mode 100644 index b2e1577f1..000000000 --- a/src/core/server/src/poll/mod.rs +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2021 Twitter, Inc. -// Licensed under the Apache License, Version 2.0 -// http://www.apache.org/licenses/LICENSE-2.0 - -//! This module provides common functionality for threads which are based on an -//! event loop. - -use mio::event::Source; -use mio::net::TcpListener; -use mio::Events; -use mio::Interest; -use mio::Token; -use mio::Waker; -use session::Session; -use session::TcpStream; -use slab::Slab; -use std::convert::TryFrom; -use std::net::SocketAddr; -use std::sync::Arc; -use std::time::Duration; - -pub const LISTENER_TOKEN: Token = Token(usize::MAX - 1); -pub const WAKER_TOKEN: Token = Token(usize::MAX); - -pub struct Poll { - listener: Option, - poll: mio::Poll, - sessions: Slab, - waker: Arc, -} - -impl Poll { - /// Create a new `Poll` instance. - pub fn new() -> Result { - let poll = mio::Poll::new().map_err(|e| { - error!("{}", e); - std::io::Error::new(std::io::ErrorKind::Other, "failed to create poll instance") - })?; - - let waker = Arc::new(Waker::new(poll.registry(), WAKER_TOKEN).unwrap()); - - let sessions = Slab::::new(); - - Ok(Self { - listener: None, - poll, - sessions, - waker, - }) - } - - /// Bind and begin listening on the provided address. - pub fn bind(&mut self, addr: SocketAddr) -> Result<(), std::io::Error> { - let mut listener = TcpListener::bind(addr).map_err(|e| { - error!("{}", e); - std::io::Error::new(std::io::ErrorKind::Other, "failed to start tcp listener") - })?; - - // register listener to event loop - self.poll - .registry() - .register(&mut listener, LISTENER_TOKEN, Interest::READABLE) - .map_err(|e| { - error!("{}", e); - std::io::Error::new( - std::io::ErrorKind::Other, - "failed to register listener with epoll", - ) - })?; - - self.listener = Some(listener); - - Ok(()) - } - - /// Get a copy of the `Waker` for this `Poll` instance - pub fn waker(&self) -> Arc { - self.waker.clone() - } - - pub fn poll(&mut self, events: &mut Events, timeout: Duration) -> Result<(), std::io::Error> { - self.poll.poll(events, Some(timeout)) - } - - pub fn accept(&mut self) -> Result<(TcpStream, SocketAddr), std::io::Error> { - if let Some(ref mut listener) = self.listener { - let (stream, addr) = listener.accept()?; - - // disable Nagle's algorithm - let _ = stream.set_nodelay(true); - - let stream = TcpStream::try_from(stream)?; - Ok((stream, addr)) - } else { - Err(std::io::Error::new( - std::io::ErrorKind::Other, - "not listening", - )) - } - } - - // Session methods - - /// Add a new session - pub fn add_session(&mut self, mut session: Session) -> Result { - let s = self.sessions.vacant_entry(); - let token = Token(s.key()); - session.set_token(token); - session.register(&self.poll)?; - s.insert(session); - Ok(token) - } - - /// Close an existing session - pub fn close_session(&mut self, token: Token) -> Result<(), std::io::Error> { - let mut session = self.remove_session(token)?; - trace!("closing session: {:?}", session); - session.close(); - Ok(()) - } - - /// Remove a session from the poller and return it to the caller - pub fn remove_session(&mut self, token: Token) -> Result { - let mut session = self.take_session(token)?; - trace!("removing session: {:?}", session); - session.deregister(&self.poll)?; - Ok(session) - } - - pub fn get_mut_session(&mut self, token: Token) -> Result<&mut Session, std::io::Error> { - self.sessions - .get_mut(token.0) - .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "no such session")) - } - - fn take_session(&mut self, token: Token) -> Result { - if self.sessions.contains(token.0) { - let session = self.sessions.remove(token.0); - Ok(session) - } else { - Err(std::io::Error::new( - std::io::ErrorKind::Other, - "no such session", - )) - } - } - - pub fn reregister(&mut self, token: Token) { - match token { - LISTENER_TOKEN => { - if let Some(ref mut listener) = self.listener { - if listener - .reregister(self.poll.registry(), LISTENER_TOKEN, Interest::READABLE) - .is_err() - { - warn!("reregister of listener failed, attempting to recover"); - let _ = listener.deregister(self.poll.registry()); - if listener - .register(self.poll.registry(), LISTENER_TOKEN, Interest::READABLE) - .is_err() - { - fatal!("reregister of listener failed and was unrecoverable"); - } - } - } - } - WAKER_TOKEN => { - trace!("reregister of waker token is not supported"); - } - _ => { - if let Some(session) = self.sessions.get_mut(token.0) { - trace!("reregistering session: {:?}", session); - if session.reregister(&self.poll).is_err() { - error!("failed to reregister session"); - let _ = self.close_session(token); - } - } else { - trace!("attempted to reregister non-existent session: {}", token.0); - } - } - } - } -} diff --git a/src/core/server/src/process.rs b/src/core/server/src/process.rs new file mode 100644 index 000000000..4d98b7a2e --- /dev/null +++ b/src/core/server/src/process.rs @@ -0,0 +1,129 @@ +// Copyright 2021 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + +use crate::*; +use std::thread::JoinHandle; + +pub struct ProcessBuilder { + admin: AdminBuilder, + listener: ListenerBuilder, + log_drain: Box, + workers: WorkersBuilder, +} + +impl ProcessBuilder +where + Parser: 'static + Parse + Clone + Send, + Request: 'static + Klog + Klog + Send, + Response: 'static + Compose + Send, + Storage: 'static + Execute + EntryStore + Send, +{ + pub fn new( + config: &T, + log_drain: Box, + parser: Parser, + storage: Storage, + ) -> Result { + let admin = AdminBuilder::new(config)?; + let listener = ListenerBuilder::new(config)?; + let workers = WorkersBuilder::new(config, parser, storage)?; + + Ok(Self { + admin, + listener, + log_drain, + workers, + }) + } + + pub fn version(mut self, version: &str) -> Self { + self.admin.version(version); + self + } + + pub fn spawn(self) -> Process { + let mut thread_wakers = vec![self.listener.waker()]; + thread_wakers.extend_from_slice(&self.workers.wakers()); + + // channel for the parent `Process` to send `Signal`s to the admin thread + let (signal_tx, signal_rx) = bounded(QUEUE_CAPACITY); + + // queues for the `Admin` to send `Signal`s to all sibling threads + let (mut signal_queue_tx, mut signal_queue_rx) = + Queues::new(vec![self.admin.waker()], thread_wakers, QUEUE_CAPACITY); + + // queues for the `Listener` to send `Session`s to the worker threads + let (mut listener_session_queues, worker_session_queues) = Queues::new( + vec![self.listener.waker()], + self.workers.worker_wakers(), + QUEUE_CAPACITY, + ); + + let mut admin = self + .admin + .build(self.log_drain, signal_rx, signal_queue_tx.remove(0)); + + let mut listener = self + .listener + .build(signal_queue_rx.remove(0), listener_session_queues.remove(0)); + + let workers = self.workers.build(worker_session_queues, signal_queue_rx); + + let admin = std::thread::Builder::new() + .name(format!("{}_admin", THREAD_PREFIX)) + .spawn(move || admin.run()) + .unwrap(); + + let listener = std::thread::Builder::new() + .name(format!("{}_listener", THREAD_PREFIX)) + .spawn(move || listener.run()) + .unwrap(); + + let workers = workers.spawn(); + + Process { + admin, + listener, + signal_tx, + workers, + } + } +} + +pub struct Process { + admin: JoinHandle<()>, + listener: JoinHandle<()>, + signal_tx: Sender, + workers: Vec>, +} + +impl Process { + /// Attempts to gracefully shutdown the `Process` by sending a shutdown to + /// each thread and then waiting to join those threads. + /// + /// Will terminate ungracefully if it encounters an error in sending a + /// shutdown to any of the threads. + /// + /// This function will block until all threads have terminated. + pub fn shutdown(self) { + // this sends a shutdown to the admin thread, which will broadcast the + // signal to all sibling threads in the process + if self.signal_tx.try_send(Signal::Shutdown).is_err() { + fatal!("error sending shutdown signal to thread"); + } + + // wait and join all threads + self.wait() + } + + /// Will block until all threads terminate. This should be used to keep the + /// process alive while the child threads run. + pub fn wait(self) { + for thread in self.workers { + let _ = thread.join(); + } + let _ = self.listener.join(); + let _ = self.admin.join(); + } +} diff --git a/src/core/server/src/process/mod.rs b/src/core/server/src/process/mod.rs deleted file mode 100644 index ddf8272e1..000000000 --- a/src/core/server/src/process/mod.rs +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2021 Twitter, Inc. -// Licensed under the Apache License, Version 2.0 -// http://www.apache.org/licenses/LICENSE-2.0 - -//! This module defines the server process as a collection of threads which can -//! be spawned and provides a `Process` type which is used as a control handle -//! to shutdown or wait on the threads. - -use crate::threads::*; -use crate::{QUEUE_CAPACITY, THREAD_PREFIX}; -use common::signal::Signal; -use config::*; -use crossbeam_channel::{bounded, Sender}; -use logger::Drain; -use queues::Queues; -use std::thread::JoinHandle; - -use entrystore::EntryStore; -use protocol_common::{Compose, Execute, Parse}; - -/// A builder for Pelikan server processes. -pub struct ProcessBuilder -where - Storage: Execute + EntryStore + Send, - Parser: Parse + Clone + Send, - Request: Send, - Response: Compose + std::marker::Send, -{ - admin: AdminBuilder, - listener: ListenerBuilder, - workers: WorkersBuilder, -} - -impl - ProcessBuilder -where - Storage: Execute + EntryStore + Send, - Parser: Parse + Clone + Send, - Request: Send, - Response: Compose + std::marker::Send, -{ - /// Creates a new `ProcessBuilder` - /// - /// This function will terminate the program execution if there are any - /// issues encountered while initializing the components. - pub fn new( - config: T, - storage: Storage, - max_buffer_size: usize, - parser: Parser, - mut log_drain: Box, - ) -> Self { - // initialize admin - let ssl_context = common::ssl::ssl_context(config.tls()).unwrap_or_else(|e| { - error!("failed to initialize TLS: {}", e); - let _ = log_drain.flush(); - std::process::exit(1); - }); - - let mut admin = AdminBuilder::new(&config, ssl_context, log_drain).unwrap_or_else(|e| { - error!("failed to initialize admin: {}", e); - std::process::exit(1); - }); - - let workers = WorkersBuilder::new(&config, storage, parser).unwrap_or_else(|e| { - error!("failed to initialize workers: {}", e); - let _ = admin.log_flush(); - std::process::exit(1); - }); - - // initialize server - let ssl_context = common::ssl::ssl_context(config.tls()).unwrap_or_else(|e| { - error!("failed to initialize TLS: {}", e); - let _ = admin.log_flush(); - std::process::exit(1); - }); - let listener = - ListenerBuilder::new(&config, ssl_context, max_buffer_size).unwrap_or_else(|e| { - error!("failed to initialize listener: {}", e); - let _ = admin.log_flush(); - std::process::exit(1); - }); - - Self { - admin, - listener, - workers, - } - } - - pub fn version(mut self, version: &str) -> Self { - self.admin.version(version); - self - } - - /// Convert the `ProcessBuilder` to a running `Process` by spawning the - /// threads for each component. Returns a `Process` which serves as a - /// control handle for the threads. - pub fn spawn(self) -> Process { - let mut thread_wakers = vec![self.listener.waker()]; - thread_wakers.extend_from_slice(&self.workers.wakers()); - - // channel for the parent `Process` to send `Signal`s to the admin thread - let (signal_tx, signal_rx) = bounded(QUEUE_CAPACITY); - - // queues for the `Admin` to send `Signal`s to all sibling threads - let (mut signal_queue_tx, mut signal_queue_rx) = - Queues::new(vec![self.admin.waker()], thread_wakers, QUEUE_CAPACITY); - - // queues for the `Listener` to send `Session`s to the worker threads - let (mut session_queue_tx, session_queue_rx) = Queues::new( - vec![self.listener.waker()], - self.workers.worker_wakers(), - QUEUE_CAPACITY, - ); - - let mut admin = self.admin.build(signal_queue_tx.remove(0), signal_rx); - let mut listener = self - .listener - .build(signal_queue_rx.remove(0), session_queue_tx.remove(0)); - let workers = self.workers.build(signal_queue_rx, session_queue_rx); - - let admin = std::thread::Builder::new() - .name(format!("{}_admin", THREAD_PREFIX)) - .spawn(move || admin.run()) - .unwrap(); - - let workers = workers.spawn(); - - let listener = std::thread::Builder::new() - .name(format!("{}_listener", THREAD_PREFIX)) - .spawn(move || listener.run()) - .unwrap(); - - Process { - admin, - listener, - workers, - signal_tx, - } - } -} - -/// This type provides a control handle for all the threads within the server -/// process. -pub struct Process { - admin: JoinHandle<()>, - listener: JoinHandle<()>, - workers: Vec>, - signal_tx: Sender, -} - -impl Process { - /// Attempts to gracefully shutdown the `Process` by sending a shutdown to - /// each thread and then waiting to join those threads. - /// - /// Will terminate ungracefully if it encounters an error in sending a - /// shutdown to any of the threads. - /// - /// This function will block until all threads have terminated. - pub fn shutdown(self) { - // this sends a shutdown to the admin thread, which will broadcast the - // signal to all sibling threads in the process - if self.signal_tx.try_send(Signal::Shutdown).is_err() { - fatal!("error sending shutdown signal to thread"); - } - - // wait and join all threads - self.wait() - } - - /// Will block until all threads terminate. This should be used to keep the - /// process alive while the child threads run. - pub fn wait(self) { - for thread in self.workers { - let _ = thread.join(); - } - let _ = self.listener.join(); - let _ = self.admin.join(); - } -} diff --git a/src/core/server/src/threads/admin.rs b/src/core/server/src/threads/admin.rs deleted file mode 100644 index 5d669b3a8..000000000 --- a/src/core/server/src/threads/admin.rs +++ /dev/null @@ -1,790 +0,0 @@ -// Copyright 2021 Twitter, Inc. -// Licensed under the Apache License, Version 2.0 -// http://www.apache.org/licenses/LICENSE-2.0 - -//! The admin thread, which handles admin requests to return stats, get version -//! info, etc. - -use crate::poll::{Poll, LISTENER_TOKEN, WAKER_TOKEN}; -use crate::threads::EventLoop; -use crate::QUEUE_RETRIES; -use crate::TCP_ACCEPT_EX; -use crate::*; -use common::signal::Signal; -use common::ssl::{HandshakeError, MidHandshakeSslStream, Ssl, SslContext, SslStream}; -use config::*; -use core::time::Duration; -use crossbeam_channel::Receiver; -use logger::Drain; -use mio::event::Event; -use mio::{Events, Token, Waker}; -use protocol_admin::*; -use queues::Queues; -use session::{Session, TcpStream}; -use std::io::{BufRead, Error, ErrorKind, Write}; -use std::net::SocketAddr; -use std::sync::Arc; -use tiny_http::{Method, Request, Response}; - -counter!(ADMIN_REQUEST_PARSE); -counter!(ADMIN_RESPONSE_COMPOSE); -counter!(ADMIN_EVENT_ERROR); -counter!(ADMIN_EVENT_WRITE); -counter!(ADMIN_EVENT_READ); -counter!(ADMIN_EVENT_LOOP); -counter!(ADMIN_EVENT_TOTAL); - -counter!(RU_UTIME); -counter!(RU_STIME); -gauge!(RU_MAXRSS); -gauge!(RU_IXRSS); -gauge!(RU_IDRSS); -gauge!(RU_ISRSS); -counter!(RU_MINFLT); -counter!(RU_MAJFLT); -counter!(RU_NSWAP); -counter!(RU_INBLOCK); -counter!(RU_OUBLOCK); -counter!(RU_MSGSND); -counter!(RU_MSGRCV); -counter!(RU_NSIGNALS); -counter!(RU_NVCSW); -counter!(RU_NIVCSW); - -const KB: u64 = 1024; // one kilobyte in bytes -const S: u64 = 1_000_000_000; // one second in nanoseconds -const US: u64 = 1_000; // one microsecond in nanoseconds - -pub static PERCENTILES: &[(&str, f64)] = &[ - ("p25", 25.0), - ("p50", 50.0), - ("p75", 75.0), - ("p90", 90.0), - ("p99", 99.0), - ("p999", 99.9), - ("p9999", 99.99), -]; - -pub struct AdminBuilder { - addr: SocketAddr, - nevent: usize, - poll: Poll, - timeout: Duration, - ssl_context: Option, - parser: AdminRequestParser, - log_drain: Box, - http_server: Option, - version: String, -} - -impl AdminBuilder { - /// Creates a new `Admin` event loop. - pub fn new( - config: &T, - ssl_context: Option, - mut log_drain: Box, - ) -> Result { - let config = config.admin(); - - let addr = config.socket_addr().map_err(|e| { - error!("{}", e); - error!("bad admin listen address"); - let _ = log_drain.flush(); - Error::new(ErrorKind::Other, "bad listen address") - })?; - let mut poll = Poll::new().map_err(|e| { - error!("{}", e); - error!("failed to create epoll instance"); - let _ = log_drain.flush(); - Error::new(ErrorKind::Other, "failed to create epoll instance") - })?; - poll.bind(addr).map_err(|e| { - error!("{}", e); - error!("failed to bind admin tcp listener"); - let _ = log_drain.flush(); - Error::new(ErrorKind::Other, "failed to bind listener") - })?; - - let ssl_context = if config.use_tls() { ssl_context } else { None }; - - let timeout = std::time::Duration::from_millis(config.timeout() as u64); - - let nevent = config.nevent(); - - let http_server = if config.http_enabled() { - let addr = config.http_socket_addr().map_err(|e| { - error!("{}", e); - error!("bad admin http listen address"); - let _ = log_drain.flush(); - Error::new(ErrorKind::Other, "bad listen address") - })?; - let server = tiny_http::Server::http(addr).map_err(|e| { - error!("{}", e); - error!("could not start admin http server"); - let _ = log_drain.flush(); - Error::new(ErrorKind::Other, "failed to create http server") - })?; - Some(server) - } else { - None - }; - - Ok(Self { - addr, - timeout, - nevent, - poll, - ssl_context, - parser: AdminRequestParser::new(), - log_drain, - http_server, - version: "unknown".to_string(), - }) - } - - pub fn waker(&self) -> Arc { - self.poll.waker() - } - - /// Triggers a flush of the log - pub fn log_flush(&mut self) -> Result<(), std::io::Error> { - self.log_drain.flush() - } - - /// Set the reported version number - pub fn version(&mut self, version: &str) { - self.version = version.to_string(); - } - - pub fn build( - self, - signal_queue_tx: Queues, - signal_queue_rx: Receiver, - ) -> Admin { - Admin { - addr: self.addr, - nevent: self.nevent, - poll: self.poll, - timeout: self.timeout, - ssl_context: self.ssl_context, - parser: self.parser, - log_drain: self.log_drain, - http_server: self.http_server, - signal_queue_tx, - signal_queue_rx, - version: self.version, - } - } -} - -pub struct Admin { - addr: SocketAddr, - nevent: usize, - poll: Poll, - timeout: Duration, - ssl_context: Option, - parser: AdminRequestParser, - log_drain: Box, - /// optional http server - http_server: Option, - /// used to send signals to all sibling threads - signal_queue_tx: Queues, - /// used to receive signals from the parent thread - signal_queue_rx: Receiver, - /// version number to report - version: String, -} - -impl Drop for Admin { - fn drop(&mut self) { - let _ = self.log_drain.flush(); - } -} - -impl Admin { - /// Adds a new fully established TLS session - fn add_established_tls_session(&mut self, stream: SslStream) { - let session = Session::tls_with_capacity( - stream, - crate::DEFAULT_BUFFER_SIZE, - crate::ADMIN_MAX_BUFFER_SIZE, - ); - if self.poll.add_session(session).is_err() { - TCP_ACCEPT_EX.increment(); - } - } - - /// Adds a new TLS session that requires further handshaking - fn add_handshaking_tls_session(&mut self, stream: MidHandshakeSslStream) { - let session = Session::handshaking_with_capacity( - stream, - crate::DEFAULT_BUFFER_SIZE, - crate::ADMIN_MAX_BUFFER_SIZE, - ); - trace!("accepted new session: {:?}", session.peer_addr()); - if self.poll.add_session(session).is_err() { - TCP_ACCEPT_EX.increment(); - } - } - - /// Adds a new plain (non-TLS) session - fn add_plain_session(&mut self, stream: TcpStream) { - let session = Session::plain_with_capacity( - stream, - crate::DEFAULT_BUFFER_SIZE, - crate::ADMIN_MAX_BUFFER_SIZE, - ); - trace!("accepted new session: {:?}", session.peer_addr()); - if self.poll.add_session(session).is_err() { - TCP_ACCEPT_EX.increment(); - } - } - - /// Repeatedly call accept on the listener - fn do_accept(&mut self) { - loop { - match self.poll.accept() { - Ok((stream, _)) => { - // handle TLS if it is configured - if let Some(ssl_context) = &self.ssl_context { - match Ssl::new(ssl_context).map(|v| v.accept(stream)) { - // handle case where we have a fully-negotiated - // TLS stream on accept() - Ok(Ok(tls_stream)) => { - self.add_established_tls_session(tls_stream); - } - // handle case where further negotiation is - // needed - Ok(Err(HandshakeError::WouldBlock(tls_stream))) => { - self.add_handshaking_tls_session(tls_stream); - } - // some other error has occurred and we drop the - // stream - Ok(Err(_)) | Err(_) => { - TCP_ACCEPT_EX.increment(); - } - } - } else { - self.add_plain_session(stream); - }; - } - Err(e) => { - if e.kind() == ErrorKind::WouldBlock { - break; - } - } - } - } - } - - /// This is a handler for the stats commands on the legacy admin port. It - /// responses using the Memcached `stats` command response format, each stat - /// appears on its own line with a CR+LF used as end of line symbol. The - /// stats appear in sorted order. - /// - /// ```text - /// STAT get 0 - /// STAT get_cardinality_p25 0 - /// STAT get_cardinality_p50 0 - /// STAT get_cardinality_p75 0 - /// STAT get_cardinality_p90 0 - /// STAT get_cardinality_p99 0 - /// STAT get_cardinality_p999 0 - /// STAT get_cardinality_p9999 0 - /// STAT get_ex 0 - /// STAT get_key 0 - /// STAT get_key_hit 0 - /// STAT get_key_miss 0 - /// ``` - fn handle_stats_request(session: &mut Session) { - ADMIN_REQUEST_PARSE.increment(); - let mut data = Vec::new(); - for metric in &rustcommon_metrics::metrics() { - let any = match metric.as_any() { - Some(any) => any, - None => { - continue; - } - }; - - if let Some(counter) = any.downcast_ref::() { - data.push(format!("STAT {} {}\r\n", metric.name(), counter.value())); - } else if let Some(gauge) = any.downcast_ref::() { - data.push(format!("STAT {} {}\r\n", metric.name(), gauge.value())); - } else if let Some(heatmap) = any.downcast_ref::() { - for (label, value) in PERCENTILES { - let percentile = heatmap.percentile(*value).unwrap_or(0); - data.push(format!( - "STAT {}_{} {}\r\n", - metric.name(), - label, - percentile - )); - } - } - } - - data.sort(); - for line in data { - let _ = session.write(line.as_bytes()); - } - let _ = session.write(b"END\r\n"); - session.finalize_response(); - ADMIN_RESPONSE_COMPOSE.increment(); - } - - fn handle_version_request(session: &mut Session, version: &str) { - let _ = session.write(format!("VERSION {}\r\n", version).as_bytes()); - session.finalize_response(); - ADMIN_RESPONSE_COMPOSE.increment(); - } - - /// Handle an event on an existing session - fn handle_session_event(&mut self, event: &Event) { - let token = event.token(); - trace!("got event for admin session: {}", token.0); - - // handle error events first - if event.is_error() { - ADMIN_EVENT_ERROR.increment(); - self.handle_error(token); - } - - // handle handshaking - if let Ok(session) = self.poll.get_mut_session(token) { - if session.is_handshaking() { - if let Err(e) = session.do_handshake() { - if e.kind() == ErrorKind::WouldBlock { - // the session is still handshaking - return; - } else { - // some error occured while handshaking - let _ = self.poll.close_session(token); - } - } - } - } - - // handle write events before read events to reduce write - // buffer growth if there is also a readable event - if event.is_writable() { - ADMIN_EVENT_WRITE.increment(); - self.do_write(token); - } - - // read events are handled last - if event.is_readable() { - ADMIN_EVENT_READ.increment(); - let _ = self.do_read(token); - }; - } - - /// A "human-readable" exposition format which outputs one stat per line, - /// with a LF used as the end of line symbol. - /// - /// ```text - /// get: 0 - /// get_cardinality_p25: 0 - /// get_cardinality_p50: 0 - /// get_cardinality_p75: 0 - /// get_cardinality_p90: 0 - /// get_cardinality_p9999: 0 - /// get_cardinality_p999: 0 - /// get_cardinality_p99: 0 - /// get_ex: 0 - /// get_key: 0 - /// get_key_hit: 0 - /// get_key_miss: 0 - /// ``` - fn human_stats(&self) -> String { - let mut data = Vec::new(); - - for metric in &rustcommon_metrics::metrics() { - let any = match metric.as_any() { - Some(any) => any, - None => { - continue; - } - }; - - if let Some(counter) = any.downcast_ref::() { - data.push(format!("{}: {}", metric.name(), counter.value())); - } else if let Some(gauge) = any.downcast_ref::() { - data.push(format!("{}: {}", metric.name(), gauge.value())); - } else if let Some(heatmap) = any.downcast_ref::() { - for (label, value) in PERCENTILES { - let percentile = heatmap.percentile(*value).unwrap_or(0); - data.push(format!("{}_{}: {}", metric.name(), label, percentile)); - } - } - } - - data.sort(); - data.join("\n") + "\n" - } - - /// JSON stats output which follows the conventions found in Finagle and - /// TwitterServer libraries. Percentiles are appended to the metric name, - /// eg: `request_latency_p999` for the 99.9th percentile. For more details - /// about the Finagle / TwitterServer format see: - /// https://twitter.github.io/twitter-server/Features.html#metrics - /// - /// ```text - /// {"get": 0,"get_cardinality_p25": 0,"get_cardinality_p50": 0, ... } - /// ``` - fn json_stats(&self) -> String { - let head = "{".to_owned(); - - let mut data = Vec::new(); - - for metric in &rustcommon_metrics::metrics() { - let any = match metric.as_any() { - Some(any) => any, - None => { - continue; - } - }; - - if let Some(counter) = any.downcast_ref::() { - data.push(format!("\"{}\": {}", metric.name(), counter.value())); - } else if let Some(gauge) = any.downcast_ref::() { - data.push(format!("\"{}\": {}", metric.name(), gauge.value())); - } else if let Some(heatmap) = any.downcast_ref::() { - for (label, value) in PERCENTILES { - let percentile = heatmap.percentile(*value).unwrap_or(0); - data.push(format!("\"{}_{}\": {}", metric.name(), label, percentile)); - } - } - } - - data.sort(); - let body = data.join(","); - let mut content = head; - content += &body; - content += "}"; - content - } - - /// Prometheus / OpenTelemetry compatible stats output. Each stat is - /// annotated with a type. Percentiles use the label 'percentile' to - /// indicate which percentile corresponds to the value: - /// - /// ```text - /// # TYPE get counter - /// get 0 - /// # TYPE get_cardinality gauge - /// get_cardinality{percentile="p25"} 0 - /// # TYPE get_cardinality gauge - /// get_cardinality{percentile="p50"} 0 - /// # TYPE get_cardinality gauge - /// get_cardinality{percentile="p75"} 0 - /// # TYPE get_cardinality gauge - /// get_cardinality{percentile="p90"} 0 - /// # TYPE get_cardinality gauge - /// get_cardinality{percentile="p99"} 0 - /// # TYPE get_cardinality gauge - /// get_cardinality{percentile="p999"} 0 - /// # TYPE get_cardinality gauge - /// get_cardinality{percentile="p9999"} 0 - /// # TYPE get_ex counter - /// get_ex 0 - /// # TYPE get_key counter - /// get_key 0 - /// # TYPE get_key_hit counter - /// get_key_hit 0 - /// # TYPE get_key_miss counter - /// get_key_miss 0 - /// ``` - fn prometheus_stats(&self) -> String { - let mut data = Vec::new(); - - for metric in &rustcommon_metrics::metrics() { - let any = match metric.as_any() { - Some(any) => any, - None => { - continue; - } - }; - - if let Some(counter) = any.downcast_ref::() { - data.push(format!( - "# TYPE {} counter\n{} {}", - metric.name(), - metric.name(), - counter.value() - )); - } else if let Some(gauge) = any.downcast_ref::() { - data.push(format!( - "# TYPE {} gauge\n{} {}", - metric.name(), - metric.name(), - gauge.value() - )); - } else if let Some(heatmap) = any.downcast_ref::() { - for (label, value) in PERCENTILES { - let percentile = heatmap.percentile(*value).unwrap_or(0); - data.push(format!( - "# TYPE {} gauge\n{}{{percentile=\"{}\"}} {}", - metric.name(), - metric.name(), - label, - percentile - )); - } - } - } - data.sort(); - let mut content = data.join("\n"); - content += "\n"; - let parts: Vec<&str> = content.split('/').collect(); - parts.join("_") - } - - /// Handle a HTTP request - fn handle_http_request(&self, request: Request) { - let url = request.url(); - let parts: Vec<&str> = url.split('?').collect(); - let url = parts[0]; - match url { - // Prometheus/OpenTelemetry expect the `/metrics` URI will return - // stats in the Prometheus format - "/metrics" => match request.method() { - Method::Get => { - let _ = request.respond(Response::from_string(self.prometheus_stats())); - } - _ => { - let _ = request.respond(Response::empty(400)); - } - }, - // we export Finagle/TwitterServer format stats on a few endpoints - // for maximum compatibility with various internal conventions - "/metrics.json" | "/vars.json" | "/admin/metrics.json" => match request.method() { - Method::Get => { - let _ = request.respond(Response::from_string(self.json_stats())); - } - _ => { - let _ = request.respond(Response::empty(400)); - } - }, - // human-readable stats are exported on the `/vars` endpoint based - // on internal conventions - "/vars" => match request.method() { - Method::Get => { - let _ = request.respond(Response::from_string(self.human_stats())); - } - _ => { - let _ = request.respond(Response::empty(400)); - } - }, - _ => { - let _ = request.respond(Response::empty(404)); - } - } - } - - /// Runs the `Admin` in a loop, accepting new sessions for the admin - /// listener and handling events on existing sessions. - pub fn run(&mut self) { - info!("running admin on: {}", self.addr); - - let mut events = Events::with_capacity(self.nevent); - - // run in a loop, accepting new sessions and events on existing sessions - loop { - ADMIN_EVENT_LOOP.increment(); - - if self.poll.poll(&mut events, self.timeout).is_err() { - error!("Error polling"); - } - - ADMIN_EVENT_TOTAL.add(events.iter().count() as _); - - // handle all events - for event in events.iter() { - match event.token() { - LISTENER_TOKEN => { - self.do_accept(); - } - WAKER_TOKEN => { - // check if we have received signals from any sibling - // thread - while let Ok(signal) = self.signal_queue_rx.try_recv() { - match signal { - Signal::FlushAll => {} - Signal::Shutdown => { - // if a shutdown is received from any - // thread, we will broadcast it to all - // sibling threads and stop our event loop - info!("shutting down"); - let _ = self.signal_queue_tx.try_send_all(Signal::Shutdown); - if self.signal_queue_tx.wake().is_err() { - fatal!("error waking threads for shutdown"); - } - let _ = self.log_drain.flush(); - return; - } - } - } - } - _ => { - self.handle_session_event(event); - } - } - } - - // handle all http requests if the http server is enabled - if let Some(ref server) = self.http_server { - while let Ok(Some(request)) = server.try_recv() { - self.handle_http_request(request); - } - } - - // handle all signals - while let Ok(signal) = self.signal_queue_rx.try_recv() { - match signal { - Signal::FlushAll => {} - Signal::Shutdown => { - // if a shutdown is received from any - // thread, we will broadcast it to all - // sibling threads and stop our event loop - info!("shutting down"); - let _ = self.signal_queue_tx.try_send_all(Signal::Shutdown); - if self.signal_queue_tx.wake().is_err() { - fatal!("error waking threads for shutdown"); - } - let _ = self.log_drain.flush(); - return; - } - } - } - - // get updated usage - self.get_rusage(); - - // flush pending log entries to log destinations - let _ = self.log_drain.flush(); - } - } - - // TODO(bmartin): move this into a common module, should be shared with - // other backends - pub fn get_rusage(&self) { - let mut rusage = libc::rusage { - ru_utime: libc::timeval { - tv_sec: 0, - tv_usec: 0, - }, - ru_stime: libc::timeval { - tv_sec: 0, - tv_usec: 0, - }, - ru_maxrss: 0, - ru_ixrss: 0, - ru_idrss: 0, - ru_isrss: 0, - ru_minflt: 0, - ru_majflt: 0, - ru_nswap: 0, - ru_inblock: 0, - ru_oublock: 0, - ru_msgsnd: 0, - ru_msgrcv: 0, - ru_nsignals: 0, - ru_nvcsw: 0, - ru_nivcsw: 0, - }; - - if unsafe { libc::getrusage(libc::RUSAGE_SELF, &mut rusage) } == 0 { - RU_UTIME.set(rusage.ru_utime.tv_sec as u64 * S + rusage.ru_utime.tv_usec as u64 * US); - RU_STIME.set(rusage.ru_stime.tv_sec as u64 * S + rusage.ru_stime.tv_usec as u64 * US); - RU_MAXRSS.set(rusage.ru_maxrss * KB as i64); - RU_IXRSS.set(rusage.ru_ixrss * KB as i64); - RU_IDRSS.set(rusage.ru_idrss * KB as i64); - RU_ISRSS.set(rusage.ru_isrss * KB as i64); - RU_MINFLT.set(rusage.ru_minflt as u64); - RU_MAJFLT.set(rusage.ru_majflt as u64); - RU_NSWAP.set(rusage.ru_nswap as u64); - RU_INBLOCK.set(rusage.ru_inblock as u64); - RU_OUBLOCK.set(rusage.ru_oublock as u64); - RU_MSGSND.set(rusage.ru_msgsnd as u64); - RU_MSGRCV.set(rusage.ru_msgrcv as u64); - RU_NSIGNALS.set(rusage.ru_nsignals as u64); - RU_NVCSW.set(rusage.ru_nvcsw as u64); - RU_NIVCSW.set(rusage.ru_nivcsw as u64); - } - } -} - -impl EventLoop for Admin { - fn handle_data(&mut self, token: Token) -> Result<(), std::io::Error> { - trace!("handling request for admin session: {}", token.0); - if let Ok(session) = self.poll.get_mut_session(token) { - loop { - if session.write_capacity() == 0 { - // if the write buffer is over-full, skip processing - break; - } - match self.parser.parse(session.buffer()) { - Ok(parsed_request) => { - let consumed = parsed_request.consumed(); - let request = parsed_request.into_inner(); - session.consume(consumed); - - match request { - AdminRequest::FlushAll => { - for _ in 0..QUEUE_RETRIES { - if self.signal_queue_tx.try_send_all(Signal::FlushAll).is_ok() { - warn!("sending flush_all signal"); - break; - } - } - for _ in 0..QUEUE_RETRIES { - if self.signal_queue_tx.wake().is_ok() { - break; - } - } - - let _ = session.write(b"OK\r\n"); - session.finalize_response(); - ADMIN_RESPONSE_COMPOSE.increment(); - } - AdminRequest::Stats => { - Self::handle_stats_request(session); - } - AdminRequest::Quit => { - let _ = self.poll.close_session(token); - return Ok(()); - } - AdminRequest::Version => { - Self::handle_version_request(session, &self.version); - } - } - } - Err(ParseError::Incomplete) => { - break; - } - Err(_) => { - self.handle_error(token); - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "bad request", - )); - } - } - } - } else { - // no session for the token - trace!( - "attempted to handle data for non-existent session: {}", - token.0 - ); - return Ok(()); - } - self.poll.reregister(token); - Ok(()) - } - - fn poll(&mut self) -> &mut Poll { - &mut self.poll - } -} diff --git a/src/core/server/src/threads/listener.rs b/src/core/server/src/threads/listener.rs deleted file mode 100644 index d5957b8d6..000000000 --- a/src/core/server/src/threads/listener.rs +++ /dev/null @@ -1,271 +0,0 @@ -// Copyright 2021 Twitter, Inc. -// Licensed under the Apache License, Version 2.0 -// http://www.apache.org/licenses/LICENSE-2.0 - -//! The server thread which accepts new connections, handles TLS handshaking, -//! and sends established sessions to the worker thread(s). - -use super::EventLoop; -use crate::poll::{Poll, LISTENER_TOKEN, WAKER_TOKEN}; -use crate::*; -use common::signal::Signal; -use common::ssl::{HandshakeError, MidHandshakeSslStream, Ssl, SslContext, SslStream}; -use config::ServerConfig; -use mio::event::Event; -use mio::Events; -use mio::Token; -use queues::*; -use session::{Session, TcpStream}; -use std::net::SocketAddr; -use std::sync::Arc; -use std::time::Duration; - -counter!(SERVER_EVENT_ERROR); -counter!(SERVER_EVENT_WRITE); -counter!(SERVER_EVENT_READ); -counter!(SERVER_EVENT_LOOP); -counter!(SERVER_EVENT_TOTAL); - -pub struct ListenerBuilder { - addr: SocketAddr, - max_buffer_size: usize, - nevent: usize, - poll: Poll, - ssl_context: Option, - timeout: Duration, -} - -impl ListenerBuilder { - /// Creates a new `Listener` from a `ServerConfig` and an optional - /// `SslContext`. - pub fn new( - config: &T, - ssl_context: Option, - max_buffer_size: usize, - ) -> Result { - let config = config.server(); - - let addr = config.socket_addr().map_err(|e| { - error!("{}", e); - std::io::Error::new(std::io::ErrorKind::Other, "Bad listen address") - })?; - let mut poll = Poll::new().map_err(|e| { - error!("{}", e); - std::io::Error::new(std::io::ErrorKind::Other, "Failed to create epoll instance") - })?; - - poll.bind(addr)?; - - let nevent = config.nevent(); - let timeout = Duration::from_millis(config.timeout() as u64); - - Ok(Self { - addr, - nevent, - poll, - ssl_context, - timeout, - max_buffer_size, - }) - } - - pub fn waker(&self) -> Arc { - self.poll.waker() - } - - pub fn build( - self, - signal_queue: Queues<(), Signal>, - session_queue: Queues, - ) -> Listener { - Listener { - addr: self.addr, - max_buffer_size: self.max_buffer_size, - nevent: self.nevent, - poll: self.poll, - ssl_context: self.ssl_context, - timeout: self.timeout, - signal_queue, - session_queue, - } - } -} - -pub struct Listener { - addr: SocketAddr, - max_buffer_size: usize, - nevent: usize, - poll: Poll, - ssl_context: Option, - timeout: Duration, - signal_queue: Queues<(), Signal>, - session_queue: Queues, -} - -impl Listener { - /// Call accept one time - // TODO(bmartin): splitting accept and negotiation into separate threads - // would allow us to handle TLS handshake with multiple threads and avoid - // the overhead of re-registering the listener after each accept. - fn do_accept(&mut self) { - if let Ok((stream, _)) = self.poll.accept() { - // handle TLS if it is configured - if let Some(ssl_context) = &self.ssl_context { - match Ssl::new(ssl_context).map(|v| v.accept(stream)) { - // handle case where we have a fully-negotiated - // TLS stream on accept() - Ok(Ok(tls_stream)) => { - self.add_established_tls_session(tls_stream); - } - // handle case where further negotiation is - // needed - Ok(Err(HandshakeError::WouldBlock(tls_stream))) => { - self.add_handshaking_tls_session(tls_stream); - } - // some other error has occurred and we drop the - // stream - Ok(Err(e)) => { - error!("accept failed: {}", e); - TCP_ACCEPT_EX.increment(); - } - Err(e) => { - error!("accept failed: {}", e); - TCP_ACCEPT_EX.increment(); - } - } - } else { - self.add_plain_session(stream); - }; - self.poll.reregister(LISTENER_TOKEN); - } - } - - /// Adds a new fully established TLS session - fn add_established_tls_session(&mut self, stream: SslStream) { - let session = - Session::tls_with_capacity(stream, crate::DEFAULT_BUFFER_SIZE, self.max_buffer_size); - trace!("accepted new session: {:?}", session); - if self.session_queue.try_send_any(session).is_err() { - error!("error sending session to worker"); - TCP_ACCEPT_EX.increment(); - } - } - - /// Adds a new TLS session that requires further handshaking - fn add_handshaking_tls_session(&mut self, stream: MidHandshakeSslStream) { - let session = Session::handshaking_with_capacity( - stream, - crate::DEFAULT_BUFFER_SIZE, - self.max_buffer_size, - ); - if self.poll.add_session(session).is_err() { - error!("failed to register handshaking TLS session with epoll"); - TCP_ACCEPT_EX.increment(); - } - } - - /// Adds a new plain (non-TLS) session - fn add_plain_session(&mut self, stream: TcpStream) { - let session = - Session::plain_with_capacity(stream, crate::DEFAULT_BUFFER_SIZE, self.max_buffer_size); - trace!("accepted new session: {:?}", session); - if self.session_queue.try_send_any(session).is_err() { - error!("error sending session to worker"); - TCP_ACCEPT_EX.increment(); - } - } - - /// Handle an event on an existing session - fn handle_session_event(&mut self, event: &Event) { - let token = event.token(); - - // handle error events first - if event.is_error() { - SERVER_EVENT_ERROR.increment(); - self.handle_error(token); - } - - // handle write events before read events to reduce write - // buffer growth if there is also a readable event - if event.is_writable() { - SERVER_EVENT_WRITE.increment(); - self.do_write(token); - } - - // read events are handled last - if event.is_readable() { - SERVER_EVENT_READ.increment(); - let _ = self.do_read(token); - } - - if let Ok(session) = self.poll.get_mut_session(token) { - if session.do_handshake().is_ok() { - trace!("handshake complete for session: {:?}", session); - if let Ok(session) = self.poll.remove_session(token) { - if self.session_queue.try_send_any(session).is_err() { - error!("error sending session to worker"); - TCP_ACCEPT_EX.increment(); - } - } else { - error!("error removing session from poller"); - TCP_ACCEPT_EX.increment(); - } - } else { - trace!("handshake incomplete for session: {:?}", session); - } - } - } - - /// Runs the `Listener` in a loop, accepting new sessions and moving them to - /// a worker queue. - pub fn run(&mut self) { - info!("running server on: {}", self.addr); - - let mut events = Events::with_capacity(self.nevent); - - // repeatedly run accepting new connections and moving them to the worker - loop { - SERVER_EVENT_LOOP.increment(); - if self.poll.poll(&mut events, self.timeout).is_err() { - error!("Error polling server"); - } - SERVER_EVENT_TOTAL.add(events.iter().count() as _); - - // handle all events - for event in events.iter() { - match event.token() { - LISTENER_TOKEN => { - self.do_accept(); - } - WAKER_TOKEN => { - while let Some(signal) = - self.signal_queue.try_recv().map(|v| v.into_inner()) - { - match signal { - Signal::FlushAll => {} - Signal::Shutdown => { - return; - } - } - } - } - _ => { - self.handle_session_event(event); - } - } - } - - let _ = self.session_queue.wake(); - } - } -} - -impl EventLoop for Listener { - fn handle_data(&mut self, _token: Token) -> Result<(), std::io::Error> { - Ok(()) - } - - fn poll(&mut self) -> &mut Poll { - &mut self.poll - } -} diff --git a/src/core/server/src/threads/mod.rs b/src/core/server/src/threads/mod.rs deleted file mode 100644 index 1667c5829..000000000 --- a/src/core/server/src/threads/mod.rs +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2021 Twitter, Inc. -// Licensed under the Apache License, Version 2.0 -// http://www.apache.org/licenses/LICENSE-2.0 - -//! This module contains all the threads that make-up a server as well as their -//! builders. - -mod admin; -mod listener; -mod traits; -mod workers; - -pub use admin::{Admin, AdminBuilder, PERCENTILES}; -pub use listener::{Listener, ListenerBuilder}; -pub use traits::EventLoop; -pub use workers::{Workers, WorkersBuilder}; diff --git a/src/core/server/src/threads/traits/event_loop.rs b/src/core/server/src/threads/traits/event_loop.rs deleted file mode 100644 index 7f58956a5..000000000 --- a/src/core/server/src/threads/traits/event_loop.rs +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2021 Twitter, Inc. -// Licensed under the Apache License, Version 2.0 -// http://www.apache.org/licenses/LICENSE-2.0 - -//! A trait defining common functions for event-based threads which operate on -//! sessions. - -use std::io::{BufRead, ErrorKind, Write}; - -use mio::Token; - -use crate::poll::Poll; - -/// An `EventLoop` describes the functions which must be implemented for a basic -/// event loop and provides some default implementations and helper functions. -pub trait EventLoop { - // the following functions must be implemented - - /// Provides access to the `Poll` structure which allows polling for new - /// readiness events and managing registration for event sources. - fn poll(&mut self) -> &mut Poll; - - /// Handle new data received for the `Session` with the provided `Token`. - /// This will include parsing the incoming data and composing a response. - fn handle_data(&mut self, token: Token) -> Result<(), std::io::Error>; - - /// Handle a read event for the `Session` with the `Token`. - fn do_read(&mut self, token: Token) -> Result<(), ()> { - if let Ok(session) = self.poll().get_mut_session(token) { - // read from session to buffer - match session.fill_buf().map(|b| b.len()) { - Ok(0) => { - trace!("hangup for session: {:?}", session); - let _ = self.poll().close_session(token); - Err(()) - } - Ok(bytes) => { - trace!("read {} bytes for session: {:?}", bytes, session); - if self.handle_data(token).is_err() { - self.handle_error(token); - Err(()) - } else { - Ok(()) - } - } - Err(e) => { - match e.kind() { - ErrorKind::WouldBlock => { - // spurious read - self.poll().reregister(token); - Ok(()) - } - ErrorKind::Interrupted => self.do_read(token), - _ => { - trace!("error reading for session: {:?} {:?}", session, e); - // some read error - self.handle_error(token); - Err(()) - } - } - } - } - } else { - trace!("attempted to read from non-existent session: {}", token.0); - Err(()) - } - } - - /// Handle a write event for a `Session` with the `Token`. - fn do_write(&mut self, token: Token) { - if let Ok(session) = self.poll().get_mut_session(token) { - trace!("write for session: {:?}", session); - match session.flush() { - Ok(_) => { - self.poll().reregister(token); - } - Err(e) => match e.kind() { - ErrorKind::WouldBlock => {} - ErrorKind::Interrupted => self.do_write(token), - _ => { - self.handle_error(token); - } - }, - } - } else { - trace!("attempted to write to non-existent session: {}", token.0) - } - } - - /// Handle errors for the `Session` with the `Token` by logging a message - /// and closing the session. - fn handle_error(&mut self, token: Token) { - if let Ok(session) = self.poll().get_mut_session(token) { - trace!("handling error for session: {:?}", session); - let _ = session.flush(); - let _ = self.poll().close_session(token); - } else { - trace!( - "attempted to handle error for non-existent session: {}", - token.0 - ) - } - } -} diff --git a/src/core/server/src/threads/traits/mod.rs b/src/core/server/src/threads/traits/mod.rs deleted file mode 100644 index 737709626..000000000 --- a/src/core/server/src/threads/traits/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright 2021 Twitter, Inc. -// Licensed under the Apache License, Version 2.0 -// http://www.apache.org/licenses/LICENSE-2.0 - -mod event_loop; - -pub use event_loop::EventLoop; diff --git a/src/core/server/src/threads/workers/mod.rs b/src/core/server/src/threads/workers/mod.rs deleted file mode 100644 index 76b469d7d..000000000 --- a/src/core/server/src/threads/workers/mod.rs +++ /dev/null @@ -1,287 +0,0 @@ -// Copyright 2021 Twitter, Inc. -// Licensed under the Apache License, Version 2.0 -// http://www.apache.org/licenses/LICENSE-2.0 - -//! Worker threads which are used in multi or single worker mode to handle -//! sending and receiving data on established client sessions - -mod multi; -mod single; -mod storage; - -pub use self::storage::{StorageWorker, StorageWorkerBuilder}; -use crate::*; -use crate::{QUEUE_CAPACITY, THREAD_PREFIX}; -use common::signal::Signal; -use config::WorkerConfig; -use entrystore::EntryStore; -use mio::Waker; -pub use multi::{MultiWorker, MultiWorkerBuilder}; -use protocol_common::ExecutionResult; -use protocol_common::{Compose, Execute, Parse}; -use queues::Queues; -use session::Session; -pub use single::{SingleWorker, SingleWorkerBuilder}; -use std::io::Error; -use std::sync::Arc; -use std::thread::JoinHandle; - -use super::EventLoop; -use mio::Token; - -counter!(WORKER_EVENT_LOOP); -counter!(WORKER_EVENT_TOTAL); -counter!(WORKER_EVENT_ERROR); -counter!(WORKER_EVENT_WRITE); -counter!(WORKER_EVENT_READ); -counter!( - WORKER_EVENT_MAX_REACHED, - "the number of times the maximum number of events was returned" -); -heatmap!(WORKER_EVENT_DEPTH, 100_000); - -counter!(STORAGE_EVENT_LOOP); -heatmap!(STORAGE_QUEUE_DEPTH, 1_000_000); - -counter!(PROCESS_REQ); - -type Instant = common::time::Instant>; -type WrappedResult = TokenWrapper>>; - -pub struct TokenWrapper { - inner: T, - token: Token, -} - -impl TokenWrapper { - pub fn new(inner: T, token: Token) -> Self { - Self { inner, token } - } - - pub fn token(&self) -> Token { - self.token - } - - pub fn into_inner(self) -> T { - self.inner - } -} - -/// A builder type for the worker threads which process requests and write -/// responses. -pub enum WorkersBuilder -where - Parser: Parse, - Response: Compose, - Storage: Execute + EntryStore, -{ - /// Used to create two or more `worker` threads in addition to a shared - /// `storage` thread. - Multi { - storage: StorageWorkerBuilder, - workers: Vec>, - }, - /// Used to create a single `worker` thread with thread-local storage. - Single { - worker: SingleWorkerBuilder, - }, -} - -impl WorkersBuilder -where - Parser: Parse + Clone, - Response: Compose, - Storage: Execute + EntryStore, -{ - /// Create a new `WorkersBuilder` from the provided config, storage, and - /// parser. - pub fn new( - config: &T, - storage: Storage, - parser: Parser, - ) -> Result, Error> { - let worker_config = config.worker(); - - if worker_config.threads() == 1 { - Self::single_worker(config, storage, parser) - } else { - Self::multi_worker(config, storage, parser) - } - } - - // Creates a multi-worker builder - fn multi_worker( - config: &T, - storage: Storage, - parser: Parser, - ) -> Result, Error> { - let worker_config = config.worker(); - - // initialize storage - let storage = StorageWorkerBuilder::new(config, storage).unwrap_or_else(|e| { - error!("{}", e); - std::process::exit(1); - }); - - // initialize workers - let mut workers = Vec::new(); - for _ in 0..worker_config.threads() { - let worker = MultiWorkerBuilder::new(config, parser.clone()).unwrap_or_else(|e| { - error!("{}", e); - std::process::exit(1); - }); - workers.push(worker); - } - - Ok(WorkersBuilder::Multi { storage, workers }) - } - - // Creates a single-worker builder - fn single_worker( - config: &T, - storage: Storage, - parser: Parser, - ) -> Result, Error> { - // initialize worker - let worker = SingleWorkerBuilder::new(config, storage, parser).unwrap_or_else(|e| { - error!("{}", e); - std::process::exit(1); - }); - - Ok(WorkersBuilder::Single { worker }) - } - - /// Returns the wakers for all workers. Used when setting-up the queues to - /// signal to all threads. - pub(crate) fn wakers(&self) -> Vec> { - match self { - Self::Multi { storage, workers } => { - let mut wakers = vec![storage.waker()]; - for waker in workers.iter().map(|v| v.waker()) { - wakers.push(waker); - } - wakers - } - Self::Single { worker } => { - vec![worker.waker()] - } - } - } - - /// Returns the wakers for the non-storage workers. Used when setting-up the - /// queues to send sessions to the workers. - pub(crate) fn worker_wakers(&self) -> Vec> { - match self { - Self::Multi { workers, .. } => workers.iter().map(|v| v.waker()).collect(), - Self::Single { worker } => { - vec![worker.waker()] - } - } - } - - /// Converts the builder into the finalized `Workers` type by providing the - /// necessary queues. - pub fn build( - self, - signal_queues: Vec>, - session_queues: Vec>, - ) -> Workers { - let mut signal_queues = signal_queues; - let mut session_queues = session_queues; - match self { - Self::Multi { - storage, - mut workers, - } => { - let storage_wakers = vec![storage.waker()]; - let worker_wakers: Vec> = workers.iter().map(|v| v.waker()).collect(); - let (mut response_queues, mut request_queues) = - Queues::new(worker_wakers, storage_wakers, QUEUE_CAPACITY); - - // The storage thread precedes the worker threads in the set of - // wakers, so its signal queue is the first element of - // `signal_queues`. Its request queue is also the first (and - // only) element of `request_queues`. We remove these and build - // the storage so we can loop through the remaining signal - // queues when launching the worker threads. - let s = storage.build(signal_queues.remove(0), request_queues.remove(0)); - - let mut w = Vec::new(); - for worker_builder in workers.drain(..) { - w.push(worker_builder.build( - signal_queues.remove(0), - session_queues.remove(0), - response_queues.remove(0), - )); - } - - Workers::Multi { - storage: s, - workers: w, - } - } - Self::Single { worker } => Workers::Single { - worker: worker.build(signal_queues.remove(0), session_queues.remove(0)), - }, - } - } -} - -/// Represents the finalized `Workers`. -pub enum Workers { - /// A multi-threaded worker which includes two or more threads to handle - /// request/response as well as a shared storage thread. - Multi { - storage: StorageWorker, - workers: Vec>, - }, - /// A single-threaded worker which handles request/response and owns the - /// storage. - Single { - worker: SingleWorker, - }, -} - -impl< - Storage: 'static + Send, - Parser: 'static + Send, - Request: 'static + Send, - Response: 'static + Send, - > Workers -where - Parser: Parse, - Response: Compose, - Storage: Execute + EntryStore, -{ - /// Converts the `Workers` into running threads. - pub fn spawn(self) -> Vec> { - match self { - Self::Single { mut worker } => { - vec![std::thread::Builder::new() - .name(format!("{}_worker", THREAD_PREFIX)) - .spawn(move || worker.run()) - .unwrap()] - } - Self::Multi { - mut storage, - workers, - } => { - let mut threads = Vec::new(); - threads.push( - std::thread::Builder::new() - .name(format!("{}_storage", THREAD_PREFIX)) - .spawn(move || storage.run()) - .unwrap(), - ); - for mut worker in workers { - let worker_thread = std::thread::Builder::new() - .name(format!("{}_worker{}", THREAD_PREFIX, threads.len())) - .spawn(move || worker.run()) - .unwrap(); - threads.push(worker_thread); - } - threads - } - } - } -} diff --git a/src/core/server/src/threads/workers/multi.rs b/src/core/server/src/threads/workers/multi.rs deleted file mode 100644 index 70fcf3c88..000000000 --- a/src/core/server/src/threads/workers/multi.rs +++ /dev/null @@ -1,322 +0,0 @@ -// Copyright 2021 Twitter, Inc. -// Licensed under the Apache License, Version 2.0 -// http://www.apache.org/licenses/LICENSE-2.0 - -//! The multi-threaded worker, which is used when there are multiple worker -//! threads configured. This worker parses buffers to produce requests, sends -//! the requests to the storage worker. Responses from the storage worker are -//! then serialized onto the session buffer. - -use super::*; -use crate::poll::Poll; -use crate::QUEUE_RETRIES; -use common::signal::Signal; -use config::WorkerConfig; -use core::marker::PhantomData; -use core::time::Duration; -use entrystore::EntryStore; -use mio::event::Event; -use mio::{Events, Token, Waker}; -use protocol_common::{Compose, Execute, Parse, ParseError}; -use queues::TrackedItem; -use session::Session; -use std::io::{BufRead, Write}; -use std::sync::Arc; - -const WAKER_TOKEN: Token = Token(usize::MAX); -const STORAGE_THREAD_ID: usize = 0; - -/// A builder for the request/response worker which communicates to the storage -/// thread over a queue. -pub struct MultiWorkerBuilder { - nevent: usize, - parser: Parser, - poll: Poll, - timeout: Duration, - _storage: PhantomData, - _request: PhantomData, - _response: PhantomData, -} - -impl MultiWorkerBuilder { - /// Create a new builder from the provided config and parser. - pub fn new(config: &T, parser: Parser) -> Result { - let poll = Poll::new().map_err(|e| { - error!("{}", e); - std::io::Error::new(std::io::ErrorKind::Other, "Failed to create epoll instance") - })?; - - Ok(Self { - poll, - nevent: config.worker().nevent(), - timeout: Duration::from_millis(config.worker().timeout() as u64), - _request: PhantomData, - _response: PhantomData, - _storage: PhantomData, - parser, - }) - } - - /// Get the waker that is registered to the epoll instance. - pub(crate) fn waker(&self) -> Arc { - self.poll.waker() - } - - /// Converts the builder into a `MultiWorker` by providing the queues that - /// are necessary for communication between components. - pub fn build( - self, - signal_queue: Queues<(), Signal>, - session_queue: Queues<(), Session>, - storage_queue: Queues, WrappedResult>, - ) -> MultiWorker { - MultiWorker { - nevent: self.nevent, - parser: self.parser, - poll: self.poll, - timeout: self.timeout, - signal_queue, - _storage: PhantomData, - storage_queue, - session_queue, - } - } -} - -/// Represents a finalized request/response worker which is ready to be run. -pub struct MultiWorker { - nevent: usize, - parser: Parser, - poll: Poll, - timeout: Duration, - session_queue: Queues<(), Session>, - signal_queue: Queues<(), Signal>, - _storage: PhantomData, - storage_queue: Queues, WrappedResult>, -} - -impl MultiWorker -where - Parser: Parse, - Response: Compose, - Storage: Execute + EntryStore, -{ - /// Run the worker in a loop, handling new events. - pub fn run(&mut self) { - // these are buffers which are re-used in each loop iteration to receive - // events and queue messages - let mut events = Events::with_capacity(self.nevent); - let mut responses = Vec::with_capacity(QUEUE_CAPACITY); - let mut sessions = Vec::with_capacity(QUEUE_CAPACITY); - - loop { - WORKER_EVENT_LOOP.increment(); - - // get events with timeout - if self.poll.poll(&mut events, self.timeout).is_err() { - error!("Error polling"); - } - - let timestamp = Instant::now(); - - let count = events.iter().count(); - WORKER_EVENT_TOTAL.add(count as _); - if count == self.nevent { - WORKER_EVENT_MAX_REACHED.increment(); - } else { - WORKER_EVENT_DEPTH.increment(timestamp, count as _, 1); - } - - // process all events - for event in events.iter() { - match event.token() { - WAKER_TOKEN => { - self.handle_new_sessions(&mut sessions); - self.handle_storage_queue(&mut responses); - - // check if we received any signals from the admin thread - while let Some(signal) = - self.signal_queue.try_recv().map(|v| v.into_inner()) - { - match signal { - Signal::FlushAll => {} - Signal::Shutdown => { - // if we received a shutdown, we can return - // and stop processing events - return; - } - } - } - } - _ => { - self.handle_event(event, timestamp); - } - } - } - - // wakes the storage thread if necessary - let _ = self.storage_queue.wake(); - } - } - - fn handle_event(&mut self, event: &Event, timestamp: Instant) { - let token = event.token(); - - // handle error events first - if event.is_error() { - WORKER_EVENT_ERROR.increment(); - self.handle_error(token); - } - - // handle write events before read events to reduce write buffer - // growth if there is also a readable event - if event.is_writable() { - WORKER_EVENT_WRITE.increment(); - self.do_write(token); - } - - // read events are handled last - if event.is_readable() { - WORKER_EVENT_READ.increment(); - if let Ok(session) = self.poll.get_mut_session(token) { - session.set_timestamp(timestamp); - } - let _ = self.do_read(token); - } - - if let Ok(session) = self.poll.get_mut_session(token) { - if session.read_pending() > 0 { - trace!( - "session: {:?} has {} bytes pending in read buffer", - session, - session.read_pending() - ); - } - if session.write_pending() > 0 { - trace!( - "session: {:?} has {} bytes pending in write buffer", - session, - session.read_pending() - ); - } - } - } - - fn handle_session_read(&mut self, token: Token) -> Result<(), std::io::Error> { - let session = self.poll.get_mut_session(token)?; - match self.parser.parse(session.buffer()) { - Ok(request) => { - let consumed = request.consumed(); - let request = request.into_inner(); - trace!("parsed request for sesion: {:?}", session); - session.consume(consumed); - let mut message = TokenWrapper::new(request, token); - - for retry in 0..QUEUE_RETRIES { - if let Err(m) = self.storage_queue.try_send_to(STORAGE_THREAD_ID, message) { - if (retry + 1) == QUEUE_RETRIES { - error!("queue full trying to send message to storage thread"); - let _ = self.poll.close_session(token); - } - // try to wake storage thread - let _ = self.storage_queue.wake(); - message = m; - } else { - break; - } - } - Ok(()) - } - Err(ParseError::Incomplete) => { - trace!("incomplete request for session: {:?}", session); - Err(std::io::Error::new( - std::io::ErrorKind::WouldBlock, - "incomplete request", - )) - } - Err(_) => { - debug!("bad request for session: {:?}", session); - trace!("session: {:?} read buffer: {:?}", session, session.buffer()); - let _ = self.poll.close_session(token); - Err(std::io::Error::new( - std::io::ErrorKind::Other, - "bad request", - )) - } - } - } - - fn handle_storage_queue( - &mut self, - responses: &mut Vec>>, - ) { - trace!("handling event for storage queue"); - // process all storage queue responses - self.storage_queue.try_recv_all(responses); - - for message in responses.drain(..).map(|v| v.into_inner()) { - let token = message.token(); - let mut reregister = false; - if let Ok(session) = self.poll.get_mut_session(token) { - let result = message.into_inner(); - trace!("composing response for session: {:?}", session); - result.compose(session); - session.finalize_response(); - // if we have pending writes, we should attempt to flush the session - // now. if we still have pending bytes, we should re-register to - // remove the read interest. - if session.write_pending() > 0 { - let _ = session.flush(); - if session.write_pending() > 0 { - reregister = true; - } - } - if session.read_pending() > 0 && self.handle_session_read(token).is_ok() { - let _ = self.storage_queue.wake(); - } - } - if reregister { - self.poll.reregister(token); - } - } - let _ = self.storage_queue.wake(); - } - - fn handle_new_sessions(&mut self, sessions: &mut Vec>) { - self.session_queue.try_recv_all(sessions); - for session in sessions.drain(..).map(|v| v.into_inner()) { - let pending = session.read_pending(); - trace!( - "new session: {:?} with {} bytes pending in read buffer", - session, - pending - ); - - if let Ok(token) = self.poll.add_session(session) { - if pending > 0 { - // handle any pending data immediately - if self.handle_data(token).is_err() { - self.handle_error(token); - } - } - } - } - } -} - -impl EventLoop - for MultiWorker -where - Parser: Parse, - Response: Compose, - Storage: Execute + EntryStore, -{ - fn handle_data(&mut self, token: Token) -> Result<(), std::io::Error> { - let _ = self.handle_session_read(token); - Ok(()) - } - - fn poll(&mut self) -> &mut Poll { - &mut self.poll - } -} diff --git a/src/core/server/src/threads/workers/single.rs b/src/core/server/src/threads/workers/single.rs deleted file mode 100644 index 12eaf9a63..000000000 --- a/src/core/server/src/threads/workers/single.rs +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright 2021 Twitter, Inc. -// Licensed under the Apache License, Version 2.0 -// http://www.apache.org/licenses/LICENSE-2.0 - -//! The single-threaded worker, which is used when there is only one worker -//! thread configured. This worker parses buffers to produce requests, executes -//! the request using the backing storage, and then composes a response onto the -//! session buffer. - -use super::EventLoop; -use super::*; -use crate::poll::{Poll, WAKER_TOKEN}; -use common::signal::Signal; -use config::WorkerConfig; -use core::marker::PhantomData; -use core::time::Duration; -use entrystore::EntryStore; -use mio::event::Event; -use mio::Events; -use mio::Token; -use mio::Waker; -use protocol_common::{Compose, Execute, Parse, ParseError}; -use session::Session; -use std::io::{BufRead, Write}; -use std::sync::Arc; - -/// A builder type for a single-threaded worker which owns the storage. -pub struct SingleWorkerBuilder { - nevent: usize, - parser: Parser, - poll: Poll, - timeout: Duration, - storage: Storage, - _request: PhantomData, - _response: PhantomData, -} - -impl SingleWorkerBuilder { - /// Create a new builder for a single-threaded worker from the provided - /// config, storage, and parser - pub fn new( - config: &T, - storage: Storage, - parser: Parser, - ) -> Result { - let poll = Poll::new().map_err(|e| { - error!("{}", e); - std::io::Error::new(std::io::ErrorKind::Other, "Failed to create epoll instance") - })?; - - Ok(Self { - poll, - nevent: config.worker().nevent(), - timeout: Duration::from_millis(config.worker().timeout() as u64), - storage, - _request: PhantomData, - _response: PhantomData, - parser, - }) - } - - /// Returns the waker for this worker. - pub(crate) fn waker(&self) -> Arc { - self.poll.waker() - } - - /// Finalize the builder and return a `SingleWorker` by providing the queues - /// that are required to communicate with other threads. - pub fn build( - self, - signal_queue: Queues<(), Signal>, - session_queue: Queues<(), Session>, - ) -> SingleWorker { - SingleWorker { - nevent: self.nevent, - parser: self.parser, - poll: self.poll, - timeout: self.timeout, - storage: self.storage, - session_queue, - signal_queue, - _request: PhantomData, - _response: PhantomData, - } - } -} - -/// A finalized single-threaded worker which is ready to be run. -pub struct SingleWorker { - nevent: usize, - parser: Parser, - poll: Poll, - timeout: Duration, - storage: Storage, - session_queue: Queues<(), Session>, - signal_queue: Queues<(), Signal>, - _request: PhantomData, - _response: PhantomData, -} - -impl SingleWorker -where - Parser: Parse, - Response: Compose, - Storage: Execute + EntryStore, -{ - /// Run the worker in a loop, handling new events. - pub fn run(&mut self) { - let mut events = Events::with_capacity(self.nevent); - - loop { - WORKER_EVENT_LOOP.increment(); - - self.storage.expire(); - - // get events with timeout - if self.poll.poll(&mut events, self.timeout).is_err() { - error!("Error polling"); - } - - let timestamp = Instant::now(); - - let count = events.iter().count(); - WORKER_EVENT_TOTAL.add(count as _); - if count == self.nevent { - WORKER_EVENT_MAX_REACHED.increment(); - } else { - WORKER_EVENT_DEPTH.increment(timestamp, count as _, 1); - } - - // process all events - for event in events.iter() { - match event.token() { - WAKER_TOKEN => { - self.handle_new_sessions(); - - // check if we received any signals from the admin thread - while let Some(signal) = self.signal_queue.try_recv() { - match signal.into_inner() { - Signal::FlushAll => { - warn!("received flush_all"); - self.storage.clear(); - } - Signal::Shutdown => { - // if we received a shutdown, we can return - // and stop processing events - return; - } - } - } - } - _ => { - self.handle_event(event, timestamp); - } - } - } - } - } - - fn handle_new_sessions(&mut self) { - while let Some(session) = self.session_queue.try_recv().map(|v| v.into_inner()) { - let pending = session.read_pending(); - trace!( - "new session: {:?} with {} bytes pending in read buffer", - session, - pending - ); - - // reserve vacant slab - if let Ok(token) = self.poll.add_session(session) { - if pending > 0 { - // handle any pending data immediately - if self.handle_data(token).is_err() { - self.handle_error(token); - } - } - } - } - } - - fn handle_event(&mut self, event: &Event, timestamp: Instant) { - let token = event.token(); - - // handle error events first - if event.is_error() { - WORKER_EVENT_ERROR.increment(); - self.handle_error(token); - } - - // handle write events before read events to reduce write buffer - // growth if there is also a readable event - if event.is_writable() { - WORKER_EVENT_WRITE.increment(); - self.do_write(token); - } - - // read events are handled last - if event.is_readable() { - WORKER_EVENT_READ.increment(); - if let Ok(session) = self.poll.get_mut_session(token) { - session.set_timestamp(timestamp); - } - let _ = self.do_read(token); - } - - if let Ok(session) = self.poll.get_mut_session(token) { - if session.read_pending() > 0 { - trace!( - "session: {:?} has {} bytes pending in read buffer", - session, - session.read_pending() - ); - } - if session.write_pending() > 0 { - trace!( - "session: {:?} has {} bytes pending in write buffer", - session, - session.read_pending() - ); - } - } - } -} - -impl EventLoop - for SingleWorker -where - Parser: Parse, - Response: Compose, - Storage: Execute + EntryStore, -{ - fn handle_data(&mut self, token: Token) -> Result<(), std::io::Error> { - if let Ok(session) = self.poll.get_mut_session(token) { - loop { - if session.write_capacity() == 0 { - // if the write buffer is over-full, skip processing - break; - } - match self.parser.parse(session.buffer()) { - Ok(parsed_request) => { - trace!("parsed request for sesion: {:?}", session); - PROCESS_REQ.increment(); - let consumed = parsed_request.consumed(); - let request = parsed_request.into_inner(); - session.consume(consumed); - - let result = self.storage.execute(request); - trace!("composing response for session: {:?}", session); - result.compose(session); - session.finalize_response(); - if result.should_hangup() { - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "response requires hangup", - )); - } - } - Err(ParseError::Incomplete) => { - trace!("incomplete request for session: {:?}", session); - break; - } - Err(_) => { - debug!("bad request for session: {:?}", session); - trace!("session: {:?} read buffer: {:?}", session, session.buffer()); - self.handle_error(token); - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "bad request", - )); - } - } - } - // if we have pending writes, we should attempt to flush the session - // now. if we still have pending bytes, we should re-register to - // remove the read interest. - if session.write_pending() > 0 { - let _ = session.flush(); - if session.write_pending() > 0 { - self.poll.reregister(token); - } - } - Ok(()) - } else { - // no session for the token - trace!( - "attempted to handle data for non-existent session: {}", - token.0 - ); - Ok(()) - } - } - - fn poll(&mut self) -> &mut Poll { - &mut self.poll - } -} diff --git a/src/core/server/src/workers/mod.rs b/src/core/server/src/workers/mod.rs new file mode 100644 index 000000000..80d17c9a6 --- /dev/null +++ b/src/core/server/src/workers/mod.rs @@ -0,0 +1,198 @@ +// Copyright 2021 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + +use crate::*; +use std::thread::JoinHandle; + +mod multi; +mod single; +mod storage; + +use multi::*; +use single::*; +use storage::*; + +heatmap!( + WORKER_EVENT_DEPTH, + 100_000, + "distribution of the number of events received per iteration of the event loop" +); +counter!(WORKER_EVENT_ERROR, "the number of error events received"); +counter!( + WORKER_EVENT_LOOP, + "the number of times the event loop has run" +); +counter!( + WORKER_EVENT_MAX_REACHED, + "the number of times the maximum number of events was returned" +); +counter!(WORKER_EVENT_READ, "the number of read events received"); +counter!(WORKER_EVENT_TOTAL, "the total number of events received"); +counter!(WORKER_EVENT_WRITE, "the number of write events received"); + +fn map_result(result: Result) -> Result<()> { + match result { + Ok(0) => Err(Error::new(ErrorKind::Other, "client hangup")), + Ok(_) => Ok(()), + Err(e) => map_err(e), + } +} + +pub enum Workers { + Single { + worker: SingleWorker, + }, + Multi { + workers: Vec>, + storage: StorageWorker, + }, +} + +impl Workers +where + Parser: 'static + Parse + Clone + Send, + Request: 'static + Klog + Klog + Send, + Response: 'static + Compose + Send, + Storage: 'static + EntryStore + Execute + Send, +{ + pub fn spawn(self) -> Vec> { + match self { + Self::Single { mut worker } => { + vec![std::thread::Builder::new() + .name(format!("{}_work", THREAD_PREFIX)) + .spawn(move || worker.run()) + .unwrap()] + } + Self::Multi { + mut workers, + mut storage, + } => { + let mut join_handles = vec![std::thread::Builder::new() + .name(format!("{}_storage", THREAD_PREFIX)) + .spawn(move || storage.run()) + .unwrap()]; + + for (id, mut worker) in workers.drain(..).enumerate() { + join_handles.push( + std::thread::Builder::new() + .name(format!("{}_work_{}", THREAD_PREFIX, id)) + .spawn(move || worker.run()) + .unwrap(), + ) + } + + join_handles + } + } + } +} + +pub enum WorkersBuilder { + Single { + worker: SingleWorkerBuilder, + }, + Multi { + workers: Vec>, + storage: StorageWorkerBuilder, + }, +} + +impl WorkersBuilder +where + Parser: Parse + Clone, + Response: Compose, + Storage: Execute + EntryStore, +{ + pub fn new(config: &T, parser: Parser, storage: Storage) -> Result { + let threads = config.worker().threads(); + + if threads > 1 { + let mut workers = vec![]; + for _ in 0..threads { + workers.push(MultiWorkerBuilder::new(config, parser.clone())?) + } + + Ok(Self::Multi { + workers, + storage: StorageWorkerBuilder::new(config, storage)?, + }) + } else { + Ok(Self::Single { + worker: SingleWorkerBuilder::new(config, parser, storage)?, + }) + } + } + + pub fn worker_wakers(&self) -> Vec> { + match self { + Self::Single { worker } => { + vec![worker.waker()] + } + Self::Multi { + workers, + storage: _, + } => workers.iter().map(|w| w.waker()).collect(), + } + } + + pub fn wakers(&self) -> Vec> { + match self { + Self::Single { worker } => { + vec![worker.waker()] + } + Self::Multi { workers, storage } => { + let mut wakers = vec![storage.waker()]; + for worker in workers { + wakers.push(worker.waker()); + } + wakers + } + } + } + + pub fn build( + self, + session_queues: Vec>, + signal_queues: Vec>, + ) -> Workers { + let mut signal_queues = signal_queues; + let mut session_queues = session_queues; + match self { + Self::Multi { + storage, + mut workers, + } => { + let storage_wakers = vec![storage.waker()]; + let worker_wakers: Vec> = workers.iter().map(|v| v.waker()).collect(); + let (mut worker_data_queues, mut storage_data_queues) = + Queues::new(worker_wakers, storage_wakers, QUEUE_CAPACITY); + + // The storage thread precedes the worker threads in the set of + // wakers, so its signal queue is the first element of + // `signal_queues`. Its request queue is also the first (and + // only) element of `request_queues`. We remove these and build + // the storage so we can loop through the remaining signal + // queues when launching the worker threads. + let s = storage.build(storage_data_queues.remove(0), signal_queues.remove(0)); + + let mut w = Vec::new(); + for worker_builder in workers.drain(..) { + w.push(worker_builder.build( + worker_data_queues.remove(0), + session_queues.remove(0), + signal_queues.remove(0), + )); + } + + Workers::Multi { + storage: s, + workers: w, + } + } + Self::Single { worker } => Workers::Single { + worker: worker.build(session_queues.remove(0), signal_queues.remove(0)), + }, + } + } +} diff --git a/src/core/server/src/workers/multi.rs b/src/core/server/src/workers/multi.rs new file mode 100644 index 000000000..113cf08b9 --- /dev/null +++ b/src/core/server/src/workers/multi.rs @@ -0,0 +1,265 @@ +// Copyright 2021 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + +use super::*; + +pub struct MultiWorkerBuilder { + nevent: usize, + parser: Parser, + poll: Poll, + sessions: Slab>, + timeout: Duration, + waker: Arc, +} + +impl MultiWorkerBuilder { + pub fn new(config: &T, parser: Parser) -> Result { + let config = config.worker(); + + let poll = Poll::new()?; + + let waker = Arc::new(Waker::from( + ::net::Waker::new(poll.registry(), WAKER_TOKEN).unwrap(), + )); + + let nevent = config.nevent(); + let timeout = Duration::from_millis(config.timeout() as u64); + + Ok(Self { + nevent, + parser, + poll, + sessions: Slab::new(), + timeout, + waker, + }) + } + + pub fn waker(&self) -> Arc { + self.waker.clone() + } + + pub fn build( + self, + data_queue: Queues<(Request, Token), (Request, Response, Token)>, + session_queue: Queues, + signal_queue: Queues<(), Signal>, + ) -> MultiWorker { + MultiWorker { + data_queue, + nevent: self.nevent, + parser: self.parser, + poll: self.poll, + session_queue, + sessions: self.sessions, + signal_queue, + timeout: self.timeout, + waker: self.waker, + } + } +} + +pub struct MultiWorker { + data_queue: Queues<(Request, Token), (Request, Response, Token)>, + nevent: usize, + parser: Parser, + poll: Poll, + session_queue: Queues, + sessions: Slab>, + signal_queue: Queues<(), Signal>, + timeout: Duration, + waker: Arc, +} + +impl MultiWorker +where + Parser: Parse + Clone, + Request: Klog + Klog, + Response: Compose, +{ + /// Return the `Session` to the `Listener` to handle flush/close + fn close(&mut self, token: Token) { + if self.sessions.contains(token.0) { + let mut session = self.sessions.remove(token.0).into_inner(); + let _ = session.deregister(self.poll.registry()); + let _ = self.session_queue.try_send_any(session); + let _ = self.session_queue.wake(); + } + } + + /// Handle up to one request for a session + fn read(&mut self, token: Token) -> Result<()> { + let session = self + .sessions + .get_mut(token.0) + .ok_or_else(|| Error::new(ErrorKind::Other, "non-existant session"))?; + + // fill the session + map_result(session.fill())?; + + // process up to one request + match session.receive() { + Ok(request) => self + .data_queue + .try_send_to(0, (request, token)) + .map_err(|_| Error::new(ErrorKind::Other, "data queue is full")), + Err(e) => map_err(e), + } + } + + /// Handle write by flushing the session + fn write(&mut self, token: Token) -> Result<()> { + let session = self + .sessions + .get_mut(token.0) + .ok_or_else(|| Error::new(ErrorKind::Other, "non-existant session"))?; + + match session.flush() { + Ok(_) => Ok(()), + Err(e) => map_err(e), + } + } + + /// Run the worker in a loop, handling new events. + pub fn run(&mut self) { + // these are buffers which are re-used in each loop iteration to receive + // events and queue messages + let mut events = Events::with_capacity(self.nevent); + let mut messages = Vec::with_capacity(QUEUE_CAPACITY); + + loop { + WORKER_EVENT_LOOP.increment(); + + // get events with timeout + if self.poll.poll(&mut events, Some(self.timeout)).is_err() { + error!("Error polling"); + } + + let timestamp = Instant::now(); + + let count = events.iter().count(); + WORKER_EVENT_TOTAL.add(count as _); + if count == self.nevent { + WORKER_EVENT_MAX_REACHED.increment(); + } else { + WORKER_EVENT_DEPTH.increment(timestamp, count as _, 1); + } + + // process all events + for event in events.iter() { + let token = event.token(); + match token { + WAKER_TOKEN => { + self.waker.reset(); + // handle up to one new session + if let Some(mut session) = + self.session_queue.try_recv().map(|v| v.into_inner()) + { + let s = self.sessions.vacant_entry(); + let interest = session.interest(); + if session + .register(self.poll.registry(), Token(s.key()), interest) + .is_ok() + { + s.insert(ServerSession::new(session, self.parser.clone())); + } else { + let _ = self.session_queue.try_send_any(session); + } + + // trigger a wake-up in case there are more sessions + let _ = self.waker.wake(); + } + + // handle all pending messages on the data queue + self.data_queue.try_recv_all(&mut messages); + for (request, response, token) in messages.drain(..).map(|v| v.into_inner()) + { + request.klog(&response); + if let Some(session) = self.sessions.get_mut(token.0) { + if response.should_hangup() { + let _ = session.send(response); + self.close(token); + continue; + } else if session.send(response).is_err() { + self.close(token); + continue; + } else if session.write_pending() > 0 { + // try to immediately flush, if we still + // have pending bytes, reregister. This + // saves us one syscall when flushing would + // not block. + if let Err(e) = session.flush() { + if map_err(e).is_err() { + self.close(token); + continue; + } + } + + if session.write_pending() > 0 { + let interest = session.interest(); + if session + .reregister(self.poll.registry(), token, interest) + .is_err() + { + self.close(token); + continue; + } + } + } + + if session.remaining() > 0 && self.read(token).is_err() { + self.close(token); + continue; + } + } + } + + // check if we received any signals from the admin thread + while let Some(signal) = + self.signal_queue.try_recv().map(|v| v.into_inner()) + { + match signal { + Signal::FlushAll => {} + Signal::Shutdown => { + // if we received a shutdown, we can return + // and stop processing events + return; + } + } + } + } + _ => { + if event.is_error() { + WORKER_EVENT_ERROR.increment(); + + self.close(token); + continue; + } + + if event.is_writable() { + WORKER_EVENT_WRITE.increment(); + + if self.write(token).is_err() { + self.close(token); + continue; + } + } + + if event.is_readable() { + WORKER_EVENT_READ.increment(); + + if self.read(token).is_err() { + self.close(token); + continue; + } + } + } + } + } + + // wakes the storage thread if necessary + let _ = self.data_queue.wake(); + } + } +} diff --git a/src/core/server/src/workers/single.rs b/src/core/server/src/workers/single.rs new file mode 100644 index 000000000..974955a10 --- /dev/null +++ b/src/core/server/src/workers/single.rs @@ -0,0 +1,289 @@ +// Copyright 2021 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + +use super::*; +use std::collections::VecDeque; + +pub struct SingleWorkerBuilder { + nevent: usize, + parser: Parser, + pending: VecDeque, + poll: Poll, + sessions: Slab>, + storage: Storage, + timeout: Duration, + waker: Arc, +} + +impl SingleWorkerBuilder { + pub fn new(config: &T, parser: Parser, storage: Storage) -> Result { + let config = config.worker(); + + let poll = Poll::new()?; + + let waker = Arc::new(Waker::from( + ::net::Waker::new(poll.registry(), WAKER_TOKEN).unwrap(), + )); + + let nevent = config.nevent(); + let timeout = Duration::from_millis(config.timeout() as u64); + + Ok(Self { + nevent, + parser, + pending: VecDeque::new(), + poll, + sessions: Slab::new(), + storage, + timeout, + waker, + }) + } + + pub fn waker(&self) -> Arc { + self.waker.clone() + } + + pub fn build( + self, + session_queue: Queues, + signal_queue: Queues<(), Signal>, + ) -> SingleWorker { + SingleWorker { + nevent: self.nevent, + parser: self.parser, + pending: self.pending, + poll: self.poll, + session_queue, + sessions: self.sessions, + signal_queue, + storage: self.storage, + timeout: self.timeout, + waker: self.waker, + } + } +} + +pub struct SingleWorker { + nevent: usize, + parser: Parser, + pending: VecDeque, + poll: Poll, + session_queue: Queues, + sessions: Slab>, + signal_queue: Queues<(), Signal>, + storage: Storage, + timeout: Duration, + waker: Arc, +} + +impl SingleWorker +where + Parser: Parse + Clone, + Request: Klog + Klog, + Response: Compose, + Storage: EntryStore + Execute, +{ + /// Return the `Session` to the `Listener` to handle flush/close + fn close(&mut self, token: Token) { + if self.sessions.contains(token.0) { + let mut session = self.sessions.remove(token.0).into_inner(); + let _ = self.poll.registry().deregister(&mut session); + let _ = self.session_queue.try_send_any(session); + let _ = self.session_queue.wake(); + } + } + + /// Handle up to one request for a session + fn read(&mut self, token: Token) -> Result<()> { + let session = self + .sessions + .get_mut(token.0) + .ok_or_else(|| Error::new(ErrorKind::Other, "non-existant session"))?; + + // fill the session + map_result(session.fill())?; + + // process up to one pending request + match session.receive() { + Ok(request) => { + let response = self.storage.execute(&request); + PROCESS_REQ.increment(); + if response.should_hangup() { + let _ = session.send(response); + return Err(Error::new(ErrorKind::Other, "should hangup")); + } + request.klog(&response); + match session.send(response) { + Ok(_) => { + // attempt to flush immediately if there's now data in + // the write buffer + if session.write_pending() > 0 { + match session.flush() { + Ok(_) => Ok(()), + Err(e) => map_err(e), + }?; + } + + // reregister to get writable event + if session.write_pending() > 0 { + let interest = session.interest(); + if self + .poll + .registry() + .reregister(session, token, interest) + .is_err() + { + return Err(Error::new(ErrorKind::Other, "failed to reregister")); + } + } + + // if there's still data to read, put the token on the + // pending queue + if session.remaining() > 0 { + self.pending.push_back(token); + } + + Ok(()) + } + Err(e) => { + if e.kind() == ErrorKind::WouldBlock { + Ok(()) + } else { + Err(e) + } + } + } + } + Err(e) => { + if e.kind() == ErrorKind::WouldBlock { + Ok(()) + } else { + Err(e) + } + } + } + } + + fn write(&mut self, token: Token) -> Result<()> { + let session = self + .sessions + .get_mut(token.0) + .ok_or_else(|| Error::new(ErrorKind::Other, "non-existant session"))?; + + match session.flush() { + Ok(_) => Ok(()), + Err(e) => map_err(e), + } + } + + /// Run the worker in a loop, handling new events. + pub fn run(&mut self) { + let mut events = Events::with_capacity(self.nevent); + + loop { + WORKER_EVENT_LOOP.increment(); + + self.storage.expire(); + + // we need another wakeup if there are still pending reads + if !self.pending.is_empty() { + let _ = self.waker.wake(); + } + + // get events with timeout + if self.poll.poll(&mut events, Some(self.timeout)).is_err() { + error!("Error polling"); + } + + let timestamp = Instant::now(); + + let count = events.iter().count(); + WORKER_EVENT_TOTAL.add(count as _); + if count == self.nevent { + WORKER_EVENT_MAX_REACHED.increment(); + } else { + WORKER_EVENT_DEPTH.increment(timestamp, count as _, 1); + } + + // process all events + for event in events.iter() { + let token = event.token(); + + match token { + WAKER_TOKEN => { + self.waker.reset(); + // handle outstanding reads + for _ in 0..self.pending.len() { + if let Some(token) = self.pending.pop_front() { + if self.read(token).is_err() { + self.close(token); + } + } + } + + // handle up to one new session + if let Some(mut session) = + self.session_queue.try_recv().map(|v| v.into_inner()) + { + let s = self.sessions.vacant_entry(); + let interest = session.interest(); + if session + .register(self.poll.registry(), Token(s.key()), interest) + .is_ok() + { + s.insert(ServerSession::new(session, self.parser.clone())); + } else { + let _ = self.session_queue.try_send_any(session); + } + + // trigger a wake-up in case there are more sessions + let _ = self.waker.wake(); + } + + // check if we received any signals from the admin thread + while let Some(signal) = self.signal_queue.try_recv() { + match signal.into_inner() { + Signal::FlushAll => { + self.storage.clear(); + } + Signal::Shutdown => { + // if we received a shutdown, we can return + // and stop processing events + return; + } + } + } + } + _ => { + if event.is_error() { + WORKER_EVENT_ERROR.increment(); + + self.close(token); + continue; + } + + if event.is_writable() { + WORKER_EVENT_WRITE.increment(); + + if self.write(token).is_err() { + self.close(token); + continue; + } + } + + if event.is_readable() { + WORKER_EVENT_READ.increment(); + + if self.read(token).is_err() { + self.close(token); + continue; + } + } + } + } + } + } + } +} diff --git a/src/core/server/src/threads/workers/storage.rs b/src/core/server/src/workers/storage.rs similarity index 53% rename from src/core/server/src/threads/workers/storage.rs rename to src/core/server/src/workers/storage.rs index 6a77ed886..10cd25ce3 100644 --- a/src/core/server/src/threads/workers/storage.rs +++ b/src/core/server/src/workers/storage.rs @@ -2,89 +2,98 @@ // Licensed under the Apache License, Version 2.0 // http://www.apache.org/licenses/LICENSE-2.0 -use super::*; -use crate::poll::Poll; -use crate::QUEUE_RETRIES; -use common::signal::Signal; -use config::WorkerConfig; -use core::time::Duration; -use entrystore::EntryStore; -use mio::{Events, Waker}; -use protocol_common::Execute; -use std::marker::PhantomData; -use std::sync::Arc; - -/// A builder type for a storage worker which owns the storage and executes -/// requests from a queue and returns responses back to the worker threads. -pub struct StorageWorkerBuilder { - poll: Poll, +use crate::*; + +counter!( + STORAGE_EVENT_LOOP, + "the number of times the event loop has run" +); +heatmap!( + STORAGE_QUEUE_DEPTH, + 1_000_000, + "the distribution of the depth of the storage queue on each loop" +); + +pub struct StorageWorkerBuilder { nevent: usize, - timeout: Duration, + poll: Poll, storage: Storage, + timeout: Duration, + waker: Arc, _request: PhantomData, _response: PhantomData, } -impl StorageWorkerBuilder { - /// Create a new `StorageWorkerBuilder` from the config and storage. - pub fn new(config: &T, storage: Storage) -> Result { - let poll = Poll::new().map_err(|e| { - error!("{}", e); - std::io::Error::new(std::io::ErrorKind::Other, "Failed to create epoll instance") - })?; +impl StorageWorkerBuilder { + pub fn new(config: &T, storage: Storage) -> Result { + let config = config.worker(); + + let poll = Poll::new()?; + + let waker = Arc::new(Waker::from( + ::net::Waker::new(poll.registry(), WAKER_TOKEN).unwrap(), + )); + + let nevent = config.nevent(); + let timeout = Duration::from_millis(config.timeout() as u64); Ok(Self { - nevent: config.worker().nevent(), - timeout: Duration::from_millis(config.worker().timeout() as u64), + nevent, poll, storage, + timeout, + waker, _request: PhantomData, _response: PhantomData, }) } - /// Returns the waker for the storage worker. - pub(crate) fn waker(&self) -> Arc { - self.poll.waker() + pub fn waker(&self) -> Arc { + self.waker.clone() } - /// Finalize the builder and return a `StorageWorker` by providing the - /// queues that are necessary for communication with other threads. pub fn build( self, + data_queue: Queues<(Request, Response, Token), (Request, Token)>, signal_queue: Queues<(), Signal>, - storage_queue: Queues, TokenWrapper>, - ) -> StorageWorker { + ) -> StorageWorker { StorageWorker { - poll: self.poll, + data_queue, nevent: self.nevent, - timeout: self.timeout, + poll: self.poll, signal_queue, storage: self.storage, - storage_queue, + timeout: self.timeout, + waker: self.waker, + _request: PhantomData, + _response: PhantomData, } } } -/// A finalized `StorageWorker` which is ready to be run. -pub struct StorageWorker { - poll: Poll, +pub struct StorageWorker { + data_queue: Queues<(Request, Response, Token), (Request, Token)>, nevent: usize, - timeout: Duration, + poll: Poll, signal_queue: Queues<(), Signal>, storage: Storage, - storage_queue: Queues, TokenWrapper>, + timeout: Duration, + #[allow(dead_code)] + waker: Arc, + _request: PhantomData, + _response: PhantomData, } -impl StorageWorker +impl StorageWorker where Storage: Execute + EntryStore, + Request: Klog + Klog, Response: Compose, { /// Run the `StorageWorker` in a loop, handling new session events. pub fn run(&mut self) { let mut events = Events::with_capacity(self.nevent); - let mut requests = Vec::with_capacity(1024); + let mut messages = Vec::with_capacity(1024); loop { STORAGE_EVENT_LOOP.increment(); @@ -92,33 +101,35 @@ where self.storage.expire(); // get events with timeout - if self.poll.poll(&mut events, self.timeout).is_err() { + if self.poll.poll(&mut events, Some(self.timeout)).is_err() { error!("Error polling"); } let timestamp = Instant::now(); if !events.is_empty() { + self.waker.reset(); + trace!("handling events"); - self.storage_queue.try_recv_all(&mut requests); + self.data_queue.try_recv_all(&mut messages); - STORAGE_QUEUE_DEPTH.increment(timestamp, requests.len() as _, 1); + STORAGE_QUEUE_DEPTH.increment(timestamp, messages.len() as _, 1); - for request in requests.drain(..) { - let sender = request.sender(); - let request = request.into_inner(); + for message in messages.drain(..) { + let sender = message.sender(); + let (request, token) = message.into_inner(); trace!("handling request from worker: {}", sender); + let response = self.storage.execute(&request); PROCESS_REQ.increment(); - let token = request.token(); - let response = self.storage.execute(request.into_inner()); - let mut message = TokenWrapper::new(response, token); + let mut message = (request, response, token); for retry in 0..QUEUE_RETRIES { - if let Err(m) = self.storage_queue.try_send_to(sender, message) { + if let Err(m) = self.data_queue.try_send_to(sender, message) { if (retry + 1) == QUEUE_RETRIES { error!("error sending message to worker"); } - let _ = self.storage_queue.wake(); + // wake workers immediately + let _ = self.data_queue.wake(); message = m; } else { break; @@ -126,7 +137,7 @@ where } } - let _ = self.storage_queue.wake(); + let _ = self.data_queue.wake(); // check if we received any signals from the admin thread while let Some(s) = self.signal_queue.try_recv().map(|v| v.into_inner()) { diff --git a/src/core/waker/Cargo.toml b/src/core/waker/Cargo.toml new file mode 100644 index 000000000..c7096142c --- /dev/null +++ b/src/core/waker/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "waker" +version = "0.1.0" +edition = "2021" +authors = ["Brian Martin "] +homepage = "https://pelikan.io" +repository = "https://github.com/twitter/pelikan" +license = "Apache-2.0" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +mio = "0.8.4" + +[target.'cfg(target_os = "linux")'.dependencies] +libc = "0.2.126" diff --git a/src/core/waker/src/lib.rs b/src/core/waker/src/lib.rs new file mode 100644 index 000000000..be0a9df75 --- /dev/null +++ b/src/core/waker/src/lib.rs @@ -0,0 +1,147 @@ +// Copyright 2022 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + +//! Provides a `Waker` trait to allow using the `Waker` from `mio` or a provided +//! `Waker` that uses eventfd directly (supported only on linux) interchangably. +//! +//! This is particularly useful in cases where some struct (such as a queue) may +//! be used with either `mio`-based event loops, or with io_uring. The `Waker` +//! provided by `mio` is not directly usable in io_uring based code due to the +//! fact that it must be registered to an event loop (such as epoll). + +use core::sync::atomic::{AtomicU64, Ordering}; + +pub struct Waker { + inner: Box, + pending: AtomicU64, +} + +impl From for Waker { + fn from(other: MioWaker) -> Self { + Self { + inner: Box::new(other), + pending: AtomicU64::new(0), + } + } +} + +impl Waker { + pub fn wake(&self) -> std::io::Result<()> { + if self.pending.fetch_add(1, Ordering::Relaxed) == 0 { + self.inner.wake() + } else { + Ok(()) + } + } + + pub fn as_raw_fd(&self) -> Option { + self.inner.as_raw_fd() + } + + pub fn reset(&self) { + self.pending.store(0, Ordering::Relaxed); + } +} + +pub trait GenericWaker: Send + Sync { + fn wake(&self) -> std::io::Result<()>; + + fn as_raw_fd(&self) -> Option; +} + +use std::os::unix::prelude::RawFd; + +pub use mio::Waker as MioWaker; + +impl GenericWaker for MioWaker { + fn wake(&self) -> std::io::Result<()> { + self.wake() + } + + fn as_raw_fd(&self) -> Option { + None + } +} + +#[cfg(target_os = "linux")] +pub use self::eventfd::EventfdWaker; + +#[cfg(target_os = "linux")] +mod eventfd { + use crate::*; + use std::fs::File; + use std::io::{ErrorKind, Result, Write}; + use std::os::unix::io::{AsRawFd, FromRawFd}; + use std::os::unix::prelude::RawFd; + + pub struct EventfdWaker { + inner: File, + } + + // a simple eventfd waker. based off the implementation in mio + impl EventfdWaker { + pub fn new() -> Result { + let ret = unsafe { libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK) }; + if ret < 0 { + Err(std::io::Error::new( + ErrorKind::Other, + "failed to create eventfd", + )) + } else { + Ok(Self { + inner: unsafe { File::from_raw_fd(ret) }, + }) + } + } + + pub fn wake(&self) -> Result<()> { + match (&self.inner).write(&[1, 0, 0, 0, 0, 0, 0, 0]) { + Ok(_) => Ok(()), + Err(e) => { + if e.kind() == ErrorKind::WouldBlock { + // writing blocks if the counter would overflow, reset it + // and wake again + self.reset()?; + self.wake() + } else { + Err(e) + } + } + } + } + + fn reset(&self) -> Result<()> { + match (&self.inner).write(&[0, 0, 0, 0, 0, 0, 0, 0]) { + Ok(_) => Ok(()), + Err(e) => { + if e.kind() == ErrorKind::WouldBlock { + // we can ignore wouldblock during reset + Ok(()) + } else { + Err(e) + } + } + } + } + } + + impl GenericWaker for EventfdWaker { + fn wake(&self) -> Result<()> { + self.wake() + } + + fn as_raw_fd(&self) -> Option { + Some(self.inner.as_raw_fd()) + } + } + + impl From for Waker { + fn from(other: EventfdWaker) -> Self { + Self { + inner: Box::new(other), + pending: AtomicU64::new(0), + } + } + } +} diff --git a/src/entrystore/src/noop/ping.rs b/src/entrystore/src/noop/ping.rs index 4be142d18..3870d4547 100644 --- a/src/entrystore/src/noop/ping.rs +++ b/src/entrystore/src/noop/ping.rs @@ -11,11 +11,9 @@ use protocol_ping::*; impl PingStorage for Noop {} impl Execute for Noop { - fn execute(&mut self, request: Request) -> Box> { - let response = match request { + fn execute(&mut self, request: &Request) -> Response { + match request { Request::Ping => Response::Pong, - }; - - Box::new(PingExecutionResult::new(request, response)) + } } } diff --git a/src/entrystore/src/seg/memcache.rs b/src/entrystore/src/seg/memcache.rs index ee61c701e..249a72a77 100644 --- a/src/entrystore/src/seg/memcache.rs +++ b/src/entrystore/src/seg/memcache.rs @@ -13,24 +13,22 @@ use protocol_memcache::*; use std::time::Duration; impl Execute for Seg { - fn execute(&mut self, request: Request) -> Box> { - let response = match request { - Request::Get(ref get) => self.get(&get), - Request::Gets(ref gets) => self.gets(&gets), - Request::Set(ref set) => self.set(&set), - Request::Add(ref add) => self.add(&add), - Request::Replace(ref replace) => self.replace(&replace), - Request::Cas(ref cas) => self.cas(&cas), - Request::Incr(ref incr) => self.incr(&incr), - Request::Decr(ref decr) => self.decr(&decr), - Request::Append(ref append) => self.append(&append), - Request::Prepend(ref prepend) => self.prepend(&prepend), - Request::Delete(ref delete) => self.delete(&delete), - Request::FlushAll(ref flush_all) => self.flush_all(&flush_all), - Request::Quit(ref quit) => self.quit(&quit), - }; - - Box::new(MemcacheExecutionResult::new(request, response)) + fn execute(&mut self, request: &Request) -> Response { + match request { + Request::Get(get) => self.get(get), + Request::Gets(gets) => self.gets(gets), + Request::Set(set) => self.set(set), + Request::Add(add) => self.add(add), + Request::Replace(replace) => self.replace(replace), + Request::Cas(cas) => self.cas(cas), + Request::Incr(incr) => self.incr(incr), + Request::Decr(decr) => self.decr(decr), + Request::Append(append) => self.append(append), + Request::Prepend(prepend) => self.prepend(prepend), + Request::Delete(delete) => self.delete(delete), + Request::FlushAll(flush_all) => self.flush_all(flush_all), + Request::Quit(quit) => self.quit(quit), + } } } @@ -50,10 +48,12 @@ impl Storage for Seg { item.key(), flags, None, - &format!("{}", v).as_bytes(), + format!("{}", v).as_bytes(), )); } } + } else { + values.push(Value::none(key)); } } Values::new(values.into_boxed_slice()).into() @@ -74,58 +74,29 @@ impl Storage for Seg { item.key(), flags, Some(item.cas().into()), - &format!("{}", v).as_bytes(), + format!("{}", v).as_bytes(), )); } } + } else { + values.push(Value::none(key)); } } Values::new(values.into_boxed_slice()).into() } fn set(&mut self, set: &Set) -> Response { - if let Some(0) = set.ttl() { + if set.ttl() == Some(0) { + // immediate expire maps to a delete self.data.delete(set.key()); Response::stored(set.noreply()) - } else { - if let Ok(s) = std::str::from_utf8(set.value()) { - if let Ok(v) = s.parse::() { - if self - .data - .insert( - set.key(), - v, - Some(&set.flags().to_be_bytes()), - Duration::from_secs(set.ttl().unwrap_or(0).into()), - ) - .is_ok() - { - Response::stored(set.noreply()) - } else { - Response::server_error("") - } - } else { - if self - .data - .insert( - set.key(), - set.value(), - Some(&set.flags().to_be_bytes()), - Duration::from_secs(set.ttl().unwrap_or(0).into()), - ) - .is_ok() - { - Response::stored(set.noreply()) - } else { - Response::server_error("") - } - } - } else { + } else if let Ok(s) = std::str::from_utf8(set.value()) { + if let Ok(v) = s.parse::() { if self .data .insert( set.key(), - set.value(), + v, Some(&set.flags().to_be_bytes()), Duration::from_secs(set.ttl().unwrap_or(0).into()), ) @@ -135,7 +106,33 @@ impl Storage for Seg { } else { Response::server_error("") } + } else if self + .data + .insert( + set.key(), + set.value(), + Some(&set.flags().to_be_bytes()), + Duration::from_secs(set.ttl().unwrap_or(0).into()), + ) + .is_ok() + { + Response::stored(set.noreply()) + } else { + Response::server_error("") } + } else if self + .data + .insert( + set.key(), + set.value(), + Some(&set.flags().to_be_bytes()), + Duration::from_secs(set.ttl().unwrap_or(0).into()), + ) + .is_ok() + { + Response::stored(set.noreply()) + } else { + Response::server_error("") } } @@ -144,48 +141,17 @@ impl Storage for Seg { return Response::not_stored(add.noreply()); } - if let Some(0) = add.ttl() { + if add.ttl() == Some(0) { + // immediate expire maps to a delete self.data.delete(add.key()); Response::stored(add.noreply()) - } else { - if let Ok(s) = std::str::from_utf8(add.value()) { - if let Ok(v) = s.parse::() { - if self - .data - .insert( - add.key(), - v, - Some(&add.flags().to_be_bytes()), - Duration::from_secs(add.ttl().unwrap_or(0).into()), - ) - .is_ok() - { - Response::stored(add.noreply()) - } else { - Response::server_error("") - } - } else { - if self - .data - .insert( - add.key(), - add.value(), - Some(&add.flags().to_be_bytes()), - Duration::from_secs(add.ttl().unwrap_or(0).into()), - ) - .is_ok() - { - Response::stored(add.noreply()) - } else { - Response::server_error("") - } - } - } else { + } else if let Ok(s) = std::str::from_utf8(add.value()) { + if let Ok(v) = s.parse::() { if self .data .insert( add.key(), - add.value(), + v, Some(&add.flags().to_be_bytes()), Duration::from_secs(add.ttl().unwrap_or(0).into()), ) @@ -195,7 +161,33 @@ impl Storage for Seg { } else { Response::server_error("") } + } else if self + .data + .insert( + add.key(), + add.value(), + Some(&add.flags().to_be_bytes()), + Duration::from_secs(add.ttl().unwrap_or(0).into()), + ) + .is_ok() + { + Response::stored(add.noreply()) + } else { + Response::server_error("") } + } else if self + .data + .insert( + add.key(), + add.value(), + Some(&add.flags().to_be_bytes()), + Duration::from_secs(add.ttl().unwrap_or(0).into()), + ) + .is_ok() + { + Response::stored(add.noreply()) + } else { + Response::server_error("") } } @@ -204,48 +196,17 @@ impl Storage for Seg { return Response::not_stored(replace.noreply()); } - if let Some(0) = replace.ttl() { + if replace.ttl() == Some(0) { + // immediate expire maps to a delete self.data.delete(replace.key()); Response::stored(replace.noreply()) - } else { - if let Ok(s) = std::str::from_utf8(replace.value()) { - if let Ok(v) = s.parse::() { - if self - .data - .insert( - replace.key(), - v, - Some(&replace.flags().to_be_bytes()), - Duration::from_secs(replace.ttl().unwrap_or(0).into()), - ) - .is_ok() - { - Response::stored(replace.noreply()) - } else { - Response::server_error("") - } - } else { - if self - .data - .insert( - replace.key(), - replace.value(), - Some(&replace.flags().to_be_bytes()), - Duration::from_secs(replace.ttl().unwrap_or(0).into()), - ) - .is_ok() - { - Response::stored(replace.noreply()) - } else { - Response::server_error("") - } - } - } else { + } else if let Ok(s) = std::str::from_utf8(replace.value()) { + if let Ok(v) = s.parse::() { if self .data .insert( replace.key(), - replace.value(), + v, Some(&replace.flags().to_be_bytes()), Duration::from_secs(replace.ttl().unwrap_or(0).into()), ) @@ -255,7 +216,33 @@ impl Storage for Seg { } else { Response::server_error("") } + } else if self + .data + .insert( + replace.key(), + replace.value(), + Some(&replace.flags().to_be_bytes()), + Duration::from_secs(replace.ttl().unwrap_or(0).into()), + ) + .is_ok() + { + Response::stored(replace.noreply()) + } else { + Response::server_error("") } + } else if self + .data + .insert( + replace.key(), + replace.value(), + Some(&replace.flags().to_be_bytes()), + Duration::from_secs(replace.ttl().unwrap_or(0).into()), + ) + .is_ok() + { + Response::stored(replace.noreply()) + } else { + Response::server_error("") } } diff --git a/src/logger/src/lib.rs b/src/logger/src/lib.rs index 42fec2461..46317d123 100644 --- a/src/logger/src/lib.rs +++ b/src/logger/src/lib.rs @@ -52,6 +52,12 @@ macro_rules! klog { ) } +pub trait Klog { + type Response; + + fn klog(&self, response: &Self::Response); +} + pub fn configure_logging(config: &T) -> Box { let debug_config = config.debug(); diff --git a/src/net/Cargo.toml b/src/net/Cargo.toml new file mode 100644 index 000000000..8d85c415d --- /dev/null +++ b/src/net/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "net" +version = "0.1.0" +edition = "2021" +authors = ["Brian Martin "] +description = "Networking abstractions for non-blocking event loops" +homepage = "https://pelikan.io" +repository = "https://github.com/twitter/pelikan" +license = "Apache-2.0" + + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +boring = "2.0.0" +boring-sys = "2.0.0" +foreign-types-shared = "0.3.1" +libc = "0.2" +mio = { version = "0.8.0", features = ["os-poll", "net"] } +rustcommon-metrics = { git = "https://github.com/twitter/rustcommon" } diff --git a/src/net/src/connector.rs b/src/net/src/connector.rs new file mode 100644 index 000000000..c49f1ee81 --- /dev/null +++ b/src/net/src/connector.rs @@ -0,0 +1,40 @@ +// Copyright 2022 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + +use crate::*; + +pub struct Connector { + inner: ConnectorType, +} + +enum ConnectorType { + Tcp(TcpConnector), + TlsTcp(TlsTcpConnector), +} + +impl Connector { + /// Attemps to connect to the provided address. + pub fn connect(&self, addr: A) -> Result { + match &self.inner { + ConnectorType::Tcp(connector) => Ok(Stream::from(connector.connect(addr)?)), + ConnectorType::TlsTcp(connector) => Ok(Stream::from(connector.connect(addr)?)), + } + } +} + +impl From for Connector { + fn from(other: TcpConnector) -> Self { + Self { + inner: ConnectorType::Tcp(other), + } + } +} + +impl From for Connector { + fn from(other: TlsTcpConnector) -> Self { + Self { + inner: ConnectorType::TlsTcp(other), + } + } +} diff --git a/src/net/src/lib.rs b/src/net/src/lib.rs new file mode 100644 index 000000000..68c0f95a4 --- /dev/null +++ b/src/net/src/lib.rs @@ -0,0 +1,67 @@ +// Copyright 2022 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + +mod connector; +mod listener; +mod stream; +mod tcp; +mod tls_tcp; + +pub use connector::*; +pub use listener::*; +pub use stream::*; +pub use tcp::*; +pub use tls_tcp::*; + +pub mod event { + pub use mio::event::*; +} + +pub use mio::*; + +use core::fmt::Debug; +use core::ops::Deref; +use std::io::{Error, ErrorKind, Read, Write}; +use std::net::{SocketAddr, ToSocketAddrs}; +use std::path::{Path, PathBuf}; + +use foreign_types_shared::{ForeignType, ForeignTypeRef}; +use rustcommon_metrics::*; + +type Result = std::io::Result; + +// stats + +counter!( + TCP_ACCEPT, + "number of TCP streams passively opened with accept" +); +counter!( + TCP_CONNECT, + "number of TCP streams actively opened with connect" +); +counter!(TCP_CLOSE, "number of TCP streams closed"); +gauge!(TCP_CONN_CURR, "current number of open TCP streams"); +counter!(TCP_RECV_BYTE, "number of bytes received on TCP streams"); +counter!(TCP_SEND_BYTE, "number of bytes sent on TCP streams"); + +counter!(STREAM_ACCEPT, "number of calls to accept"); +counter!( + STREAM_ACCEPT_EX, + "number of times calling accept resulted in an exception" +); +counter!(STREAM_CLOSE, "number of streams closed"); +counter!( + STREAM_HANDSHAKE, + "number of times stream handshaking was attempted" +); +counter!( + STREAM_HANDSHAKE_EX, + "number of exceptions while handshaking" +); +counter!(STREAM_SHUTDOWN, "number of streams gracefully shutdown"); +counter!( + STREAM_SHUTDOWN_EX, + "number of exceptions while attempting to gracefully shutdown a stream" +); diff --git a/src/net/src/listener.rs b/src/net/src/listener.rs new file mode 100644 index 000000000..e5e9f1f86 --- /dev/null +++ b/src/net/src/listener.rs @@ -0,0 +1,105 @@ +// Copyright 2022 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + +use crate::*; + +pub struct Listener { + inner: ListenerType, +} + +enum ListenerType { + Plain(TcpListener), + Tls((TcpListener, TlsTcpAcceptor)), +} + +impl From for Listener { + fn from(other: TcpListener) -> Self { + Self { + inner: ListenerType::Plain(other), + } + } +} + +impl From<(TcpListener, TlsTcpAcceptor)> for Listener { + fn from(other: (TcpListener, TlsTcpAcceptor)) -> Self { + Self { + inner: ListenerType::Tls(other), + } + } +} + +impl Listener { + /// Accepts a new `Stream`. + /// + /// An error `e` with `e.kind()` of `ErrorKind::WouldBlock` indicates that + /// the operation should be retried again in the future. + /// + /// All other errors should be treated as failures. + pub fn accept(&self) -> Result { + STREAM_ACCEPT.increment(); + let result = self._accept(); + if result.is_err() { + STREAM_ACCEPT_EX.increment(); + } + result + } + + fn _accept(&self) -> Result { + match &self.inner { + ListenerType::Plain(listener) => { + let (stream, _addr) = listener.accept()?; + Ok(Stream::from(stream)) + } + ListenerType::Tls((listener, acceptor)) => { + let (stream, _addr) = listener.accept()?; + let stream = acceptor.accept(stream)?; + Ok(Stream::from(stream)) + } + } + } + + pub fn local_addr(&self) -> Result { + match &self.inner { + ListenerType::Plain(listener) => listener.local_addr(), + ListenerType::Tls((listener, _acceptor)) => listener.local_addr(), + } + } +} + +impl event::Source for Listener { + fn register( + &mut self, + registry: &mio::Registry, + token: mio::Token, + interests: mio::Interest, + ) -> Result<()> { + match &mut self.inner { + ListenerType::Plain(listener) => listener.register(registry, token, interests), + ListenerType::Tls((listener, _acceptor)) => { + listener.register(registry, token, interests) + } + } + } + + fn reregister( + &mut self, + registry: &mio::Registry, + token: mio::Token, + interests: mio::Interest, + ) -> Result<()> { + match &mut self.inner { + ListenerType::Plain(listener) => listener.reregister(registry, token, interests), + ListenerType::Tls((listener, _acceptor)) => { + listener.reregister(registry, token, interests) + } + } + } + + fn deregister(&mut self, registry: &mio::Registry) -> Result<()> { + match &mut self.inner { + ListenerType::Plain(listener) => listener.deregister(registry), + ListenerType::Tls((listener, _acceptor)) => listener.deregister(registry), + } + } +} diff --git a/src/net/src/stream.rs b/src/net/src/stream.rs new file mode 100644 index 000000000..f232bae79 --- /dev/null +++ b/src/net/src/stream.rs @@ -0,0 +1,175 @@ +// Copyright 2022 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + +pub use std::net::Shutdown; +use std::os::unix::prelude::AsRawFd; + +use crate::*; + +/// A wrapper type that unifies types which represent a stream. For example, +/// plaintext TCP streams and TLS/SSL over TCP can both be wrapped by this type. +/// This allows dynamic behaviors at runtime, such as enabling TLS/SSL through +/// configuration or allowing clients to request an upgrade to TLS/SSL from a +/// plaintext stream. +pub struct Stream { + inner: StreamType, +} + +impl AsRawFd for Stream { + fn as_raw_fd(&self) -> i32 { + match &self.inner { + StreamType::Tcp(s) => s.as_raw_fd(), + StreamType::TlsTcp(s) => s.as_raw_fd(), + } + } +} + +impl Stream { + pub fn interest(&mut self) -> Interest { + match &mut self.inner { + StreamType::Tcp(s) => { + if !s.is_established() { + Interest::READABLE.add(Interest::WRITABLE) + } else { + Interest::READABLE + } + } + StreamType::TlsTcp(s) => s.interest(), + } + } + + pub fn is_established(&mut self) -> bool { + match &mut self.inner { + StreamType::Tcp(s) => s.is_established(), + StreamType::TlsTcp(s) => !s.is_handshaking(), + } + } + + pub fn is_handshaking(&self) -> bool { + match &self.inner { + StreamType::Tcp(_) => false, + StreamType::TlsTcp(s) => s.is_handshaking(), + } + } + + pub fn do_handshake(&mut self) -> Result<()> { + match &mut self.inner { + StreamType::Tcp(_) => Ok(()), + StreamType::TlsTcp(s) => s.do_handshake(), + } + } + + pub fn set_nodelay(&mut self, nodelay: bool) -> Result<()> { + match &mut self.inner { + StreamType::Tcp(s) => s.set_nodelay(nodelay), + StreamType::TlsTcp(s) => s.set_nodelay(nodelay), + } + } + + pub fn shutdown(&mut self) -> Result { + let result = match &mut self.inner { + StreamType::Tcp(s) => s.shutdown(Shutdown::Both).map(|_| true), + StreamType::TlsTcp(s) => s.shutdown().map(|v| v == ShutdownResult::Received), + }; + + STREAM_SHUTDOWN.increment(); + if result.is_err() { + STREAM_SHUTDOWN_EX.increment(); + } + + result + } +} + +impl Drop for Stream { + fn drop(&mut self) { + STREAM_CLOSE.increment(); + } +} + +impl Debug for Stream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { + match &self.inner { + StreamType::Tcp(s) => write!(f, "{:?}", s), + StreamType::TlsTcp(s) => write!(f, "{:?}", s), + } + } +} + +impl From for Stream { + fn from(other: TcpStream) -> Self { + Self { + inner: StreamType::Tcp(other), + } + } +} + +impl From for Stream { + fn from(other: TlsTcpStream) -> Self { + Self { + inner: StreamType::TlsTcp(other), + } + } +} + +impl Read for Stream { + fn read(&mut self, buf: &mut [u8]) -> Result { + match &mut self.inner { + StreamType::Tcp(s) => s.read(buf), + StreamType::TlsTcp(s) => s.read(buf), + } + } +} + +impl Write for Stream { + fn write(&mut self, buf: &[u8]) -> Result { + match &mut self.inner { + StreamType::Tcp(s) => s.write(buf), + StreamType::TlsTcp(s) => s.write(buf), + } + } + + fn flush(&mut self) -> Result<()> { + match &mut self.inner { + StreamType::Tcp(s) => s.flush(), + StreamType::TlsTcp(s) => s.flush(), + } + } +} + +impl event::Source for Stream { + fn register(&mut self, registry: &Registry, token: Token, interest: Interest) -> Result<()> { + match &mut self.inner { + StreamType::Tcp(s) => s.register(registry, token, interest), + StreamType::TlsTcp(s) => s.register(registry, token, interest), + } + } + + fn reregister( + &mut self, + registry: &mio::Registry, + token: mio::Token, + interest: mio::Interest, + ) -> Result<()> { + match &mut self.inner { + StreamType::Tcp(s) => s.reregister(registry, token, interest), + StreamType::TlsTcp(s) => s.reregister(registry, token, interest), + } + } + + fn deregister(&mut self, registry: &mio::Registry) -> Result<()> { + match &mut self.inner { + StreamType::Tcp(s) => s.deregister(registry), + StreamType::TlsTcp(s) => s.deregister(registry), + } + } +} + +/// Provides concrete types for stream variants. Since the number of variants is +/// expected to be small, dispatch through enum variants should be more +/// efficient than using a trait for dynamic dispatch. +enum StreamType { + Tcp(TcpStream), + TlsTcp(TlsTcpStream), +} diff --git a/src/net/src/tcp.rs b/src/net/src/tcp.rs new file mode 100644 index 000000000..4bf5aa3e4 --- /dev/null +++ b/src/net/src/tcp.rs @@ -0,0 +1,320 @@ +// Copyright 2022 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + +use crate::*; +use std::os::unix::prelude::FromRawFd; + +pub use std::net::Shutdown; + +#[derive(PartialEq)] +enum State { + Connecting, + Established, +} + +pub struct TcpStream { + inner: mio::net::TcpStream, + state: State, +} + +impl TcpStream { + pub fn connect(addr: SocketAddr) -> Result { + let inner = mio::net::TcpStream::connect(addr)?; + + TCP_CONN_CURR.increment(); + TCP_CONNECT.increment(); + + Ok(Self { + inner, + state: State::Connecting, + }) + } + + pub fn is_established(&mut self) -> bool { + if self.state == State::Established { + true + } else if self.inner.peer_addr().is_ok() { + self.state = State::Established; + true + } else { + false + } + } + + pub fn from_std(stream: std::net::TcpStream) -> Self { + let inner = mio::net::TcpStream::from_std(stream); + let state = if inner.peer_addr().is_ok() { + State::Established + } else { + State::Connecting + }; + + Self { inner, state } + } + + pub fn set_nodelay(&mut self, nodelay: bool) -> Result<()> { + self.inner.set_nodelay(nodelay) + } +} + +impl Drop for TcpStream { + fn drop(&mut self) { + TCP_CONN_CURR.decrement(); + TCP_CLOSE.increment(); + } +} + +impl Debug for TcpStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { + write!(f, "{:?}", self.inner) + } +} + +impl Deref for TcpStream { + type Target = mio::net::TcpStream; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl Read for TcpStream { + fn read(&mut self, buf: &mut [u8]) -> Result { + match self.inner.read(buf) { + Ok(amt) => { + TCP_RECV_BYTE.add(amt as _); + Ok(amt) + } + Err(e) => Err(e), + } + } +} + +impl Write for TcpStream { + fn write(&mut self, buf: &[u8]) -> Result { + match self.inner.write(buf) { + Ok(amt) => { + TCP_SEND_BYTE.add(amt as _); + Ok(amt) + } + Err(e) => Err(e), + } + } + + fn flush(&mut self) -> Result<()> { + self.inner.flush() + } +} + +impl event::Source for TcpStream { + fn register( + &mut self, + registry: &mio::Registry, + token: mio::Token, + interest: mio::Interest, + ) -> Result<()> { + self.inner.register(registry, token, interest) + } + + fn reregister( + &mut self, + registry: &mio::Registry, + token: mio::Token, + interest: mio::Interest, + ) -> Result<()> { + self.inner.reregister(registry, token, interest) + } + + fn deregister(&mut self, registry: &mio::Registry) -> Result<()> { + self.inner.deregister(registry) + } +} + +impl FromRawFd for TcpStream { + unsafe fn from_raw_fd(raw_fd: i32) -> Self { + let inner = mio::net::TcpStream::from_raw_fd(raw_fd); + let state = if inner.peer_addr().is_ok() { + State::Established + } else { + State::Connecting + }; + + Self { inner, state } + } +} + +pub struct TcpListener { + inner: mio::net::TcpListener, +} + +impl Deref for TcpListener { + type Target = mio::net::TcpListener; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl TcpListener { + pub fn bind(addr: A) -> Result { + // we create from a std TcpListener so SO_REUSEADDR is not set for us + let l = std::net::TcpListener::bind(addr)?; + // this means we need to set non-blocking ourselves + l.set_nonblocking(true)?; + + let inner = mio::net::TcpListener::from_std(l); + + Ok(Self { inner }) + } + + pub fn accept(&self) -> Result<(TcpStream, SocketAddr)> { + let result = self.inner.accept().map(|(stream, addr)| { + ( + TcpStream { + inner: stream, + state: State::Established, + }, + addr, + ) + }); + + if result.is_ok() { + TCP_ACCEPT.increment(); + TCP_CONN_CURR.increment(); + } + + result + } + + pub fn local_addr(&self) -> Result { + self.inner.local_addr() + } +} + +impl event::Source for TcpListener { + fn register( + &mut self, + registry: &mio::Registry, + token: mio::Token, + interests: mio::Interest, + ) -> Result<()> { + self.inner.register(registry, token, interests) + } + + fn reregister( + &mut self, + registry: &mio::Registry, + token: mio::Token, + interests: mio::Interest, + ) -> Result<()> { + self.inner.reregister(registry, token, interests) + } + + fn deregister(&mut self, registry: &mio::Registry) -> Result<()> { + self.inner.deregister(registry) + } +} + +#[derive(Default)] +pub struct TcpConnector { + _inner: (), +} + +impl TcpConnector { + pub fn new() -> Self { + Self::default() + } + + pub fn connect(&self, addr: A) -> Result { + let addrs: Vec = addr.to_socket_addrs()?.collect(); + let mut s = Err(Error::new(ErrorKind::Other, "failed to resolve")); + for addr in addrs { + s = TcpStream::connect(addr); + if s.is_ok() { + break; + } + } + + s + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn create_connector() -> Connector { + let tls_connector = TcpConnector::new(); + + Connector::from(tls_connector) + } + + fn create_listener(addr: &'static str) -> Listener { + let tcp_listener = TcpListener::bind(addr).expect("failed to bind"); + + Listener::from(tcp_listener) + } + + #[test] + fn listener() { + let _ = create_listener("127.0.0.1:0"); + } + + #[test] + fn connector() { + let _ = create_connector(); + } + + #[test] + fn ping_pong() { + let connector = create_connector(); + let listener = create_listener("127.0.0.1:0"); + + let addr = listener.local_addr().expect("listener has no local addr"); + + let mut client_stream = connector.connect(addr).expect("failed to connect"); + std::thread::sleep(std::time::Duration::from_millis(100)); + let mut server_stream = listener.accept().expect("failed to accept"); + + std::thread::sleep(std::time::Duration::from_millis(100)); + + client_stream + .write_all(b"PING\r\n") + .expect("failed to write"); + client_stream.flush().expect("failed to flush"); + + std::thread::sleep(std::time::Duration::from_millis(100)); + + let mut buf = [0; 4096]; + + match server_stream.read(&mut buf) { + Ok(6) => { + assert_eq!(&buf[0..6], b"PING\r\n"); + server_stream + .write_all(b"PONG\r\n") + .expect("failed to write"); + } + Ok(n) => { + panic!("read: {} bytes but expected 6", n); + } + Err(e) => { + panic!("error reading: {}", e); + } + } + + std::thread::sleep(std::time::Duration::from_millis(100)); + + match client_stream.read(&mut buf) { + Ok(6) => { + assert_eq!(&buf[0..6], b"PONG\r\n"); + } + Ok(n) => { + panic!("read: {} bytes but expected 6", n); + } + Err(e) => { + panic!("error reading: {}", e); + } + } + } +} diff --git a/src/net/src/tls_tcp.rs b/src/net/src/tls_tcp.rs new file mode 100644 index 000000000..df5c76ea9 --- /dev/null +++ b/src/net/src/tls_tcp.rs @@ -0,0 +1,685 @@ +// Copyright 2022 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + +pub use boring::ssl::{ShutdownResult, SslVerifyMode}; +use std::os::unix::prelude::AsRawFd; + +use boring::ssl::{ErrorCode, Ssl, SslFiletype, SslMethod, SslStream}; +use boring::x509::X509; + +use crate::*; + +#[derive(PartialEq)] +enum TlsState { + Handshaking, + Negotiated, +} + +/// Wraps a TLS/SSL stream so that negotiated and handshaking sessions have a +/// uniform type. +pub struct TlsTcpStream { + inner: SslStream, + state: TlsState, +} + +impl AsRawFd for TlsTcpStream { + fn as_raw_fd(&self) -> i32 { + self.inner.get_ref().as_raw_fd() + } +} + +impl TlsTcpStream { + pub fn set_nodelay(&mut self, nodelay: bool) -> Result<()> { + self.inner.get_mut().set_nodelay(nodelay) + } + + pub fn is_handshaking(&self) -> bool { + self.state == TlsState::Handshaking + } + + pub fn interest(&self) -> Interest { + if self.is_handshaking() { + Interest::READABLE.add(Interest::WRITABLE) + } else { + Interest::READABLE + } + } + + /// Attempts to drive the TLS/SSL handshake to completion. If the return + /// variant is `Ok` it indiates that the handshake is complete. An error + /// result of `WouldBlock` indicates that the handshake may complete in the + /// future. Other error types indiate a handshake failure with no possible + /// recovery and that the connection should be closed. + pub fn do_handshake(&mut self) -> Result<()> { + if self.is_handshaking() { + let ptr = self.inner.ssl().as_ptr(); + let ret = unsafe { boring_sys::SSL_do_handshake(ptr) }; + if ret > 0 { + STREAM_HANDSHAKE.increment(); + self.state = TlsState::Negotiated; + Ok(()) + } else { + let code = unsafe { ErrorCode::from_raw(boring_sys::SSL_get_error(ptr, ret)) }; + match code { + ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => { + Err(Error::from(ErrorKind::WouldBlock)) + } + _ => { + STREAM_HANDSHAKE.increment(); + STREAM_HANDSHAKE_EX.increment(); + Err(Error::new(ErrorKind::Other, "handshake failed")) + } + } + } + } else { + Ok(()) + } + } + + pub fn shutdown(&mut self) -> Result { + self.inner + .shutdown() + .map_err(|e| Error::new(ErrorKind::Other, e.to_string())) + } +} + +impl Debug for TlsTcpStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { + write!(f, "{:?}", self.inner.get_ref()) + } +} + +impl Read for TlsTcpStream { + fn read(&mut self, buf: &mut [u8]) -> Result { + if self.is_handshaking() { + Err(Error::new( + ErrorKind::WouldBlock, + "read on handshaking session would block", + )) + } else { + self.inner.read(buf) + } + } +} + +impl Write for TlsTcpStream { + fn write(&mut self, buf: &[u8]) -> Result { + if self.is_handshaking() { + Err(Error::new( + ErrorKind::WouldBlock, + "write on handshaking session would block", + )) + } else { + self.inner.write(buf) + } + } + + fn flush(&mut self) -> Result<()> { + if self.is_handshaking() { + Err(Error::new( + ErrorKind::WouldBlock, + "flush on handshaking session would block", + )) + } else { + self.inner.flush() + } + } +} + +impl event::Source for TlsTcpStream { + fn register(&mut self, registry: &Registry, token: Token, interest: Interest) -> Result<()> { + self.inner.get_mut().register(registry, token, interest) + } + + fn reregister( + &mut self, + registry: &mio::Registry, + token: mio::Token, + interest: mio::Interest, + ) -> Result<()> { + self.inner.get_mut().reregister(registry, token, interest) + } + + fn deregister(&mut self, registry: &mio::Registry) -> Result<()> { + self.inner.get_mut().deregister(registry) + } +} + +/// Provides a wrapped acceptor for server-side TLS. This returns our wrapped +/// `TlsStream` type so that clients can store negotiated and handshaking +/// streams in a structure with a uniform type. +pub struct TlsTcpAcceptor { + inner: boring::ssl::SslContext, +} + +impl TlsTcpAcceptor { + pub fn mozilla_intermediate_v5() -> Result { + let inner = boring::ssl::SslAcceptor::mozilla_intermediate_v5(SslMethod::tls_server()) + .map_err(|e| Error::new(ErrorKind::Other, e.to_string()))?; + + Ok(TlsTcpAcceptorBuilder { + inner, + ca_file: None, + certificate_file: None, + certificate_chain_file: None, + private_key_file: None, + }) + } + + pub fn accept(&self, stream: TcpStream) -> Result { + let ssl = Ssl::new(&self.inner)?; + + let stream = unsafe { SslStream::from_raw_parts(ssl.into_ptr(), stream) }; + + let ret = unsafe { boring_sys::SSL_accept(stream.ssl().as_ptr()) }; + + if ret > 0 { + Ok(TlsTcpStream { + inner: stream, + state: TlsState::Negotiated, + }) + } else { + let code = unsafe { + ErrorCode::from_raw(boring_sys::SSL_get_error(stream.ssl().as_ptr(), ret)) + }; + match code { + ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Ok(TlsTcpStream { + inner: stream, + state: TlsState::Handshaking, + }), + _ => Err(Error::new(ErrorKind::Other, "handshake failed")), + } + } + } +} + +/// Provides a wrapped builder for producing a `TlsAcceptor`. This has some +/// minor differences from the `boring::ssl::SslAcceptorBuilder` to provide +/// improved ergonomics. +pub struct TlsTcpAcceptorBuilder { + inner: boring::ssl::SslAcceptorBuilder, + ca_file: Option, + certificate_file: Option, + certificate_chain_file: Option, + private_key_file: Option, +} + +impl TlsTcpAcceptorBuilder { + pub fn build(mut self) -> Result { + // load the CA file, if provided + if let Some(f) = self.ca_file { + self.inner.set_ca_file(f).map_err(|e| { + Error::new(ErrorKind::Other, format!("failed to load CA file: {}", e)) + })?; + } + + // load the private key from file + if let Some(f) = self.private_key_file { + self.inner + .set_private_key_file(f, SslFiletype::PEM) + .map_err(|e| { + Error::new( + ErrorKind::Other, + format!("failed to load private key file: {}", e), + ) + })?; + } else { + return Err(Error::new(ErrorKind::Other, "no private key file provided")); + } + + // load the certificate chain, certificate file, or both + match (self.certificate_chain_file, self.certificate_file) { + (Some(chain), Some(cert)) => { + // assume we have the leaf in a standalone file, and the + // intermediates + root in another file + + // first load the leaf + self.inner + .set_certificate_file(cert, SslFiletype::PEM) + .map_err(|e| { + Error::new( + ErrorKind::Other, + format!("failed to load certificate file: {}", e), + ) + })?; + + // append the rest of the chain + let pem = std::fs::read(chain).map_err(|e| { + Error::new( + ErrorKind::Other, + format!("failed to load certificate chain file: {}", e), + ) + })?; + let chain = X509::stack_from_pem(&pem).map_err(|e| { + Error::new( + ErrorKind::Other, + format!("failed to load certificate chain file: {}", e), + ) + })?; + for cert in chain { + self.inner.add_extra_chain_cert(cert).map_err(|e| { + Error::new( + ErrorKind::Other, + format!("bad certificate in certificate chain file: {}", e), + ) + })?; + } + } + (Some(chain), None) => { + // assume we have a complete chain: leaf + intermediates + root in + // one file + + // load the entire chain + self.inner.set_certificate_chain_file(chain).map_err(|e| { + Error::new( + ErrorKind::Other, + format!("failed to load certificate chain file: {}", e), + ) + })?; + } + (None, Some(cert)) => { + // this will just load the leaf certificate from the file + self.inner + .set_certificate_file(cert, SslFiletype::PEM) + .map_err(|e| { + Error::new( + ErrorKind::Other, + format!("failed to load certificate file: {}", e), + ) + })?; + } + (None, None) => { + return Err(Error::new( + ErrorKind::Other, + "no certificate file or certificate chain file provided", + )); + } + } + + let inner = self.inner.build().into_context(); + + Ok(TlsTcpAcceptor { inner }) + } + + pub fn verify(mut self, mode: SslVerifyMode) -> Self { + self.inner.set_verify(mode); + self + } + + /// Load trusted root certificates from a file. + /// + /// The file should contain a sequence of PEM-formatted CA certificates. + pub fn ca_file>(mut self, file: P) -> Self { + self.ca_file = Some(file.as_ref().to_path_buf()); + self + } + + /// Load a leaf certificate from a file. + /// + /// This loads only a single PEM-formatted certificate from the file which + /// will be used as the leaf certifcate. + /// + /// Use `set_certificate_chain_file` to provide a complete certificate + /// chain. Use this with the `set_certifcate_chain_file` if the leaf + /// certifcate and remainder of the certificate chain are split across two + /// files. + pub fn certificate_file>(mut self, file: P) -> Self { + self.certificate_file = Some(file.as_ref().to_path_buf()); + self + } + + /// Load a certificate chain from a file. + /// + /// The file should contain a sequence of PEM-formatted certificates. If + /// used without `set_certificate_file` the provided file must contain the + /// leaf certificate and the complete chain of certificates up to and + /// including the trusted root certificate. If used with + /// `set_certificate_file`, this file must not contain the leaf certifcate + /// and will be treated as the complete chain of certificates up to and + /// including the trusted root certificate. + pub fn certificate_chain_file>(mut self, file: P) -> Self { + self.certificate_chain_file = Some(file.as_ref().to_path_buf()); + self + } + + /// Loads the private key from a PEM-formatted file. + pub fn private_key_file>(mut self, file: P) -> Self { + self.private_key_file = Some(file.as_ref().to_path_buf()); + self + } +} + +/// Provides a wrapped connector for client-side TLS. This returns our wrapped +/// `TlsStream` type so that clients can store negotiated and handshaking +/// streams in a structure with a uniform type. +#[allow(dead_code)] +pub struct TlsTcpConnector { + inner: boring::ssl::SslContext, +} + +impl TlsTcpConnector { + pub fn builder() -> Result { + let inner = boring::ssl::SslConnector::builder(SslMethod::tls_client()) + .map_err(|e| Error::new(ErrorKind::Other, e.to_string()))?; + + Ok(TlsTcpConnectorBuilder { + inner, + ca_file: None, + certificate_file: None, + certificate_chain_file: None, + private_key_file: None, + }) + } + + pub fn connect(&self, addr: A) -> Result { + let addrs: Vec = addr.to_socket_addrs()?.collect(); + let mut s = Err(Error::new(ErrorKind::Other, "failed to resolve")); + for addr in addrs { + s = TcpStream::connect(addr); + if s.is_ok() { + break; + } + } + + let ssl = Ssl::new(&self.inner)?; + + let stream = unsafe { SslStream::from_raw_parts(ssl.into_ptr(), s?) }; + + let ret = unsafe { boring_sys::SSL_connect(stream.ssl().as_ptr()) }; + + if ret > 0 { + Ok(TlsTcpStream { + inner: stream, + state: TlsState::Negotiated, + }) + } else { + let code = unsafe { + ErrorCode::from_raw(boring_sys::SSL_get_error(stream.ssl().as_ptr(), ret)) + }; + match code { + ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Ok(TlsTcpStream { + inner: stream, + state: TlsState::Handshaking, + }), + _ => Err(Error::new(ErrorKind::Other, "handshake failed")), + } + } + } +} + +/// Provides a wrapped builder for producing a `TlsConnector`. This has some +/// minor differences from the `boring::ssl::SslConnectorBuilder` to provide +/// improved ergonomics. +pub struct TlsTcpConnectorBuilder { + inner: boring::ssl::SslConnectorBuilder, + ca_file: Option, + certificate_file: Option, + certificate_chain_file: Option, + private_key_file: Option, +} + +impl TlsTcpConnectorBuilder { + pub fn build(mut self) -> Result { + // load the CA file, if provided + if let Some(f) = self.ca_file { + self.inner.set_ca_file(f).map_err(|e| { + Error::new(ErrorKind::Other, format!("failed to load CA file: {}", e)) + })?; + } + + // load the private key from file + if let Some(f) = self.private_key_file { + self.inner + .set_private_key_file(f, SslFiletype::PEM) + .map_err(|e| { + Error::new( + ErrorKind::Other, + format!("failed to load private key file: {}", e), + ) + })?; + } else { + return Err(Error::new(ErrorKind::Other, "no private key file provided")); + } + + // load the certificate chain, certificate file, or both + match (self.certificate_chain_file, self.certificate_file) { + (Some(chain), Some(cert)) => { + // assume we have the leaf in a standalone file, and the + // intermediates + root in another file + + // first load the leaf + self.inner + .set_certificate_file(cert, SslFiletype::PEM) + .map_err(|e| { + Error::new( + ErrorKind::Other, + format!("failed to load certificate file: {}", e), + ) + })?; + + // append the rest of the chain + let pem = std::fs::read(chain).map_err(|e| { + Error::new( + ErrorKind::Other, + format!("failed to load certificate chain file: {}", e), + ) + })?; + let chain = X509::stack_from_pem(&pem).map_err(|e| { + Error::new( + ErrorKind::Other, + format!("failed to load certificate chain file: {}", e), + ) + })?; + for cert in chain { + self.inner.add_extra_chain_cert(cert).map_err(|e| { + Error::new( + ErrorKind::Other, + format!("bad certificate in certificate chain file: {}", e), + ) + })?; + } + } + (Some(chain), None) => { + // assume we have a complete chain: leaf + intermediates + root in + // one file + + // load the entire chain + self.inner.set_certificate_chain_file(chain).map_err(|e| { + Error::new( + ErrorKind::Other, + format!("failed to load certificate chain file: {}", e), + ) + })?; + } + (None, Some(cert)) => { + // this will just load the leaf certificate from the file + self.inner + .set_certificate_file(cert, SslFiletype::PEM) + .map_err(|e| { + Error::new( + ErrorKind::Other, + format!("failed to load certificate file: {}", e), + ) + })?; + } + (None, None) => { + return Err(Error::new( + ErrorKind::Other, + "no certificate file or certificate chain file provided", + )); + } + } + + let inner = self.inner.build().into_context(); + + Ok(TlsTcpConnector { inner }) + } + + pub fn verify(mut self, mode: SslVerifyMode) -> Self { + self.inner.set_verify(mode); + self + } + + /// Load trusted root certificates from a file. + /// + /// The file should contain a sequence of PEM-formatted CA certificates. + pub fn ca_file>(mut self, file: P) -> Self { + self.ca_file = Some(file.as_ref().to_path_buf()); + self + } + + /// Load a leaf certificate from a file. + /// + /// This loads only a single PEM-formatted certificate from the file which + /// will be used as the leaf certifcate. + /// + /// Use `set_certificate_chain_file` to provide a complete certificate + /// chain. Use this with the `set_certifcate_chain_file` if the leaf + /// certifcate and remainder of the certificate chain are split across two + /// files. + pub fn certificate_file>(mut self, file: P) -> Self { + self.certificate_file = Some(file.as_ref().to_path_buf()); + self + } + + /// Load a certificate chain from a file. + /// + /// The file should contain a sequence of PEM-formatted certificates. If + /// used without `set_certificate_file` the provided file must contain the + /// leaf certificate and the complete chain of certificates up to and + /// including the trusted root certificate. If used with + /// `set_certificate_file`, this file must not contain the leaf certifcate + /// and will be treated as the complete chain of certificates up to and + /// including the trusted root certificate. + pub fn certificate_chain_file>(mut self, file: P) -> Self { + self.certificate_chain_file = Some(file.as_ref().to_path_buf()); + self + } + + /// Loads the private key from a PEM-formatted file. + pub fn private_key_file>(mut self, file: P) -> Self { + self.private_key_file = Some(file.as_ref().to_path_buf()); + self + } +} + +// NOTE: these tests only work if there's a `test` folder within this crate that +// contains the necessary keys and certs. They are left here for reference and +// in the future we should automate creation of self-signed keys and certs for +// use for testing during local development and in CI. + +// #[cfg(test)] +// mod tests { +// use super::*; + +// fn gen_keys() -> Result<(), ()> { + +// } + +// fn create_connector() -> Connector { +// let tls_connector = TlsTcpConnector::builder() +// .expect("failed to create builder") +// .ca_file("test/root.crt") +// .certificate_chain_file("test/client.crt") +// .private_key_file("test/client.key") +// .build() +// .expect("failed to initialize tls connector"); + +// Connector::from(tls_connector) +// } + +// fn create_listener(addr: &'static str) -> Listener { +// let tcp_listener = TcpListener::bind(addr).expect("failed to bind"); +// let tls_acceptor = TlsTcpAcceptor::mozilla_intermediate_v5() +// .expect("failed to create builder") +// .ca_file("test/root.crt") +// .certificate_chain_file("test/server.crt") +// .private_key_file("test/server.key") +// .build() +// .expect("failed to initialize tls acceptor"); + +// Listener::from((tcp_listener, tls_acceptor)) +// } + +// #[test] +// fn listener() { +// let _ = create_listener("127.0.0.1:0"); +// } + +// #[test] +// fn connector() { +// let _ = create_connector(); +// } + +// #[test] +// fn ping_pong() { +// let connector = create_connector(); +// let listener = create_listener("127.0.0.1:0"); + +// let addr = listener.local_addr().expect("listener has no local addr"); + +// let mut client_stream = connector.connect(addr).expect("failed to connect"); +// std::thread::sleep(std::time::Duration::from_millis(100)); +// let mut server_stream = listener.accept().expect("failed to accept"); + +// let mut server_handshake_complete = false; +// let mut client_handshake_complete = false; + +// while !(server_handshake_complete && client_handshake_complete) { +// if !server_handshake_complete { +// std::thread::sleep(std::time::Duration::from_millis(100)); +// if server_stream.do_handshake().is_ok() { +// server_handshake_complete = true; +// } +// } + +// if !client_handshake_complete { +// std::thread::sleep(std::time::Duration::from_millis(100)); +// if client_stream.do_handshake().is_ok() { +// client_handshake_complete = true; +// } +// } +// } + +// std::thread::sleep(std::time::Duration::from_millis(100)); + +// client_stream +// .write_all(b"PING\r\n") +// .expect("failed to write"); +// client_stream.flush().expect("failed to flush"); + +// std::thread::sleep(std::time::Duration::from_millis(100)); + +// let mut buf = [0; 4096]; + +// match server_stream.read(&mut buf) { +// Ok(6) => { +// assert_eq!(&buf[0..6], b"PING\r\n"); +// server_stream +// .write_all(b"PONG\r\n") +// .expect("failed to write"); +// } +// Ok(n) => { +// panic!("read: {} bytes but expected 6", n); +// } +// Err(e) => { +// panic!("error reading: {}", e); +// } +// } + +// std::thread::sleep(std::time::Duration::from_millis(100)); + +// match client_stream.read(&mut buf) { +// Ok(6) => { +// assert_eq!(&buf[0..6], b"PONG\r\n"); +// } +// Ok(n) => { +// panic!("read: {} bytes but expected 6", n); +// } +// Err(e) => { +// panic!("error reading: {}", e); +// } +// } +// } +// } diff --git a/src/protocol/admin/Cargo.toml b/src/protocol/admin/Cargo.toml index 90b3df343..64e6b2dfa 100644 --- a/src/protocol/admin/Cargo.toml +++ b/src/protocol/admin/Cargo.toml @@ -13,7 +13,7 @@ common = { path = "../../common" } config = { path = "../../config" } logger = { path = "../../logger" } protocol-common = { path = "../../protocol/common" } -session = { path = "../../session" } +rustcommon-metrics = { git = "https://github.com/twitter/rustcommon", features = ["heatmap"] } storage-types = { path = "../../storage/types" } [dev-dependencies] diff --git a/src/protocol/admin/src/admin.rs b/src/protocol/admin/src/admin.rs index 63ccce82c..a311f9390 100644 --- a/src/protocol/admin/src/admin.rs +++ b/src/protocol/admin/src/admin.rs @@ -9,6 +9,9 @@ use crate::*; use common::bytes::SliceExtension; +use rustcommon_metrics::*; + +use std::io::{Error, ErrorKind, Result}; // TODO(bmartin): see TODO for protocol::data::Request, this is cleaner here // since the variants are simple, but better to take the same approach in both @@ -31,7 +34,7 @@ impl AdminRequestParser { } impl Parse for AdminRequestParser { - fn parse(&self, buffer: &[u8]) -> Result, ParseError> { + fn parse(&self, buffer: &[u8]) -> Result> { // check if we got a CRLF if let Some(command_end) = buffer .windows(CRLF.len()) @@ -47,7 +50,7 @@ impl Parse for AdminRequestParser { // remove the need for ignoring this lint. #[allow(clippy::match_single_binding)] match command_verb { - _ => Err(ParseError::Unknown), + _ => Err(Error::from(ErrorKind::InvalidInput)), } } else { match &trimmed_buffer[0..] { @@ -61,11 +64,99 @@ impl Parse for AdminRequestParser { AdminRequest::Version, command_end + CRLF.len(), )), - _ => Err(ParseError::Unknown), + _ => Err(Error::from(ErrorKind::InvalidInput)), } } } else { - Err(ParseError::Incomplete) + Err(Error::from(ErrorKind::WouldBlock)) + } + } +} + +pub struct Version { + version: String, +} + +impl Compose for Version { + fn compose(&self, buf: &mut dyn BufMut) -> usize { + buf.put_slice(b"VERSION "); + buf.put_slice(self.version.as_bytes()); + buf.put_slice(b"\r\n"); + + 10 + self.version.as_bytes().len() + } +} + +pub enum AdminResponse { + Hangup, + Ok, + Stats, + Version(Version), +} + +impl AdminResponse { + pub fn hangup() -> Self { + Self::Hangup + } + + pub fn ok() -> Self { + Self::Ok + } + + pub fn stats() -> Self { + Self::Stats + } + + pub fn version(version: String) -> Self { + Self::Version(Version { version }) + } +} + +impl Compose for AdminResponse { + fn compose(&self, buf: &mut dyn BufMut) -> usize { + match self { + Self::Hangup => 0, + Self::Ok => { + buf.put_slice(b"OK\r\n"); + 4 + } + Self::Stats => { + let mut size = 0; + let mut data = Vec::new(); + for metric in &rustcommon_metrics::metrics() { + let any = match metric.as_any() { + Some(any) => any, + None => { + continue; + } + }; + + if let Some(counter) = any.downcast_ref::() { + data.push(format!("STAT {} {}\r\n", metric.name(), counter.value())); + } else if let Some(gauge) = any.downcast_ref::() { + data.push(format!("STAT {} {}\r\n", metric.name(), gauge.value())); + } else if let Some(heatmap) = any.downcast_ref::() { + for (label, value) in PERCENTILES { + let percentile = heatmap.percentile(*value).unwrap_or(0); + data.push(format!( + "STAT {}_{} {}\r\n", + metric.name(), + label, + percentile + )); + } + } + } + + data.sort(); + for line in data { + size += line.as_bytes().len(); + buf.put_slice(line.as_bytes()); + } + buf.put_slice(b"END\r\n"); + size + 5 + } + Self::Version(v) => v.compose(buf), } } } @@ -80,7 +171,11 @@ mod tests { let buffers: Vec<&[u8]> = vec![b"", b"stats", b"stats\r"]; for buffer in buffers.iter() { - assert_eq!(parser.parse(buffer), Err(ParseError::Incomplete)); + if let Err(e) = parser.parse(buffer) { + assert_eq!(e.kind(), ErrorKind::WouldBlock); + } else { + panic!("parser should not have returned a request"); + } } } diff --git a/src/protocol/admin/src/lib.rs b/src/protocol/admin/src/lib.rs index 650469b81..d8ba3f145 100644 --- a/src/protocol/admin/src/lib.rs +++ b/src/protocol/admin/src/lib.rs @@ -7,3 +7,15 @@ pub use protocol_common::*; mod admin; pub use admin::*; + +pub static PERCENTILES: &[(&str, f64)] = &[ + ("p25", 25.0), + ("p50", 50.0), + ("p75", 75.0), + ("p90", 90.0), + ("p99", 99.0), + ("p999", 99.9), + ("p9999", 99.99), +]; + +common::metrics::test_no_duplicates!(); diff --git a/src/protocol/common/Cargo.toml b/src/protocol/common/Cargo.toml index f18782092..5e28b3603 100644 --- a/src/protocol/common/Cargo.toml +++ b/src/protocol/common/Cargo.toml @@ -9,10 +9,10 @@ repository = "https://github.com/twitter/pelikan" license = "Apache-2.0" [dependencies] +bytes = "1.1.0" common = { path = "../../common" } config = { path = "../../config" } logger = { path = "../../logger" } -session = { path = "../../session" } storage-types = { path = "../../storage/types" } [dev-dependencies] diff --git a/src/protocol/common/src/lib.rs b/src/protocol/common/src/lib.rs index cef7305a5..cd1614e3d 100644 --- a/src/protocol/common/src/lib.rs +++ b/src/protocol/common/src/lib.rs @@ -6,12 +6,12 @@ //! traits so that the a server implementation can easily switch between //! protocol implementations. -use session::Session; +pub use bytes::BufMut; pub const CRLF: &str = "\r\n"; pub trait Compose { - fn compose(&self, dst: &mut Session); + fn compose(&self, dst: &mut dyn BufMut) -> usize; /// Indicates that the connection should be closed. /// Override this function as appropriate for the @@ -22,20 +22,7 @@ pub trait Compose { } pub trait Execute { - fn execute(&mut self, request: Request) -> Box>; -} - -pub trait ExecutionResult: Send + Compose { - fn request(&self) -> &Request; - - fn response(&self) -> &Response; -} - -#[derive(Debug, PartialEq)] -pub enum ParseError { - Invalid, - Incomplete, - Unknown, + fn execute(&mut self, request: &Request) -> Response; } #[derive(Debug, PartialEq)] @@ -59,5 +46,5 @@ impl ParseOk { } pub trait Parse { - fn parse(&self, buffer: &[u8]) -> Result, ParseError>; + fn parse(&self, buffer: &[u8]) -> Result, std::io::Error>; } diff --git a/src/protocol/memcache/Cargo.toml b/src/protocol/memcache/Cargo.toml index 2f8eac9f7..b57edc35c 100644 --- a/src/protocol/memcache/Cargo.toml +++ b/src/protocol/memcache/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "protocol-memcache" -version = "0.1.0" +version = "0.2.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -15,9 +15,7 @@ common = { path = "../../common" } logger = { path = "../../logger" } nom = "5.1.2" protocol-common = { path = "../../protocol/common" } -rustcommon-metrics = { git = "https://github.com/twitter/rustcommon" } -session = { path = "../../session" } - +rustcommon-metrics = { git = "https://github.com/twitter/rustcommon", features = ["heatmap"] } [dev-dependencies] -criterion = "0.3.4" \ No newline at end of file +criterion = "0.3.4" diff --git a/src/protocol/memcache/src/lib.rs b/src/protocol/memcache/src/lib.rs index 8ea1ce772..2b4cbb612 100644 --- a/src/protocol/memcache/src/lib.rs +++ b/src/protocol/memcache/src/lib.rs @@ -2,9 +2,11 @@ // Licensed under the Apache License, Version 2.0 // http://www.apache.org/licenses/LICENSE-2.0 +#[macro_use] +extern crate logger; + mod request; mod response; -mod result; mod storage; mod util; @@ -12,10 +14,16 @@ pub(crate) use util::*; pub use request::*; pub use response::*; -pub use result::*; pub use storage::*; -use common::expiry::TimeType; +pub use protocol_common::*; + +use logger::Klog; +use rustcommon_metrics::*; + +// use common::expiry::TimeType; + +const CRLF: &[u8] = b"\r\n"; pub enum MemcacheError { Error(Error), @@ -23,8 +31,6 @@ pub enum MemcacheError { ServerError(ServerError), } -use rustcommon_metrics::*; - type Instant = common::time::Instant>; counter!(GET); @@ -94,3 +100,5 @@ counter!(FLUSH_ALL); counter!(FLUSH_ALL_EX); counter!(QUIT); + +common::metrics::test_no_duplicates!(); diff --git a/src/protocol/memcache/src/request/add.rs b/src/protocol/memcache/src/request/add.rs index f4e6a373b..eddb2c2e5 100644 --- a/src/protocol/memcache/src/request/add.rs +++ b/src/protocol/memcache/src/request/add.rs @@ -65,24 +65,70 @@ impl RequestParser { } impl Compose for Add { - fn compose(&self, session: &mut session::Session) { - let _ = session.write_all(b"add "); - let _ = session.write_all(&self.key); - let _ = session.write_all(format!(" {}", self.flags).as_bytes()); - match self.ttl { - None => { - let _ = session.write_all(b" 0"); + fn compose(&self, session: &mut dyn BufMut) -> usize { + let verb = b"add "; + let flags = format!(" {}", self.flags).into_bytes(); + let ttl = convert_ttl(self.ttl); + let vlen = format!(" {}", self.value.len()).into_bytes(); + let header_end = if self.noreply { + " noreply\r\n".as_bytes() + } else { + "\r\n".as_bytes() + }; + + let size = verb.len() + + self.key.len() + + flags.len() + + ttl.len() + + vlen.len() + + header_end.len() + + self.value.len() + + CRLF.len(); + + session.put_slice(verb); + session.put_slice(&self.key); + session.put_slice(&flags); + session.put_slice(&ttl); + session.put_slice(&vlen); + session.put_slice(header_end); + session.put_slice(&self.value); + session.put_slice(CRLF); + + size + } +} + +impl Klog for Add { + type Response = Response; + + fn klog(&self, response: &Self::Response) { + let ttl: i64 = match self.ttl() { + None => 0, + Some(0) => -1, + Some(t) => t as _, + }; + let (code, len) = match response { + Response::Stored(ref res) => { + ADD_STORED.increment(); + (STORED, res.len()) } - Some(0) => { - let _ = session.write_all(b" -1"); + Response::NotStored(ref res) => { + ADD_NOT_STORED.increment(); + (NOT_STORED, res.len()) } - Some(s) => { - let _ = session.write_all(format!(" {}", s).as_bytes()); + _ => { + return; } - } - let _ = session.write_all(format!(" {}\r\n", self.value.len()).as_bytes()); - let _ = session.write_all(&self.value); - let _ = session.write_all(b"\r\n"); + }; + klog!( + "\"add {} {} {} {}\" {} {}", + string_key(self.key()), + self.flags(), + ttl, + self.value().len(), + code, + len + ); } } diff --git a/src/protocol/memcache/src/request/append.rs b/src/protocol/memcache/src/request/append.rs index 123f263c1..483eb4da3 100644 --- a/src/protocol/memcache/src/request/append.rs +++ b/src/protocol/memcache/src/request/append.rs @@ -65,24 +65,70 @@ impl RequestParser { } impl Compose for Append { - fn compose(&self, session: &mut session::Session) { - let _ = session.write_all(b"append "); - let _ = session.write_all(&self.key); - let _ = session.write_all(format!(" {}", self.flags).as_bytes()); - match self.ttl { - None => { - let _ = session.write_all(b" 0"); + fn compose(&self, session: &mut dyn BufMut) -> usize { + let verb = b"append "; + let flags = format!(" {}", self.flags).into_bytes(); + let ttl = convert_ttl(self.ttl); + let vlen = format!(" {}", self.value.len()); + let header_end = if self.noreply { + " noreply\r\n".as_bytes() + } else { + "\r\n".as_bytes() + }; + + let size = verb.len() + + self.key.len() + + flags.len() + + ttl.len() + + vlen.len() + + header_end.len() + + self.value.len() + + CRLF.len(); + + session.put_slice(verb); + session.put_slice(&self.key); + session.put_slice(&flags); + session.put_slice(&ttl); + session.put_slice(vlen.as_bytes()); + session.put_slice(header_end); + session.put_slice(&self.value); + session.put_slice(CRLF); + + size + } +} + +impl Klog for Append { + type Response = Response; + + fn klog(&self, response: &Self::Response) { + let ttl: i64 = match self.ttl() { + None => 0, + Some(0) => -1, + Some(t) => t as _, + }; + let (code, len) = match response { + Response::Stored(ref res) => { + APPEND_STORED.increment(); + (STORED, res.len()) } - Some(0) => { - let _ = session.write_all(b" -1"); + Response::NotStored(ref res) => { + APPEND_NOT_STORED.increment(); + (NOT_STORED, res.len()) } - Some(s) => { - let _ = session.write_all(format!(" {}", s).as_bytes()); + _ => { + return; } - } - let _ = session.write_all(format!(" {}\r\n", self.value.len()).as_bytes()); - let _ = session.write_all(&self.value); - let _ = session.write_all(b"\r\n"); + }; + klog!( + "\"append {} {} {} {}\" {} {}", + string_key(self.key()), + self.flags(), + ttl, + self.value().len(), + code, + len + ); } } diff --git a/src/protocol/memcache/src/request/cas.rs b/src/protocol/memcache/src/request/cas.rs index 048ed3d3f..d969952d4 100644 --- a/src/protocol/memcache/src/request/cas.rs +++ b/src/protocol/memcache/src/request/cas.rs @@ -133,24 +133,78 @@ impl RequestParser { } impl Compose for Cas { - fn compose(&self, session: &mut session::Session) { - let _ = session.write_all(b"cas "); - let _ = session.write_all(&self.key); - let _ = session.write_all(format!(" {}", self.flags).as_bytes()); - match self.ttl { - None => { - let _ = session.write_all(b" 0"); + fn compose(&self, session: &mut dyn BufMut) -> usize { + let verb = b"cas "; + let flags = format!(" {}", self.flags).into_bytes(); + let ttl = convert_ttl(self.ttl); + let vlen = format!(" {}", self.value.len()).into_bytes(); + let cas = format!(" {}", self.cas).into_bytes(); + let header_end = if self.noreply { + " noreply\r\n".as_bytes() + } else { + "\r\n".as_bytes() + }; + + let size = verb.len() + + self.key.len() + + flags.len() + + ttl.len() + + vlen.len() + + cas.len() + + header_end.len() + + self.value.len() + + CRLF.len(); + + session.put_slice(verb); + session.put_slice(&self.key); + session.put_slice(&flags); + session.put_slice(&ttl); + session.put_slice(&vlen); + session.put_slice(&cas); + session.put_slice(header_end); + session.put_slice(&self.value); + session.put_slice(CRLF); + + size + } +} + +impl Klog for Cas { + type Response = Response; + + fn klog(&self, response: &Self::Response) { + let ttl: i64 = match self.ttl() { + None => 0, + Some(0) => -1, + Some(t) => t as _, + }; + let (code, len) = match response { + Response::Stored(ref res) => { + CAS_STORED.increment(); + (STORED, res.len()) } - Some(0) => { - let _ = session.write_all(b" -1"); + Response::Exists(ref res) => { + CAS_EXISTS.increment(); + (EXISTS, res.len()) } - Some(s) => { - let _ = session.write_all(format!(" {}", s).as_bytes()); + Response::NotFound(ref res) => { + CAS_NOT_FOUND.increment(); + (NOT_FOUND, res.len()) } - } - let _ = session.write_all(format!(" {} {}\r\n", self.value.len(), self.cas).as_bytes()); - let _ = session.write_all(&self.value); - let _ = session.write_all(b"\r\n"); + _ => { + return; + } + }; + klog!( + "\"cas {} {} {} {} {}\" {} {}", + string_key(self.key()), + self.flags(), + ttl, + self.value().len(), + self.cas(), + code, + len + ); } } diff --git a/src/protocol/memcache/src/request/decr.rs b/src/protocol/memcache/src/request/decr.rs index c4eac35f8..c3b5ffa71 100644 --- a/src/protocol/memcache/src/request/decr.rs +++ b/src/protocol/memcache/src/request/decr.rs @@ -53,15 +53,44 @@ impl RequestParser { } impl Compose for Decr { - fn compose(&self, session: &mut session::Session) { - let _ = session.write_all(b"decr "); - let _ = session.write_all(&self.key); - let _ = session.write_all(format!(" {}", self.value).as_bytes()); - if self.noreply { - let _ = session.write_all(b" noreply\r\n"); + fn compose(&self, session: &mut dyn BufMut) -> usize { + let verb = b"decr "; + let value = format!(" {}", self.value).into_bytes(); + let header_end = if self.noreply { + " noreply\r\n".as_bytes() } else { - let _ = session.write_all(b"\r\n"); - } + "\r\n".as_bytes() + }; + + let size = verb.len() + self.key.len() + value.len() + header_end.len(); + + session.put_slice(verb); + session.put_slice(&self.key); + session.put_slice(&value); + session.put_slice(header_end); + + size + } +} + +impl Klog for Decr { + type Response = Response; + + fn klog(&self, response: &Self::Response) { + let (code, len) = match response { + Response::Numeric(ref res) => { + DECR_STORED.increment(); + (STORED, res.len()) + } + Response::NotFound(ref res) => { + DECR_NOT_FOUND.increment(); + (NOT_FOUND, res.len()) + } + _ => { + return; + } + }; + klog!("\"decr {}\" {} {}", string_key(self.key()), code, len); } } diff --git a/src/protocol/memcache/src/request/delete.rs b/src/protocol/memcache/src/request/delete.rs index cf2e1233a..3f10b1e8e 100644 --- a/src/protocol/memcache/src/request/delete.rs +++ b/src/protocol/memcache/src/request/delete.rs @@ -75,14 +75,42 @@ impl RequestParser { } impl Compose for Delete { - fn compose(&self, session: &mut session::Session) { - let _ = session.write_all(b"delete "); - let _ = session.write_all(&self.key); - if self.noreply { - let _ = session.write_all(b" noreply\r\n"); + fn compose(&self, session: &mut dyn BufMut) -> usize { + let verb = b"delete "; + let header_end = if self.noreply { + " noreply\r\n".as_bytes() } else { - let _ = session.write_all(b"\r\n"); - } + "\r\n".as_bytes() + }; + + let size = verb.len() + self.key.len() + header_end.len(); + + session.put_slice(verb); + session.put_slice(&self.key); + session.put_slice(header_end); + + size + } +} + +impl Klog for Delete { + type Response = Response; + + fn klog(&self, response: &Self::Response) { + let (code, len) = match response { + Response::Deleted(ref res) => { + DELETE_DELETED.increment(); + (DELETED, res.len()) + } + Response::NotFound(ref res) => { + DELETE_NOT_FOUND.increment(); + (NOT_FOUND, res.len()) + } + _ => { + return; + } + }; + klog!("\"delete {}\" {} {}", string_key(self.key()), code, len); } } diff --git a/src/protocol/memcache/src/request/flush_all.rs b/src/protocol/memcache/src/request/flush_all.rs index 9de7ef0b1..692af3f08 100644 --- a/src/protocol/memcache/src/request/flush_all.rs +++ b/src/protocol/memcache/src/request/flush_all.rs @@ -82,19 +82,35 @@ impl RequestParser { } impl Compose for FlushAll { - fn compose(&self, session: &mut session::Session) { - let _ = session.write_all(b"flush_all"); - if self.delay != 0 { - let _ = session.write_all(format!(" {}", self.delay).as_bytes()); - } - if self.noreply { - let _ = session.write_all(b" noreply\r\n"); + fn compose(&self, session: &mut dyn BufMut) -> usize { + let verb = b"flush_all"; + let delay = if self.delay != 0 { + format!(" {}", self.delay).into_bytes() } else { - let _ = session.write_all(b"\r\n"); - } + vec![] + }; + let header_end = if self.noreply { + " noreply\r\n".as_bytes() + } else { + "\r\n".as_bytes() + }; + + let size = verb.len() + delay.len() + header_end.len(); + + session.put_slice(verb); + session.put_slice(&delay); + session.put_slice(header_end); + + size } } +impl Klog for FlushAll { + type Response = Response; + + fn klog(&self, _response: &Self::Response) {} +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/protocol/memcache/src/request/get.rs b/src/protocol/memcache/src/request/get.rs index f79975d8b..05314b007 100644 --- a/src/protocol/memcache/src/request/get.rs +++ b/src/protocol/memcache/src/request/get.rs @@ -82,13 +82,55 @@ impl RequestParser { } impl Compose for Get { - fn compose(&self, session: &mut session::Session) { - let _ = session.write_all(b"get"); + fn compose(&self, session: &mut dyn BufMut) -> usize { + let verb = b"get"; + + let mut size = verb.len() + CRLF.len(); + + session.put_slice(verb); for key in self.keys.iter() { - let _ = session.write_all(b" "); - let _ = session.write_all(key); + session.put_slice(b" "); + session.put_slice(key); + size += 1 + key.len(); + } + session.put_slice(CRLF); + + size + } +} + +impl Klog for Get { + type Response = Response; + + fn klog(&self, response: &Self::Response) { + if let Response::Values(ref res) = response { + let mut hit_keys = 0; + let mut miss_keys = 0; + + for value in res.values() { + if value.len().is_none() { + miss_keys += 1; + + klog!( + "\"get {}\" {} 0", + String::from_utf8_lossy(value.key()), + MISS + ); + } else { + hit_keys += 1; + + klog!( + "\"get {}\" {} {}", + String::from_utf8_lossy(value.key()), + HIT, + value.len().unwrap(), + ); + } + } + + GET_KEY_HIT.add(hit_keys as _); + GET_KEY_MISS.add(miss_keys as _); } - let _ = session.write_all(b"\r\n"); } } diff --git a/src/protocol/memcache/src/request/gets.rs b/src/protocol/memcache/src/request/gets.rs index bbdb3233b..ae3893c2c 100644 --- a/src/protocol/memcache/src/request/gets.rs +++ b/src/protocol/memcache/src/request/gets.rs @@ -38,13 +38,55 @@ impl RequestParser { } impl Compose for Gets { - fn compose(&self, session: &mut session::Session) { - let _ = session.write_all(b"gets"); + fn compose(&self, session: &mut dyn BufMut) -> usize { + let verb = b"gets"; + + let mut size = verb.len() + CRLF.len(); + + session.put_slice(verb); for key in self.keys.iter() { - let _ = session.write_all(b" "); - let _ = session.write_all(key); + session.put_slice(b" "); + session.put_slice(key); + size += 1 + key.len(); + } + session.put_slice(CRLF); + + size + } +} + +impl Klog for Gets { + type Response = Response; + + fn klog(&self, response: &Self::Response) { + if let Response::Values(ref res) = response { + let mut hit_keys = 0; + let mut miss_keys = 0; + + for value in res.values() { + if value.len().is_none() { + miss_keys += 1; + + klog!( + "\"gets {}\" {} 0", + String::from_utf8_lossy(value.key()), + MISS + ); + } else { + hit_keys += 1; + + klog!( + "\"gets {}\" {} {}", + String::from_utf8_lossy(value.key()), + HIT, + value.len().unwrap(), + ); + } + } + + GETS_KEY_HIT.add(hit_keys as _); + GETS_KEY_MISS.add(miss_keys as _); } - let _ = session.write_all(b"\r\n"); } } diff --git a/src/protocol/memcache/src/request/incr.rs b/src/protocol/memcache/src/request/incr.rs index 0c52adfd9..fa722394d 100644 --- a/src/protocol/memcache/src/request/incr.rs +++ b/src/protocol/memcache/src/request/incr.rs @@ -82,15 +82,44 @@ impl RequestParser { } impl Compose for Incr { - fn compose(&self, session: &mut session::Session) { - let _ = session.write_all(b"incr "); - let _ = session.write_all(&self.key); - let _ = session.write_all(format!(" {}", self.value).as_bytes()); - if self.noreply { - let _ = session.write_all(b" noreply\r\n"); + fn compose(&self, session: &mut dyn BufMut) -> usize { + let verb = b"incr "; + let value = format!(" {}", self.value).into_bytes(); + let header_end = if self.noreply { + " noreply\r\n".as_bytes() } else { - let _ = session.write_all(b"\r\n"); - } + "\r\n".as_bytes() + }; + + let size = verb.len() + self.key.len() + value.len() + header_end.len(); + + session.put_slice(verb); + session.put_slice(&self.key); + session.put_slice(&value); + session.put_slice(header_end); + + size + } +} + +impl Klog for Incr { + type Response = Response; + + fn klog(&self, response: &Self::Response) { + let (code, len) = match response { + Response::Numeric(ref res) => { + INCR_STORED.increment(); + (STORED, res.len()) + } + Response::NotFound(ref res) => { + INCR_NOT_FOUND.increment(); + (NOT_STORED, res.len()) + } + _ => { + return; + } + }; + klog!("\"incr {}\" {} {}", string_key(self.key()), code, len); } } diff --git a/src/protocol/memcache/src/request/mod.rs b/src/protocol/memcache/src/request/mod.rs index b36209a57..0645f60b2 100644 --- a/src/protocol/memcache/src/request/mod.rs +++ b/src/protocol/memcache/src/request/mod.rs @@ -5,9 +5,8 @@ use crate::*; use common::expiry::TimeType; use core::fmt::{Display, Formatter}; -use protocol_common::Parse; -use protocol_common::{ParseError, ParseOk}; -use session::Session; +use protocol_common::{BufMut, Parse, ParseOk}; +use std::borrow::Cow; mod add; mod append; @@ -41,6 +40,19 @@ pub const DEFAULT_MAX_BATCH_SIZE: usize = 1024; pub const DEFAULT_MAX_KEY_LEN: usize = 250; pub const DEFAULT_MAX_VALUE_SIZE: usize = 512 * 1024 * 1024; // 512MB max value size +// response codes for klog +const MISS: u8 = 0; +const HIT: u8 = 4; +const STORED: u8 = 5; +const EXISTS: u8 = 6; +const DELETED: u8 = 7; +const NOT_FOUND: u8 = 8; +const NOT_STORED: u8 = 9; + +fn string_key(key: &[u8]) -> Cow<'_, str> { + String::from_utf8_lossy(key) +} + #[derive(Copy, Clone)] pub struct RequestParser { max_value_size: usize, @@ -168,17 +180,17 @@ impl Default for RequestParser { } impl Parse for RequestParser { - fn parse(&self, buffer: &[u8]) -> Result, protocol_common::ParseError> { + fn parse(&self, buffer: &[u8]) -> Result, std::io::Error> { match self.parse_request(buffer) { Ok((input, request)) => Ok(ParseOk::new(request, buffer.len() - input.len())), - Err(Err::Incomplete(_)) => Err(ParseError::Incomplete), - Err(_) => Err(ParseError::Invalid), + Err(Err::Incomplete(_)) => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)), + Err(_) => Err(std::io::Error::from(std::io::ErrorKind::InvalidInput)), } } } impl Compose for Request { - fn compose(&self, session: &mut Session) { + fn compose(&self, session: &mut dyn BufMut) -> usize { match self { Self::Add(r) => r.compose(session), Self::Append(r) => r.compose(session), @@ -197,6 +209,28 @@ impl Compose for Request { } } +impl Klog for Request { + type Response = Response; + + fn klog(&self, response: &Self::Response) { + match self { + Self::Add(r) => r.klog(response), + Self::Append(r) => r.klog(response), + Self::Cas(r) => r.klog(response), + Self::Decr(r) => r.klog(response), + Self::Delete(r) => r.klog(response), + Self::FlushAll(r) => r.klog(response), + Self::Incr(r) => r.klog(response), + Self::Get(r) => r.klog(response), + Self::Gets(r) => r.klog(response), + Self::Prepend(r) => r.klog(response), + Self::Quit(r) => r.klog(response), + Self::Replace(r) => r.klog(response), + Self::Set(r) => r.klog(response), + } + } +} + #[derive(Debug, PartialEq, Eq)] pub enum Request { Add(Add), @@ -257,6 +291,17 @@ pub enum ExpireTime { UnixSeconds(u32), } +fn convert_ttl(ttl: Option) -> Vec { + match ttl { + None => " 0".to_owned(), + Some(0) => " -1".to_owned(), + Some(s) => { + format!(" {}", s) + } + } + .into_bytes() +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/protocol/memcache/src/request/prepend.rs b/src/protocol/memcache/src/request/prepend.rs index 27b4bfbb8..1952cab35 100644 --- a/src/protocol/memcache/src/request/prepend.rs +++ b/src/protocol/memcache/src/request/prepend.rs @@ -65,24 +65,70 @@ impl RequestParser { } impl Compose for Prepend { - fn compose(&self, session: &mut session::Session) { - let _ = session.write_all(b"prepend "); - let _ = session.write_all(&self.key); - let _ = session.write_all(format!(" {}", self.flags).as_bytes()); - match self.ttl { - None => { - let _ = session.write_all(b" 0"); + fn compose(&self, session: &mut dyn BufMut) -> usize { + let verb = b"prepend "; + let flags = format!(" {}", self.flags).into_bytes(); + let ttl = convert_ttl(self.ttl); + let vlen = format!(" {}", self.value.len()).into_bytes(); + let header_end = if self.noreply { + " noreply\r\n".as_bytes() + } else { + "\r\n".as_bytes() + }; + + let size = verb.len() + + self.key.len() + + flags.len() + + ttl.len() + + vlen.len() + + header_end.len() + + self.value.len() + + CRLF.len(); + + session.put_slice(verb); + session.put_slice(&self.key); + session.put_slice(&flags); + session.put_slice(&ttl); + session.put_slice(&vlen); + session.put_slice(header_end); + session.put_slice(&self.value); + session.put_slice(CRLF); + + size + } +} + +impl Klog for Prepend { + type Response = Response; + + fn klog(&self, response: &Self::Response) { + let ttl: i64 = match self.ttl() { + None => 0, + Some(0) => -1, + Some(t) => t as _, + }; + let (code, len) = match response { + Response::Stored(ref res) => { + PREPEND_STORED.increment(); + (STORED, res.len()) } - Some(0) => { - let _ = session.write_all(b" -1"); + Response::NotStored(ref res) => { + PREPEND_NOT_STORED.increment(); + (NOT_STORED, res.len()) } - Some(s) => { - let _ = session.write_all(format!(" {}", s).as_bytes()); + _ => { + return; } - } - let _ = session.write_all(format!(" {}\r\n", self.value.len()).as_bytes()); - let _ = session.write_all(&self.value); - let _ = session.write_all(b"\r\n"); + }; + klog!( + "\"prepend {} {} {} {}\" {} {}", + string_key(self.key()), + self.flags(), + ttl, + self.value().len(), + code, + len + ); } } diff --git a/src/protocol/memcache/src/request/quit.rs b/src/protocol/memcache/src/request/quit.rs index 20ee22f26..d077864a8 100644 --- a/src/protocol/memcache/src/request/quit.rs +++ b/src/protocol/memcache/src/request/quit.rs @@ -22,11 +22,18 @@ impl RequestParser { } impl Compose for Quit { - fn compose(&self, session: &mut session::Session) { - let _ = session.write_all(b"quit\r\n"); + fn compose(&self, session: &mut dyn BufMut) -> usize { + session.put_slice(b"quit\r\n"); + 6 } } +impl Klog for Quit { + type Response = Response; + + fn klog(&self, _response: &Self::Response) {} +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/protocol/memcache/src/request/replace.rs b/src/protocol/memcache/src/request/replace.rs index c71ac1af5..465c4bd00 100644 --- a/src/protocol/memcache/src/request/replace.rs +++ b/src/protocol/memcache/src/request/replace.rs @@ -65,24 +65,70 @@ impl RequestParser { } impl Compose for Replace { - fn compose(&self, session: &mut session::Session) { - let _ = session.write_all(b"replace "); - let _ = session.write_all(&self.key); - let _ = session.write_all(format!(" {}", self.flags).as_bytes()); - match self.ttl { - None => { - let _ = session.write_all(b" 0"); + fn compose(&self, session: &mut dyn BufMut) -> usize { + let verb = b"replace "; + let flags = format!(" {}", self.flags).into_bytes(); + let ttl = convert_ttl(self.ttl); + let vlen = format!(" {}", self.value.len()).into_bytes(); + let header_end = if self.noreply { + " noreply\r\n".as_bytes() + } else { + "\r\n".as_bytes() + }; + + let size = verb.len() + + self.key.len() + + flags.len() + + ttl.len() + + vlen.len() + + header_end.len() + + self.value.len() + + CRLF.len(); + + session.put_slice(verb); + session.put_slice(&self.key); + session.put_slice(&flags); + session.put_slice(&ttl); + session.put_slice(&vlen); + session.put_slice(header_end); + session.put_slice(&self.value); + session.put_slice(CRLF); + + size + } +} + +impl Klog for Replace { + type Response = Response; + + fn klog(&self, response: &Self::Response) { + let ttl: i64 = match self.ttl() { + None => 0, + Some(0) => -1, + Some(t) => t as _, + }; + let (code, len) = match response { + Response::Stored(ref res) => { + REPLACE_STORED.increment(); + (STORED, res.len()) } - Some(0) => { - let _ = session.write_all(b" -1"); + Response::NotStored(ref res) => { + REPLACE_NOT_STORED.increment(); + (NOT_STORED, res.len()) } - Some(s) => { - let _ = session.write_all(format!(" {}", s).as_bytes()); + _ => { + return; } - } - let _ = session.write_all(format!(" {}\r\n", self.value.len()).as_bytes()); - let _ = session.write_all(&self.value); - let _ = session.write_all(b"\r\n"); + }; + klog!( + "\"replace {} {} {} {}\" {} {}", + string_key(self.key()), + self.flags(), + ttl, + self.value().len(), + code, + len + ); } } diff --git a/src/protocol/memcache/src/request/set.rs b/src/protocol/memcache/src/request/set.rs index 0b2a0a27e..2aff62feb 100644 --- a/src/protocol/memcache/src/request/set.rs +++ b/src/protocol/memcache/src/request/set.rs @@ -2,9 +2,8 @@ // Licensed under the Apache License, Version 2.0 // http://www.apache.org/licenses/LICENSE-2.0 -use crate::*; -use common::time::Seconds; -use common::time::UnixInstant; +use super::*; +use common::time::{Seconds, UnixInstant}; #[derive(Debug, PartialEq, Eq)] pub struct Set { @@ -123,24 +122,70 @@ impl RequestParser { } impl Compose for Set { - fn compose(&self, session: &mut session::Session) { - let _ = session.write_all(b"set "); - let _ = session.write_all(&self.key); - let _ = session.write_all(format!(" {}", self.flags).as_bytes()); - match self.ttl { - None => { - let _ = session.write_all(b" 0"); + fn compose(&self, session: &mut dyn BufMut) -> usize { + let verb = b"set "; + let flags = format!(" {}", self.flags).into_bytes(); + let ttl = convert_ttl(self.ttl); + let vlen = format!(" {}", self.value.len()).into_bytes(); + let header_end = if self.noreply { + " noreply\r\n".as_bytes() + } else { + "\r\n".as_bytes() + }; + + let size = verb.len() + + self.key.len() + + flags.len() + + ttl.len() + + vlen.len() + + header_end.len() + + self.value.len() + + CRLF.len(); + + session.put_slice(verb); + session.put_slice(&self.key); + session.put_slice(&flags); + session.put_slice(&ttl); + session.put_slice(&vlen); + session.put_slice(header_end); + session.put_slice(&self.value); + session.put_slice(CRLF); + + size + } +} + +impl Klog for Set { + type Response = Response; + + fn klog(&self, response: &Self::Response) { + let ttl: i64 = match self.ttl() { + None => 0, + Some(0) => -1, + Some(t) => t as _, + }; + let (code, len) = match response { + Response::Stored(ref res) => { + SET_STORED.increment(); + (STORED, res.len()) } - Some(0) => { - let _ = session.write_all(b" -1"); + Response::NotStored(ref res) => { + SET_NOT_STORED.increment(); + (NOT_STORED, res.len()) } - Some(s) => { - let _ = session.write_all(format!(" {}", s).as_bytes()); + _ => { + return; } - } - let _ = session.write_all(format!(" {}\r\n", self.value.len()).as_bytes()); - let _ = session.write_all(&self.value); - let _ = session.write_all(b"\r\n"); + }; + klog!( + "\"set {} {} {} {}\" {} {}", + string_key(self.key()), + self.flags(), + ttl, + self.value().len(), + code, + len + ); } } diff --git a/src/protocol/memcache/src/response/client_error.rs b/src/protocol/memcache/src/response/client_error.rs index 106eb369b..0b8a32796 100644 --- a/src/protocol/memcache/src/response/client_error.rs +++ b/src/protocol/memcache/src/response/client_error.rs @@ -22,10 +22,16 @@ impl ClientError { } impl Compose for ClientError { - fn compose(&self, session: &mut session::Session) { - let _ = session.write_all(MSG_PREFIX); - let _ = session.write_all(self.inner.as_bytes()); - let _ = session.write_all(b"\r\n"); + fn compose(&self, session: &mut dyn BufMut) -> usize { + let msg = self.inner.as_bytes(); + + let size = MSG_PREFIX.len() + msg.len() + CRLF.len(); + + session.put_slice(MSG_PREFIX); + session.put_slice(msg); + session.put_slice(CRLF); + + size } } diff --git a/src/protocol/memcache/src/response/deleted.rs b/src/protocol/memcache/src/response/deleted.rs index 136f9d688..32bfb8afc 100644 --- a/src/protocol/memcache/src/response/deleted.rs +++ b/src/protocol/memcache/src/response/deleted.rs @@ -30,9 +30,12 @@ impl Deleted { } impl Compose for Deleted { - fn compose(&self, session: &mut session::Session) { + fn compose(&self, session: &mut dyn BufMut) -> usize { if !self.noreply { - let _ = session.write_all(MSG); + session.put_slice(MSG); + MSG.len() + } else { + 0 } } } diff --git a/src/protocol/memcache/src/response/error.rs b/src/protocol/memcache/src/response/error.rs index b79ead6a2..c5262f4bb 100644 --- a/src/protocol/memcache/src/response/error.rs +++ b/src/protocol/memcache/src/response/error.rs @@ -30,8 +30,9 @@ impl Error { } impl Compose for Error { - fn compose(&self, session: &mut session::Session) { - let _ = session.write_all(MSG); + fn compose(&self, session: &mut dyn BufMut) -> usize { + session.put_slice(MSG); + MSG.len() } } diff --git a/src/protocol/memcache/src/response/exists.rs b/src/protocol/memcache/src/response/exists.rs index 38613da52..71f8ecc72 100644 --- a/src/protocol/memcache/src/response/exists.rs +++ b/src/protocol/memcache/src/response/exists.rs @@ -30,9 +30,12 @@ impl Exists { } impl Compose for Exists { - fn compose(&self, session: &mut session::Session) { + fn compose(&self, session: &mut dyn BufMut) -> usize { if !self.noreply { - let _ = session.write_all(MSG); + session.put_slice(MSG); + MSG.len() + } else { + 0 } } } diff --git a/src/protocol/memcache/src/response/mod.rs b/src/protocol/memcache/src/response/mod.rs index 8c8ffbf23..88af9d2b1 100644 --- a/src/protocol/memcache/src/response/mod.rs +++ b/src/protocol/memcache/src/response/mod.rs @@ -3,7 +3,7 @@ // http://www.apache.org/licenses/LICENSE-2.0 use crate::*; -use protocol_common::*; +use protocol_common::{BufMut, Parse, ParseOk}; mod client_error; mod deleted; @@ -99,7 +99,7 @@ impl From for Response { } impl Compose for Response { - fn compose(&self, session: &mut session::Session) { + fn compose(&self, session: &mut dyn BufMut) -> usize { match self { Self::Error(e) => e.compose(session), Self::ClientError(e) => e.compose(session), @@ -111,7 +111,7 @@ impl Compose for Response { Self::Values(e) => e.compose(session), Self::Numeric(e) => e.compose(session), Self::Deleted(e) => e.compose(session), - Self::Hangup => {} + Self::Hangup => 0, } } @@ -223,11 +223,11 @@ pub(crate) fn response(input: &[u8]) -> IResult<&[u8], Response> { } impl Parse for ResponseParser { - fn parse(&self, buffer: &[u8]) -> Result, protocol_common::ParseError> { + fn parse(&self, buffer: &[u8]) -> Result, std::io::Error> { match response(buffer) { Ok((input, response)) => Ok(ParseOk::new(response, buffer.len() - input.len())), - Err(Err::Incomplete(_)) => Err(ParseError::Incomplete), - Err(_) => Err(ParseError::Invalid), + Err(Err::Incomplete(_)) => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)), + Err(_) => Err(std::io::Error::from(std::io::ErrorKind::InvalidInput)), } } } diff --git a/src/protocol/memcache/src/response/not_found.rs b/src/protocol/memcache/src/response/not_found.rs index f979f5467..21592de75 100644 --- a/src/protocol/memcache/src/response/not_found.rs +++ b/src/protocol/memcache/src/response/not_found.rs @@ -30,9 +30,12 @@ impl NotFound { } impl Compose for NotFound { - fn compose(&self, session: &mut session::Session) { + fn compose(&self, session: &mut dyn BufMut) -> usize { if !self.noreply { - let _ = session.write_all(MSG); + session.put_slice(MSG); + MSG.len() + } else { + 0 } } } diff --git a/src/protocol/memcache/src/response/not_stored.rs b/src/protocol/memcache/src/response/not_stored.rs index a7ff2fb67..8d175cd5b 100644 --- a/src/protocol/memcache/src/response/not_stored.rs +++ b/src/protocol/memcache/src/response/not_stored.rs @@ -30,9 +30,12 @@ impl NotStored { } impl Compose for NotStored { - fn compose(&self, session: &mut session::Session) { + fn compose(&self, session: &mut dyn BufMut) -> usize { if !self.noreply { - let _ = session.write_all(MSG); + session.put_slice(MSG); + MSG.len() + } else { + 0 } } } diff --git a/src/protocol/memcache/src/response/numeric.rs b/src/protocol/memcache/src/response/numeric.rs index f884d99ca..03e77991e 100644 --- a/src/protocol/memcache/src/response/numeric.rs +++ b/src/protocol/memcache/src/response/numeric.rs @@ -29,9 +29,13 @@ impl Numeric { } impl Compose for Numeric { - fn compose(&self, session: &mut session::Session) { + fn compose(&self, session: &mut dyn BufMut) -> usize { if !self.noreply { - let _ = session.write_all(format!("{}\r\n", self.value).as_bytes()); + let response = format!("{}\r\n", self.value).into_bytes(); + session.put_slice(&response); + response.len() + } else { + 0 } } } diff --git a/src/protocol/memcache/src/response/server_error.rs b/src/protocol/memcache/src/response/server_error.rs index e356574cf..d4139236e 100644 --- a/src/protocol/memcache/src/response/server_error.rs +++ b/src/protocol/memcache/src/response/server_error.rs @@ -22,10 +22,16 @@ impl ServerError { } impl Compose for ServerError { - fn compose(&self, session: &mut session::Session) { - let _ = session.write_all(MSG_PREFIX); - let _ = session.write_all(self.inner.as_bytes()); - let _ = session.write_all(b"\r\n"); + fn compose(&self, session: &mut dyn BufMut) -> usize { + let msg = self.inner.as_bytes(); + + let size = MSG_PREFIX.len() + msg.len() + CRLF.len(); + + session.put_slice(MSG_PREFIX); + session.put_slice(msg); + session.put_slice(CRLF); + + size } } diff --git a/src/protocol/memcache/src/response/stored.rs b/src/protocol/memcache/src/response/stored.rs index 349c2983c..b9b926b67 100644 --- a/src/protocol/memcache/src/response/stored.rs +++ b/src/protocol/memcache/src/response/stored.rs @@ -30,9 +30,12 @@ impl Stored { } impl Compose for Stored { - fn compose(&self, session: &mut session::Session) { + fn compose(&self, session: &mut dyn BufMut) -> usize { if !self.noreply { - let _ = session.write_all(MSG); + session.put_slice(MSG); + MSG.len() + } else { + 0 } } } diff --git a/src/protocol/memcache/src/response/values.rs b/src/protocol/memcache/src/response/values.rs index 0f3e69fe9..08694c3eb 100644 --- a/src/protocol/memcache/src/response/values.rs +++ b/src/protocol/memcache/src/response/values.rs @@ -24,7 +24,7 @@ pub struct Value { key: Box<[u8]>, flags: u32, cas: Option, - data: Box<[u8]>, + data: Option>, } impl Value { @@ -33,37 +33,68 @@ impl Value { key: key.to_owned().into_boxed_slice(), flags, cas, - data: data.to_owned().into_boxed_slice(), + data: Some(data.to_owned().into_boxed_slice()), + } + } + + pub fn none(key: &[u8]) -> Self { + Self { + key: key.to_owned().into_boxed_slice(), + flags: 0, + cas: None, + data: None, } } pub fn key(&self) -> &[u8] { &self.key } + + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> Option { + self.data.as_ref().map(|v| v.len()) + } } impl Compose for Values { - fn compose(&self, session: &mut session::Session) { + fn compose(&self, session: &mut dyn BufMut) -> usize { + let suffix = b"END\r\n"; + + let mut size = suffix.len(); + for value in self.values.iter() { - value.compose(session); + size += value.compose(session); } - let _ = session.write_all(b"END\r\n"); + session.put_slice(suffix); + + size } } impl Compose for Value { - fn compose(&self, session: &mut session::Session) { - let _ = session.write_all(b"VALUE "); - let _ = session.write_all(&self.key); - if let Some(cas) = self.cas { - let _ = session - .write_all(format!(" {} {} {}\r\n", self.flags, self.data.len(), cas).as_bytes()); - } else { - let _ = - session.write_all(format!(" {} {}\r\n", self.flags, self.data.len()).as_bytes()); + fn compose(&self, session: &mut dyn BufMut) -> usize { + if self.data.is_none() { + return 0; } - let _ = session.write_all(&self.data); - let _ = session.write_all(b"\r\n"); + + let data = self.data.as_ref().unwrap(); + + let prefix = b"VALUE "; + let header_fields = if let Some(cas) = self.cas { + format!(" {} {} {}\r\n", self.flags, data.len(), cas).into_bytes() + } else { + format!(" {} {}\r\n", self.flags, data.len()).into_bytes() + }; + + let size = prefix.len() + self.key.len() + header_fields.len() + data.len() + CRLF.len(); + + session.put_slice(prefix); + session.put_slice(&self.key); + session.put_slice(&header_fields); + session.put_slice(data); + session.put_slice(CRLF); + + size } } @@ -115,7 +146,7 @@ pub fn parse(input: &[u8]) -> IResult<&[u8], Values> { key: key.to_owned().into_boxed_slice(), flags, cas, - data: data.to_owned().into_boxed_slice(), + data: Some(data.to_owned().into_boxed_slice()), }); // look for a space or the start of a CRLF diff --git a/src/protocol/memcache/src/result/mod.rs b/src/protocol/memcache/src/result/mod.rs deleted file mode 100644 index 357caa969..000000000 --- a/src/protocol/memcache/src/result/mod.rs +++ /dev/null @@ -1,325 +0,0 @@ -// Copyright 2022 Twitter, Inc. -// Licensed under the Apache License, Version 2.0 -// http://www.apache.org/licenses/LICENSE-2.0 - -use crate::*; -use logger::*; -use protocol_common::ExecutionResult; -use session::Session; -use std::borrow::Cow; -use std::ops::Deref; - -// response codes for klog -const MISS: u8 = 0; -const HIT: u8 = 4; -const STORED: u8 = 5; -const EXISTS: u8 = 6; -const DELETED: u8 = 7; -const NOT_FOUND: u8 = 8; -const NOT_STORED: u8 = 9; - -pub struct MemcacheExecutionResult { - pub(crate) request: Request, - pub(crate) response: Response, -} - -impl MemcacheExecutionResult { - pub fn new(request: Request, response: Response) -> Self { - Self { request, response } - } -} - -impl ExecutionResult for MemcacheExecutionResult { - fn request(&self) -> &Request { - &self.request - } - - fn response(&self) -> &Response { - &self.response - } -} - -impl Compose for MemcacheExecutionResult { - fn compose(&self, dst: &mut Session) { - match self.request { - Request::Get(ref req) => match self.response { - Response::Values(ref res) => { - let total_keys = req.keys.len(); - let hit_keys = res.values.len(); - let miss_keys = total_keys - hit_keys; - GET_KEY_HIT.add(hit_keys as _); - GET_KEY_MISS.add(miss_keys as _); - - let values = res.values(); - let mut value_index = 0; - - for key in req.keys() { - let key = key.deref(); - // if we are out of values or the keys don't match, it's a miss - if value_index >= values.len() || values[value_index].key() != key { - klog!("\"get {}\" 0 0", String::from_utf8_lossy(key)); - } else { - let start = dst.write_pending(); - values[value_index].compose(dst); - let size = dst.write_pending() - start; - klog!("\"get {}\" 4 {}", String::from_utf8_lossy(key), size); - value_index += 1; - } - } - - let _ = dst.write_all(b"END\r\n"); - - return; - } - _ => return Error {}.compose(dst), - }, - Request::Gets(ref req) => match self.response { - Response::Values(ref res) => { - let total_keys = req.keys.len(); - let hit_keys = res.values.len(); - let miss_keys = total_keys - hit_keys; - GETS_KEY_HIT.add(hit_keys as _); - GETS_KEY_MISS.add(miss_keys as _); - - let values = res.values(); - let mut value_index = 0; - - for key in req.keys() { - let key = key.deref(); - // if we are out of values or the keys don't match, it's a miss - if value_index >= values.len() || values[value_index].key() != key { - klog!("\"gets {}\" {} 0", String::from_utf8_lossy(key), MISS); - } else { - let start = dst.write_pending(); - values[value_index].compose(dst); - let size = dst.write_pending() - start; - klog!("\"gets {}\" {} {}", String::from_utf8_lossy(key), HIT, size); - value_index += 1; - } - } - - let _ = dst.write_all(b"END\r\n"); - - return; - } - _ => return Error {}.compose(dst), - }, - Request::Set(ref req) => { - let ttl: i64 = match req.ttl() { - None => 0, - Some(0) => -1, - Some(t) => t as _, - }; - let (code, len) = match self.response { - Response::Stored(ref res) => { - SET_STORED.increment(); - (STORED, res.len()) - } - Response::NotStored(ref res) => { - SET_NOT_STORED.increment(); - (NOT_STORED, res.len()) - } - _ => return Error {}.compose(dst), - }; - klog!( - "\"set {} {} {} {}\" {} {}", - string_key(req.key()), - req.flags(), - ttl, - req.value().len(), - code, - len - ); - } - Request::Add(ref req) => { - let ttl: i64 = match req.ttl() { - None => 0, - Some(0) => -1, - Some(t) => t as _, - }; - let (code, len) = match self.response { - Response::Stored(ref res) => { - ADD_STORED.increment(); - (STORED, res.len()) - } - Response::NotStored(ref res) => { - ADD_NOT_STORED.increment(); - (NOT_STORED, res.len()) - } - _ => return Error {}.compose(dst), - }; - klog!( - "\"add {} {} {} {}\" {} {}", - string_key(req.key()), - req.flags(), - ttl, - req.value().len(), - code, - len - ); - } - Request::Replace(ref req) => { - let ttl: i64 = match req.ttl() { - None => 0, - Some(0) => -1, - Some(t) => t as _, - }; - let (code, len) = match self.response { - Response::Stored(ref res) => { - REPLACE_STORED.increment(); - (STORED, res.len()) - } - Response::NotStored(ref res) => { - REPLACE_NOT_STORED.increment(); - (NOT_STORED, res.len()) - } - _ => return Error {}.compose(dst), - }; - klog!( - "\"replace {} {} {} {}\" {} {}", - string_key(req.key()), - req.flags(), - ttl, - req.value().len(), - code, - len - ); - } - Request::Cas(ref req) => { - let ttl: i64 = match req.ttl() { - None => 0, - Some(0) => -1, - Some(t) => t as _, - }; - let (code, len) = match self.response { - Response::Stored(ref res) => { - CAS_STORED.increment(); - (STORED, res.len()) - } - Response::Exists(ref res) => { - CAS_EXISTS.increment(); - (EXISTS, res.len()) - } - Response::NotFound(ref res) => { - CAS_NOT_FOUND.increment(); - (NOT_FOUND, res.len()) - } - _ => return Error {}.compose(dst), - }; - klog!( - "\"cas {} {} {} {} {}\" {} {}", - string_key(req.key()), - req.flags(), - ttl, - req.value().len(), - req.cas(), - code, - len - ); - } - Request::Append(ref req) => { - let ttl: i64 = match req.ttl() { - None => 0, - Some(0) => -1, - Some(t) => t as _, - }; - let (code, len) = match self.response { - Response::Stored(ref res) => { - APPEND_STORED.increment(); - (STORED, res.len()) - } - Response::NotStored(ref res) => { - APPEND_NOT_STORED.increment(); - (NOT_STORED, res.len()) - } - _ => return Error {}.compose(dst), - }; - klog!( - "\"append {} {} {} {}\" {} {}", - string_key(req.key()), - req.flags(), - ttl, - req.value().len(), - code, - len - ); - } - Request::Prepend(ref req) => { - let ttl: i64 = match req.ttl() { - None => 0, - Some(0) => -1, - Some(t) => t as _, - }; - let (code, len) = match self.response { - Response::Stored(ref res) => { - PREPEND_STORED.increment(); - (STORED, res.len()) - } - Response::NotStored(ref res) => { - PREPEND_NOT_STORED.increment(); - (NOT_STORED, res.len()) - } - _ => return Error {}.compose(dst), - }; - klog!( - "\"prepend {} {} {} {}\" {} {}", - string_key(req.key()), - req.flags(), - ttl, - req.value().len(), - code, - len - ); - } - Request::Incr(ref req) => { - let (code, len) = match self.response { - Response::Numeric(ref res) => { - INCR_STORED.increment(); - (STORED, res.len()) - } - Response::NotFound(ref res) => { - INCR_NOT_FOUND.increment(); - (NOT_FOUND, res.len()) - } - _ => return Error {}.compose(dst), - }; - klog!("\"incr {}\" {} {}", string_key(req.key()), code, len); - } - Request::Decr(ref req) => { - let (code, len) = match self.response { - Response::Numeric(ref res) => { - DECR_STORED.increment(); - (STORED, res.len()) - } - Response::NotFound(ref res) => { - DECR_NOT_FOUND.increment(); - (NOT_FOUND, res.len()) - } - _ => return Error {}.compose(dst), - }; - klog!("\"decr {}\" {} {}", string_key(req.key()), code, len); - } - Request::Delete(ref req) => { - let (code, len) = match self.response { - Response::Deleted(ref res) => { - DELETE_DELETED.increment(); - (DELETED, res.len()) - } - Response::NotFound(ref res) => { - DELETE_NOT_FOUND.increment(); - (NOT_FOUND, res.len()) - } - _ => return Error {}.compose(dst), - }; - klog!("\"delete {}\" {} {}", string_key(req.key()), code, len); - } - Request::FlushAll(_) => {} - Request::Quit(_) => {} - } - self.response.compose(dst) - } -} - -fn string_key(key: &[u8]) -> Cow<'_, str> { - String::from_utf8_lossy(key) -} diff --git a/src/protocol/ping/Cargo.toml b/src/protocol/ping/Cargo.toml index 820d67ccd..b9062fc00 100644 --- a/src/protocol/ping/Cargo.toml +++ b/src/protocol/ping/Cargo.toml @@ -1,9 +1,8 @@ [package] name = "protocol-ping" -version = "0.0.1" +version = "0.0.2" +edition = "2021" authors = ["Brian Martin "] -edition = "2018" -description = "protocols used in Pelikan servers" homepage = "https://pelikan.io" repository = "https://github.com/twitter/pelikan" license = "Apache-2.0" @@ -19,7 +18,6 @@ config = { path = "../../config" } logger = { path = "../../logger" } protocol-common = { path = "../../protocol/common" } rustcommon-metrics = { git = "https://github.com/twitter/rustcommon" } -session = { path = "../../session" } storage-types = { path = "../../storage/types" } [dev-dependencies] @@ -28,4 +26,4 @@ criterion = "0.3.4" [features] default = [] client = [] -server = [] \ No newline at end of file +server = [] diff --git a/src/protocol/ping/src/lib.rs b/src/protocol/ping/src/lib.rs index d2189bd94..da9490d91 100644 --- a/src/protocol/ping/src/lib.rs +++ b/src/protocol/ping/src/lib.rs @@ -9,6 +9,9 @@ // TODO(bmartin): this crate should probably be split into one crate per // protocol to help separate the metrics namespaces. +#[macro_use] +extern crate logger; + pub use protocol_common::*; mod ping; diff --git a/src/protocol/ping/src/ping/mod.rs b/src/protocol/ping/src/ping/mod.rs index 764ea1c0b..f7fac57d7 100644 --- a/src/protocol/ping/src/ping/mod.rs +++ b/src/protocol/ping/src/ping/mod.rs @@ -8,4 +8,4 @@ mod storage; mod wire; pub use storage::PingStorage; -pub use wire::{PingExecutionResult, Request, RequestParser, Response, ResponseParser}; +pub use wire::{Request, RequestParser, Response, ResponseParser}; diff --git a/src/protocol/ping/src/ping/wire/mod.rs b/src/protocol/ping/src/ping/wire/mod.rs index a30dd56fa..dcc9b2f87 100644 --- a/src/protocol/ping/src/ping/wire/mod.rs +++ b/src/protocol/ping/src/ping/wire/mod.rs @@ -7,42 +7,11 @@ mod request; mod response; -use protocol_common::Compose; -use protocol_common::ExecutionResult; pub use request::*; pub use response::*; -use session::Session; #[allow(unused)] use rustcommon_metrics::*; counter!(PING); counter!(PONG); - -pub struct PingExecutionResult { - request: Request, - response: Response, -} - -impl PingExecutionResult { - pub fn new(request: Request, response: Response) -> Self { - Self { request, response } - } -} - -impl ExecutionResult for PingExecutionResult { - fn request(&self) -> &Request { - &self.request - } - - fn response(&self) -> &Response { - &self.response - } -} - -impl Compose for PingExecutionResult { - fn compose(&self, dst: &mut Session) { - PONG.increment(); - self.response.compose(dst) - } -} diff --git a/src/protocol/ping/src/ping/wire/request/compose.rs b/src/protocol/ping/src/ping/wire/request/compose.rs index 260894736..c277b42da 100644 --- a/src/protocol/ping/src/ping/wire/request/compose.rs +++ b/src/protocol/ping/src/ping/wire/request/compose.rs @@ -1,7 +1,4 @@ -use crate::Compose; -use crate::Request; -use session::Session; -use std::io::Write; +use crate::*; // TODO(bmartin): consider a different trait bound here when reworking buffers. // We ignore the unused result warnings here because we know we're using a @@ -10,10 +7,11 @@ use std::io::Write; #[allow(unused_must_use)] impl Compose for Request { - fn compose(&self, dst: &mut Session) { + fn compose(&self, dst: &mut dyn BufMut) -> usize { match self { Self::Ping => { - dst.write_all(b"ping\r\n"); + dst.put_slice(b"ping\r\n"); + 6 } } } diff --git a/src/protocol/ping/src/ping/wire/request/keyword.rs b/src/protocol/ping/src/ping/wire/request/keyword.rs index 015cfd6ba..cce4e7208 100644 --- a/src/protocol/ping/src/ping/wire/request/keyword.rs +++ b/src/protocol/ping/src/ping/wire/request/keyword.rs @@ -4,7 +4,6 @@ //! This module defines all possible `Ping` commands. -use crate::ParseError; use core::convert::TryFrom; /// Ping request keywords @@ -13,13 +12,13 @@ pub enum Keyword { } impl TryFrom<&[u8]> for Keyword { - type Error = ParseError; + type Error = std::io::Error; fn try_from(value: &[u8]) -> Result { let keyword = match value { b"ping" | b"PING" => Self::Ping, _ => { - return Err(ParseError::Unknown); + return Err(std::io::Error::from(std::io::ErrorKind::InvalidInput)); } }; Ok(keyword) diff --git a/src/protocol/ping/src/ping/wire/request/mod.rs b/src/protocol/ping/src/ping/wire/request/mod.rs index 871d5dad1..f995a0b2f 100644 --- a/src/protocol/ping/src/ping/wire/request/mod.rs +++ b/src/protocol/ping/src/ping/wire/request/mod.rs @@ -11,7 +11,9 @@ mod parse; #[cfg(test)] mod test; +use crate::Response; pub use keyword::Keyword; +use logger::Klog; pub use parse::Parser as RequestParser; @@ -20,3 +22,13 @@ pub use parse::Parser as RequestParser; pub enum Request { Ping, } + +impl Klog for Request { + type Response = Response; + + fn klog(&self, _response: &Self::Response) { + match self { + Request::Ping => klog!("ping {}", 6), + } + } +} diff --git a/src/protocol/ping/src/ping/wire/request/parse.rs b/src/protocol/ping/src/ping/wire/request/parse.rs index 261e248db..1e772294d 100644 --- a/src/protocol/ping/src/ping/wire/request/parse.rs +++ b/src/protocol/ping/src/ping/wire/request/parse.rs @@ -21,7 +21,7 @@ impl Parser { } impl Parse for Parser { - fn parse(&self, buffer: &[u8]) -> Result, ParseError> { + fn parse(&self, buffer: &[u8]) -> Result, std::io::Error> { match parse_keyword(buffer)? { Keyword::Ping => parse_ping(buffer), } @@ -52,7 +52,7 @@ impl<'a> ParseState<'a> { } } -fn parse_keyword(buffer: &[u8]) -> Result { +fn parse_keyword(buffer: &[u8]) -> Result { let command; { let mut parse_state = ParseState::new(buffer); @@ -63,14 +63,14 @@ fn parse_keyword(buffer: &[u8]) -> Result { command = Keyword::try_from(&buffer[0..line_end])?; } } else { - return Err(ParseError::Incomplete); + return Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)); } } Ok(command) } #[allow(clippy::unnecessary_wraps)] -fn parse_ping(buffer: &[u8]) -> Result, ParseError> { +fn parse_ping(buffer: &[u8]) -> Result, std::io::Error> { let mut parse_state = ParseState::new(buffer); // this was already checked for when determining the command diff --git a/src/protocol/ping/src/ping/wire/request/test.rs b/src/protocol/ping/src/ping/wire/request/test.rs index 49117e704..a414af774 100644 --- a/src/protocol/ping/src/ping/wire/request/test.rs +++ b/src/protocol/ping/src/ping/wire/request/test.rs @@ -4,8 +4,8 @@ //! Tests for the `Ping` protocol implementation. -use crate::RequestParser; use crate::*; +use std::io::ErrorKind; #[test] fn ping() { @@ -20,7 +20,7 @@ fn incomplete() { let parser = RequestParser::new(); if let Err(e) = parser.parse(b"ping") { - if e != ParseError::Incomplete { + if e.kind() != ErrorKind::WouldBlock { panic!("invalid parse result"); } } else { @@ -41,7 +41,7 @@ fn unknown() { for request in &["unknown\r\n"] { if let Err(e) = parser.parse(request.as_bytes()) { - if e != ParseError::Unknown { + if e.kind() != ErrorKind::InvalidInput { panic!("invalid parse result"); } } else { diff --git a/src/protocol/ping/src/ping/wire/response/compose.rs b/src/protocol/ping/src/ping/wire/response/compose.rs index b70e360ce..5f3a4b190 100644 --- a/src/protocol/ping/src/ping/wire/response/compose.rs +++ b/src/protocol/ping/src/ping/wire/response/compose.rs @@ -1,7 +1,4 @@ -use crate::Compose; -use crate::Response; -use session::Session; -use std::io::Write; +use crate::*; // TODO(bmartin): consider a different trait bound here when reworking buffers. // We ignore the unused result warnings here because we know we're using a @@ -10,10 +7,11 @@ use std::io::Write; #[allow(unused_must_use)] impl Compose for Response { - fn compose(&self, dst: &mut Session) { + fn compose(&self, dst: &mut dyn BufMut) -> usize { match self { Self::Pong => { - dst.write_all(b"PONG\r\n"); + dst.put_slice(b"PONG\r\n"); + 6 } } } diff --git a/src/protocol/ping/src/ping/wire/response/keyword.rs b/src/protocol/ping/src/ping/wire/response/keyword.rs index 30c1fcc9b..abd483877 100644 --- a/src/protocol/ping/src/ping/wire/response/keyword.rs +++ b/src/protocol/ping/src/ping/wire/response/keyword.rs @@ -4,7 +4,6 @@ //! This module defines all possible `Ping` commands. -use crate::ParseError; use core::convert::TryFrom; /// Ping response keywords @@ -13,13 +12,13 @@ pub enum Keyword { } impl TryFrom<&[u8]> for Keyword { - type Error = ParseError; + type Error = std::io::Error; fn try_from(value: &[u8]) -> Result { let keyword = match value { b"pong" | b"PONG" => Self::Pong, _ => { - return Err(ParseError::Unknown); + return Err(std::io::Error::from(std::io::ErrorKind::InvalidInput)); } }; Ok(keyword) diff --git a/src/protocol/ping/src/ping/wire/response/parse.rs b/src/protocol/ping/src/ping/wire/response/parse.rs index d85a1a67d..b582adf5f 100644 --- a/src/protocol/ping/src/ping/wire/response/parse.rs +++ b/src/protocol/ping/src/ping/wire/response/parse.rs @@ -22,7 +22,7 @@ impl Parser { } impl Parse for Parser { - fn parse(&self, buffer: &[u8]) -> Result, ParseError> { + fn parse(&self, buffer: &[u8]) -> Result, std::io::Error> { match parse_keyword(buffer)? { Keyword::Pong => parse_pong(buffer), } @@ -53,7 +53,7 @@ impl<'a> ParseState<'a> { } } -fn parse_keyword(buffer: &[u8]) -> Result { +fn parse_keyword(buffer: &[u8]) -> Result { let command; { let mut parse_state = ParseState::new(buffer); @@ -64,14 +64,14 @@ fn parse_keyword(buffer: &[u8]) -> Result { command = Keyword::try_from(&buffer[0..line_end])?; } } else { - return Err(ParseError::Incomplete); + return Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)); } } Ok(command) } #[allow(clippy::unnecessary_wraps)] -fn parse_pong(buffer: &[u8]) -> Result, ParseError> { +fn parse_pong(buffer: &[u8]) -> Result, std::io::Error> { let mut parse_state = ParseState::new(buffer); // this was already checked for when determining the command diff --git a/src/protocol/ping/src/ping/wire/response/test.rs b/src/protocol/ping/src/ping/wire/response/test.rs index a01a749a2..01ace19fd 100644 --- a/src/protocol/ping/src/ping/wire/response/test.rs +++ b/src/protocol/ping/src/ping/wire/response/test.rs @@ -4,8 +4,8 @@ //! Tests for the `Ping` protocol implementation. -use crate::ping::ResponseParser; use crate::*; +use std::io::ErrorKind; #[test] fn ping() { @@ -20,7 +20,7 @@ fn incomplete() { let parser = ResponseParser::new(); if let Err(e) = parser.parse(b"pong") { - if e != ParseError::Incomplete { + if e.kind() != ErrorKind::WouldBlock { panic!("invalid parse result"); } } else { @@ -41,7 +41,7 @@ fn unknown() { for request in &["unknown\r\n"] { if let Err(e) = parser.parse(request.as_bytes()) { - if e != ParseError::Unknown { + if e.kind() != ErrorKind::InvalidInput { panic!("invalid parse result"); } } else { diff --git a/src/protocol/resp/Cargo.toml b/src/protocol/resp/Cargo.toml index 3b709082f..4ef419efa 100644 --- a/src/protocol/resp/Cargo.toml +++ b/src/protocol/resp/Cargo.toml @@ -1,11 +1,16 @@ [package] name = "protocol-resp" -version = "0.1.0" +version = "0.2.0" edition = "2021" +authors = ["Brian Martin "] +homepage = "https://pelikan.io" +repository = "https://github.com/twitter/pelikan" +license = "Apache-2.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -nom = "*" +common = { path = "../../common" } +nom = "5.1.2" protocol-common = { path = "../../protocol/common" } -session = { path = "../../session" } \ No newline at end of file +rustcommon-metrics = { git = "https://github.com/twitter/rustcommon" } diff --git a/src/protocol/resp/src/lib.rs b/src/protocol/resp/src/lib.rs index a4151c30d..8dd68c328 100644 --- a/src/protocol/resp/src/lib.rs +++ b/src/protocol/resp/src/lib.rs @@ -11,3 +11,5 @@ pub(crate) use util::*; pub use request::*; pub use response::*; + +common::metrics::test_no_duplicates!(); diff --git a/src/protocol/resp/src/message/array.rs b/src/protocol/resp/src/message/array.rs index 9ab62ff65..785eef991 100644 --- a/src/protocol/resp/src/message/array.rs +++ b/src/protocol/resp/src/message/array.rs @@ -11,16 +11,22 @@ pub struct Array { } impl Compose for Array { - fn compose(&self, session: &mut session::Session) { + fn compose(&self, session: &mut dyn BufMut) -> usize { + let mut len = 0; if let Some(values) = &self.inner { - let _ = session.write_all(format!("${}\r\n", values.len()).as_bytes()); + let header = format!("${}\r\n", values.len()); + let _ = session.put_slice(header.as_bytes()); + len += header.as_bytes().len(); for value in values { - value.compose(session); + len += value.compose(session); } - let _ = session.write_all(b"\r\n"); + let _ = session.put_slice(b"\r\n"); + len += 2; } else { - let _ = session.write_all(b"*-1\r\n"); + let _ = session.put_slice(b"*-1\r\n"); + len += 5; } + len } } diff --git a/src/protocol/resp/src/message/bulk_string.rs b/src/protocol/resp/src/message/bulk_string.rs index 36dc89dba..6f9ec8d49 100644 --- a/src/protocol/resp/src/message/bulk_string.rs +++ b/src/protocol/resp/src/message/bulk_string.rs @@ -5,6 +5,8 @@ use super::*; use std::sync::Arc; +use std::io::{Error, ErrorKind}; + #[derive(Debug, PartialEq, Eq)] #[allow(clippy::redundant_allocation)] pub struct BulkString { @@ -26,28 +28,31 @@ impl From>> for BulkString { } impl TryInto for BulkString { - type Error = ParseError; + type Error = Error; - fn try_into(self) -> std::result::Result { + fn try_into(self) -> std::result::Result { if self.inner.is_none() { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "null bulk string")); } std::str::from_utf8(self.inner.as_ref().unwrap()) - .map_err(|_| ParseError::Invalid)? + .map_err(|_| Error::new(ErrorKind::Other, "bulk string is not valid utf8"))? .parse::() - .map_err(|_| ParseError::Invalid) + .map_err(|_| Error::new(ErrorKind::Other, "bulk string is not a valid u64")) } } impl Compose for BulkString { - fn compose(&self, session: &mut session::Session) { + fn compose(&self, buf: &mut dyn BufMut) -> usize { if let Some(value) = &self.inner { - let _ = session.write_all(format!("${}\r\n", value.len()).as_bytes()); - let _ = session.write_all(value); - let _ = session.write_all(b"\r\n"); + let header = format!("${}\r\n", value.len()); + let _ = buf.put_slice(header.as_bytes()); + let _ = buf.put_slice(value); + let _ = buf.put_slice(b"\r\n"); + header.as_bytes().len() + value.len() + 2 } else { - let _ = session.write_all(b"$-1\r\n"); + let _ = buf.put_slice(b"$-1\r\n"); + 5 } } } diff --git a/src/protocol/resp/src/message/error.rs b/src/protocol/resp/src/message/error.rs index 4a50f2884..07f25155c 100644 --- a/src/protocol/resp/src/message/error.rs +++ b/src/protocol/resp/src/message/error.rs @@ -10,10 +10,11 @@ pub struct Error { } impl Compose for Error { - fn compose(&self, session: &mut session::Session) { - let _ = session.write_all(b"-"); - let _ = session.write_all(self.inner.as_bytes()); - let _ = session.write_all(b"\r\n"); + fn compose(&self, buf: &mut dyn BufMut) -> usize { + let _ = buf.put_slice(b"-"); + let _ = buf.put_slice(self.inner.as_bytes()); + let _ = buf.put_slice(b"\r\n"); + self.inner.as_bytes().len() + 3 } } diff --git a/src/protocol/resp/src/message/integer.rs b/src/protocol/resp/src/message/integer.rs index d97e875bb..ac55d7e1a 100644 --- a/src/protocol/resp/src/message/integer.rs +++ b/src/protocol/resp/src/message/integer.rs @@ -10,8 +10,10 @@ pub struct Integer { } impl Compose for Integer { - fn compose(&self, session: &mut session::Session) { - let _ = session.write_all(format!(":{}\r\n", self.inner).as_bytes()); + fn compose(&self, buf: &mut dyn BufMut) -> usize { + let data = format!(":{}\r\n", self.inner); + let _ = buf.put_slice(data.as_bytes()); + data.as_bytes().len() } } diff --git a/src/protocol/resp/src/message/mod.rs b/src/protocol/resp/src/message/mod.rs index f6bfae3b4..f38b7f823 100644 --- a/src/protocol/resp/src/message/mod.rs +++ b/src/protocol/resp/src/message/mod.rs @@ -53,13 +53,13 @@ impl Message { } impl Compose for Message { - fn compose(&self, session: &mut session::Session) { + fn compose(&self, buf: &mut dyn BufMut) -> usize { match self { - Self::SimpleString(s) => s.compose(session), - Self::BulkString(s) => s.compose(session), - Self::Error(e) => e.compose(session), - Self::Integer(i) => i.compose(session), - Self::Array(a) => a.compose(session), + Self::SimpleString(s) => s.compose(buf), + Self::BulkString(s) => s.compose(buf), + Self::Error(e) => e.compose(buf), + Self::Integer(i) => i.compose(buf), + Self::Array(a) => a.compose(buf), } } } @@ -117,11 +117,14 @@ pub(crate) fn message(input: &[u8]) -> IResult<&[u8], Message> { } impl Parse for MessageParser { - fn parse(&self, buffer: &[u8]) -> Result, protocol_common::ParseError> { + fn parse(&self, buffer: &[u8]) -> Result, std::io::Error> { match message(buffer) { Ok((input, message)) => Ok(ParseOk::new(message, buffer.len() - input.len())), - Err(Err::Incomplete(_)) => Err(ParseError::Incomplete), - Err(_) => Err(ParseError::Invalid), + Err(Err::Incomplete(_)) => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)), + Err(_) => Err(std::io::Error::new( + std::io::ErrorKind::Other, + "malformed message", + )), } } } diff --git a/src/protocol/resp/src/message/simple_string.rs b/src/protocol/resp/src/message/simple_string.rs index a588c5376..bc2afd9f0 100644 --- a/src/protocol/resp/src/message/simple_string.rs +++ b/src/protocol/resp/src/message/simple_string.rs @@ -10,10 +10,11 @@ pub struct SimpleString { } impl Compose for SimpleString { - fn compose(&self, session: &mut session::Session) { - let _ = session.write_all(b"+"); - let _ = session.write_all(self.inner.as_bytes()); - let _ = session.write_all(b"\r\n"); + fn compose(&self, buf: &mut dyn BufMut) -> usize { + let _ = buf.put_slice(b"+"); + let _ = buf.put_slice(self.inner.as_bytes()); + let _ = buf.put_slice(b"\r\n"); + self.inner.as_bytes().len() + 3 } } diff --git a/src/protocol/resp/src/request/get.rs b/src/protocol/resp/src/request/get.rs index 2dff42622..51f3b068a 100644 --- a/src/protocol/resp/src/request/get.rs +++ b/src/protocol/resp/src/request/get.rs @@ -3,6 +3,7 @@ // http://www.apache.org/licenses/LICENSE-2.0 use super::*; +use std::io::{Error, ErrorKind}; use std::sync::Arc; #[derive(Debug, PartialEq, Eq)] @@ -12,39 +13,39 @@ pub struct GetRequest { } impl TryFrom for GetRequest { - type Error = ParseError; + type Error = Error; - fn try_from(other: Message) -> Result { + fn try_from(other: Message) -> Result { if let Message::Array(array) = other { if array.inner.is_none() { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); } let mut array = array.inner.unwrap(); if array.len() != 2 { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); } let key = if let Message::BulkString(key) = array.remove(1) { if key.inner.is_none() { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); } let key = key.inner.unwrap(); if key.len() == 0 { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); } key } else { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); }; Ok(Self { key }) } else { - Err(ParseError::Invalid) + Err(Error::new(ErrorKind::Other, "malformed command")) } } } @@ -73,9 +74,9 @@ impl From<&GetRequest> for Message { } impl Compose for GetRequest { - fn compose(&self, session: &mut session::Session) { + fn compose(&self, buf: &mut dyn BufMut) -> usize { let message = Message::from(self); - message.compose(session) + message.compose(buf) } } diff --git a/src/protocol/resp/src/request/mod.rs b/src/protocol/resp/src/request/mod.rs index 2156eb691..b5909a9f2 100644 --- a/src/protocol/resp/src/request/mod.rs +++ b/src/protocol/resp/src/request/mod.rs @@ -4,9 +4,10 @@ use crate::message::*; use crate::*; +use protocol_common::BufMut; use protocol_common::Parse; -use protocol_common::{ParseError, ParseOk}; -use session::Session; +use protocol_common::ParseOk; +use std::io::{Error, ErrorKind}; use std::sync::Arc; mod get; @@ -29,11 +30,11 @@ impl RequestParser { } impl Parse for RequestParser { - fn parse(&self, buffer: &[u8]) -> Result, protocol_common::ParseError> { + fn parse(&self, buffer: &[u8]) -> Result, Error> { // we have two different parsers, one for RESP and one for inline // both require that there's at least one character in the buffer if buffer.is_empty() { - return Err(ParseError::Incomplete); + return Err(Error::from(ErrorKind::WouldBlock)); } let (message, consumed) = if matches!(buffer[0], b'*' | b'+' | b'-' | b':' | b'$') { @@ -65,7 +66,7 @@ impl Parse for RequestParser { } if &remaining[0..2] != b"\r\n" { - return Err(ParseError::Incomplete); + return Err(Error::from(ErrorKind::WouldBlock)); } let message = Message::Array(Array { @@ -80,13 +81,13 @@ impl Parse for RequestParser { match &message { Message::Array(array) => { if array.inner.is_none() { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); } let array = array.inner.as_ref().unwrap(); if array.is_empty() { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); } match &array[0] { @@ -97,17 +98,17 @@ impl Parse for RequestParser { Some(b"set") | Some(b"SET") => { SetRequest::try_from(message).map(Request::from) } - _ => Err(ParseError::Invalid), + _ => Err(Error::new(ErrorKind::Other, "unknown command")), }, _ => { // all valid commands are encoded as a bulk string - Err(ParseError::Invalid) + Err(Error::new(ErrorKind::Other, "malformed command")) } } } _ => { // all valid requests are arrays - Err(ParseError::Invalid) + Err(Error::new(ErrorKind::Other, "malformed command")) } } .map(|v| ParseOk::new(v, consumed)) @@ -115,10 +116,10 @@ impl Parse for RequestParser { } impl Compose for Request { - fn compose(&self, session: &mut Session) { + fn compose(&self, buf: &mut dyn BufMut) -> usize { match self { - Self::Get(r) => r.compose(session), - Self::Set(r) => r.compose(session), + Self::Get(r) => r.compose(buf), + Self::Set(r) => r.compose(buf), } } } diff --git a/src/protocol/resp/src/request/set.rs b/src/protocol/resp/src/request/set.rs index 46d7901d5..21e285ca9 100644 --- a/src/protocol/resp/src/request/set.rs +++ b/src/protocol/resp/src/request/set.rs @@ -3,6 +3,7 @@ // http://www.apache.org/licenses/LICENSE-2.0 use super::*; +use std::io::{Error, ErrorKind}; use std::sync::Arc; #[derive(Debug, PartialEq, Eq, Copy, Clone)] @@ -45,23 +46,23 @@ impl SetRequest { } impl TryFrom for SetRequest { - type Error = ParseError; + type Error = Error; - fn try_from(other: Message) -> Result { + fn try_from(other: Message) -> Result { if let Message::Array(array) = other { if array.inner.is_none() { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); } let mut array = array.inner.unwrap(); if array.len() < 3 { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); } let key = take_bulk_string(&mut array)?; if key.is_empty() { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); } let value = take_bulk_string(&mut array)?; @@ -75,14 +76,14 @@ impl TryFrom for SetRequest { while i < array.len() { if let Message::BulkString(field) = &array[i] { if field.inner.is_none() { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); } let field = field.inner.as_ref().unwrap(); match field.as_ref().as_ref() { b"EX" => { if expire_time.is_some() || array.len() < i + 2 { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); } let s = take_bulk_string_as_u64(&mut array)?; expire_time = Some(ExpireTime::Seconds(s)); @@ -90,7 +91,7 @@ impl TryFrom for SetRequest { } b"PX" => { if expire_time.is_some() || array.len() < i + 2 { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); } let ms = take_bulk_string_as_u64(&mut array)?; expire_time = Some(ExpireTime::Milliseconds(ms)); @@ -98,7 +99,7 @@ impl TryFrom for SetRequest { } b"EXAT" => { if expire_time.is_some() || array.len() < i + 2 { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); } let s = take_bulk_string_as_u64(&mut array)?; expire_time = Some(ExpireTime::UnixSeconds(s)); @@ -106,7 +107,7 @@ impl TryFrom for SetRequest { } b"PXAT" => { if expire_time.is_some() || array.len() < i + 2 { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); } let ms = take_bulk_string_as_u64(&mut array)?; expire_time = Some(ExpireTime::UnixMilliseconds(ms)); @@ -114,37 +115,37 @@ impl TryFrom for SetRequest { } b"KEEPTTL" => { if expire_time.is_some() { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); } expire_time = Some(ExpireTime::KeepTtl); } b"NX" => { if mode != SetMode::Set { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); } mode = SetMode::Add; } b"XX" => { if mode != SetMode::Set { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); } mode = SetMode::Replace; } b"GET" => { if get_old { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); } get_old = true; } _ => { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); } } } else { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "malformed command")); } i += 1; @@ -158,7 +159,7 @@ impl TryFrom for SetRequest { get_old, }) } else { - Err(ParseError::Invalid) + Err(Error::new(ErrorKind::Other, "malformed command")) } } } @@ -213,9 +214,9 @@ impl From<&SetRequest> for Message { } impl Compose for SetRequest { - fn compose(&self, session: &mut session::Session) { + fn compose(&self, buf: &mut dyn BufMut) -> usize { let message = Message::from(self); - message.compose(session) + message.compose(buf) } } diff --git a/src/protocol/resp/src/util.rs b/src/protocol/resp/src/util.rs index 7fa5571bf..f6ba32c03 100644 --- a/src/protocol/resp/src/util.rs +++ b/src/protocol/resp/src/util.rs @@ -2,14 +2,13 @@ // Licensed under the Apache License, Version 2.0 // http://www.apache.org/licenses/LICENSE-2.0 -use crate::message::*; pub use nom::bytes::streaming::*; pub use nom::character::streaming::*; -pub use nom::error::ErrorKind; pub use nom::{AsChar, Err, IResult, InputTakeAtPosition, Needed}; pub use protocol_common::Compose; -use protocol_common::ParseError; -pub use std::io::Write; +pub use std::io::{Error, ErrorKind, Write}; + +use crate::message::*; use std::sync::Arc; // consumes one or more literal spaces @@ -19,7 +18,7 @@ pub fn space1(input: &[u8]) -> IResult<&[u8], &[u8]> { let c = item.as_char(); c != ' ' }, - ErrorKind::Space, + nom::error::ErrorKind::Space, ) } @@ -39,24 +38,27 @@ pub fn string(input: &[u8]) -> IResult<&[u8], &[u8]> { } #[allow(clippy::redundant_allocation)] -pub fn take_bulk_string(array: &mut Vec) -> Result>, ParseError> { +pub fn take_bulk_string(array: &mut Vec) -> Result>, Error> { if let Message::BulkString(s) = array.remove(1) { if s.inner.is_none() { - return Err(ParseError::Invalid); + return Err(Error::new(ErrorKind::Other, "bulk string is null")); } let s = s.inner.unwrap(); Ok(s) } else { - Err(ParseError::Invalid) + Err(Error::new( + ErrorKind::Other, + "next array element is not a bulk string", + )) } } -pub fn take_bulk_string_as_u64(array: &mut Vec) -> Result { +pub fn take_bulk_string_as_u64(array: &mut Vec) -> Result { let s = take_bulk_string(array)?; std::str::from_utf8(&s) - .map_err(|_| ParseError::Invalid)? + .map_err(|_| Error::new(ErrorKind::Other, "bulk string not valid utf8"))? .parse::() - .map_err(|_| ParseError::Invalid) + .map_err(|_| Error::new(ErrorKind::Other, "bulk string is not a u64")) } diff --git a/src/protocol/thrift/Cargo.toml b/src/protocol/thrift/Cargo.toml index 6c1b0c086..ce58884d6 100644 --- a/src/protocol/thrift/Cargo.toml +++ b/src/protocol/thrift/Cargo.toml @@ -1,12 +1,16 @@ [package] name = "protocol-thrift" -version = "0.0.1" +version = "0.0.2" edition = "2021" +authors = ["Brian Martin "] +homepage = "https://pelikan.io" +repository = "https://github.com/twitter/pelikan" +license = "Apache-2.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +common = { path = "../../common" } logger = { path = "../../logger" } protocol-common = { path = "../../protocol/common" } rustcommon-metrics = { git = "https://github.com/twitter/rustcommon" } -session = { path = "../../session" } diff --git a/src/protocol/thrift/src/lib.rs b/src/protocol/thrift/src/lib.rs index cf2c35364..0aacd1440 100644 --- a/src/protocol/thrift/src/lib.rs +++ b/src/protocol/thrift/src/lib.rs @@ -4,11 +4,11 @@ //! A protocol crate for Thrift binary protocol. +use protocol_common::BufMut; use protocol_common::Compose; use protocol_common::Parse; -use protocol_common::{ParseError, ParseOk}; +use protocol_common::ParseOk; use rustcommon_metrics::*; -use std::io::Write; const THRIFT_HEADER_LEN: usize = std::mem::size_of::(); @@ -29,10 +29,11 @@ impl Message { } impl Compose for Message { - fn compose(&self, session: &mut session::Session) { + fn compose(&self, session: &mut dyn BufMut) -> usize { MESSAGES_COMPOSED.increment(); - let _ = session.write_all(&(self.data.len() as u32).to_be_bytes()); - let _ = session.write_all(&self.data); + session.put_slice(&(self.data.len() as u32).to_be_bytes()); + session.put_slice(&self.data); + std::mem::size_of::() + self.data.len() } } @@ -49,9 +50,9 @@ impl MessageParser { } impl Parse for MessageParser { - fn parse(&self, buffer: &[u8]) -> Result, ParseError> { + fn parse(&self, buffer: &[u8]) -> Result, std::io::Error> { if buffer.len() < THRIFT_HEADER_LEN { - return Err(ParseError::Incomplete); + return Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)); } let data_len = u32::from_be_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]); @@ -59,11 +60,11 @@ impl Parse for MessageParser { let framed_len = THRIFT_HEADER_LEN + data_len as usize; if framed_len == 0 || framed_len > self.max_size { - return Err(ParseError::Invalid); + return Err(std::io::Error::from(std::io::ErrorKind::InvalidInput)); } if buffer.len() < framed_len { - Err(ParseError::Incomplete) + Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)) } else { MESSAGES_PARSED.increment(); let data = buffer[THRIFT_HEADER_LEN..framed_len] @@ -97,3 +98,5 @@ mod tests { assert_eq!(*parsed.data, body); } } + +common::metrics::test_no_duplicates!(); diff --git a/src/proxy/momento/Cargo.toml b/src/proxy/momento/Cargo.toml index 6440870f4..8c680dacf 100644 --- a/src/proxy/momento/Cargo.toml +++ b/src/proxy/momento/Cargo.toml @@ -16,10 +16,11 @@ config = { path = "../../config" } libc = "0.2.83" logger = { path = "../../logger" } momento = "0.3.1" +net = { path = "../../net" } protocol-admin = { path = "../../protocol/admin" } protocol-memcache = { path = "../../protocol/memcache" } protocol-resp = { path = "../../protocol/resp" } rustcommon-metrics = { git = "https://github.com/twitter/rustcommon" } session = { path = "../../session" } storage-types = { path = "../../storage/types" } -tokio = { version = "1.17.0", features = ["full"] } \ No newline at end of file +tokio = { version = "1.17.0", features = ["full"] } diff --git a/src/proxy/momento/src/admin.rs b/src/proxy/momento/src/admin.rs index 84d3bd4fd..41cd0ab81 100644 --- a/src/proxy/momento/src/admin.rs +++ b/src/proxy/momento/src/admin.rs @@ -3,6 +3,11 @@ // http://www.apache.org/licenses/LICENSE-2.0 use crate::*; +use session::Buf; + +gauge!(ADMIN_CONN_CURR); +counter!(ADMIN_CONN_ACCEPT); +counter!(ADMIN_CONN_CLOSE); pub(crate) async fn admin(mut log_drain: Box, admin_listener: TcpListener) { loop { @@ -12,12 +17,12 @@ pub(crate) async fn admin(mut log_drain: Box, admin_listener: if let Ok(Ok((socket, _))) = timeout(Duration::from_millis(1), admin_listener.accept()).await { - TCP_CONN_CURR.increment(); - TCP_ACCEPT.increment(); + ADMIN_CONN_CURR.increment(); + ADMIN_CONN_ACCEPT.increment(); tokio::spawn(async move { admin::handle_admin_client(socket).await; - TCP_CLOSE.increment(); - TCP_CONN_CURR.decrement(); + ADMIN_CONN_CLOSE.increment(); + ADMIN_CONN_CURR.decrement(); }); }; @@ -71,7 +76,7 @@ pub(crate) async fn admin(mut log_drain: Box, admin_listener: async fn handle_admin_client(mut socket: tokio::net::TcpStream) { // initialize a buffer for incoming bytes from the client - let mut buf = Buffer::with_capacity(INITIAL_BUFFER_SIZE); + let mut buf = Buffer::new(INITIAL_BUFFER_SIZE); // initialize the request parser let parser = AdminRequestParser::new(); @@ -80,15 +85,17 @@ async fn handle_admin_client(mut socket: tokio::net::TcpStream) { break; } - ADMIN_REQUEST_PARSE.increment(); - match parser.parse(buf.borrow()) { Ok(request) => { + ADMIN_REQUEST_PARSE.increment(); + let consumed = request.consumed(); let request = request.into_inner(); match request { AdminRequest::Stats { .. } => { + ADMIN_RESPONSE_COMPOSE.increment(); + if stats_response(&mut socket).await.is_err() { break; } @@ -97,19 +104,16 @@ async fn handle_admin_client(mut socket: tokio::net::TcpStream) { debug!("unsupported command: {:?}", request); } } - buf.consume(consumed); - } - Err(ParseError::Incomplete) => {} - Err(ParseError::Invalid) => { - // invalid request - let _ = socket.write_all(b"CLIENT_ERROR\r\n").await; - break; - } - Err(ParseError::Unknown) => { - // unknown command - let _ = socket.write_all(b"CLIENT_ERROR\r\n").await; - break; + buf.advance(consumed); } + Err(e) => match e.kind() { + ErrorKind::WouldBlock => {} + _ => { + // invalid request + let _ = socket.write_all(b"CLIENT_ERROR\r\n").await; + break; + } + }, } } } @@ -161,7 +165,6 @@ async fn stats_response(socket: &mut tokio::net::TcpStream) -> Result<(), Error> } data.sort(); - ADMIN_RESPONSE_COMPOSE.increment(); for line in data { socket.write_all(line.as_bytes()).await?; } diff --git a/src/proxy/momento/src/frontend.rs b/src/proxy/momento/src/frontend.rs index 39010bd6f..8cbb70fd4 100644 --- a/src/proxy/momento/src/frontend.rs +++ b/src/proxy/momento/src/frontend.rs @@ -4,6 +4,7 @@ use crate::protocol::*; use crate::*; +use session::Buf; pub(crate) async fn handle_memcache_client( mut socket: tokio::net::TcpStream, @@ -11,7 +12,7 @@ pub(crate) async fn handle_memcache_client( cache_name: String, ) { // initialize a buffer for incoming bytes from the client - let mut buf = Buffer::with_capacity(INITIAL_BUFFER_SIZE); + let mut buf = Buffer::new(INITIAL_BUFFER_SIZE); // initialize the request parser let parser = memcache::RequestParser::new(); @@ -48,19 +49,16 @@ pub(crate) async fn handle_memcache_client( debug!("unsupported command: {}", request); } } - buf.consume(consumed); - } - Err(ParseError::Incomplete) => {} - Err(ParseError::Invalid) => { - // invalid request - let _ = socket.write_all(b"CLIENT_ERROR\r\n").await; - break; - } - Err(ParseError::Unknown) => { - // unknown command - let _ = socket.write_all(b"CLIENT_ERROR\r\n").await; - break; + buf.advance(consumed); } + Err(e) => match e.kind() { + ErrorKind::WouldBlock => {} + _ => { + // invalid request + let _ = socket.write_all(b"CLIENT_ERROR\r\n").await; + break; + } + }, } } } @@ -71,7 +69,7 @@ pub(crate) async fn handle_resp_client( cache_name: String, ) { // initialize a buffer for incoming bytes from the client - let mut buf = Buffer::with_capacity(INITIAL_BUFFER_SIZE); + let mut buf = Buffer::new(INITIAL_BUFFER_SIZE); // initialize the request parser let parser = resp::RequestParser::new(); @@ -105,21 +103,16 @@ pub(crate) async fn handle_resp_client( } } } - buf.consume(consumed); - } - Err(ParseError::Incomplete) => {} - Err(ParseError::Invalid) => { - // invalid request - println!("bad request"); - let _ = socket.write_all(b"CLIENT_ERROR\r\n").await; - break; - } - Err(ParseError::Unknown) => { - // unknown command - println!("unknown command"); - let _ = socket.write_all(b"CLIENT_ERROR\r\n").await; - break; + buf.advance(consumed); } + Err(e) => match e.kind() { + ErrorKind::WouldBlock => {} + _ => { + println!("bad request"); + let _ = socket.write_all(b"CLIENT_ERROR\r\n").await; + break; + } + }, } } } diff --git a/src/proxy/momento/src/klog.rs b/src/proxy/momento/src/klog.rs index 734d7f36b..684dd81b5 100644 --- a/src/proxy/momento/src/klog.rs +++ b/src/proxy/momento/src/klog.rs @@ -1,3 +1,7 @@ +// Copyright 2022 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + pub(crate) fn klog_get(key: &str, response_len: usize) { if response_len == 0 { klog!("\"get {}\" 0 {}", key, response_len); diff --git a/src/proxy/momento/src/listener.rs b/src/proxy/momento/src/listener.rs index ff9281102..54c1419a7 100644 --- a/src/proxy/momento/src/listener.rs +++ b/src/proxy/momento/src/listener.rs @@ -3,6 +3,7 @@ // http://www.apache.org/licenses/LICENSE-2.0 use crate::*; +use ::net::{TCP_ACCEPT, TCP_CLOSE, TCP_CONN_CURR}; pub(crate) async fn listener( listener: TcpListener, @@ -31,8 +32,8 @@ pub(crate) async fn listener( } } - TCP_CLOSE.increment(); TCP_CONN_CURR.decrement(); + TCP_CLOSE.increment(); }); } } diff --git a/src/proxy/momento/src/main.rs b/src/proxy/momento/src/main.rs index 17ac51382..ffbfb2123 100644 --- a/src/proxy/momento/src/main.rs +++ b/src/proxy/momento/src/main.rs @@ -19,6 +19,7 @@ use momento::response::cache_get_response::*; use momento::response::cache_set_response::*; use momento::response::error::*; use momento::simple_cache_client::*; +use net::TCP_RECV_BYTE; use protocol_admin::*; use rustcommon_metrics::*; use session::*; @@ -85,18 +86,6 @@ counter!(BACKEND_EX); counter!(BACKEND_EX_RATE_LIMITED); counter!(BACKEND_EX_TIMEOUT); -counter!(GET); -counter!(GET_EX); -counter!(GET_KEY); -counter!(GET_KEY_EX); -counter!(GET_KEY_HIT); -counter!(GET_KEY_MISS); - -counter!(SET); -counter!(SET_EX); -counter!(SET_NOT_STORED); -counter!(SET_STORED); - counter!(RU_UTIME); counter!(RU_STIME); gauge!(RU_MAXRSS); @@ -364,11 +353,13 @@ async fn do_read( TCP_RECV_BYTE.add(n as _); // non-zero means we have some data, mark the buffer as // having additional content - buf.increase_len(n); + unsafe { + buf.advance_mut(n); + } // if the buffer is low on space, we will grow the // buffer - if buf.available_capacity() * 2 < INITIAL_BUFFER_SIZE { + if buf.remaining_mut() * 2 < INITIAL_BUFFER_SIZE { buf.reserve(INITIAL_BUFFER_SIZE); } @@ -385,3 +376,5 @@ async fn do_read( } } } + +common::metrics::test_no_duplicates!(); diff --git a/src/proxy/momento/src/protocol/memcache/get.rs b/src/proxy/momento/src/protocol/memcache/get.rs index 1a946dd94..fe83f34d0 100644 --- a/src/proxy/momento/src/protocol/memcache/get.rs +++ b/src/proxy/momento/src/protocol/memcache/get.rs @@ -1,7 +1,11 @@ -use crate::klog::klog_get; -use crate::*; +// Copyright 2022 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 -pub use protocol_memcache::{Request, RequestParser}; +use crate::klog::klog_get; +use crate::{Error, *}; +use ::net::*; +use protocol_memcache::*; pub async fn get( client: &mut SimpleCacheClient, @@ -9,8 +13,6 @@ pub async fn get( socket: &mut tokio::net::TcpStream, keys: &[Box<[u8]>], ) -> Result<(), Error> { - GET.increment(); - // check if any of the keys are invalid before // sending the requests to the backend for key in keys.iter() { @@ -25,9 +27,8 @@ pub async fn get( let mut response_buf = Vec::new(); - for key in keys.iter() { + for key in keys { BACKEND_REQUEST.increment(); - GET_KEY.increment(); // we've already checked the keys, so we // know this unwrap is safe @@ -41,13 +42,8 @@ pub async fn get( // the backend. BACKEND_EX.increment(); - GET_KEY_EX.increment(); - - // TODO: what is the right - // way to handle this? - // - // currently ignoring and - // moving on to the next key + // ignore and move on to the next key, treating this as + // a cache miss } MomentoGetStatus::HIT => { GET_KEY_HIT.increment(); @@ -56,9 +52,7 @@ pub async fn get( let item_header = format!("VALUE {} 0 {}\r\n", key, length); - let response_len = 2 + item_header.len() + response.value.len(); - - klog_get(key, response_len); + klog_get(key, response.value.len()); response_buf.extend_from_slice(item_header.as_bytes()); response_buf.extend_from_slice(&response.value); @@ -76,20 +70,17 @@ pub async fn get( Ok(Err(MomentoError::LimitExceeded(_))) => { BACKEND_EX.increment(); BACKEND_EX_RATE_LIMITED.increment(); - GET_KEY_EX.increment(); } Ok(Err(e)) => { // we got some error from the momento client // log and incr stats and move on treating it // as a miss error!("error for get: {}", e); - GET_KEY_EX.increment(); BACKEND_EX.increment(); } Err(_) => { // we had a timeout, incr stats and move on // treating it as a miss - GET_KEY_EX.increment(); BACKEND_EX.increment(); BACKEND_EX_TIMEOUT.increment(); } diff --git a/src/proxy/momento/src/protocol/memcache/mod.rs b/src/proxy/momento/src/protocol/memcache/mod.rs index 345b9cff1..8652401d6 100644 --- a/src/proxy/momento/src/protocol/memcache/mod.rs +++ b/src/proxy/momento/src/protocol/memcache/mod.rs @@ -1,3 +1,9 @@ +// Copyright 2022 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + +pub use protocol_memcache::{Request, RequestParser}; + mod get; mod set; diff --git a/src/proxy/momento/src/protocol/memcache/set.rs b/src/proxy/momento/src/protocol/memcache/set.rs index 50e576d88..7651cdf4e 100644 --- a/src/proxy/momento/src/protocol/memcache/set.rs +++ b/src/proxy/momento/src/protocol/memcache/set.rs @@ -1,5 +1,11 @@ +// Copyright 2022 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + use crate::klog::klog_set; -use crate::*; +use crate::{Error, *}; +use ::net::*; +use protocol_memcache::*; pub async fn set( client: &mut SimpleCacheClient, @@ -20,15 +26,19 @@ pub async fn set( if value.is_empty() { error!("empty values are not supported by momento"); - SESSION_SEND.increment(); - SESSION_SEND_BYTE.add(7); - TCP_SEND_BYTE.add(7); - if socket.write_all(b"ERROR\r\n").await.is_err() { - SESSION_SEND_EX.increment(); - } + let _ = socket.write_all(b"ERROR\r\n").await; + return Err(Error::from(ErrorKind::InvalidInput)); } + let value = if let Ok(value) = std::str::from_utf8(request.value()) { + value.to_owned() + } else { + debug!("value is not valid utf8: {:?}", request.value()); + let _ = socket.write_all(b"ERROR\r\n").await; + return Err(Error::from(ErrorKind::InvalidInput)); + }; + BACKEND_REQUEST.increment(); let ttl = if let Some(ttl) = request.ttl() { @@ -98,6 +108,7 @@ pub async fn set( SESSION_SEND.increment(); SESSION_SEND_BYTE.add(12); TCP_SEND_BYTE.add(12); + // let client know this wasn't stored if let Err(e) = socket.write_all(b"NOT_STORED\r\n").await { SESSION_SEND_EX.increment(); diff --git a/src/proxy/momento/src/protocol/mod.rs b/src/proxy/momento/src/protocol/mod.rs index 95c1a461f..c07a4ed8c 100644 --- a/src/proxy/momento/src/protocol/mod.rs +++ b/src/proxy/momento/src/protocol/mod.rs @@ -1,2 +1,6 @@ +// Copyright 2022 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + pub mod memcache; pub mod resp; diff --git a/src/proxy/momento/src/protocol/resp/get.rs b/src/proxy/momento/src/protocol/resp/get.rs index e55069adb..0195bed89 100644 --- a/src/proxy/momento/src/protocol/resp/get.rs +++ b/src/proxy/momento/src/protocol/resp/get.rs @@ -1,7 +1,11 @@ -use crate::klog::klog_get; -use crate::*; +// Copyright 2022 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 -pub use protocol_resp::{Request, RequestParser}; +use crate::klog::klog_get; +use crate::{Error, *}; +use ::net::*; +use protocol_memcache::*; pub async fn get( client: &mut SimpleCacheClient, @@ -37,8 +41,6 @@ pub async fn get( // the backend. BACKEND_EX.increment(); - GET_KEY_EX.increment(); - // TODO: what is the right // way to handle this? // @@ -72,7 +74,6 @@ pub async fn get( Ok(Err(MomentoError::LimitExceeded(_))) => { BACKEND_EX.increment(); BACKEND_EX_RATE_LIMITED.increment(); - GET_KEY_EX.increment(); response_buf.extend_from_slice(b"-ERR ratelimit exceed\r\n"); } Ok(Err(e)) => { @@ -80,14 +81,12 @@ pub async fn get( // log and incr stats and move on treating it // as a miss error!("error for get: {}", e); - GET_KEY_EX.increment(); BACKEND_EX.increment(); response_buf.extend_from_slice(b"-ERR backend error\r\n"); } Err(_) => { // we had a timeout, incr stats and move on // treating it as a miss - GET_KEY_EX.increment(); BACKEND_EX.increment(); BACKEND_EX_TIMEOUT.increment(); response_buf.extend_from_slice(b"-ERR backend timeout\r\n"); diff --git a/src/proxy/momento/src/protocol/resp/mod.rs b/src/proxy/momento/src/protocol/resp/mod.rs index 345b9cff1..fb65889e3 100644 --- a/src/proxy/momento/src/protocol/resp/mod.rs +++ b/src/proxy/momento/src/protocol/resp/mod.rs @@ -1,3 +1,9 @@ +// Copyright 2022 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + +pub use protocol_resp::{Request, RequestParser}; + mod get; mod set; diff --git a/src/proxy/momento/src/protocol/resp/set.rs b/src/proxy/momento/src/protocol/resp/set.rs index 3a01ce702..e929815e3 100644 --- a/src/proxy/momento/src/protocol/resp/set.rs +++ b/src/proxy/momento/src/protocol/resp/set.rs @@ -1,6 +1,11 @@ -use crate::klog::klog_set; -use crate::*; +// Copyright 2022 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 +use crate::klog::klog_set; +use crate::{Error, *}; +use ::net::*; +use protocol_memcache::*; use protocol_resp::SetRequest; pub async fn set( @@ -25,6 +30,7 @@ pub async fn set( SESSION_SEND.increment(); SESSION_SEND_BYTE.add(7); TCP_SEND_BYTE.add(7); + if socket.write_all(b"ERROR\r\n").await.is_err() { SESSION_SEND_EX.increment(); } @@ -68,6 +74,7 @@ pub async fn set( SESSION_SEND.increment(); SESSION_SEND_BYTE.add(8); TCP_SEND_BYTE.add(8); + if let Err(e) = socket.write_all(b"+OK\r\n").await { SESSION_SEND_EX.increment(); // hangup if we can't send a response back @@ -87,6 +94,7 @@ pub async fn set( SESSION_SEND.increment(); SESSION_SEND_BYTE.add(12); TCP_SEND_BYTE.add(12); + // let client know this wasn't stored if let Err(e) = socket.write_all(b"-ERR backend error\r\n").await { SESSION_SEND_EX.increment(); @@ -142,7 +150,7 @@ pub async fn set( // let client know this wasn't stored if let Err(e) = socket.write_all(b"-ERR backend error\r\n").await { - SESSION_SEND_EX.increment(); + // SESSION_SEND_EX.increment(); // hangup if we can't send a response back return Err(e); } diff --git a/src/proxy/ping/src/lib.rs b/src/proxy/ping/src/lib.rs index 328f4702b..ca5bb0088 100644 --- a/src/proxy/ping/src/lib.rs +++ b/src/proxy/ping/src/lib.rs @@ -7,6 +7,14 @@ use logger::configure_logging; use protocol_ping::*; use proxy::{Process, ProcessBuilder}; +type BackendParser = ResponseParser; +type BackendRequest = Request; +type BackendResponse = Response; + +type FrontendParser = RequestParser; +type FrontendRequest = Request; +type FrontendResponse = Response; + #[allow(dead_code)] pub struct Pingproxy { process: Process, @@ -32,9 +40,15 @@ impl Pingproxy { let response_parser = ResponseParser::new(); // initialize process - let process_builder = - ProcessBuilder::new(config, request_parser, response_parser, log_drain) - .expect("failed to launch"); + let process_builder = ProcessBuilder::< + BackendParser, + BackendRequest, + BackendResponse, + FrontendParser, + FrontendRequest, + FrontendResponse, + >::new(&config, log_drain, response_parser, request_parser) + .expect("failed to launch"); let process = process_builder.spawn(); Self { process } diff --git a/src/proxy/thrift/src/lib.rs b/src/proxy/thrift/src/lib.rs index 1fbcd81f7..b05169e0e 100644 --- a/src/proxy/thrift/src/lib.rs +++ b/src/proxy/thrift/src/lib.rs @@ -9,6 +9,14 @@ use proxy::{Process, ProcessBuilder}; const MAX_SIZE: usize = 16 * 1024 * 1024; // 16MB +type BackendParser = MessageParser; +type BackendRequest = Message; +type BackendResponse = Message; + +type FrontendParser = MessageParser; +type FrontendRequest = Message; +type FrontendResponse = Message; + #[allow(dead_code)] pub struct Thriftproxy { process: Process, @@ -34,9 +42,15 @@ impl Thriftproxy { let response_parser = MessageParser::new(MAX_SIZE); // initialize process - let process_builder = - ProcessBuilder::new(config, request_parser, response_parser, log_drain) - .expect("failed to launch"); + let process_builder = ProcessBuilder::< + BackendParser, + BackendRequest, + BackendResponse, + FrontendParser, + FrontendRequest, + FrontendResponse, + >::new(&config, log_drain, response_parser, request_parser) + .expect("failed to launch"); let process = process_builder.spawn(); Self { process } diff --git a/src/queues/Cargo.toml b/src/queues/Cargo.toml index 7252527bd..c6024e577 100644 --- a/src/queues/Cargo.toml +++ b/src/queues/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "queues" -version = "0.2.0" +version = "0.3.0" authors = ["Brian Martin "] edition = "2018" description = "queue types for inter-process communication" @@ -12,6 +12,9 @@ license = "Apache-2.0" [dependencies] crossbeam-queue = "0.3.5" -mio = { version = "0.8.0", features = ["os-poll", "net"] } rand = "0.8.5" rand_chacha = "0.3.1" +waker = { path = "../core/waker" } + +[dev-dependencies] +net = { path = "../net" } diff --git a/src/queues/src/lib.rs b/src/queues/src/lib.rs index a24032214..1aa0cf5c3 100644 --- a/src/queues/src/lib.rs +++ b/src/queues/src/lib.rs @@ -4,7 +4,7 @@ //! Queue type for inter-process communication (IPC). -pub use mio::Waker; +pub use waker::Waker; use crossbeam_queue::*; use rand::distributions::Uniform; @@ -270,18 +270,24 @@ impl TrackedItem { #[cfg(test)] mod tests { use crate::Queues; - use mio::*; + use ::net::Waker as MioWaker; + use ::net::{Poll, Token}; use std::sync::Arc; + use waker::Waker; const WAKER_TOKEN: Token = Token(usize::MAX); #[test] fn basic() { let poll = Poll::new().expect("failed to create event loop"); - let waker = - Arc::new(Waker::new(poll.registry(), WAKER_TOKEN).expect("failed to create waker")); + let waker = Arc::new(Waker::from( + MioWaker::new(poll.registry(), WAKER_TOKEN).expect("failed to create waker"), + )); - let (mut a, mut b) = Queues::::new(vec![waker.clone()], vec![waker], 1024); + let a_wakers = vec![waker.clone()]; + let b_wakers = vec![waker]; + + let (mut a, mut b) = Queues::::new(&a_wakers, &b_wakers, 1024); let mut a = a.remove(0); let mut b = b.remove(0); diff --git a/src/server/pingserver/Cargo.toml b/src/server/pingserver/Cargo.toml index b239a736b..b453516e2 100644 --- a/src/server/pingserver/Cargo.toml +++ b/src/server/pingserver/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pingserver" -version = "0.1.0" +version = "0.2.0" authors = ["Brian Martin "] edition = "2018" description = "a simple ascii ping/pong server" @@ -28,8 +28,6 @@ name = "benchmark" path = "benches/benchmark.rs" harness = false -[features] - [dependencies] backtrace = "0.3.56" clap = "2.33.3" @@ -41,6 +39,5 @@ protocol-ping = { path = "../../protocol/ping", features = ["server"] } rustcommon-metrics = { git = "https://github.com/twitter/rustcommon" } server = { path = "../../core/server" } - [dev-dependencies] criterion = "0.3" diff --git a/src/server/pingserver/benches/benchmark.rs b/src/server/pingserver/benches/benchmark.rs index a4f3600dd..eaa8573a1 100644 --- a/src/server/pingserver/benches/benchmark.rs +++ b/src/server/pingserver/benches/benchmark.rs @@ -23,7 +23,7 @@ fn ping_benchmark(c: &mut Criterion) { let config = PingserverConfig::default(); // launch the server - let server = Pingserver::new(config); + let server = Pingserver::new(config).expect("failed to launch pingserver"); // wait for server to startup. duration is chosen to be longer than we'd // expect startup to take in a slow ci environment. diff --git a/src/server/pingserver/src/lib.rs b/src/server/pingserver/src/lib.rs index 6af94cf8a..ac9c51bce 100644 --- a/src/server/pingserver/src/lib.rs +++ b/src/server/pingserver/src/lib.rs @@ -25,7 +25,7 @@ pub struct Pingserver { impl Pingserver { /// Creates a new `Pingserver` process from the given `PingserverConfig`. - pub fn new(config: PingserverConfig) -> Self { + pub fn new(config: PingserverConfig) -> Result { // initialize logging let log_drain = configure_logging(&config); @@ -35,26 +35,19 @@ impl Pingserver { // initialize storage let storage = Storage::new(); - // use a fixed buffer size for the pingserver - let max_buffer_size = server::DEFAULT_BUFFER_SIZE; - // initialize parser let parser = Parser::new(); // initialize process - let process_builder = ProcessBuilder::::new( - config, - storage, - max_buffer_size, - parser, - log_drain, - ) + let process_builder = ProcessBuilder::::new( + &config, log_drain, parser, storage, + )? .version(env!("CARGO_PKG_VERSION")); // spawn threads let process = process_builder.spawn(); - Self { process } + Ok(Self { process }) } /// Wait for all threads to complete. Blocks until the process has fully diff --git a/src/server/pingserver/src/main.rs b/src/server/pingserver/src/main.rs index 3e51c5477..285e07129 100644 --- a/src/server/pingserver/src/main.rs +++ b/src/server/pingserver/src/main.rs @@ -96,7 +96,7 @@ fn main() { match PingserverConfig::load(file) { Ok(c) => c, Err(e) => { - error!("{}", e); + println!("error launching pingserver: {}", e); std::process::exit(1); } } @@ -105,5 +105,11 @@ fn main() { }; // launch - Pingserver::new(config).wait() + match Pingserver::new(config) { + Ok(s) => s.wait(), + Err(e) => { + println!("error launching pingserver: {}", e); + std::process::exit(1); + } + } } diff --git a/src/server/pingserver/tests/integration.rs b/src/server/pingserver/tests/integration.rs index 626a74185..da3e5a77e 100644 --- a/src/server/pingserver/tests/integration.rs +++ b/src/server/pingserver/tests/integration.rs @@ -16,7 +16,7 @@ use std::time::Duration; fn main() { debug!("launching server"); - let server = Pingserver::new(PingserverConfig::default()); + let server = Pingserver::new(PingserverConfig::default()).expect("failed to launch"); // wait for server to startup. duration is chosen to be longer than we'd // expect startup to take in a slow ci environment. @@ -75,14 +75,19 @@ fn test(name: &str, data: &[(&str, Option<&str>)]) { let mut buf = vec![0; 4096]; if let Some(response) = response { - if stream.read(&mut buf).is_err() { - panic!("error reading response"); - } else if response.as_bytes() != &buf[0..response.len()] { - error!("expected: {:?}", response.as_bytes()); - error!("received: {:?}", &buf[0..response.len()]); - panic!("status: failed\n"); - } else { - debug!("correct response"); + match stream.read(&mut buf) { + Err(e) => { + panic!("error reading response: {}", e); + } + Ok(_) => { + if response.as_bytes() != &buf[0..response.len()] { + error!("expected: {:?}", response.as_bytes()); + error!("received: {:?}", &buf[0..response.len()]); + panic!("status: failed\n"); + } else { + debug!("correct response"); + } + } } assert_eq!(response.as_bytes(), &buf[0..response.len()]); } else if let Err(e) = stream.read(&mut buf) { diff --git a/src/server/segcache/Cargo.toml b/src/server/segcache/Cargo.toml index 0ee298ae7..ae1ec4444 100644 --- a/src/server/segcache/Cargo.toml +++ b/src/server/segcache/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "segcache" -version = "0.1.0" +version = "0.2.0" +edition = "2021" authors = ["Brian Martin "] -edition = "2018" description = "a Memcache protocol server with segment-structured storage" homepage = "https://pelikan.io" repository = "https://github.com/twitter/pelikan" diff --git a/src/server/segcache/src/lib.rs b/src/server/segcache/src/lib.rs index c555eaa93..04b7fb03b 100644 --- a/src/server/segcache/src/lib.rs +++ b/src/server/segcache/src/lib.rs @@ -33,24 +33,15 @@ impl Segcache { // initialize storage let storage = Storage::new(&config)?; - let max_buffer_size = std::cmp::max( - server::DEFAULT_BUFFER_SIZE, - config.seg().segment_size() as usize * 2, - ); - // initialize parser let parser = Parser::new() .max_value_size(config.seg().segment_size() as usize) .time_type(config.time().time_type()); // initialize process - let process_builder = ProcessBuilder::::new( - config, - storage, - max_buffer_size, - parser, - log_drain, - ) + let process_builder = ProcessBuilder::::new( + &config, log_drain, parser, storage, + )? .version(env!("CARGO_PKG_VERSION")); // spawn threads diff --git a/src/server/segcache/src/main.rs b/src/server/segcache/src/main.rs index bf4322ebe..1d37dd846 100644 --- a/src/server/segcache/src/main.rs +++ b/src/server/segcache/src/main.rs @@ -120,7 +120,8 @@ fn main() { match Segcache::new(config) { Ok(segcache) => segcache.wait(), Err(e) => { - error!("error launching segcache: {}", e); + println!("error launching segcache: {}", e); + std::process::exit(1); } } } diff --git a/src/session/Cargo.toml b/src/session/Cargo.toml index 0d96fbab7..8fc6ad74b 100644 --- a/src/session/Cargo.toml +++ b/src/session/Cargo.toml @@ -1,9 +1,8 @@ [package] name = "session" -version = "0.0.2" +version = "0.1.0" +edition = "2021" authors = ["Brian Martin "] -edition = "2018" -description = "TCP Sessions with or without TLS for use with mio event loops" homepage = "https://pelikan.io" repository = "https://github.com/twitter/pelikan" license = "Apache-2.0" @@ -11,20 +10,10 @@ license = "Apache-2.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -common = { path = "../common" } -config = { path = "../config" } -logger = { path = "../logger" } -mio = { version = "0.8.0", features = ["os-poll", "net"] } -rand = "0.8.0" -rtrb = "0.1.3" -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0.64" -slab = "0.4.2" -strum = "0.20.0" -strum_macros = "0.20.1" -sysconf = "0.3.4" -thiserror = "1.0.23" - -[dependencies.rustcommon-metrics] -git = "https://github.com/twitter/rustcommon" -features = ["heatmap"] +# buffer = { path = "../buffer" } +bytes = "1.1.0" +log = "0.4.17" +net = { path = "../net" } +protocol-common = { path = "../protocol/common" } +rustcommon-metrics = { git = "https://github.com/twitter/rustcommon", features = ["heatmap"] } +rustcommon-time = { git = "https://github.com/twitter/rustcommon" } diff --git a/src/session/src/buffer.rs b/src/session/src/buffer.rs index fb4c03757..e66d535a8 100644 --- a/src/session/src/buffer.rs +++ b/src/session/src/buffer.rs @@ -1,194 +1,294 @@ -// Copyright 2021 Twitter, Inc. +// Copyright 2022 Twitter, Inc. // Licensed under the Apache License, Version 2.0 // http://www.apache.org/licenses/LICENSE-2.0 -//! A very simple buffer type that can be replaced in the future. - -use crate::SESSION_BUFFER_BYTE; +pub use bytes::buf::UninitSlice; +pub use bytes::{Buf, BufMut}; +use crate::*; use core::borrow::{Borrow, BorrowMut}; +use std::alloc::*; + +const KB: usize = 1024; +const MB: usize = 1024 * KB; -/// A growable byte buffer +/// A simple growable byte buffer, represented as a contiguous range of bytes pub struct Buffer { - buffer: Vec, + ptr: *mut u8, + cap: usize, read_offset: usize, write_offset: usize, - target_capacity: usize, + target_size: usize, } -impl Buffer { - /// Create a new `Buffer` that can hold up to `capacity` bytes without - /// re-allocating. - #[allow(clippy::slow_vector_initialization)] - pub fn with_capacity(capacity: usize) -> Self { - let mut buffer = Vec::with_capacity(capacity); - buffer.resize(capacity, 0); +unsafe impl Send for Buffer {} +unsafe impl Sync for Buffer {} - SESSION_BUFFER_BYTE.add(buffer.capacity() as _); +impl Buffer { + /// Create a new buffer that can hold up to `target_size` bytes without + /// resizing. The buffer may grow beyond the `target_size`, but will shrink + /// back down to the `target_size` when possible. + pub fn new(target_size: usize) -> Self { + let target_size = target_size.next_power_of_two(); + let layout = Layout::array::(target_size).unwrap(); + let ptr = unsafe { alloc(layout) }; + let cap = target_size; + let read_offset = 0; + let write_offset = 0; + + SESSION_BUFFER_BYTE.add(cap as _); Self { - buffer, - read_offset: 0, - write_offset: 0, - target_capacity: capacity, + ptr, + cap, + read_offset, + write_offset, + target_size, } } - /// Returns the amount of space available to write into the buffer without - /// reallocating. - pub fn available_capacity(&self) -> usize { - self.buffer.len() - self.write_offset + /// Returns the current capacity of the buffer. + pub fn capacity(&self) -> usize { + self.cap } - /// Return the number of bytes currently in the buffer. - pub fn len(&self) -> usize { - self.write_offset - self.read_offset - } - - /// Check if the buffer is empty. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } + /// Reserve space for `amt` additional bytes. + pub fn reserve(&mut self, amt: usize) { + // if the buffer is empty, reset the offsets + if self.remaining() == 0 { + self.read_offset = 0; + self.write_offset = 0; + } - // TODO(bmartin): we're currently relying on the resize behaviors of the - // underlying `Vec` storage. This currently results in growth to the next - // nearest power of two. Effectively resulting in buffer doubling when a - // resize is required. - /// Reserve room for `additional` bytes in the buffer. This may reserve more - /// space than requested to avoid frequent allocations. If the buffer - /// already has sufficient available capacity, this is a no-op. - pub fn reserve(&mut self, additional: usize) { - let old_cap = self.buffer.capacity(); - let needed = additional.saturating_sub(self.available_capacity()); - if needed > 0 { - let current = self.buffer.len(); - let target = (current + needed).next_power_of_two(); - self.buffer.resize(target, 0); - SESSION_BUFFER_BYTE.add((self.buffer.capacity() - old_cap) as _); + // grow the buffer if needed, uses a multiple of the target size + if amt > self.remaining_mut() { + // calculate the required buffer size + let size = self.write_offset + amt; + + // determine what power of the target size would be required to + // hold the new size + let pow = (size).next_power_of_two(); + + // determine how much to grow the buffer by + let amt = if size > MB || pow > MB { + // if it would be above a MB, determine the next whole MB and + // subtract the current capacity to determine the amount to grow + // by + (size / MB + 1) * MB - self.cap + } else { + // if it would be 1 MB or less, set it to the next power of two + // multiple of the target size, minus the current capacity + pow - self.cap + }; + + SESSION_BUFFER_BYTE.add(amt as _); + + // new size will be the current capacity plus the amount needed + let size = self.cap + amt; + let layout = Layout::array::(self.cap).unwrap(); + self.ptr = unsafe { realloc(self.ptr, layout, size) }; + self.cap = size; } } - /// Append the bytes from `other` onto `self`. - pub fn extend_from_slice(&mut self, other: &[u8]) { - self.reserve(other.len()); - self.buffer[self.write_offset..(self.write_offset + other.len())].copy_from_slice(other); - self.increase_len(other.len()); - } + /// Clear the buffer. + pub fn clear(&mut self) { + self.read_offset = 0; + self.write_offset = 0; + + // if the buffer is oversized, shrink to the target size + if self.cap > self.target_size { + trace!("shrinking buffer"); - /// Mark that `amt` bytes have been consumed and should not be returned in - /// future reads from the buffer. - pub fn consume(&mut self, bytes: usize) { - let old_capacity = self.buffer.capacity(); - self.read_offset = std::cmp::min(self.read_offset + bytes, self.write_offset); + SESSION_BUFFER_BYTE.sub((self.cap - self.target_size) as _); - // if we have content, before shrinking we must shift content left - if !self.is_empty() { - self.buffer - .copy_within(self.read_offset..self.write_offset, 0); + let layout = Layout::array::(self.cap).unwrap(); + self.ptr = unsafe { realloc(self.ptr, layout, self.target_size) }; + self.cap = self.target_size; } + } - self.write_offset -= self.read_offset; - self.read_offset = 0; + /// Compact the buffer by moving contents to the beginning and freeing any + /// excess space. As an optimization, this will not always compact the + /// buffer to its `target_size`. + pub fn compact(&mut self) { + // if the buffer is empty, we clear the buffer and return + if self.remaining() == 0 { + self.clear(); + return; + } - // determine the target size of the buffer - let target_size = if self.len() * 2 > self.buffer.len() { - // buffer too full to shrink, early return + // if its not too large, we don't compact + if self.cap == self.target_size { return; - } else if self.len() > self.target_capacity { - // should shrink, but not to target capacity - self.buffer.len() / 2 + } + + // if the buffer data is deep into the buffer, we can copy the data to + // the start of the buffer to make additional space available for writes + if self.read_offset > self.target_size { + if self.remaining() < self.read_offset { + unsafe { + std::ptr::copy_nonoverlapping( + self.ptr.add(self.read_offset), + self.ptr, + self.remaining(), + ); + } + } else { + unsafe { + std::ptr::copy(self.ptr.add(self.read_offset), self.ptr, self.remaining()); + } + } + self.write_offset = self.remaining(); + self.read_offset = 0; + } + + let target = if self.write_offset > MB { + (1 + (self.write_offset / MB)) * MB } else { - // shrink down to target capacity - self.target_capacity + self.write_offset.next_power_of_two() }; - // buffer can be reduced to the target_size determined above - self.buffer.truncate(target_size); - self.buffer.shrink_to_fit(); + SESSION_BUFFER_BYTE.sub((self.cap - target) as _); + let layout = Layout::array::(self.cap).unwrap(); + self.ptr = unsafe { realloc(self.ptr, layout, target) }; + self.cap = target; + } - // update stats if the buffer has resized - SESSION_BUFFER_BYTE.sub(old_capacity as i64 - self.buffer.capacity() as i64); + /// Get the current write position as a pointer. `remaining_mut` should be + /// used as the length. + pub fn write_ptr(&mut self) -> *mut u8 { + unsafe { self.ptr.add(self.write_offset) } } - /// Marks the buffer as now containing `amt` additional bytes. This function - /// prevents advancing the write offset beyond the initialized area of the - /// underlying storage. - pub fn increase_len(&mut self, amt: usize) { - self.write_offset = std::cmp::min(self.write_offset + amt, self.buffer.len()); + /// Get the current read position as a pointer. `remaining` should be used + /// as the length. + pub fn read_ptr(&mut self) -> *mut u8 { + unsafe { self.ptr.add(self.read_offset) } + } +} + +impl Drop for Buffer { + fn drop(&mut self) { + SESSION_BUFFER_BYTE.sub(self.cap as _); } } impl Borrow<[u8]> for Buffer { fn borrow(&self) -> &[u8] { - &self.buffer[self.read_offset..self.write_offset] + unsafe { std::slice::from_raw_parts(self.ptr.add(self.read_offset), self.remaining()) } } } impl BorrowMut<[u8]> for Buffer { - fn borrow_mut(&mut self) -> &mut [u8] { - let available = self.buffer.len(); - &mut self.buffer[self.write_offset..available] + fn borrow_mut(self: &mut Buffer) -> &mut [u8] { + unsafe { + std::slice::from_raw_parts_mut(self.ptr.add(self.write_offset), self.remaining_mut()) + } } } -impl Drop for Buffer { - fn drop(&mut self) { - SESSION_BUFFER_BYTE.sub(self.buffer.capacity() as _); +impl Buf for Buffer { + fn remaining(&self) -> usize { + self.write_offset - self.read_offset + } + + fn chunk(&self) -> &[u8] { + self.borrow() + } + + fn advance(&mut self, amt: usize) { + self.read_offset = std::cmp::min(self.read_offset + amt, self.write_offset); + self.compact(); + } +} + +unsafe impl BufMut for Buffer { + fn remaining_mut(&self) -> usize { + self.cap - self.write_offset + } + + unsafe fn advance_mut(&mut self, amt: usize) { + self.write_offset = std::cmp::min(self.write_offset + amt, self.cap); + } + + fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice { + unsafe { + UninitSlice::from_raw_parts_mut(self.ptr.add(self.write_offset), self.remaining_mut()) + } + } + + fn put(&mut self, mut src: T) + where + Self: Sized, + { + while src.has_remaining() { + let chunk = src.chunk(); + let len = chunk.len(); + self.put_slice(chunk); + src.advance(len); + } + } + + fn put_slice(&mut self, src: &[u8]) { + self.reserve(src.len()); + assert!(self.remaining_mut() >= src.len()); + unsafe { + std::ptr::copy_nonoverlapping(src.as_ptr(), self.ptr.add(self.write_offset), src.len()); + } + unsafe { + self.advance_mut(src.len()); + } } } #[cfg(test)] mod tests { - use crate::Buffer; + use crate::*; use std::borrow::Borrow; #[test] // test buffer initialization with various capacities fn new() { - let buffer = Buffer::with_capacity(1024); - assert_eq!(buffer.len(), 0); - assert_eq!(buffer.available_capacity(), 1024); - assert!(buffer.is_empty()); - - let buffer = Buffer::with_capacity(2048); - assert_eq!(buffer.len(), 0); - assert_eq!(buffer.available_capacity(), 2048); - assert!(buffer.is_empty()); - - // test zero capacity buffer - let buffer = Buffer::with_capacity(0); - assert_eq!(buffer.len(), 0); - assert_eq!(buffer.available_capacity(), 0); - assert!(buffer.is_empty()); - - // test with non power of 2 - let buffer = Buffer::with_capacity(100); - assert_eq!(buffer.len(), 0); - assert_eq!(buffer.available_capacity(), 100); - assert!(buffer.is_empty()); + let buffer = Buffer::new(1024); + assert_eq!(buffer.remaining(), 0); + assert_eq!(buffer.remaining_mut(), 1024); + + let buffer = Buffer::new(2048); + assert_eq!(buffer.remaining(), 0); + assert_eq!(buffer.remaining_mut(), 2048); + + // test zero capacity buffer, rounds to 1 byte buffer + let buffer = Buffer::new(0); + assert_eq!(buffer.remaining(), 0); + assert_eq!(buffer.remaining_mut(), 1); + + // test with non power of 2, rounds to next power of two + let buffer = Buffer::new(100); + assert_eq!(buffer.remaining(), 0); + assert_eq!(buffer.remaining_mut(), 128); } #[test] // tests a small buffer growing only on second write fn write_1() { - let mut buffer = Buffer::with_capacity(8); - assert_eq!(buffer.len(), 0); - assert_eq!(buffer.available_capacity(), 8); - assert!(buffer.is_empty()); + let mut buffer = Buffer::new(8); + assert_eq!(buffer.remaining(), 0); + assert_eq!(buffer.remaining_mut(), 8); // first write fits in buffer - buffer.extend_from_slice(b"GET "); - assert_eq!(buffer.len(), 4); - assert_eq!(buffer.available_capacity(), 4); - assert!(!buffer.is_empty()); + buffer.put_slice(b"GET "); + assert_eq!(buffer.remaining(), 4); + assert_eq!(buffer.remaining_mut(), 4); + let content: &[u8] = buffer.borrow(); assert_eq!(content, b"GET "); // second write causes buffer to grow - buffer.extend_from_slice(b"SOME_KEY\r\n"); - assert_eq!(buffer.len(), 14); - assert_eq!(buffer.available_capacity(), 2); - assert!(!buffer.is_empty()); + buffer.put_slice(b"SOME_KEY\r\n"); + assert_eq!(buffer.remaining(), 14); + assert_eq!(buffer.remaining_mut(), 2); + let content: &[u8] = buffer.borrow(); assert_eq!(content, b"GET SOME_KEY\r\n"); } @@ -196,24 +296,23 @@ mod tests { #[test] // test a zero capacity buffer growing on two consecutive writes fn write_2() { - let mut buffer = Buffer::with_capacity(0); - assert_eq!(buffer.len(), 0); - assert_eq!(buffer.available_capacity(), 0); - assert!(buffer.is_empty()); + let mut buffer = Buffer::new(0); + assert_eq!(buffer.remaining(), 0); + assert_eq!(buffer.remaining_mut(), 1); // zero capacity buffer grows on first write - buffer.extend_from_slice(b"GET KEY\r\n"); - assert_eq!(buffer.len(), 9); - assert_eq!(buffer.available_capacity(), 7); - assert!(!buffer.is_empty()); + buffer.put_slice(b"GET KEY\r\n"); + assert_eq!(buffer.remaining(), 9); + assert_eq!(buffer.remaining_mut(), 7); + let content: &[u8] = buffer.borrow(); assert_eq!(content, b"GET KEY\r\n"); // and again on second write - buffer.extend_from_slice(b"SET OTHER_KEY 0 0 1\r\nA\r\n"); - assert_eq!(buffer.len(), 33); - assert_eq!(buffer.available_capacity(), 31); - assert!(!buffer.is_empty()); + buffer.put_slice(b"SET OTHER_KEY 0 0 1\r\nA\r\n"); + assert_eq!(buffer.remaining(), 33); + assert_eq!(buffer.remaining_mut(), 31); + let content: &[u8] = buffer.borrow(); assert_eq!(content, b"GET KEY\r\nSET OTHER_KEY 0 0 1\r\nA\r\n"); } @@ -221,164 +320,136 @@ mod tests { #[test] // tests a large buffer that grows on first write fn write_3() { - let mut buffer = Buffer::with_capacity(16); - assert_eq!(buffer.len(), 0); - assert_eq!(buffer.available_capacity(), 16); - assert!(buffer.is_empty()); - - buffer.extend_from_slice(b"SET SOME_REALLY_LONG_KEY 0 0 1\r\nA\r\n"); - assert_eq!(buffer.len(), 35); - assert_eq!(buffer.available_capacity(), 29); + let mut buffer = Buffer::new(16); + assert_eq!(buffer.remaining(), 0); + assert_eq!(buffer.remaining_mut(), 16); + + buffer.put_slice(b"SET SOME_REALLY_LONG_KEY 0 0 1\r\nA\r\n"); + assert_eq!(buffer.remaining(), 35); + assert_eq!(buffer.remaining_mut(), 29); } #[test] // tests a consume operation where all bytes are consumed and the buffer // remains its original size fn consume_1() { - let mut buffer = Buffer::with_capacity(16); - assert_eq!(buffer.len(), 0); - assert_eq!(buffer.available_capacity(), 16); - assert!(buffer.is_empty()); - - buffer.extend_from_slice(b"END\r\n"); - assert_eq!(buffer.len(), 5); - assert_eq!(buffer.available_capacity(), 11); - assert!(!buffer.is_empty()); - - buffer.consume(5); - assert_eq!(buffer.len(), 0); - assert_eq!(buffer.available_capacity(), 16); - assert!(buffer.is_empty()); + let mut buffer = Buffer::new(16); + assert_eq!(buffer.remaining(), 0); + assert_eq!(buffer.remaining_mut(), 16); + + buffer.put_slice(b"END\r\n"); + assert_eq!(buffer.remaining(), 5); + assert_eq!(buffer.remaining_mut(), 11); + + buffer.advance(5); + assert_eq!(buffer.remaining(), 0); + assert_eq!(buffer.remaining_mut(), 16); } #[test] // tests a consume operation where all bytes are consumed and the buffer // shrinks to its original size fn consume_2() { - let mut buffer = Buffer::with_capacity(2); - assert_eq!(buffer.len(), 0); - assert_eq!(buffer.available_capacity(), 2); - assert!(buffer.is_empty()); + let mut buffer = Buffer::new(2); + assert_eq!(buffer.remaining(), 0); + assert_eq!(buffer.remaining_mut(), 2); // buffer extends to the next power of two // with 5 byte message we need 8 bytes for the buffer - buffer.extend_from_slice(b"END\r\n"); - assert_eq!(buffer.len(), 5); - assert_eq!(buffer.available_capacity(), 3); - assert!(!buffer.is_empty()); - - buffer.consume(5); - assert_eq!(buffer.len(), 0); - assert_eq!(buffer.available_capacity(), 2); - assert!(buffer.is_empty()); + buffer.put_slice(b"END\r\n"); + assert_eq!(buffer.remaining(), 5); + assert_eq!(buffer.remaining_mut(), 3); + + buffer.advance(5); + assert_eq!(buffer.remaining(), 0); + assert_eq!(buffer.remaining_mut(), 2); } #[test] // tests a consume operation where not all bytes are consumed and buffer // remains its original size fn consume_3() { - let mut buffer = Buffer::with_capacity(8); - assert_eq!(buffer.len(), 0); - assert_eq!(buffer.available_capacity(), 8); - assert!(buffer.is_empty()); + let mut buffer = Buffer::new(8); + assert_eq!(buffer.remaining(), 0); + assert_eq!(buffer.remaining_mut(), 8); let content = b"END\r\n"; - let len = content.len(); - buffer.extend_from_slice(content); - assert_eq!(buffer.len(), len); - assert_eq!(buffer.available_capacity(), 3); - assert!(!buffer.is_empty()); + buffer.put_slice(content); + assert_eq!(buffer.remaining(), 5); + assert_eq!(buffer.remaining_mut(), 3); // consume all but the last byte of content in the buffer, one byte at // a time // - buffer len decreases with each call to consume() - // - buffer available capacity increases with each call to consume() - for i in 1..len { - buffer.consume(1); - assert_eq!(buffer.len(), len - i); - assert_eq!(buffer.available_capacity(), 3 + i); - assert!(!buffer.is_empty()); + // - buffer available capacity stays the same + for i in 1..5 { + buffer.advance(1); + assert_eq!(buffer.remaining(), 5 - i); + assert_eq!(buffer.remaining_mut(), 3); } // when consuming the final byte, the read/write offsets move to the // start of the buffer, and available capacity should be the original // buffer size - buffer.consume(1); - assert_eq!(buffer.len(), 0); - assert_eq!(buffer.available_capacity(), 8); - assert!(buffer.is_empty()); + buffer.advance(1); + assert_eq!(buffer.remaining(), 0); + assert_eq!(buffer.remaining_mut(), 8); } #[test] // tests a consume operation where not all bytes are consumed and buffer // shrinks as bytes are consumed fn consume_4() { - let mut buffer = Buffer::with_capacity(16); - assert_eq!(buffer.len(), 0); - assert_eq!(buffer.available_capacity(), 16); - assert!(buffer.is_empty()); + let mut buffer = Buffer::new(16); + assert_eq!(buffer.remaining(), 0); + assert_eq!(buffer.remaining_mut(), 16); let content = b"VALUE SOME_REALLY_LONG_KEY 0 1\r\n1\r\nEND\r\n"; // buffer resizes up to 64 bytes to hold 40 bytes - // length = 40, size = 64, capacity = 24 - buffer.extend_from_slice(content); - assert_eq!(buffer.len(), 40); - assert_eq!(buffer.available_capacity(), 24); - assert!(!buffer.is_empty()); - - // partial consume, len decrease, buffer shrinks by half - // length = 32, size = 32, capacity = 0 - buffer.consume(8); - assert_eq!(buffer.len(), 32); - assert_eq!(buffer.available_capacity(), 0); - assert!(!buffer.is_empty()); - - // consume one more byte and we should get available capacity - // length = 31, size = 32, capacity = 1 - buffer.consume(1); - assert_eq!(buffer.len(), 31); - assert_eq!(buffer.available_capacity(), 1); - assert!(!buffer.is_empty()); - - // partial consume, len decrease, buffer shrinks down to target capacity - // length = 16, size = 16, capacity = 0 - buffer.consume(15); - assert_eq!(buffer.len(), 16); - assert_eq!(buffer.available_capacity(), 0); - - // from here on, buffer will not shrink below target capacity - - // consume one more byte - // length = 15, size = 16, capacity = 0 - buffer.consume(1); - assert_eq!(buffer.len(), 15); - assert_eq!(buffer.available_capacity(), 1); - - // partial consume, len decrease - // length = 8, size = 16, capacity = 8 - buffer.consume(7); - assert_eq!(buffer.len(), 8); - assert_eq!(buffer.available_capacity(), 8); - - // partial consume, len decrease - // length = 7, size = 16, capacity = 9 - buffer.consume(1); - assert_eq!(buffer.len(), 7); - assert_eq!(buffer.available_capacity(), 9); + buffer.put_slice(content); + assert_eq!(buffer.remaining(), 40); + assert_eq!(buffer.remaining_mut(), 24); + + // partial consume, len decrease, no compact + buffer.advance(8); + assert_eq!(buffer.remaining(), 32); + assert_eq!(buffer.remaining_mut(), 24); + + // consume one more byte, still no compact + buffer.advance(1); + assert_eq!(buffer.remaining(), 31); + assert_eq!(buffer.remaining_mut(), 24); + + // partial consume, remaining drops to best fitting power of two + buffer.advance(15); + assert_eq!(buffer.remaining(), 16); + assert_eq!(buffer.remaining_mut(), 0); + + // from here on, buffer will not shrink below target capacity and will + // not compact + + // partial consume, since the buffer is the target size already, there + // will be no compaction + buffer.advance(1); + assert_eq!(buffer.remaining(), 15); + assert_eq!(buffer.remaining_mut(), 0); // consume all but the final byte // partial consume, len decrease // length = 1, size = 16, capacity = 15 - buffer.consume(6); - assert_eq!(buffer.len(), 1); - assert_eq!(buffer.available_capacity(), 15); + buffer.advance(14); + assert_eq!(buffer.remaining(), 1); + assert_eq!(buffer.remaining_mut(), 0); + + // on the final advance, all bytes are consumed and the entire buffer + // is now clear // consume the final byte // length = 0, size = 16, capacity = 16 - buffer.consume(1); - assert_eq!(buffer.len(), 0); - assert_eq!(buffer.available_capacity(), 16); + buffer.advance(1); + assert_eq!(buffer.remaining(), 0); + assert_eq!(buffer.remaining_mut(), 16); } } diff --git a/src/session/src/client.rs b/src/session/src/client.rs new file mode 100644 index 000000000..0188c67d8 --- /dev/null +++ b/src/session/src/client.rs @@ -0,0 +1,194 @@ +// Copyright 2022 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + +use super::*; + +/// A basic session to represent the client side of a framed session, meaning +/// that is is used by a client to talk to a server. +/// +/// `ClientSession` latency tracking counts all of the time from a message being +/// sent until a corresponding message is returned by a call to receive. This +/// means any delays within the client between actual bytes being received in +/// the session buffer until actually handling the returned message is counted +/// towards the latency. For example, if the client sleeps between filling the +/// session buffer and receiving a message from the session, that time is +/// counted towards the latency. +pub struct ClientSession { + // the actual session + session: Session, + // a parser which produces messages from the session buffer + parser: Parser, + // a queue of time and message pairs that are awaiting responses + pending: VecDeque<(Instant, Tx)>, + // a marker for the received message type + _rx: PhantomData, +} + +impl Debug for ClientSession { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { + write!(f, "{:?}", self.session) + } +} + +impl AsRawFd for ClientSession { + fn as_raw_fd(&self) -> i32 { + self.session.as_raw_fd() + } +} + +impl ClientSession +where + Tx: Compose, + Parser: Parse, +{ + /// Create a new `ClientSession` from a `Session` and a `Parser`. + pub fn new(session: Session, parser: Parser) -> Self { + Self { + session, + parser, + pending: VecDeque::with_capacity(NUM_PENDING), + _rx: PhantomData, + } + } + + /// Sends the message to the underlying session but does *not* flush the + /// session buffer. This function also adds a timestamp to a queue so that + /// response latencies can be determined. The latency will include any time + /// that it takes to compose the message onto the session buffer, time to + /// flush the session buffer, and any additional calls to flush which may be + /// required. + pub fn send(&mut self, tx: Tx) -> Result { + SESSION_SEND.increment(); + let now = Instant::now(); + let size = tx.compose(&mut self.session); + self.pending.push_back((now, tx)); + Ok(size) + } + + /// Attempts to return a pair of messages, the one sent to the server as + /// well as the one received from the server, from the underlying session + /// buffer. This operates only on buffered data and does not result in a + /// read() of the underlying session. + pub fn receive(&mut self) -> Result<(Tx, Rx)> { + let src: &[u8] = self.session.borrow(); + match self.parser.parse(src) { + Ok(res) => { + SESSION_RECV.increment(); + let now = Instant::now(); + let (timestamp, request) = self + .pending + .pop_front() + .ok_or_else(|| Error::from(ErrorKind::InvalidInput))?; + let latency = now - timestamp; + REQUEST_LATENCY.increment(now, latency.as_nanos(), 1); + let consumed = res.consumed(); + let msg = res.into_inner(); + self.session.consume(consumed); + Ok((request, msg)) + } + Err(e) => { + if e.kind() != ErrorKind::WouldBlock { + SESSION_RECV_EX.increment(); + } + Err(e) + } + } + } + + /// Attempts to flush the session write buffer. + pub fn flush(&mut self) -> Result<()> { + self.session.flush()?; + Ok(()) + } + + /// Returns the number of bytes currently in the write buffer. + pub fn write_pending(&self) -> usize { + self.session.write_pending() + } + + /// Performs a read of the underlying session to fill the read buffer. + pub fn fill(&mut self) -> Result { + self.session.fill() + } + + /// Returns the current event interest for this session. + pub fn interest(&mut self) -> Interest { + self.session.interest() + } + + /// Attempt to handshake the underlying session. + pub fn do_handshake(&mut self) -> Result<()> { + self.session.do_handshake() + } + + /// Get direct access to the read buffer. + pub fn read_buffer_mut(&mut self) -> &mut Buffer { + self.session.read_buffer_mut() + } + + /// Get direct access to the write buffer. + pub fn write_buffer_mut(&mut self) -> &mut Buffer { + self.session.write_buffer_mut() + } +} + +impl Borrow<[u8]> for ClientSession { + fn borrow(&self) -> &[u8] { + self.session.borrow() + } +} + +impl Buf for ClientSession { + fn remaining(&self) -> usize { + self.session.remaining() + } + + fn chunk(&self) -> &[u8] { + self.session.chunk() + } + + fn advance(&mut self, amt: usize) { + self.session.advance(amt) + } +} + +unsafe impl BufMut for ClientSession { + fn remaining_mut(&self) -> usize { + self.session.remaining_mut() + } + + unsafe fn advance_mut(&mut self, amt: usize) { + self.session.advance_mut(amt) + } + + fn chunk_mut(&mut self) -> &mut UninitSlice { + self.session.chunk_mut() + } + + #[allow(unused_mut)] + fn put(&mut self, mut src: T) + where + Self: Sized, + { + self.session.put(src) + } + + fn put_slice(&mut self, src: &[u8]) { + self.session.put_slice(src) + } +} + +impl event::Source for ClientSession { + fn register(&mut self, registry: &Registry, token: Token, interest: Interest) -> Result<()> { + self.session.register(registry, token, interest) + } + + fn reregister(&mut self, registry: &Registry, token: Token, interest: Interest) -> Result<()> { + self.session.reregister(registry, token, interest) + } + + fn deregister(&mut self, registry: &Registry) -> Result<()> { + self.session.deregister(registry) + } +} diff --git a/src/session/src/lib.rs b/src/session/src/lib.rs index 389bcad72..13a124128 100644 --- a/src/session/src/lib.rs +++ b/src/session/src/lib.rs @@ -1,51 +1,49 @@ -// Copyright 2021 Twitter, Inc. +// Copyright 2022 Twitter, Inc. // Licensed under the Apache License, Version 2.0 // http://www.apache.org/licenses/LICENSE-2.0 -//! This crate provides buffered TCP sessions with or without TLS which can be -//! used with [`::mio`]. TLS/SSL is provided by BoringSSL with the [`::boring`] -//! crate. +//! Abstractions for bi-directional buffered communications on top of streams. +//! This allows for efficient reading and writing for stream-oriented +//! communication and provides abstractions for request/response oriented +//! client/server communications. + +// pub use buffer::*; #[macro_use] -extern crate logger; +extern crate log; mod buffer; -mod stream; -mod tcp_stream; -use common::ssl::{MidHandshakeSslStream, SslStream}; -use mio::event::Source; -use mio::{Interest, Poll, Token}; -use rustcommon_metrics::{counter, gauge, heatmap, metric, Counter, Gauge, Heatmap, Relaxed}; -use std::borrow::{Borrow, BorrowMut}; -use std::cmp::Ordering; -use std::io::{BufRead, ErrorKind, Read, Write}; -use std::net::SocketAddr; - -pub use buffer::Buffer; -use stream::Stream; - -type Instant = common::time::Instant>; - -pub use tcp_stream::TcpStream; +mod client; +mod server; + +pub use buffer::*; +pub use client::ClientSession; +pub use server::ServerSession; + +use std::os::unix::prelude::AsRawFd; + +use ::net::*; +use core::borrow::{Borrow, BorrowMut}; +use core::fmt::Debug; +use core::marker::PhantomData; +use protocol_common::Compose; +use protocol_common::Parse; +use rustcommon_metrics::*; +use rustcommon_time::Nanoseconds; +use std::collections::VecDeque; +use std::io::Error; +use std::io::ErrorKind; +use std::io::Read; +use std::io::Result; +use std::io::Write; + +const ONE_SECOND: u64 = 1_000_000_000; // in nanoseconds gauge!( SESSION_BUFFER_BYTE, "current size of the session buffers in bytes" ); -counter!( - TCP_ACCEPT, - "number of times accept has been called on listening sockets" -); -counter!(TCP_CLOSE, "number of times TCP streams have been closed"); -gauge!(TCP_CONN_CURR, "current number of open TCP streams"); -counter!(TCP_RECV_BYTE, "number of bytes received on TCP streams"); -counter!(TCP_SEND_BYTE, "number of bytes sent on TCP streams"); -counter!( - TCP_SEND_PARTIAL, - "number of partial writes to the system socket buffer" -); - counter!(SESSION_RECV, "number of reads from sessions"); counter!( SESSION_RECV_EX, @@ -61,440 +59,269 @@ counter!(SESSION_SEND_BYTE, "number of bytes written to sessions"); heatmap!( REQUEST_LATENCY, - 1_000_000_000, + ONE_SECOND, "distribution of request latencies in nanoseconds" ); -heatmap!( - PIPELINE_DEPTH, - 100_000, - "distribution of request pipeline depth" -); -// TODO(bmartin): implement connect/reconnect so we can use this in clients too. -/// The core `Session` type which represents a TCP stream (with or without TLS), -/// the session buffer, the mio [`::mio::Token`], +type Instant = rustcommon_time::Instant>; + +// The size of one kilobyte, in bytes +const KB: usize = 1024; + +// If the read buffer has less than this amount available before a read, we will +// grow the read buffer. The selected value is set to the size of a single page. +const BUFFER_MIN_FREE: usize = 4 * KB; + +// The target size of the read operations, the selected value is the upper-bound +// on TLS fragment size as per RFC 5246: +// https://datatracker.ietf.org/doc/html/rfc5246#section-6.2.1 +const TARGET_READ_SIZE: usize = 16 * KB; + +// The initial size of any queues which track pending requests and responses. +// This is *not* a hard bound, but is used to size the initial allocations. +const NUM_PENDING: usize = 256; + +/// A `Session` is an underlying `Stream` with its read and write buffers. This +/// abstraction allows the caller to efficiently read from the underlying stream +/// by buffering the incoming bytes. It also allows for efficient writing by +/// first buffering writes to the underlying stream. pub struct Session { - token: Token, stream: Stream, read_buffer: Buffer, write_buffer: Buffer, - min_capacity: usize, - max_capacity: usize, - // hold current interest set - interest: Interest, - // TODO(bmartin): consider moving these fields and associated logic - // out into a response tracking struct. It would make the session - // type more applicable to clients if we move this out. - // - /// A timestamp which is used to calculate response latency - timestamp: Instant, - /// This is a queue of pending response sizes. When a response is finalized, - /// the bytes in that response are pushed onto the back of the queue. As the - /// session flushes out to the underlying socket, we can calculate when a - /// response is completely flushed to the underlying socket and record a - /// response latency. - pending_responses: [usize; 256], - /// This is the index of the first pending response. - pending_head: usize, - /// This is the count of pending responses. - pending_count: usize, - /// This holds the total number of bytes pending for finalized responses. By - /// tracking this, we can determine the size of a response even if it is - /// written into the session with multiple calls to write. It is essentially - /// a cached value of `write_buffer.pending_bytes()` that does not reflect - /// bytes from responses which are not yet finalized. - pending_bytes: usize, - /// This tracks the pipeline depth by tracking the number of responses - /// between resets of the session timestamp. - processed: usize, } -impl std::fmt::Debug for Session { +impl AsRawFd for Session { + fn as_raw_fd(&self) -> i32 { + self.stream.as_raw_fd() + } +} + +impl Debug for Session { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { - if let Ok(peer_addr) = self.peer_addr() { - write!(f, "{}", peer_addr) - } else { - write!(f, "no peer address") - } + write!(f, "{:?}", self.stream) } } impl Session { - /// Create a new `Session` with representing a plain `TcpStream` with - /// internal buffers which can hold up to capacity bytes without - /// reallocating. - pub fn plain_with_capacity( - stream: TcpStream, - min_capacity: usize, - max_capacity: usize, - ) -> Self { - Self::new(Stream::plain(stream), min_capacity, max_capacity) + /// Construct a new `Session` from a `Stream` and read and write + /// `SessionBuffer`s. + pub fn new(stream: Stream, read_buffer: Buffer, write_buffer: Buffer) -> Self { + Self { + stream, + read_buffer, + write_buffer, + } } - /// Create a new `Session` representing a negotiated `SslStream` - pub fn tls_with_capacity( - stream: SslStream, - min_capacity: usize, - max_capacity: usize, - ) -> Self { - Self::new(Stream::tls(stream), min_capacity, max_capacity) + /// Return the event `Interest`s for the `Session`. + pub fn interest(&mut self) -> Interest { + if self.write_buffer.has_remaining() { + self.stream.interest().add(Interest::WRITABLE) + } else { + self.stream.interest() + } } - /// Create a new `Session` representing a `MidHandshakeSslStream` - pub fn handshaking_with_capacity( - stream: MidHandshakeSslStream, - min_capacity: usize, - max_capacity: usize, - ) -> Self { - Self::new(Stream::handshaking(stream), min_capacity, max_capacity) + /// Indicates if the `Session` can be considered established, meaning that + /// any underlying stream negotation and handshaking is completed. + pub fn is_established(&mut self) -> bool { + self.stream.is_established() } - /// Create a new `Session` - fn new(stream: Stream, min_capacity: usize, max_capacity: usize) -> Self { - TCP_ACCEPT.increment(); - TCP_CONN_CURR.add(1); - Self { - token: Token(0), - stream, - read_buffer: Buffer::with_capacity(min_capacity), - write_buffer: Buffer::with_capacity(min_capacity), - min_capacity, - max_capacity, - interest: Interest::READABLE, - timestamp: Instant::now(), - pending_responses: [0; 256], - pending_head: 0, - pending_count: 0, - pending_bytes: 0, - processed: 0, - } + pub fn is_handshaking(&self) -> bool { + self.stream.is_handshaking() } - /// Register the `Session` with the event loop - pub fn register(&mut self, poll: &Poll) -> Result<(), std::io::Error> { - let interest = self.readiness(); - self.stream.register(poll.registry(), self.token, interest) - } + /// Fill the read buffer by calling read on the underlying stream until read + /// would block. Returns the number of bytes read. `Ok(0)` indicates that + /// the remote side has closed the stream. + pub fn fill(&mut self) -> Result { + let mut read = 0; - /// Deregister the `Session` from the event loop - pub fn deregister(&mut self, poll: &Poll) -> Result<(), std::io::Error> { - self.stream.deregister(poll.registry()) - } + loop { + // if the buffer has too little space available, expand it + if self.read_buffer.remaining_mut() < BUFFER_MIN_FREE { + self.read_buffer.reserve(TARGET_READ_SIZE); + } - /// Reregister the `Session` with the event loop - pub fn reregister(&mut self, poll: &Poll) -> Result<(), std::io::Error> { - let interest = self.readiness(); - if interest == self.interest { - return Ok(()); + // read directly into the read buffer + match self.stream.read(self.read_buffer.borrow_mut()) { + Ok(0) => { + // This means the underlying stream is closed, we need to + // notify the caller by returning this result. + return Ok(0); + } + Ok(n) => { + // Successfully read 'n' bytes from the stream into the + // buffer. Advance the write position. + unsafe { + self.read_buffer.advance_mut(n); + } + read += n; + } + Err(e) => match e.kind() { + ErrorKind::WouldBlock => { + if read == 0 { + return Err(e); + } else { + return Ok(read); + } + } + ErrorKind::Interrupted => {} + _ => { + return Err(e); + } + }, + } } - debug!("reregister: {:?}", interest); - self.interest = interest; - self.stream - .reregister(poll.registry(), self.token, interest) } - /// Get the token which is used with the event loop - pub fn token(&self) -> Token { - self.token + /// Mark `amt` bytes as consumed from the read buffer. + pub fn consume(&mut self, amt: usize) { + self.read_buffer.advance(amt) } - /// Set the token which is used with the event loop - pub fn set_token(&mut self, token: Token) { - self.token = token; + /// Return the number of bytes currently in the write buffer. + pub fn write_pending(&self) -> usize { + self.write_buffer.remaining() } - /// Get the set of readiness events the session is waiting for - /// - /// NOTE: we effectively block additional reads when there are writes - /// pending. This may not be an appropriate choice for all use-cases, but - /// for a server, it can be used to apply back-pressure. - // - // TODO(bmartin): we could make this behavior conditional if we have a - // use-case that requires different handling, but it comes with complexity - // of having to set the behavior for each session. - fn readiness(&self) -> Interest { - if self.write_buffer.is_empty() { - Interest::READABLE - } else { - Interest::WRITABLE + /// Attempts to flush the `Session` to the underlying `Stream`. This may + /// result in multiple calls + pub fn flush(&mut self) -> Result { + let mut flushed = 0; + while self.write_buffer.has_remaining() { + match self.stream.write(self.write_buffer.borrow()) { + Ok(amt) => { + // successfully wrote `amt` bytes to the stream, advance the + // write buffer and increment the flushed stat + self.write_buffer.advance(amt); + flushed += amt; + } + Err(e) => match e.kind() { + ErrorKind::WouldBlock => { + // returns `WouldBlock` if this is the first time + if flushed == 0 { + return Err(e); + } + // otherwise, break from the loop and return the amount + // written until now + break; + } + ErrorKind::Interrupted => { + // this should be retried immediately + } + _ => { + // all other errors get bubbled up + return Err(e); + } + }, + } } - } - /// Returns a boolean which indicates if the session is handshaking - pub fn is_handshaking(&self) -> bool { - self.stream.is_handshaking() + SESSION_SEND_BYTE.add(flushed as _); + + Ok(flushed) } - /// Drives the handshake for the session. A successful result indicates that - /// the session hadshake is completed successfully. The error result should - /// be checked to determine if the operation would block, resulted in some - /// unrecoverable error, or if the session was not in a handshaking state - /// when this was called. - pub fn do_handshake(&mut self) -> Result<(), std::io::Error> { + pub fn do_handshake(&mut self) -> Result<()> { self.stream.do_handshake() } - /// Closes the session and the underlying stream. - pub fn close(&mut self) { - self.stream.close(); + pub fn read_buffer_mut(&mut self) -> &mut Buffer { + &mut self.read_buffer } - /// Returns the number of bytes in the read buffer - pub fn read_pending(&self) -> usize { - self.read_buffer.len() + pub fn write_buffer_mut(&mut self) -> &mut Buffer { + &mut self.write_buffer } +} - /// Returns the number of bytes in the write buffer - pub fn write_pending(&self) -> usize { - self.write_buffer.len() +// NOTE: this is opioniated in that we set the buffer sizes, but should be an +// acceptable default for most session construction +impl From for Session { + fn from(other: Stream) -> Self { + Self::new( + other, + Buffer::new(TARGET_READ_SIZE), + Buffer::new(TARGET_READ_SIZE), + ) } +} - /// Returns the number of bytes free in the write buffer relative to the - /// minimum buffer size. This allows us to use it as a signal that we should - /// apply some backpressure on handling requests for the session. - pub fn write_capacity(&self) -> usize { - self.min_capacity.saturating_sub(self.write_pending()) +impl From for Session { + fn from(other: TcpStream) -> Self { + Self::new( + Stream::from(other), + Buffer::new(TARGET_READ_SIZE), + Buffer::new(TARGET_READ_SIZE), + ) } +} - /// Returns a reference to the internally buffered data. - /// - /// Unlike [`fill_buf`], this will not attempt to fill the buffer if it is - /// empty. - /// - /// [`fill_buf`]: BufRead::fill_buf - pub fn buffer(&self) -> &[u8] { +impl Borrow<[u8]> for Session { + fn borrow(&self) -> &[u8] { self.read_buffer.borrow() } +} - pub fn peer_addr(&self) -> Result { - self.stream.peer_addr() +impl Borrow<[u8]> for &mut Session { + fn borrow(&self) -> &[u8] { + self.read_buffer.borrow() } +} - pub fn timestamp(&self) -> Instant { - self.timestamp +impl Buf for Session { + fn remaining(&self) -> usize { + self.read_buffer.remaining() } - pub fn set_timestamp(&mut self, timestamp: Instant) { - if self.processed > 0 { - PIPELINE_DEPTH.increment(self.timestamp, self.processed as _, 1); - self.processed = 0; - } - self.timestamp = timestamp; + fn chunk(&self) -> &[u8] { + self.read_buffer.chunk() } - pub fn finalize_response(&mut self) { - self.processed += 1; - let previous = self.pending_bytes; - let current = self.write_pending(); - - match current.cmp(&previous) { - Ordering::Greater => { - // We've finalized a response that has some pending bytes to - // track. If there's room in the tracking struct, we add it so - // we can determine latency later. - if self.pending_count < self.pending_responses.len() { - let mut idx = self.pending_head + self.pending_count; - if idx >= self.pending_responses.len() { - idx %= self.pending_responses.len(); - } - self.pending_responses[idx] = current - previous; - self.pending_count += 1; - } - } - Ordering::Equal => { - // We've finalized a response that is zero-length. This is - // expected for empty responses such as when handling memcache - // requests which specify `NOREPLY`. Since there are no pending - // bytes for a zero-length response, we can determine the - // latency now. - let now = Instant::now(); - let latency = (now - self.timestamp()).as_nanos() as u64; - REQUEST_LATENCY.increment(now, latency, 1); - } - Ordering::Less => { - // This indicates that our tracking is off. This could be due to - // a protocol failing to finalize some type of response. - // - // NOTE: this does not indicate corruption of the buffer and - // only indicates some issue with the pending response tracking - // used to calculate latencies. This path is an attempt to - // recover by skipping the tracking for this request. - error!( - "Failed to calculate length of finalized response. \ - Previous pending bytes: {} Current write buffer length: {}", - previous, current - ); - - // If it's a debug build, we will also assert that this is - // unexpected. - debug_assert!(false); - } - } - - self.pending_bytes = current; + fn advance(&mut self, amt: usize) { + self.read_buffer.advance(amt) } } -impl Read for Session { - fn read(&mut self, buf: &mut [u8]) -> Result { - if self.read_buffer.is_empty() { - self.fill_buf()?; - } - let bytes = std::cmp::min(buf.len(), self.read_buffer.len()); - let buffer: &[u8] = self.read_buffer.borrow(); - buf[0..bytes].copy_from_slice(&buffer[0..bytes]); - self.consume(bytes); - Ok(bytes) +unsafe impl BufMut for Session { + fn remaining_mut(&self) -> usize { + self.write_buffer.remaining_mut() } -} -impl BufRead for Session { - fn fill_buf(&mut self) -> Result<&[u8], std::io::Error> { - SESSION_RECV.increment(); - let mut total_bytes = 0; - loop { - if self.read_buffer.len() == self.max_capacity { - return Err(std::io::Error::new(ErrorKind::Other, "buffer full")); - } + unsafe fn advance_mut(&mut self, amt: usize) { + self.write_buffer.advance_mut(amt) + } - // reserve additional space in the buffer if needed - if self.read_buffer.available_capacity() == 0 { - self.read_buffer.reserve(self.min_capacity); - } + fn chunk_mut(&mut self) -> &mut UninitSlice { + self.write_buffer.chunk_mut() + } - match self.stream.read(self.read_buffer.borrow_mut()) { - Ok(0) => { - // Stream is disconnected, stop reading - break; - } - Ok(bytes) => { - self.read_buffer.increase_len(bytes); - total_bytes += bytes; - } - Err(e) => { - if e.kind() == ErrorKind::WouldBlock { - // check if we blocked on first read or subsequent read. - // if blocked on a subsequent read, we stop reading and - // allow the function to return the number of bytes read - // until now. - if total_bytes == 0 { - return Err(e); - } else { - break; - } - } else { - SESSION_RECV_EX.increment(); - return Err(e); - } - } - } - } - SESSION_RECV_BYTE.add(total_bytes as _); - Ok(self.read_buffer.borrow()) + #[allow(unused_mut)] + fn put(&mut self, mut src: T) + where + Self: Sized, + { + self.write_buffer.put(src) } - fn consume(&mut self, amt: usize) { - self.read_buffer.consume(amt); + fn put_slice(&mut self, src: &[u8]) { + self.write_buffer.put_slice(src) } } -impl Write for Session { - fn write(&mut self, src: &[u8]) -> Result { - self.write_buffer.reserve(src.len()); - self.write_buffer.extend_from_slice(src); - Ok(src.len()) +impl event::Source for Session { + fn register(&mut self, registry: &Registry, token: Token, interest: Interest) -> Result<()> { + self.stream.register(registry, token, interest) } - // need a different flush - fn flush(&mut self) -> Result<(), std::io::Error> { - SESSION_SEND.increment(); - match self.stream.write((self.write_buffer).borrow()) { - Ok(0) => Ok(()), - Ok(mut bytes) => { - let flushed_bytes = bytes; - SESSION_SEND_BYTE.add(bytes as _); - self.write_buffer.consume(bytes); - - // NOTE: we expect that the stream flush is essentially a no-op - // based on the implementation for `TcpStream` - - let now = Instant::now(); - let latency = (now - self.timestamp()).as_nanos() as u64; - let mut completed = 0; - - // iterate through the pending response lengths and perform the - // bookkeeping to calculate how many have been flushed to the - // `TcpStream` in this call of `flush()` - while bytes > 0 && self.pending_count > 0 { - // first response out of the buffer - let head = &mut self.pending_responses[self.pending_head]; - - if bytes >= *head { - // we flushed all (or more) than the first response - bytes -= *head; - *head = 0; - completed += 1; - self.pending_count -= 1; - - // move the head pointer forward - if self.pending_head + 1 < self.pending_responses.len() { - self.pending_head += 1; - } else { - self.pending_head = 0; - } - } else { - // we only flushed part of the first response - *head -= bytes; - bytes = 0; - } - } - - match flushed_bytes.cmp(&self.pending_bytes) { - Ordering::Less => { - // The buffer is not completely flushed to the - // underlying stream, we will still have more pending - // bytes. - self.pending_bytes -= flushed_bytes; - } - Ordering::Equal => { - // The buffer is completely flushed. We have no more - // pending bytes. - self.pending_bytes = 0; - } - Ordering::Greater => { - // This indicates that the tracking is off. Potentially - // due to a protocol implementation that failed to - // finalize some response. - // - // NOTE: this does not indicate corruption of the buffer - // and only indicates some issue with the pending - // response tracking used to calculate latencies. This - // path is an attempt to recover and resume tracking by - // setting the pending bytes to the current write buffer - // length. - error!( - "Session flushed {} bytes, but only had {} pending bytes to track", - flushed_bytes, self.pending_bytes - ); - self.pending_bytes = self.write_pending(); - - // If it's a debug build, we will also assert that this - // is unexpected. - debug_assert!(false); - } - } - - // Increment the histogram with the calculated latency. - REQUEST_LATENCY.increment(now, latency, completed); + fn reregister(&mut self, registry: &Registry, token: Token, interest: Interest) -> Result<()> { + self.stream.reregister(registry, token, interest) + } - Ok(()) - } - Err(e) => { - SESSION_SEND_EX.increment(); - Err(e) - } - } + fn deregister(&mut self, registry: &Registry) -> Result<()> { + self.stream.deregister(registry) } } - -common::metrics::test_no_duplicates!(); diff --git a/src/session/src/server.rs b/src/session/src/server.rs new file mode 100644 index 000000000..7577e1542 --- /dev/null +++ b/src/session/src/server.rs @@ -0,0 +1,260 @@ +// Copyright 2022 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + +use super::*; + +/// A basic session to represent the server side of a framed session, meaning +/// that is is used by a server to talk to a client. +/// +/// `ServerSession` latency tracking counts the time from data being read into +/// the session buffer until a corresponding message is flushed entirely to the +/// socket buffer. This means that if it takes multiple reads of the socket +/// buffer until a message from the client is parsable, the time of the last +/// read is the start of our latency tracking. Similarly, if it takes multiple +/// writes to flush a message back to the client, the time of the last write to +/// the socket buffer is the stop point for latency tracking. This means that +/// the server latency will be lower than that seen by the client, as delays in +/// the kernel, network, and client itself do not count towards server latency. +/// Instead, the latency represents the time it takes to parse the message, take +/// some possible action, and write a response back out to the socket buffer. +pub struct ServerSession { + // the actual session + session: Session, + // a parser which produces requests from the session buffer + parser: Parser, + // tracks the timestamps of any pending requests + pending: VecDeque, + // tracks outstanding responses and the number of bytes remaining for each + outstanding: VecDeque<(Option, usize)>, + // tracks the time the session buffer was last filled + timestamp: Instant, + // markers for the receive and transmit types + _rx: PhantomData, + _tx: PhantomData, +} + +impl AsRawFd for ServerSession { + fn as_raw_fd(&self) -> i32 { + self.session.as_raw_fd() + } +} + +impl Debug for ServerSession { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { + write!(f, "{:?}", self.session) + } +} + +impl ServerSession +where + Tx: Compose, + Parser: Parse, +{ + // Create a new `ServerSession` from a `Session` and a `Parser` + pub fn new(session: Session, parser: Parser) -> Self { + Self { + session, + parser, + pending: VecDeque::with_capacity(NUM_PENDING), + outstanding: VecDeque::with_capacity(NUM_PENDING), + timestamp: Instant::now(), + _rx: PhantomData, + _tx: PhantomData, + } + } + + /// Consume the `ServerSession` and return the inner `Session` + pub fn into_inner(self) -> Session { + self.session + } + + /// Attempt to receive a single message from the current session buffer. + pub fn receive(&mut self) -> Result { + let src: &[u8] = self.session.borrow(); + match self.parser.parse(src) { + Ok(res) => { + self.pending.push_back(self.timestamp); + let consumed = res.consumed(); + let msg = res.into_inner(); + self.session.consume(consumed); + Ok(msg) + } + Err(e) => Err(e), + } + } + + /// Send a message to the session buffer. + pub fn send(&mut self, tx: Tx) -> Result { + SESSION_SEND.increment(); + + let timestamp = self.pending.pop_front(); + + let size = tx.compose(&mut self.session); + + if size == 0 { + // we have a zero sized response, increment heatmap now + if let Some(timestamp) = timestamp { + let now = Instant::now(); + let latency = now - timestamp; + REQUEST_LATENCY.increment(now, latency.as_nanos(), 1); + } + } else { + // we have bytes in our response, we need to add it on the + // outstanding response queue + self.outstanding.push_back((timestamp, size)); + } + + Ok(size) + } + + /// Advances the read pointer for the session write buffer by `amt` bytes. + /// This is used to mark the data as sent to the underlying session. + pub fn advance_write(&mut self, amt: usize) { + if amt == 0 { + return; + } + + let now = Instant::now(); + + let mut amt = amt; + + while amt > 0 { + if let Some(mut front) = self.outstanding.pop_front() { + if front.1 > amt { + front.1 -= amt; + self.outstanding.push_front(front); + break; + } else { + amt -= front.1; + if let Some(ts) = front.0 { + let latency = now - ts; + REQUEST_LATENCY.increment(now, latency.as_nanos(), 1); + } + } + } else { + break; + } + } + } + + /// Attempts to flush all bytes currently in the write buffer to the + /// underlying stream. Also handles bookeeping necessary to determine the + /// server-side response latency. + pub fn flush(&mut self) -> Result<()> { + let current_pending = self.session.write_pending(); + self.session.flush()?; + let final_pending = self.session.write_pending(); + + let flushed = current_pending - final_pending; + + self.advance_write(flushed); + + Ok(()) + } + + /// Returns the number of bytes pending in the write buffer. + pub fn write_pending(&self) -> usize { + self.session.write_pending() + } + + /// Reads from the underlying stream into the read buffer and returns the + /// number of bytes read. + pub fn fill(&mut self) -> Result { + SESSION_RECV.increment(); + self.timestamp = Instant::now(); + + match self.session.fill() { + Ok(amt) => { + SESSION_RECV_BYTE.add(amt as _); + Ok(amt) + } + Err(e) => { + if e.kind() != ErrorKind::WouldBlock { + SESSION_RECV_EX.increment(); + } + Err(e) + } + } + } + + /// Returns the current event interest for this session. + pub fn interest(&mut self) -> Interest { + self.session.interest() + } + + /// Attempt to handshake the underlying session. + pub fn do_handshake(&mut self) -> Result<()> { + self.session.do_handshake() + } + + /// Get direct access to the read buffer. + pub fn read_buffer_mut(&mut self) -> &mut Buffer { + self.session.read_buffer_mut() + } + + /// Get direct access to the write buffer. + pub fn write_buffer_mut(&mut self) -> &mut Buffer { + self.session.write_buffer_mut() + } +} + +impl Borrow<[u8]> for ServerSession { + fn borrow(&self) -> &[u8] { + self.session.borrow() + } +} + +impl Buf for ServerSession { + fn remaining(&self) -> usize { + self.session.remaining() + } + + fn chunk(&self) -> &[u8] { + self.session.chunk() + } + + fn advance(&mut self, amt: usize) { + self.session.advance(amt) + } +} + +unsafe impl BufMut for ServerSession { + fn remaining_mut(&self) -> usize { + self.session.remaining_mut() + } + + unsafe fn advance_mut(&mut self, amt: usize) { + self.session.advance_mut(amt) + } + + fn chunk_mut(&mut self) -> &mut UninitSlice { + self.session.chunk_mut() + } + + #[allow(unused_mut)] + fn put(&mut self, mut src: T) + where + Self: Sized, + { + self.session.put(src) + } + + fn put_slice(&mut self, src: &[u8]) { + self.session.put_slice(src) + } +} + +impl event::Source for ServerSession { + fn register(&mut self, registry: &Registry, token: Token, interest: Interest) -> Result<()> { + self.session.register(registry, token, interest) + } + + fn reregister(&mut self, registry: &Registry, token: Token, interest: Interest) -> Result<()> { + self.session.reregister(registry, token, interest) + } + + fn deregister(&mut self, registry: &Registry) -> Result<()> { + self.session.deregister(registry) + } +} diff --git a/src/session/src/stream.rs b/src/session/src/stream.rs deleted file mode 100644 index 5d2f8b37e..000000000 --- a/src/session/src/stream.rs +++ /dev/null @@ -1,262 +0,0 @@ -// Copyright 2021 Twitter, Inc. -// Licensed under the Apache License, Version 2.0 -// http://www.apache.org/licenses/LICENSE-2.0 - -//! Encapsulates plaintext and TLS TCP streams into a single type. - -use std::io::{Error, ErrorKind}; -use std::io::{Read, Write}; -use std::net::SocketAddr; - -use common::ssl::{HandshakeError, MidHandshakeSslStream, SslStream}; - -use super::TcpStream; -use crate::{TCP_CLOSE, TCP_CONN_CURR}; - -pub struct Stream { - inner: Option, -} - -pub enum StreamType { - /// An established plaintext TCP connection - Plain(TcpStream), - /// A TLS/SSL TCP stream which is fully negotiated - Tls(SslStream), - /// A TLS/SSL TCP stream which is still handshaking - Handshaking(MidHandshakeSslStream), -} - -impl Stream { - pub fn plain(tcp_stream: TcpStream) -> Self { - Self { - inner: Some(StreamType::Plain(tcp_stream)), - } - } - - pub fn tls(ssl_stream: SslStream) -> Self { - Self { - inner: Some(StreamType::Tls(ssl_stream)), - } - } - - pub fn handshaking(handshaking_ssl_stream: MidHandshakeSslStream) -> Self { - Self { - inner: Some(StreamType::Handshaking(handshaking_ssl_stream)), - } - } - - pub fn is_handshaking(&self) -> bool { - matches!(self.inner, Some(StreamType::Handshaking(_))) - } - - pub fn do_handshake(&mut self) -> Result<(), std::io::Error> { - if let Some(StreamType::Handshaking(stream)) = self.inner.take() { - let ret; - let result = stream.handshake(); - self.inner = match result { - Ok(established) => { - ret = Ok(()); - Some(StreamType::Tls(established)) - } - Err(HandshakeError::WouldBlock(handshaking)) => { - ret = Err(Error::new(ErrorKind::WouldBlock, "handshake would block")); - Some(StreamType::Handshaking(handshaking)) - } - _ => { - ret = Err(Error::new(ErrorKind::Other, "handshaking error")); - None - } - }; - ret - } else { - Err(Error::new( - ErrorKind::Other, - "session is not in handshaking state", - )) - } - } - - pub fn close(&mut self) { - TCP_CLOSE.increment(); - TCP_CONN_CURR.sub(1); - if let Some(stream) = self.inner.take() { - self.inner = match stream { - StreamType::Plain(s) => { - let _ = s.shutdown(std::net::Shutdown::Both); - Some(StreamType::Plain(s)) - } - StreamType::Tls(mut s) => { - // TODO(bmartin): session resume requires that a full graceful - // shutdown occurs - let _ = s.shutdown(); - Some(StreamType::Tls(s)) - } - StreamType::Handshaking(mut s) => { - // since we don't have a fully established session, just - // shutdown the underlying tcp stream - let _ = s.get_mut().shutdown(std::net::Shutdown::Both); - Some(StreamType::Handshaking(s)) - } - } - } - } - - pub fn peer_addr(&self) -> Result { - if let Some(ref stream) = self.inner.as_ref() { - Ok(match stream { - StreamType::Plain(s) => s.peer_addr()?, - StreamType::Tls(s) => s.get_ref().peer_addr()?, - StreamType::Handshaking(s) => s.get_ref().peer_addr()?, - }) - } else { - Err(Error::new( - ErrorKind::NotConnected, - "session is not connected", - )) - } - } -} - -impl Read for Stream { - fn read(&mut self, buf: &mut [u8]) -> Result { - if let Some(stream) = &mut self.inner { - stream.read(buf) - } else { - Err(Error::new( - ErrorKind::NotConnected, - "session is not connected", - )) - } - } -} - -impl Read for StreamType { - fn read(&mut self, buf: &mut [u8]) -> Result { - match self { - Self::Plain(s) => s.read(buf), - Self::Tls(s) => s.read(buf), - Self::Handshaking(_) => Err(Error::new( - ErrorKind::WouldBlock, - "handshaking tls stream would block on read", - )), - } - } -} - -impl Write for Stream { - fn write(&mut self, buf: &[u8]) -> Result { - if let Some(stream) = &mut self.inner { - stream.write(buf) - } else { - Err(Error::new( - ErrorKind::NotConnected, - "session is not connected", - )) - } - } - fn flush(&mut self) -> Result<(), std::io::Error> { - if let Some(stream) = &mut self.inner { - stream.flush() - } else { - Ok(()) - } - } -} - -impl Write for StreamType { - fn write(&mut self, buf: &[u8]) -> Result { - match self { - Self::Plain(s) => s.write(buf), - Self::Tls(s) => s.write(buf), - Self::Handshaking(_) => Err(Error::new( - ErrorKind::WouldBlock, - "handshaking tls stream would block on write", - )), - } - } - - fn flush(&mut self) -> std::result::Result<(), std::io::Error> { - Ok(()) - } -} - -impl mio::event::Source for Stream { - fn register( - &mut self, - registry: &mio::Registry, - token: mio::Token, - interest: mio::Interest, - ) -> std::result::Result<(), std::io::Error> { - if let Some(stream) = &mut self.inner { - stream.register(registry, token, interest) - } else { - Err(Error::new( - ErrorKind::NotConnected, - "session is not connected", - )) - } - } - - fn reregister( - &mut self, - registry: &mio::Registry, - token: mio::Token, - interest: mio::Interest, - ) -> std::result::Result<(), std::io::Error> { - if let Some(stream) = &mut self.inner { - stream.reregister(registry, token, interest) - } else { - Err(Error::new( - ErrorKind::NotConnected, - "session is not connected", - )) - } - } - - fn deregister(&mut self, registry: &mio::Registry) -> std::result::Result<(), std::io::Error> { - if let Some(stream) = &mut self.inner { - stream.deregister(registry) - } else { - Err(Error::new( - ErrorKind::NotConnected, - "session is not connected", - )) - } - } -} - -impl mio::event::Source for StreamType { - fn register( - &mut self, - registry: &mio::Registry, - token: mio::Token, - interest: mio::Interest, - ) -> std::result::Result<(), std::io::Error> { - match self { - Self::Plain(s) => registry.register(s, token, interest), - Self::Tls(s) => registry.register(s.get_mut(), token, interest), - Self::Handshaking(s) => registry.register(s.get_mut(), token, interest), - } - } - - fn reregister( - &mut self, - registry: &mio::Registry, - token: mio::Token, - interest: mio::Interest, - ) -> std::result::Result<(), std::io::Error> { - match self { - Self::Plain(s) => registry.reregister(s, token, interest), - Self::Tls(s) => registry.reregister(s.get_mut(), token, interest), - Self::Handshaking(s) => registry.reregister(s.get_mut(), token, interest), - } - } - - fn deregister(&mut self, registry: &mio::Registry) -> std::result::Result<(), std::io::Error> { - match self { - Self::Plain(s) => registry.deregister(s), - Self::Tls(s) => registry.deregister(s.get_mut()), - Self::Handshaking(s) => registry.deregister(s.get_mut()), - } - } -} diff --git a/src/session/src/tcp_stream.rs b/src/session/src/tcp_stream.rs deleted file mode 100644 index 47bf04bd4..000000000 --- a/src/session/src/tcp_stream.rs +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2021 Twitter, Inc. -// Licensed under the Apache License, Version 2.0 -// http://www.apache.org/licenses/LICENSE-2.0 - -//! A new type wrapper for `TcpStream`s which allows for capturing metrics about -//! operations on the underlying TCP stream. - -use std::convert::TryFrom; -use std::io::{Read, Write}; -use std::net::SocketAddr; - -use crate::{TCP_RECV_BYTE, TCP_SEND_BYTE, TCP_SEND_PARTIAL}; - -pub struct TcpStream { - inner: mio::net::TcpStream, -} - -impl TcpStream { - pub fn shutdown(&self, how: std::net::Shutdown) -> Result<(), std::io::Error> { - self.inner.shutdown(how) - } - - pub fn peer_addr(&self) -> Result { - self.inner.peer_addr() - } -} - -impl TryFrom for TcpStream { - type Error = std::io::Error; - - fn try_from(other: mio::net::TcpStream) -> Result { - let _ = other.peer_addr()?; - Ok(Self { inner: other }) - } -} - -impl Read for TcpStream { - fn read(&mut self, buf: &mut [u8]) -> std::result::Result { - let result = self.inner.read(buf); - if let Ok(bytes) = result { - TCP_RECV_BYTE.add(bytes as _); - } - result - } -} - -impl Write for TcpStream { - fn write(&mut self, buf: &[u8]) -> std::result::Result { - let result = self.inner.write(buf); - if let Ok(bytes) = result { - if bytes != buf.len() { - TCP_SEND_PARTIAL.increment(); - } - TCP_SEND_BYTE.add(bytes as _); - } - result - } - fn flush(&mut self) -> std::result::Result<(), std::io::Error> { - self.inner.flush() - } -} - -impl mio::event::Source for TcpStream { - fn register( - &mut self, - registry: &mio::Registry, - token: mio::Token, - interest: mio::Interest, - ) -> std::result::Result<(), std::io::Error> { - self.inner.register(registry, token, interest) - } - - fn reregister( - &mut self, - registry: &mio::Registry, - token: mio::Token, - interest: mio::Interest, - ) -> std::result::Result<(), std::io::Error> { - self.inner.reregister(registry, token, interest) - } - - fn deregister(&mut self, registry: &mio::Registry) -> std::result::Result<(), std::io::Error> { - self.inner.deregister(registry) - } -}