diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 7b9e0ee6..6b7ab5c4 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -9,10 +9,11 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 - - uses: moonrepo/setup-rust@v1 + - uses: moonrepo/setup-rust@b8edcc56aab474d90c7cf0bb8beeaf8334c15e9f with: channel: '1.74.0' bins: cargo-deny + - run: sudo apt-get install libpam0g-dev - run: cargo deny --all-features check postsubmit: diff --git a/.github/workflows/presubmit.yml b/.github/workflows/presubmit.yml index 1e940eae..c1dac3e7 100644 --- a/.github/workflows/presubmit.yml +++ b/.github/workflows/presubmit.yml @@ -7,15 +7,18 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 - - uses: moonrepo/setup-rust@v1 + - uses: moonrepo/setup-rust@b8edcc56aab474d90c7cf0bb8beeaf8334c15e9f with: channel: '1.74.0' - - run: sudo apt-get install zsh fish - - run: cargo test --all-features + - run: sudo apt-get install zsh fish libpam0g-dev + - run: SHPOOL_LEAVE_TEST_LOGS=true cargo test --all-features - uses: actions/upload-artifact@v4 + id: upload-logs-step with: name: test-logs path: /tmp/shpool-test*/*.log + - name: Output artifact ID + run: echo 'Artifact ID is ${{ steps.upload-logs-step.outputs.artifact-id }}' # miri does not handle all the IO we do, disabled for now. # @@ -36,10 +39,11 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 - - uses: moonrepo/setup-rust@v1 + - uses: moonrepo/setup-rust@b8edcc56aab474d90c7cf0bb8beeaf8334c15e9f with: components: rustfmt channel: nightly + - run: sudo apt-get install libpam0g-dev - run: cargo +nightly fmt -- --check cranky: @@ -47,12 +51,12 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 - - uses: moonrepo/setup-rust@v1 + - uses: moonrepo/setup-rust@b8edcc56aab474d90c7cf0bb8beeaf8334c15e9f with: components: clippy bins: cargo-cranky@0.3.0 channel: nightly - - run: sudo apt-get install zsh fish + - run: sudo apt-get install zsh fish libpam0g-dev - run: cargo +nightly cranky --all-targets -- -D warnings deny: @@ -60,8 +64,9 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 - - uses: moonrepo/setup-rust@v1 + - uses: moonrepo/setup-rust@b8edcc56aab474d90c7cf0bb8beeaf8334c15e9f with: channel: '1.74.0' bins: cargo-deny + - run: sudo apt-get install libpam0g-dev - run: cargo deny --all-features check licenses diff --git a/Cargo.lock b/Cargo.lock index c55650ee..d882f72d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -218,6 +218,55 @@ version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" +[[package]] +name = "dirs" +version = "4.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" +dependencies = [ + "libc", + "redox_users", + "winapi", +] + +[[package]] +name = "dlopen2" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1297103d2bbaea85724fcee6294c2d50b1081f9ad47d0f6f6f61eda65315a6" +dependencies = [ + "dlopen2_derive", + "libc", + "once_cell", + "winapi", +] + +[[package]] +name = "dlopen2_derive" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b99bf03862d7f545ebc28ddd33a665b50865f4dfd84031a393823879bd4c54" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.52", +] + +[[package]] +name = "either" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" + [[package]] name = "equivalent" version = "1.0.1" @@ -261,6 +310,23 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "getrandom" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "hashbrown" version = "0.14.3" @@ -273,6 +339,15 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +[[package]] +name = "home" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +dependencies = [ + "windows-sys", +] + [[package]] name = "iana-time-zone" version = "0.1.60" @@ -333,6 +408,17 @@ version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +[[package]] +name = "libredox" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8" +dependencies = [ + "bitflags 2.4.2", + "libc", + "redox_syscall", +] + [[package]] name = "libshpool" version = "0.5.0" @@ -346,6 +432,7 @@ dependencies = [ "lazy_static", "libc", "log", + "motd", "nix", "ntest", "serde", @@ -354,6 +441,7 @@ dependencies = [ "shpool_pty", "shpool_vt100", "signal-hook", + "terminfo", "toml", "tracing", "tracing-subscriber", @@ -386,6 +474,31 @@ dependencies = [ "autocfg", ] +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] +name = "motd" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8b1327b3888ed1ed4b1f0ba708f1e5c4f66b6140a5816792fe05f643c3efc9d" +dependencies = [ + "dlopen2", + "lazy_static", + "libc", + "log", + "pam-sys", + "serde", + "serde_derive", + "serde_json", + "tempfile", + "walkdir", + "which", +] + [[package]] name = "nix" version = "0.26.4" @@ -399,6 +512,16 @@ dependencies = [ "pin-utils", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "ntest" version = "0.9.2" @@ -447,6 +570,53 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "pam-sys" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd4858311a097f01a0006ef7d0cd50bca81ec430c949d7bf95cbefd202282434" +dependencies = [ + "libc", +] + +[[package]] +name = "phf" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_codegen" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" +dependencies = [ + "phf_generator", + "phf_shared", +] + +[[package]] +name = "phf_generator" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +dependencies = [ + "phf_shared", + "rand", +] + +[[package]] +name = "phf_shared" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project-lite" version = "0.2.13" @@ -486,6 +656,41 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" + +[[package]] +name = "redox_syscall" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +dependencies = [ + "bitflags 1.3.2", +] + +[[package]] +name = "redox_users" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a18479200779601e498ada4e8c1e1f50e3ee19deb0259c25825a98b5603b2cb4" +dependencies = [ + "getrandom", + "libredox", + "thiserror", +] + [[package]] name = "regex" version = "1.10.3" @@ -534,6 +739,15 @@ version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "serde" version = "1.0.197" @@ -598,6 +812,7 @@ dependencies = [ "crossbeam-channel", "lazy_static", "libshpool", + "motd", "nix", "ntest", "regex", @@ -647,6 +862,12 @@ dependencies = [ "libc", ] +[[package]] +name = "siphasher" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" + [[package]] name = "smallvec" version = "1.13.1" @@ -693,6 +914,39 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "terminfo" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "666cd3a6681775d22b200409aad3b089c5b99fb11ecdd8a204d9d62f8148498f" +dependencies = [ + "dirs", + "fnv", + "nom", + "phf", + "phf_codegen", +] + +[[package]] +name = "thiserror" +version = "1.0.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.52", +] + [[package]] name = "thread_local" version = "1.1.8" @@ -849,6 +1103,22 @@ dependencies = [ "quote", ] +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + [[package]] name = "wasm-bindgen" version = "0.2.92" @@ -903,6 +1173,19 @@ version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +[[package]] +name = "which" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fa5e0c10bf77f44aac573e498d1a82d5fbd5e91f6fc0a99e7be4b38e85e101c" +dependencies = [ + "either", + "home", + "once_cell", + "rustix", + "windows-sys", +] + [[package]] name = "winapi" version = "0.3.9" @@ -919,6 +1202,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" +dependencies = [ + "winapi", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" diff --git a/libshpool/Cargo.toml b/libshpool/Cargo.toml index 9c0d8c1d..d758ad8f 100644 --- a/libshpool/Cargo.toml +++ b/libshpool/Cargo.toml @@ -37,6 +37,8 @@ tracing = "0.1" # logging and performance monitoring facade bincode = "1" # serialization for the control protocol shpool_vt100 = "0.1.2" # terminal emulation for the scrollback buffer shell-words = "1" # parsing the -c/--cmd argument +motd = "0.2.0" # getting the message-of-the-day +terminfo = "0.8.0" # resolving terminal escape codes [dependencies.tracing-subscriber] version = "0.3" diff --git a/libshpool/README.md b/libshpool/README.md index 9d87764f..63ded92c 100644 --- a/libshpool/README.md +++ b/libshpool/README.md @@ -8,3 +8,14 @@ to an internal google version of the tool, but don't believe that telemetry belongs in an open-source tool. Other potential use-cases such as incorporating a shpool daemon into an IDE that hosts remote terminals could be imagined though. + +## Integrating + +In order to call libshpool, you must keep a few things in mind. +In spirit, you just need to call `libshpool::run(libshpoo::Args::parse())`, +but you need to take care of a few things manually. + +1. Handle the `version` subcommand. Since libshpool is a library, the output + will not be very good if the library handles the versioning. +2. Depend on the `motd` crate and call `motd::handle_reexec()` in your `main` + function. diff --git a/libshpool/src/config.rs b/libshpool/src/config.rs index dcc47259..81518af7 100644 --- a/libshpool/src/config.rs +++ b/libshpool/src/config.rs @@ -110,6 +110,21 @@ pub struct Config { /// verbatim except that the string '$SHPOOL_SESSION_NAME' will /// get replaced with the actual name of the shpool session. pub prompt_prefix: Option, + + /// Control when and how shpool will display the message of the day. + pub motd: Option, + + /// Override arguments to pass to pam_motd.so when resolving the + /// message of the day. Normally, you want to leave this blank + /// so that shpool will scrape the default arguments used in + /// `/etc/pam.d/{ssh,login}` which typically produces the expected + /// result, but in some cases you may need to override the argument + /// list. You can also use this to make a custom message of the + /// day that is only displayed when using shpool. + /// + /// See https://man7.org/linux/man-pages/man8/pam_motd.8.html + /// for more info. + pub motd_args: Option>, } #[derive(Deserialize, Debug, Clone)] @@ -140,6 +155,31 @@ pub enum SessionRestoreMode { Lines(u16), } +#[derive(Deserialize, Debug, Clone, Default)] +#[serde(rename_all = "lowercase")] +pub enum MotdDisplayMode { + /// Never display the message of the day. + #[default] + Never, + + /// Display the message of the day using the given program + /// as the pager. The pager will be invoked like `pager /tmp/motd.txt`, + /// and normal connection will only proceed once the pager has + /// exited. + /// + /// Display the message of the day each time a user attaches + /// (wether to a new session or reattaching to an existing session). + /// + /// `less` by default. + // Pager(String), + + /// Just dump the message of the day directly to the screen. + /// Dumps are only performed when a new session is created. + /// There is no safe way to dump directly when reattaching, + /// so we don't attempt it. + Dump, +} + #[cfg(test)] mod test { use super::*; diff --git a/libshpool/src/daemon/control_codes.rs b/libshpool/src/daemon/control_codes.rs new file mode 100644 index 00000000..f07dc2c1 --- /dev/null +++ b/libshpool/src/daemon/control_codes.rs @@ -0,0 +1,82 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! The escape codes module provides an online (trie based) matcher +//! to scan for escape codes we are interested in in the output of +//! the subshell. For the moment, we just use this to scan for +//! the ClearScreen code emitted by the prompt prefix injection shell +//! code. We need to scan for this to avoid a race that can lead to +//! the motd getting clobbered when in dump mode. + +use anyhow::{anyhow, Context}; + +use super::trie::{Trie, TrieCursor}; + +#[derive(Debug, Clone, Copy)] +pub enum Code { + ClearScreen, +} + +#[derive(Debug)] +pub struct Matcher { + codes: Trie>>, + codes_cursor: TrieCursor, +} + +impl Matcher { + pub fn new(term_db: &terminfo::Database) -> anyhow::Result { + let clear_code_bytes = match term_db.get::() { + Some(code) => code.expand().to_vec().context("expanding clear code")?, + None => { + // If we somehow have a wacky terminfo db with no clear code, we fall + // back on xterm clear since we still need something to scan for. + let xterm_db = + terminfo::Database::from_name("xterm").context("building fallback xterm db")?; + let code = xterm_db + .get::() + .ok_or(anyhow!("no fallback clear screen code"))?; + code.expand().to_vec().context("expanding fallback clear code")? + } + }; + + let raw_bindings = vec![ + // We need to scan for the clear code that gets emitted by the prompt prefix + // shell injection code so that we can make sure that the message of the day + // won't get clobbered immediately. + (clear_code_bytes, Code::ClearScreen), + ]; + let mut codes = Trie::new(); + for (raw_bytes, code) in raw_bindings.into_iter() { + codes.insert(raw_bytes.into_iter(), code); + } + + Ok(Matcher { codes, codes_cursor: TrieCursor::Start }) + } + + pub fn transition(&mut self, byte: u8) -> Option { + self.codes_cursor = self.codes.advance(self.codes_cursor, byte); + match self.codes_cursor { + TrieCursor::NoMatch => { + self.codes_cursor = TrieCursor::Start; + None + } + TrieCursor::Match { is_partial, .. } if !is_partial => { + let code = self.codes.get(self.codes_cursor).copied(); + self.codes_cursor = TrieCursor::Start; + code + } + _ => None, + } + } +} diff --git a/libshpool/src/daemon/keybindings.rs b/libshpool/src/daemon/keybindings.rs index e61504a9..b106802f 100644 --- a/libshpool/src/daemon/keybindings.rs +++ b/libshpool/src/daemon/keybindings.rs @@ -48,11 +48,13 @@ //! be singletons besides 'Ctrl' or of the form 'Ctrl-x' where //! x is some non-'Ctrl' key. -use std::{collections::HashMap, fmt, hash}; +use std::{collections::HashMap, fmt}; use anyhow::{anyhow, Context}; use serde_derive::Deserialize; +use super::trie::{Trie, TrieCursor, TrieTab}; + // // Keybindings table // @@ -96,6 +98,20 @@ pub enum BindingResult { #[derive(Eq, PartialEq, Copy, Clone, Hash)] struct ChordAtom(u8); +impl TrieTab for Vec> { + fn new() -> Self { + vec![None; u8::MAX as usize] + } + + fn get(&self, index: ChordAtom) -> Option<&usize> { + self[index.0 as usize].as_ref() + } + + fn set(&mut self, index: ChordAtom, elem: usize) { + self[index.0 as usize] = Some(elem) + } +} + impl Bindings { /// new builds a bindings matching engine, parsing the given binding->action /// mapping and compiling it into the pair of tries that we use to perform @@ -396,164 +412,6 @@ impl Lexer { } } -// -// Trie (used in both the parser and the execution engine) -// - -#[derive(Debug)] -struct Trie { - // The nodes which form the tree. The first node is the root - // node, afterwards the order is undefined. - nodes: Vec>, -} - -#[derive(Eq, PartialEq, Copy, Clone, Debug)] -enum TrieCursor { - /// A cursor to use to start a char-wise match - Start, - /// Represents a state in the middle or end of a match - Match { idx: usize, is_partial: bool }, - /// A terminal state indicating a failure to match - NoMatch, -} - -#[derive(Debug)] -struct TrieNode { - // We need to store a phantom symbol here so we can have the - // Sym type parameter available for the TrieTab trait constraint - // in the impl block. Apologies for the type tetris. - phantom: std::marker::PhantomData, - value: Option, - tab: TT, -} - -impl Trie -where - TT: TrieTab, - Sym: Copy, -{ - fn new() -> Self { - Trie { nodes: vec![TrieNode::new(None)] } - } - - fn insert>(&mut self, seq: Seq, value: V) { - let mut current_node = 0; - for sym in seq { - current_node = if let Some(next_node) = self.nodes[current_node].tab.get(sym) { - *next_node - } else { - let idx = self.nodes.len(); - self.nodes.push(TrieNode::new(None)); - self.nodes[current_node].tab.set(sym, idx); - idx - }; - } - self.nodes[current_node].value = Some(value); - } - - #[allow(dead_code)] - fn contains>(&self, seq: Seq) -> bool { - let mut match_state = TrieCursor::Start; - for sym in seq { - match_state = self.advance(match_state, sym); - if let TrieCursor::NoMatch = match_state { - return false; - } - } - if let TrieCursor::Start = match_state { - return self.nodes[0].value.is_some(); - } - - if let TrieCursor::Match { is_partial, .. } = match_state { !is_partial } else { false } - } - - fn advance(&self, cursor: TrieCursor, sym: Sym) -> TrieCursor { - let node = match cursor { - TrieCursor::Start => &self.nodes[0], - TrieCursor::Match { idx, .. } => &self.nodes[idx], - TrieCursor::NoMatch => return TrieCursor::NoMatch, - }; - - if let Some(idx) = node.tab.get(sym) { - TrieCursor::Match { idx: *idx, is_partial: self.nodes[*idx].value.is_none() } - } else { - TrieCursor::NoMatch - } - } - - fn get(&self, cursor: TrieCursor) -> Option<&V> { - if let TrieCursor::Match { idx, .. } = cursor { - self.nodes[idx].value.as_ref() - } else { - None - } - } -} - -impl TrieNode -where - TT: TrieTab, -{ - fn new(value: Option) -> Self { - TrieNode { phantom: std::marker::PhantomData, value, tab: TT::new() } - } -} - -/// The backing table the trie uses to associate symbols with state -/// indexes. This is basically std::ops::IndexMut plus a new function. -/// We can't just make this a sub-trait of IndexMut because u8 does -/// not implement IndexMut for vectors. -trait TrieTab { - fn new() -> Self; - fn get(&self, index: Idx) -> Option<&usize>; - fn set(&mut self, index: Idx, elem: usize); -} - -impl TrieTab for HashMap -where - Sym: hash::Hash + Eq + PartialEq, -{ - fn new() -> Self { - HashMap::new() - } - - fn get(&self, index: Sym) -> Option<&usize> { - self.get(&index) - } - - fn set(&mut self, index: Sym, elem: usize) { - self.insert(index, elem); - } -} - -impl TrieTab for Vec> { - fn new() -> Self { - vec![None; u8::MAX as usize] - } - - fn get(&self, index: u8) -> Option<&usize> { - self[index as usize].as_ref() - } - - fn set(&mut self, index: u8, elem: usize) { - self[index as usize] = Some(elem) - } -} - -impl TrieTab for Vec> { - fn new() -> Self { - vec![None; u8::MAX as usize] - } - - fn get(&self, index: ChordAtom) -> Option<&usize> { - self[index.0 as usize].as_ref() - } - - fn set(&mut self, index: ChordAtom, elem: usize) { - self[index.0 as usize] = Some(elem) - } -} - // // Data Tables // diff --git a/libshpool/src/daemon/mod.rs b/libshpool/src/daemon/mod.rs index 00591da4..e1081ba6 100644 --- a/libshpool/src/daemon/mod.rs +++ b/libshpool/src/daemon/mod.rs @@ -19,14 +19,17 @@ use tracing::{info, instrument}; use super::{config, hooks}; +mod control_codes; mod etc_environment; mod exit_notify; pub mod keybindings; mod prompt; mod server; mod shell; +mod show_motd; mod signals; mod systemd; +mod trie; mod ttl_reaper; #[instrument(skip_all)] @@ -39,7 +42,7 @@ pub fn run( info!("\n\n======================== STARTING DAEMON ============================\n\n"); let config = config::read_config(&config_file)?; - let server = server::Server::new(config, hooks, runtime_dir); + let server = server::Server::new(config, hooks, runtime_dir)?; let (cleanup_socket, listener) = match systemd::activation_socket() { Ok(l) => { diff --git a/libshpool/src/daemon/prompt.rs b/libshpool/src/daemon/prompt.rs index 794e6aaa..f9d71b81 100644 --- a/libshpool/src/daemon/prompt.rs +++ b/libshpool/src/daemon/prompt.rs @@ -29,9 +29,10 @@ pub fn inject_prefix( shell: &str, prompt_prefix: &str, session_name: &str, + needs_default_term: bool, ) -> anyhow::Result<()> { let prompt_prefix = prompt_prefix.replace("$SHPOOL_SESSION_NAME", session_name); - let script = if shell.ends_with("bash") { + let mut script = if shell.ends_with("bash") { format!( r#" if [[ -z "${{PROMPT_COMMAND+x}}" ]]; then @@ -46,7 +47,6 @@ pub fn inject_prefix( }} PROMPT_COMMAND=__shpool__prompt_command fi - clear "# ) } else if shell.ends_with("zsh") { @@ -62,7 +62,6 @@ pub fn inject_prefix( PROMPT="{prompt_prefix}${{PROMPT}}" }} precmd_functions+=(__shpool__prompt_command) - clear "# ) } else if shell.ends_with("fish") { @@ -70,13 +69,18 @@ pub fn inject_prefix( r#" functions --copy fish_prompt shpool__old_prompt function fish_prompt; echo -n "{prompt_prefix}"; shpool__old_prompt; end - clear "# ) } else { return Err(anyhow!("don't know how to inject a prefix for shell '{}'", shell)); }; + if needs_default_term { + script.push_str("\nclear\n"); + } else { + script.push_str("\nTERM=xterm clear\n"); + } + let mut pty_master = pty_master.is_parent().context("expected parent")?; pty_master.write_all(script.as_bytes()).context("running prefix script")?; diff --git a/libshpool/src/daemon/server.rs b/libshpool/src/daemon/server.rs index af621b2a..3cb34e7e 100644 --- a/libshpool/src/daemon/server.rs +++ b/libshpool/src/daemon/server.rs @@ -14,7 +14,9 @@ use std::{ collections::HashMap, - env, fs, io, net, + env, fs, io, + io::Write, + net, ops::Add, os, os::unix::{ @@ -36,7 +38,7 @@ use tracing::{error, info, instrument, span, trace, warn, Level}; use super::{ super::{config, consts, protocol, test_hooks, tty, user}, - etc_environment, hooks, prompt, shell, ttl_reaper, + etc_environment, hooks, prompt, shell, show_motd, ttl_reaper, }; use crate::daemon::exit_notify::ExitNotifier; @@ -56,6 +58,7 @@ pub struct Server { runtime_dir: PathBuf, register_new_reapable_session: crossbeam_channel::Sender<(String, Instant)>, hooks: Box, + motd_shower: Arc, } impl Server { @@ -64,7 +67,7 @@ impl Server { config: config::Config, hooks: Box, runtime_dir: PathBuf, - ) -> Arc { + ) -> anyhow::Result> { let shells = Arc::new(Mutex::new(HashMap::new())); // buffered so that we are unlikely to block when setting up a // new session @@ -76,13 +79,18 @@ impl Server { } }); - Arc::new(Server { + let motd_shower = Arc::new(show_motd::Shower::new( + config.motd.clone().unwrap_or_default(), + config.motd_args.clone(), + )?); + Ok(Arc::new(Server { config, shells, runtime_dir, register_new_reapable_session: new_sess_tx, hooks, - }) + motd_shower, + })) } #[instrument(skip_all)] @@ -241,11 +249,19 @@ impl Server { } if matches!(status, protocol::AttachStatus::Created { .. }) { + use config::MotdDisplayMode; + info!("creating new subshell"); if let Err(err) = self.hooks.on_new_session(&header.name) { warn!("new_session hook: {:?}", err); } - let session = self.spawn_subshell(conn_id, stream, &header)?; + let motd = self.config.motd.clone().unwrap_or_default(); + let session = self.spawn_subshell( + conn_id, + stream, + &header, + matches!(motd, MotdDisplayMode::Dump), + )?; shells.insert(header.name.clone(), Box::new(session)); // fallthrough to bidi streaming @@ -280,7 +296,8 @@ impl Server { } }; - let reply_status = write_reply(client_stream, protocol::AttachReplyHeader { status }); + let reply_status = + write_reply(client_stream, protocol::AttachReplyHeader { status: status.clone() }); if let Err(e) = reply_status { error!("error writing reply status: {:?}", e); } @@ -523,6 +540,7 @@ impl Server { conn_id: usize, client_stream: UnixStream, header: &protocol::AttachHeader, + dump_motd_on_new_session: bool, ) -> anyhow::Result { let user_info = user::info()?; let shell = if let Some(s) = &self.config.shell { @@ -563,7 +581,13 @@ impl Server { // to avoid breakage and vars the user has asked us to inject. .env_clear(); - self.inject_env(&mut cmd, &user_info, header).context("setting up shell env")?; + let term = self.inject_env(&mut cmd, &user_info, header).context("setting up shell env")?; + let term_db = if let Some(term) = &term { + terminfo::Database::from_name(term).context("resolving terminfo")? + } else { + warn!("no $TERM, using default terminfo"); + terminfo::Database::from_env().context("resolving default terminfo")? + }; let shell_basename = if header.cmd.is_none() { // spawn the shell as a login shell by setting @@ -625,15 +649,36 @@ impl Server { info!("reaped child shell: {:?}", waitable_child); }); + let has_clear_screen = term_db.get::().is_some(); + let needs_default_term = !has_clear_screen || term.is_none(); + // inject the prompt prefix, if any + info!("injecting prompt prefix"); let prompt_prefix = self.config.prompt_prefix.clone().unwrap_or(String::from("")); if let Some(shell_basename) = shell_basename { if !prompt_prefix.is_empty() { - if let Err(err) = - prompt::inject_prefix(&mut fork, shell_basename, &prompt_prefix, &header.name) - { + if let Err(err) = prompt::inject_prefix( + &mut fork, + shell_basename, + &prompt_prefix, + &header.name, + needs_default_term, + ) { warn!("issue injecting prefix: {:?}", err); } + } else { + // issue a clear even if we don't have a prompt to inject for consistency + // and to simplify motd handling + let script = if needs_default_term { + "clear\n" + } else { + // If we don't have a $TERM value or we have some wacky $TERM value for which + // there is no ClearScreen code, set TERM to xterm so that we won't get a + // warning and will generate a code we can scan for. + "TERM=xterm clear\n" + }; + let mut pty_master = fork.is_parent().context("expected parent")?; + pty_master.write_all(script.as_bytes()).context("running initial clear")?; } } @@ -655,6 +700,9 @@ impl Server { client_stream: Some(client_stream), config: self.config.clone(), reader_join_h: None, + term_db, + motd_shower: Arc::clone(&self.motd_shower), + needs_initial_motd_dump: dump_motd_on_new_session, }; let child_pid = session_inner.pty_master.child_pid().ok_or(anyhow!("no child pid"))?; session_inner.reader_join_h = Some(session_inner.spawn_reader(shell::ReaderArgs { @@ -691,13 +739,14 @@ impl Server { }) } + /// Set up the environment for the shell, returning the right TERM value. #[instrument(skip_all)] fn inject_env( &self, cmd: &mut process::Command, user_info: &user::Info, header: &protocol::AttachHeader, - ) -> anyhow::Result<()> { + ) -> anyhow::Result> { cmd.env("HOME", &user_info.home_dir) .env( "PATH", @@ -753,7 +802,7 @@ impl Server { } } info!("injecting TERM into shell {:?}", term); - if let Some(t) = term { + if let Some(t) = &term { cmd.env("TERM", t); } @@ -780,7 +829,7 @@ impl Server { } } - Ok(()) + Ok(term) } fn ssh_auth_sock_symlink(&self, session_name: PathBuf) -> PathBuf { diff --git a/libshpool/src/daemon/shell.rs b/libshpool/src/daemon/shell.rs index 7441f0f4..a3a45294 100644 --- a/libshpool/src/daemon/shell.rs +++ b/libshpool/src/daemon/shell.rs @@ -32,7 +32,7 @@ use tracing::{debug, error, info, instrument, span, trace, warn, Level}; use crate::{ consts, - daemon::{config, exit_notify::ExitNotifier, keybindings}, + daemon::{config, control_codes, exit_notify::ExitNotifier, keybindings, show_motd}, protocol, test_hooks, tty, }; @@ -101,6 +101,9 @@ pub struct SessionInner { pub pty_master: shpool_pty::fork::Fork, pub client_stream: Option, pub config: config::Config, + pub term_db: terminfo::Database, + pub motd_shower: Arc, + pub needs_initial_motd_dump: bool, /// The join handle for the always-on background reader thread. /// Only wrapped in an option so we can spawn the thread after @@ -189,6 +192,13 @@ impl SessionInner { ) -> anyhow::Result>> { use nix::poll; + let term_db = self.term_db.clone(); + let mut control_code_matcher = + control_codes::Matcher::new(&self.term_db).context("building control code matcher")?; + + let motd_shower = Arc::clone(&self.motd_shower); + let mut needs_initial_motd_dump = self.needs_initial_motd_dump; + let mut pty_master = self.pty_master.is_parent()?; let name = self.name.clone(); let mut closure = move || { @@ -342,7 +352,7 @@ impl SessionInner { .context("sending size change ack")?; } Err(err) => { - info!("size change: bailing due to: {:?}", err); + warn!("size change: bailing due to: {:?}", err); return Ok(()); } } @@ -452,14 +462,62 @@ impl SessionInner { } } + // scan for control codes we need to handle let mut reset_client_conn = false; + let mut snip_buf_to = None; + if needs_initial_motd_dump { + for (i, byte) in buf[..len].iter().enumerate() { + match control_code_matcher.transition(*byte) { + Some(control_codes::Code::ClearScreen) if needs_initial_motd_dump => { + debug!("detected initial ClearScreen code"); + if let ClientConnectionMsg::New(conn) = &client_conn { + let mut s = conn.sink.lock().unwrap(); + + // write the clear code ahead of time so we don't + // immediately clobber ourselves + let write_to = i + 1; + let chunk = protocol::Chunk { + kind: protocol::ChunkKind::Data, + buf: &buf[..write_to], + }; + let write_result = + chunk.write_to(&mut *s).and_then(|_| s.flush()); + if let Err(err) = write_result { + info!( + "client_stream write err (1), assuming hangup: {:?}", + err + ); + reset_client_conn = true; + } else { + test_hooks::emit("daemon-wrote-s2c-chunk"); + } + snip_buf_to = Some(write_to); + + if let Err(e) = motd_shower.dump(&mut *s, &term_db) { + warn!("Error handling clear: {:?}", e); + } + } + needs_initial_motd_dump = false; + } + _ => {} + } + } + } + if let Some(snip_to) = snip_buf_to { + if snip_to < buf.len() { + buf = Vec::from(&buf[snip_to..]); + } else { + buf.clear(); + } + } + if let ClientConnectionMsg::New(conn) = &client_conn { let chunk = protocol::Chunk { kind: protocol::ChunkKind::Data, buf: &buf[..len] }; let mut s = conn.sink.lock().unwrap(); let write_result = chunk.write_to(&mut *s).and_then(|_| s.flush()); if let Err(err) = write_result { - info!("client_stream write err, assuming hangup: {:?}", err); + info!("client_stream write err (2), assuming hangup: {:?}", err); reset_client_conn = true; } else { test_hooks::emit("daemon-wrote-s2c-chunk"); @@ -718,7 +776,9 @@ impl SessionInner { NoMatch => { partial_keybinding.clear(); } - Partial => partial_keybinding.push(*byte), + Partial => { + partial_keybinding.push(*byte); + } Match(action) => { info!("{:?} keybinding action fired", action); let keybinding_len = partial_keybinding.len() + 1; diff --git a/libshpool/src/daemon/show_motd.rs b/libshpool/src/daemon/show_motd.rs new file mode 100644 index 00000000..ac317429 --- /dev/null +++ b/libshpool/src/daemon/show_motd.rs @@ -0,0 +1,95 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::io; + +use anyhow::{anyhow, Context}; + +use super::super::{config, protocol}; + +/// Showers know how to show the message of the day. +#[derive(Debug, Clone)] +pub struct Shower { + motd_resolver: motd::Resolver, + mode: config::MotdDisplayMode, + args: Option>, +} + +impl Shower { + /// Make a new Shower. + pub fn new(mode: config::MotdDisplayMode, args: Option>) -> anyhow::Result { + Ok(Shower { + motd_resolver: motd::Resolver::new(motd::PamMotdResolutionStrategy::Auto) + .context("creating motd resolver")?, + mode, + args, + }) + } + + pub fn dump( + &self, + mut stream: W, + term_db: &terminfo::Database, + ) -> anyhow::Result<()> { + assert!(matches!(self.mode, config::MotdDisplayMode::Dump)); + + let raw_motd_value = self.get_raw_motd_value(term_db)?; + + let chunk = + protocol::Chunk { kind: protocol::ChunkKind::Data, buf: raw_motd_value.as_slice() }; + + chunk.write_to(&mut stream).context("dumping motd") + } + + fn get_raw_motd_value(&self, term_db: &terminfo::Database) -> anyhow::Result> { + let motd_value = self + .motd_resolver + .value(match &self.args { + Some(args) => { + let mut args = args.clone(); + // On debian based systems we need to set noupdate in order to get + // the motd from userspace. It should be ignored on non-debian systems. + if !args.iter().any(|a| a == "noupdate") { + args.push(String::from("noupdate")); + } + motd::ArgResolutionStrategy::Exact(args) + } + None => motd::ArgResolutionStrategy::Auto, + }) + .context("resolving motd")?; + Self::convert_to_raw(term_db, &motd_value) + } + + /// Convert the given motd into a byte buffer suitable to be written to the + /// terminal. The only real transformation we perform is injecting carrage + /// returns after newlines. + fn convert_to_raw(term_db: &terminfo::Database, motd: &str) -> anyhow::Result> { + let carrage_return_code = term_db + .get::() + .ok_or(anyhow!("no carrage return code"))?; + let carrage_return_bytes = + carrage_return_code.expand().to_vec().context("expanding carrage return code")?; + + let mut buf: Vec = vec![]; + + let lines = motd.split('\n'); + for line in lines { + buf.extend(line.as_bytes()); + buf.push(b'\n'); + buf.extend(&carrage_return_bytes); + } + + Ok(buf) + } +} diff --git a/libshpool/src/daemon/trie.rs b/libshpool/src/daemon/trie.rs new file mode 100644 index 00000000..faccf8f0 --- /dev/null +++ b/libshpool/src/daemon/trie.rs @@ -0,0 +1,146 @@ +use std::{collections::HashMap, hash}; + +#[derive(Debug)] +pub struct Trie { + // The nodes which form the tree. The first node is the root + // node, afterwards the order is undefined. + nodes: Vec>, +} + +#[derive(Eq, PartialEq, Copy, Clone, Debug)] +pub enum TrieCursor { + /// A cursor to use to start a char-wise match + Start, + /// Represents a state in the middle or end of a match + Match { idx: usize, is_partial: bool }, + /// A terminal state indicating a failure to match + NoMatch, +} + +#[derive(Debug)] +pub struct TrieNode { + // We need to store a phantom symbol here so we can have the + // Sym type parameter available for the TrieTab trait constraint + // in the impl block. Apologies for the type tetris. + phantom: std::marker::PhantomData, + value: Option, + tab: TT, +} + +impl Trie +where + TT: TrieTab, + Sym: Copy, +{ + pub fn new() -> Self { + Trie { nodes: vec![TrieNode::new(None)] } + } + + /// Insert a seq, value pair into the trie + pub fn insert>(&mut self, seq: Seq, value: V) { + let mut current_node = 0; + for sym in seq { + current_node = if let Some(next_node) = self.nodes[current_node].tab.get(sym) { + *next_node + } else { + let idx = self.nodes.len(); + self.nodes.push(TrieNode::new(None)); + self.nodes[current_node].tab.set(sym, idx); + idx + }; + } + self.nodes[current_node].value = Some(value); + } + + /// Check if the given sequence exists in the trie, used by tests. + #[allow(dead_code)] + pub fn contains>(&self, seq: Seq) -> bool { + let mut match_state = TrieCursor::Start; + for sym in seq { + match_state = self.advance(match_state, sym); + if let TrieCursor::NoMatch = match_state { + return false; + } + } + if let TrieCursor::Start = match_state { + return self.nodes[0].value.is_some(); + } + + if let TrieCursor::Match { is_partial, .. } = match_state { !is_partial } else { false } + } + + /// Process a single token of input, returning the current state. + /// To start a new match, use TrieCursor::Start. + pub fn advance(&self, cursor: TrieCursor, sym: Sym) -> TrieCursor { + let node = match cursor { + TrieCursor::Start => &self.nodes[0], + TrieCursor::Match { idx, .. } => &self.nodes[idx], + TrieCursor::NoMatch => return TrieCursor::NoMatch, + }; + + if let Some(idx) = node.tab.get(sym) { + TrieCursor::Match { idx: *idx, is_partial: self.nodes[*idx].value.is_none() } + } else { + TrieCursor::NoMatch + } + } + + /// Get the value for a match cursor. + pub fn get(&self, cursor: TrieCursor) -> Option<&V> { + if let TrieCursor::Match { idx, .. } = cursor { + self.nodes[idx].value.as_ref() + } else { + None + } + } +} + +impl TrieNode +where + TT: TrieTab, +{ + fn new(value: Option) -> Self { + TrieNode { phantom: std::marker::PhantomData, value, tab: TT::new() } + } +} + +/// The backing table the trie uses to associate symbols with state +/// indexes. This is basically std::ops::IndexMut plus a new function. +/// We can't just make this a sub-trait of IndexMut because u8 does +/// not implement IndexMut for vectors. +pub trait TrieTab { + fn new() -> Self; + fn get(&self, index: Idx) -> Option<&usize>; + fn set(&mut self, index: Idx, elem: usize); +} + +impl TrieTab for HashMap +where + Sym: hash::Hash + Eq + PartialEq, +{ + fn new() -> Self { + HashMap::new() + } + + fn get(&self, index: Sym) -> Option<&usize> { + self.get(&index) + } + + fn set(&mut self, index: Sym, elem: usize) { + self.insert(index, elem); + } +} + +impl TrieTab for Vec> { + fn new() -> Self { + vec![None; u8::MAX as usize] + } + + fn get(&self, index: u8) -> Option<&usize> { + self[index as usize].as_ref() + } + + fn set(&mut self, index: u8, elem: usize) { + self[index as usize] = Some(elem) + } +} diff --git a/libshpool/src/protocol.rs b/libshpool/src/protocol.rs index da8ec5a3..2fe03aba 100644 --- a/libshpool/src/protocol.rs +++ b/libshpool/src/protocol.rs @@ -218,7 +218,7 @@ impl fmt::Display for SessionStatus { } /// AttachStatus indicates what happened during an attach attempt. -#[derive(PartialEq, Eq, Serialize, Deserialize, Debug)] +#[derive(PartialEq, Eq, Serialize, Deserialize, Debug, Clone)] pub enum AttachStatus { /// Attached indicates that there was an existing shell session with /// the given name, and `shpool attach` successfully connected to it. diff --git a/shpool/Cargo.toml b/shpool/Cargo.toml index 7f5c7ecc..f2a43533 100644 --- a/shpool/Cargo.toml +++ b/shpool/Cargo.toml @@ -19,6 +19,7 @@ rust-version = "1.74" clap = { version = "4", features = ["derive"] } # cli parsing anyhow = "1" # dynamic, unstructured errors libshpool = { version = "0.5.0", path = "../libshpool" } +motd = "0.2.0" # getting the message-of-the-day [dev-dependencies] lazy_static = "1" # globals diff --git a/shpool/src/main.rs b/shpool/src/main.rs index d2852d97..fbd306d6 100644 --- a/shpool/src/main.rs +++ b/shpool/src/main.rs @@ -20,6 +20,8 @@ use clap::Parser; const VERSION: &str = env!("CARGO_PKG_VERSION"); fn main() -> anyhow::Result<()> { + motd::handle_reexec(); + let args = libshpool::Args::parse(); if args.version() { diff --git a/shpool/tests/attach.rs b/shpool/tests/attach.rs index b5050d0e..a8d1cc3a 100644 --- a/shpool/tests/attach.rs +++ b/shpool/tests/attach.rs @@ -1,7 +1,8 @@ use std::{ - fs, + env, fs, io::BufRead, - io::Read, + io::{Read, Write}, + path::PathBuf, process::{Command, Stdio}, thread, time, }; @@ -32,7 +33,7 @@ fn happy_path() -> anyhow::Result<()> { daemon_proc.await_event("daemon-about-to-listen")?; attach_proc.run_cmd("echo hi")?; - line_matcher.match_re("hi$")?; + line_matcher.scan_until_re("hi$")?; attach_proc.run_cmd("echo ping")?; line_matcher.match_re("ping$")?; @@ -105,7 +106,7 @@ fn forward_env() -> anyhow::Result<()> { let mut line_matcher = attach_proc.line_matcher()?; attach_proc.run_cmd(r#"echo "$FOO:$BAR:$BAZ" "#)?; - line_matcher.match_re("foo:bar:$")?; + line_matcher.scan_until_re("foo:bar:$")?; Ok(()) }) @@ -139,7 +140,8 @@ fn symlink_ssh_auth_sock() -> anyhow::Result<()> { waiter.wait_event("daemon-wrote-s2c-chunk")?; // resize prompt redraw attach_proc.run_cmd("ls -l $SSH_AUTH_SOCK")?; - line_matcher.match_re(r#".*sh1/ssh-auth-sock.socket ->.*ssh-auth-sock-target.fake$"#)?; + line_matcher + .scan_until_re(r#".*sh1/ssh-auth-sock.socket ->.*ssh-auth-sock-target.fake$"#)?; Ok(()) }) @@ -163,7 +165,7 @@ fn missing_ssh_auth_sock() -> anyhow::Result<()> { waiter.wait_event("daemon-wrote-s2c-chunk")?; // resize prompt re-draw attach_proc.run_cmd("ls -l $SSH_AUTH_SOCK")?; - line_matcher.match_re(r#".*No such file or directory$"#)?; + line_matcher.scan_until_re(r#".*No such file or directory$"#)?; Ok(()) }) @@ -220,7 +222,7 @@ fn config_disable_symlink_ssh_auth_sock() -> anyhow::Result<()> { waiter.wait_event("daemon-wrote-s2c-chunk")?; // resize prompt re-draw attach_proc.run_cmd("ls -l $SSH_AUTH_SOCK")?; - line_matcher.match_re(r#".*No such file or directory$"#)?; + line_matcher.scan_until_re(r#".*No such file or directory$"#)?; Ok(()) }) @@ -244,7 +246,7 @@ fn bounce() -> anyhow::Result<()> { attach_proc.run_cmd("export MYVAR=1")?; attach_proc.run_cmd("echo $MYVAR")?; - line_matcher.match_re("1$")?; + line_matcher.scan_until_re("1$")?; } // falling out of scope kills attach_proc // wait until the daemon has noticed that the connection @@ -281,10 +283,10 @@ fn two_at_once() -> anyhow::Result<()> { let mut line_matcher2 = attach_proc2.line_matcher()?; attach_proc1.run_cmd("echo proc1").context("proc1 echo")?; - line_matcher1.match_re("proc1$").context("proc1 match")?; + line_matcher1.scan_until_re("proc1$").context("proc1 match")?; attach_proc2.run_cmd("echo proc2").context("proc2 echo")?; - line_matcher2.match_re("proc2$").context("proc2 match")?; + line_matcher2.scan_until_re("proc2$").context("proc2 match")?; Ok(()) }) @@ -308,7 +310,7 @@ fn explicit_exit() -> anyhow::Result<()> { attach_proc.run_cmd("export MYVAR=first")?; attach_proc.run_cmd("echo $MYVAR")?; - line_matcher.match_re("first$")?; + line_matcher.scan_until_re("first$")?; attach_proc.run_cmd("exit")?; @@ -324,7 +326,7 @@ fn explicit_exit() -> anyhow::Result<()> { let mut line_matcher = attach_proc.line_matcher()?; attach_proc.run_cmd("echo ${MYVAR:-second}")?; - line_matcher.match_re("second$")?; + line_matcher.scan_until_re("second$")?; } Ok(()) @@ -426,7 +428,7 @@ fn force_attach() -> anyhow::Result<()> { tty1.run_cmd("echo $MYVAR")?; // read some output to make sure the var is set by the time // we force-attach - line_matcher1.match_re("set_from_tty1$")?; + line_matcher1.scan_until_re("set_from_tty1$")?; let mut tty2 = daemon_proc .attach("sh1", AttachArgs { force: true, ..Default::default() }) @@ -450,12 +452,12 @@ fn busy() -> anyhow::Result<()> { daemon_proc.attach("sh1", Default::default()).context("attaching from tty1")?; let mut line_matcher1 = tty1.line_matcher()?; tty1.run_cmd("echo foo")?; // make sure the shell is up and running - line_matcher1.match_re("foo$")?; + line_matcher1.scan_until_re("foo$")?; let mut tty2 = daemon_proc.attach("sh1", Default::default()).context("attaching from tty2")?; let mut line_matcher2 = tty2.stderr_line_matcher()?; - line_matcher2.match_re("already has a terminal attached$")?; + line_matcher2.scan_until_re("already has a terminal attached$")?; Ok(()) }) @@ -473,7 +475,7 @@ fn daemon_hangup() -> anyhow::Result<()> { // make sure the shell is up and running let mut line_matcher = attach_proc.line_matcher()?; attach_proc.run_cmd("echo foo")?; - line_matcher.match_re("foo$")?; + line_matcher.scan_until_re("foo$")?; daemon_proc.proc_kill()?; @@ -498,7 +500,7 @@ fn default_keybinding_detach() -> anyhow::Result<()> { a1.run_cmd("export MYVAR=someval")?; a1.run_cmd("echo $MYVAR")?; - lm1.match_re("someval$")?; + lm1.scan_until_re("someval$")?; a1.run_raw_cmd(vec![0, 17])?; // Ctrl-Space Ctrl-q a1.proc.wait()?; @@ -510,7 +512,7 @@ fn default_keybinding_detach() -> anyhow::Result<()> { let mut lm2 = a2.line_matcher()?; a2.run_cmd("echo $MYVAR")?; - lm2.match_re("someval$")?; + lm2.scan_until_re("someval$")?; Ok(()) }) @@ -532,7 +534,7 @@ fn keybinding_input_shear() -> anyhow::Result<()> { a1.run_cmd("export MYVAR=someval")?; a1.run_cmd("echo $MYVAR")?; - lm1.match_re("someval$")?; + lm1.scan_until_re("someval$")?; a1.run_raw(vec![0])?; // Ctrl-Space thread::sleep(time::Duration::from_millis(100)); @@ -546,7 +548,7 @@ fn keybinding_input_shear() -> anyhow::Result<()> { let mut lm2 = a2.line_matcher()?; a2.run_cmd("echo $MYVAR")?; - lm2.match_re("someval$")?; + lm2.scan_until_re("someval$")?; Ok(()) }) @@ -564,7 +566,7 @@ fn keybinding_strip_keys() -> anyhow::Result<()> { // the keybinding is 5 'a' chars in a row, so they should get stripped out a1.run_cmd("echo baaaaad")?; - lm1.match_re("bd$")?; + lm1.scan_until_re("bd$")?; Ok(()) }) @@ -586,7 +588,7 @@ fn keybinding_strip_keys_split() -> anyhow::Result<()> { a1.run_raw("aa".bytes().collect())?; thread::sleep(time::Duration::from_millis(50)); a1.run_raw("aad\n".bytes().collect())?; - lm1.match_re("bd$")?; + lm1.scan_until_re("bd$")?; Ok(()) }) @@ -604,7 +606,7 @@ fn keybinding_partial_match_nostrip() -> anyhow::Result<()> { // the keybinding is 5 'a' chars in a row, this has only 3 a1.run_cmd("echo baaad")?; - lm1.match_re("baaad$")?; + lm1.scan_until_re("baaad$")?; Ok(()) }) @@ -626,7 +628,7 @@ fn keybinding_partial_match_nostrip_split() -> anyhow::Result<()> { a1.run_raw("a".bytes().collect())?; thread::sleep(time::Duration::from_millis(50)); a1.run_raw("ad\n".bytes().collect())?; - lm1.match_re("baaad$")?; + lm1.scan_until_re("baaad$")?; Ok(()) }) @@ -646,7 +648,7 @@ fn custom_keybinding_detach() -> anyhow::Result<()> { a1.run_cmd("export MYVAR=someval")?; a1.run_cmd("echo $MYVAR")?; - lm1.match_re("someval$")?; + lm1.scan_until_re("someval$")?; a1.run_raw_cmd(vec![22, 23, 7])?; // Ctrl-v Ctrl-w Ctrl-g a1.proc.wait()?; @@ -686,7 +688,7 @@ fn injects_term_even_with_env_config() -> anyhow::Result<()> { waiter.wait_event("daemon-wrote-s2c-chunk")?; // resize prompt redraw attach_proc.run_cmd("echo $SOME_CUSTOM_ENV_VAR")?; - line_matcher.match_re("customvalue$")?; + line_matcher.scan_until_re("customvalue$")?; attach_proc.run_cmd("echo $TERM")?; line_matcher.match_re("dumb$")?; @@ -716,7 +718,7 @@ fn injects_local_env_vars() -> anyhow::Result<()> { let mut line_matcher = attach_proc.line_matcher()?; attach_proc.run_cmd("echo $DISPLAY")?; - line_matcher.match_re(":0$")?; + line_matcher.scan_until_re(":0$")?; attach_proc.run_cmd("echo $LANG")?; line_matcher.match_re("fakelang$")?; @@ -737,7 +739,7 @@ fn has_right_default_path() -> anyhow::Result<()> { let mut line_matcher = attach_proc.line_matcher()?; attach_proc.run_cmd("echo $PATH")?; - line_matcher.match_re("/usr/bin:/bin:/usr/sbin:/sbin$")?; + line_matcher.scan_until_re("/usr/bin:/bin:/usr/sbin:/sbin$")?; Ok(()) }) @@ -757,7 +759,7 @@ fn screen_restore() -> anyhow::Result<()> { let mut line_matcher = attach_proc.line_matcher()?; attach_proc.run_cmd("echo foo")?; - line_matcher.match_re("foo$")?; + line_matcher.scan_until_re("foo$")?; } // wait until the daemon has noticed that the connection @@ -771,7 +773,7 @@ fn screen_restore() -> anyhow::Result<()> { // the re-attach should redraw the screen for us, so we should // get a line with "foo" as part of the re-drawn screen. - line_matcher.match_re("foo$")?; + line_matcher.scan_until_re("foo$")?; } Ok(()) @@ -792,7 +794,7 @@ fn screen_wide_restore() -> anyhow::Result<()> { let mut line_matcher = attach_proc.line_matcher()?; attach_proc.run_cmd("echo ooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooy")?; - line_matcher.match_re("ooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooy$")?; + line_matcher.scan_until_re("ooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooy$")?; } // wait until the daemon has noticed that the connection @@ -806,7 +808,7 @@ fn screen_wide_restore() -> anyhow::Result<()> { // the re-attach should redraw the screen for us, so we should // get a line with the full echo result as part of the re-drawn screen. - line_matcher.match_re("ooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooy$")?; + line_matcher.scan_until_re("ooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooyooooxooooy$")?; } Ok(()) @@ -827,7 +829,8 @@ fn lines_restore() -> anyhow::Result<()> { let mut line_matcher = attach_proc.line_matcher()?; attach_proc.run_cmd("echo foo")?; - line_matcher.match_re("foo$")?; + attach_proc.run_cmd("echo")?; + line_matcher.scan_until_re("foo$")?; } // wait until the daemon has noticed that the connection @@ -841,7 +844,7 @@ fn lines_restore() -> anyhow::Result<()> { // the re-attach should redraw the last 2 lines for us, so we should // get a line with "foo" as part of the re-drawn screen. - line_matcher.match_re("foo$")?; + line_matcher.scan_until_re("foo$")?; } Ok(()) @@ -871,7 +874,7 @@ fn lines_big_chunk_restore() -> anyhow::Result<()> { // for a single chunk let blob = format!("echo {}", (0..max_chunk_size).map(|_| "x").collect::()); attach_proc.run_cmd(blob.as_str())?; - line_matcher.match_re("xx$")?; + line_matcher.scan_until_re("xx$")?; attach_proc.run_cmd("echo food")?; line_matcher.match_re("food$")?; @@ -939,7 +942,7 @@ fn ttl_hangup() -> anyhow::Result<()> { // ensure the shell is up and running let mut line_matcher = attach_proc.line_matcher()?; attach_proc.run_cmd("echo hi")?; - line_matcher.match_re("hi$")?; + line_matcher.scan_until_re("hi$")?; // sleep long enough for the reaper to clobber the thread thread::sleep(time::Duration::from_millis(1200)); @@ -967,7 +970,7 @@ fn ttl_no_hangup_yet() -> anyhow::Result<()> { // ensure the shell is up and running let mut line_matcher = attach_proc.line_matcher()?; attach_proc.run_cmd("echo hi")?; - line_matcher.match_re("hi$")?; + line_matcher.scan_until_re("hi$")?; let listout = daemon_proc.list()?; assert!(String::from_utf8_lossy(listout.stdout.as_slice()).contains("sh1")); @@ -997,7 +1000,7 @@ fn prompt_prefix_bash() -> anyhow::Result<()> { .arg("attach") .arg("sh1") .spawn() - .context("spawning daemon process")?; + .context("spawning attach process")?; // The attach shell should be spawned and have read the // initial prompt after half a second. @@ -1037,7 +1040,7 @@ fn prompt_prefix_zsh() -> anyhow::Result<()> { .arg("attach") .arg("sh1") .spawn() - .context("spawning daemon process")?; + .context("spawning attach process")?; // The attach shell should be spawned and have read the // initial prompt after half a second. @@ -1077,7 +1080,7 @@ fn prompt_prefix_fish() -> anyhow::Result<()> { .arg("attach") .arg("sh1") .spawn() - .context("spawning daemon process")?; + .context("spawning attach process")?; // The attach shell should be spawned and have read the // initial prompt after half a second. @@ -1099,6 +1102,73 @@ fn prompt_prefix_fish() -> anyhow::Result<()> { }) } +#[test] +#[timeout(30000)] +fn motd_dump() -> anyhow::Result<()> { + support::dump_err(|| { + // set up the config + let tmp_dir = tempfile::TempDir::with_prefix("shpool-test-config")?; + let tmp_dir_path = if env::var("SHPOOL_LEAVE_TEST_LOGS").is_ok() { + // leave the tmp files around for later inspection if we have been asked + // to leave the logs in place. + tmp_dir.into_path() + } else { + PathBuf::from(tmp_dir.path()) + }; + eprintln!("building config in {:?}", tmp_dir_path); + let motd_file = tmp_dir_path.join("motd.txt"); + { + let mut f = fs::File::create(&motd_file)?; + f.write_all("MOTD_MSG\n".as_bytes())?; + } + let config_tmpl = fs::read_to_string(support::testdata_file("motd_dump.toml.tmpl"))?; + let config_contents = config_tmpl.replace("TMP_MOTD_MSG_FILE", motd_file.to_str().unwrap()); + let config_file = tmp_dir_path.join("motd_dump.toml"); + { + let mut f = fs::File::create(&config_file)?; + f.write_all(config_contents.as_bytes())?; + } + + // spawn a daemon based on our custom config + let daemon_proc = + support::daemon::Proc::new(&config_file, true).context("starting daemon proc")?; + + // We need to manually spawn our attach proc because + // the motd gets printed immediately, so we can't always + // attach a line matcher in time. + let mut child = Command::new(support::shpool_bin()?) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .arg("--socket") + .arg(&daemon_proc.socket_path) + .arg("--config-file") + .arg(config_file) + .arg("attach") + .arg("sh1") + .spawn() + .context("spawning attach process")?; + + // The attach shell should be spawned and have read the + // initial prompt after half a second. + std::thread::sleep(time::Duration::from_millis(500)); + child.kill().context("killing child")?; + + let mut stderr = child.stderr.take().context("missing stderr")?; + let mut stderr_str = String::from(""); + stderr.read_to_string(&mut stderr_str).context("slurping stderr")?; + assert!(stderr_str.is_empty()); + + let mut stdout = child.stdout.take().context("missing stdout")?; + let mut stdout_str = String::from(""); + stdout.read_to_string(&mut stdout_str).context("slurping stdout")?; + let stdout_re = Regex::new(".*MOTD_MSG.*")?; + // eprintln!("stdout_str='{}'", stdout_str); + assert!(stdout_re.is_match(&stdout_str)); + + Ok(()) + }) +} + #[ignore] // TODO: re-enable, this test if flaky #[test] fn up_arrow_no_crash() -> anyhow::Result<()> { diff --git a/shpool/tests/daemon.rs b/shpool/tests/daemon.rs index e651c009..30d79d37 100644 --- a/shpool/tests/daemon.rs +++ b/shpool/tests/daemon.rs @@ -215,7 +215,7 @@ fn hooks() -> anyhow::Result<()> { // sequencing let mut sh1_matcher = sh1_proc.line_matcher()?; sh1_proc.run_cmd("echo hi")?; - sh1_matcher.match_re("hi$")?; + sh1_matcher.scan_until_re("hi$")?; // 1 busy let mut busy_proc = diff --git a/shpool/tests/data/motd_dump.toml.tmpl b/shpool/tests/data/motd_dump.toml.tmpl new file mode 100644 index 00000000..8b4feae5 --- /dev/null +++ b/shpool/tests/data/motd_dump.toml.tmpl @@ -0,0 +1,11 @@ +norc = true +noecho = true +shell = "/bin/bash" +session_restore_mode = "simple" + +motd = "dump" +motd_args = ["motd=TMP_MOTD_MSG_FILE"] + +[env] +PS1 = "prompt> " +TERM = "xterm" diff --git a/shpool/tests/detach.rs b/shpool/tests/detach.rs index abe046bc..dd139b4a 100644 --- a/shpool/tests/detach.rs +++ b/shpool/tests/detach.rs @@ -126,7 +126,7 @@ fn reattach() -> anyhow::Result<()> { let mut lm1 = sess1.line_matcher()?; sess1.run_cmd("export MYVAR=first ; echo hi")?; - lm1.match_re("hi$")?; + lm1.scan_until_re("hi$")?; let out = daemon_proc.detach(vec![String::from("sh1")])?; assert!(out.status.success(), "not successful"); diff --git a/shpool/tests/kill.rs b/shpool/tests/kill.rs index 10bb8f48..5d9176dd 100644 --- a/shpool/tests/kill.rs +++ b/shpool/tests/kill.rs @@ -120,7 +120,7 @@ fn reattach_after_kill() -> anyhow::Result<()> { let mut lm1 = sess1.line_matcher()?; sess1.run_cmd("export MYVAR=first")?; sess1.run_cmd("echo $MYVAR")?; - lm1.match_re("first$")?; + lm1.scan_until_re("first$")?; let out = daemon_proc.kill(vec![String::from("sh1")])?; assert!(out.status.success()); @@ -138,7 +138,7 @@ fn reattach_after_kill() -> anyhow::Result<()> { daemon_proc.attach("sh1", Default::default()).context("starting attach proc")?; let mut lm2 = sess2.line_matcher()?; sess2.run_cmd("echo ${MYVAR:-second}")?; - lm2.match_re("second$")?; + lm2.scan_until_re("second$")?; Ok(()) }) diff --git a/shpool/tests/support/daemon.rs b/shpool/tests/support/daemon.rs index 7ae97e27..5f3c3830 100644 --- a/shpool/tests/support/daemon.rs +++ b/shpool/tests/support/daemon.rs @@ -102,6 +102,12 @@ impl Proc { let log_file = tmp_dir.join("daemon.log"); eprintln!("spawning daemon proc with log {:?}", &log_file); + let resolved_config = if config.as_ref().exists() { + PathBuf::from(config.as_ref()) + } else { + testdata_file(config) + }; + let mut cmd = Command::new(shpool_bin()?); cmd.stdout(Stdio::piped()) .stderr(Stdio::piped()) @@ -111,7 +117,7 @@ impl Proc { .arg("--socket") .arg(&socket_path) .arg("--config-file") - .arg(testdata_file(config)) + .arg(resolved_config) .arg("daemon"); if listen_events { cmd.env("SHPOOL_TEST_HOOK_SOCKET_PATH", &test_hook_socket_path); diff --git a/shpool/tests/support/line_matcher.rs b/shpool/tests/support/line_matcher.rs index 8e8654d6..4e2a31a7 100644 --- a/shpool/tests/support/line_matcher.rs +++ b/shpool/tests/support/line_matcher.rs @@ -14,6 +14,51 @@ impl LineMatcher where R: std::io::Read, { + /// Scan lines until one matches the given regex + pub fn scan_until_re(&mut self, re: &str) -> anyhow::Result<()> { + let compiled_re = Regex::new(re)?; + let start = time::Instant::now(); + loop { + let mut line = String::new(); + match self.out.read_line(&mut line) { + Ok(0) => { + return Err(anyhow!("LineMatcher: EOF")); + } + Err(e) => { + if e.kind() == io::ErrorKind::WouldBlock { + if start.elapsed() > CMD_READ_TIMEOUT { + return Err(io::Error::new( + io::ErrorKind::TimedOut, + "timed out reading line", + ))?; + } + + std::thread::sleep(CMD_READ_SLEEP_DUR); + continue; + } + + return Err(e).context("reading line from shell output")?; + } + Ok(_) => { + if line.ends_with('\n') { + line.pop(); + if line.ends_with('\r') { + line.pop(); + } + } + } + } + + eprint!("scanning for /{}/... ", re); + if compiled_re.is_match(&line) { + eprintln!(" match"); + return Ok(()); + } else { + eprintln!(" no match"); + } + } + } + pub fn match_re(&mut self, re: &str) -> anyhow::Result<()> { match self.capture_re(re) { Ok(_) => Ok(()),