diff --git a/Cargo.lock b/Cargo.lock index c55650ee..d882f72d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -218,6 +218,55 @@ version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" +[[package]] +name = "dirs" +version = "4.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" +dependencies = [ + "libc", + "redox_users", + "winapi", +] + +[[package]] +name = "dlopen2" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1297103d2bbaea85724fcee6294c2d50b1081f9ad47d0f6f6f61eda65315a6" +dependencies = [ + "dlopen2_derive", + "libc", + "once_cell", + "winapi", +] + +[[package]] +name = "dlopen2_derive" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b99bf03862d7f545ebc28ddd33a665b50865f4dfd84031a393823879bd4c54" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.52", +] + +[[package]] +name = "either" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" + [[package]] name = "equivalent" version = "1.0.1" @@ -261,6 +310,23 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "getrandom" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "hashbrown" version = "0.14.3" @@ -273,6 +339,15 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +[[package]] +name = "home" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +dependencies = [ + "windows-sys", +] + [[package]] name = "iana-time-zone" version = "0.1.60" @@ -333,6 +408,17 @@ version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +[[package]] +name = "libredox" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8" +dependencies = [ + "bitflags 2.4.2", + "libc", + "redox_syscall", +] + [[package]] name = "libshpool" version = "0.5.0" @@ -346,6 +432,7 @@ dependencies = [ "lazy_static", "libc", "log", + "motd", "nix", "ntest", "serde", @@ -354,6 +441,7 @@ dependencies = [ "shpool_pty", "shpool_vt100", "signal-hook", + "terminfo", "toml", "tracing", "tracing-subscriber", @@ -386,6 +474,31 @@ dependencies = [ "autocfg", ] +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] +name = "motd" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8b1327b3888ed1ed4b1f0ba708f1e5c4f66b6140a5816792fe05f643c3efc9d" +dependencies = [ + "dlopen2", + "lazy_static", + "libc", + "log", + "pam-sys", + "serde", + "serde_derive", + "serde_json", + "tempfile", + "walkdir", + "which", +] + [[package]] name = "nix" version = "0.26.4" @@ -399,6 +512,16 @@ dependencies = [ "pin-utils", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "ntest" version = "0.9.2" @@ -447,6 +570,53 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "pam-sys" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd4858311a097f01a0006ef7d0cd50bca81ec430c949d7bf95cbefd202282434" +dependencies = [ + "libc", +] + +[[package]] +name = "phf" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_codegen" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" +dependencies = [ + "phf_generator", + "phf_shared", +] + +[[package]] +name = "phf_generator" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +dependencies = [ + "phf_shared", + "rand", +] + +[[package]] +name = "phf_shared" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project-lite" version = "0.2.13" @@ -486,6 +656,41 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" + +[[package]] +name = "redox_syscall" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +dependencies = [ + "bitflags 1.3.2", +] + +[[package]] +name = "redox_users" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a18479200779601e498ada4e8c1e1f50e3ee19deb0259c25825a98b5603b2cb4" +dependencies = [ + "getrandom", + "libredox", + "thiserror", +] + [[package]] name = "regex" version = "1.10.3" @@ -534,6 +739,15 @@ version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "serde" version = "1.0.197" @@ -598,6 +812,7 @@ dependencies = [ "crossbeam-channel", "lazy_static", "libshpool", + "motd", "nix", "ntest", "regex", @@ -647,6 +862,12 @@ dependencies = [ "libc", ] +[[package]] +name = "siphasher" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" + [[package]] name = "smallvec" version = "1.13.1" @@ -693,6 +914,39 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "terminfo" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "666cd3a6681775d22b200409aad3b089c5b99fb11ecdd8a204d9d62f8148498f" +dependencies = [ + "dirs", + "fnv", + "nom", + "phf", + "phf_codegen", +] + +[[package]] +name = "thiserror" +version = "1.0.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.52", +] + [[package]] name = "thread_local" version = "1.1.8" @@ -849,6 +1103,22 @@ dependencies = [ "quote", ] +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + [[package]] name = "wasm-bindgen" version = "0.2.92" @@ -903,6 +1173,19 @@ version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +[[package]] +name = "which" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fa5e0c10bf77f44aac573e498d1a82d5fbd5e91f6fc0a99e7be4b38e85e101c" +dependencies = [ + "either", + "home", + "once_cell", + "rustix", + "windows-sys", +] + [[package]] name = "winapi" version = "0.3.9" @@ -919,6 +1202,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" +dependencies = [ + "winapi", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" diff --git a/libshpool/Cargo.toml b/libshpool/Cargo.toml index 9c0d8c1d..d758ad8f 100644 --- a/libshpool/Cargo.toml +++ b/libshpool/Cargo.toml @@ -37,6 +37,8 @@ tracing = "0.1" # logging and performance monitoring facade bincode = "1" # serialization for the control protocol shpool_vt100 = "0.1.2" # terminal emulation for the scrollback buffer shell-words = "1" # parsing the -c/--cmd argument +motd = "0.2.0" # getting the message-of-the-day +terminfo = "0.8.0" # resolving terminal escape codes [dependencies.tracing-subscriber] version = "0.3" diff --git a/libshpool/README.md b/libshpool/README.md index 9d87764f..63ded92c 100644 --- a/libshpool/README.md +++ b/libshpool/README.md @@ -8,3 +8,14 @@ to an internal google version of the tool, but don't believe that telemetry belongs in an open-source tool. Other potential use-cases such as incorporating a shpool daemon into an IDE that hosts remote terminals could be imagined though. + +## Integrating + +In order to call libshpool, you must keep a few things in mind. +In spirit, you just need to call `libshpool::run(libshpoo::Args::parse())`, +but you need to take care of a few things manually. + +1. Handle the `version` subcommand. Since libshpool is a library, the output + will not be very good if the library handles the versioning. +2. Depend on the `motd` crate and call `motd::handle_reexec()` in your `main` + function. diff --git a/libshpool/src/config.rs b/libshpool/src/config.rs index dcc47259..a9391f5b 100644 --- a/libshpool/src/config.rs +++ b/libshpool/src/config.rs @@ -110,6 +110,9 @@ pub struct Config { /// verbatim except that the string '$SHPOOL_SESSION_NAME' will /// get replaced with the actual name of the shpool session. pub prompt_prefix: Option, + + /// Control when and how shpool will display the message of the day. + pub motd: Option, } #[derive(Deserialize, Debug, Clone)] @@ -140,6 +143,31 @@ pub enum SessionRestoreMode { Lines(u16), } +#[derive(Deserialize, Debug, Clone, Default)] +#[serde(rename_all = "lowercase")] +pub enum MotdDisplayMode { + /// Never display the message of the day. + #[default] + Never, + + /// Display the message of the day using the given program + /// as the pager. The pager will be invoked like `pager /tmp/motd.txt`, + /// and normal connection will only proceed once the pager has + /// exited. + /// + /// Display the message of the day each time a user attaches + /// (wether to a new session or reattaching to an existing session). + /// + /// `less` by default. + // Pager(String), + + /// Just dump the message of the day directly to the screen. + /// Dumps are only performed when a new session is created. + /// There is no safe way to dump directly when reattaching, + /// so we don't attempt it. + Dump, +} + #[cfg(test)] mod test { use super::*; diff --git a/libshpool/src/daemon/control_codes.rs b/libshpool/src/daemon/control_codes.rs new file mode 100644 index 00000000..8c73b386 --- /dev/null +++ b/libshpool/src/daemon/control_codes.rs @@ -0,0 +1,82 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! The escape codes module provides an online (trie based) matcher +//! to scan for escape codes we are interested in in the output of +//! the subshell. For the moment, we just use this to scan for +//! the ClearScreen code emitted by the prompt prefix injection shell +//! code. We need to scan for this to avoid a race that can lead to +//! the motd getting clobbered when in dump mode. + +use anyhow::{anyhow, Context}; + +use super::trie::{TrieCursor, Trie}; + +#[derive(Debug, Clone, Copy)] +pub enum Code { + ClearScreen, +} + +#[derive(Debug)] +pub struct Matcher { + codes: Trie>>, + codes_cursor: TrieCursor, +} + +impl Matcher { + pub fn new(term_db: &terminfo::Database) -> anyhow::Result { + let clear_code = term_db.get::() + .ok_or(anyhow!("no clear screen code"))?; + let clear_code_bytes = clear_code.expand().to_vec().context("expanding clear code")?; + + // TODO: delete + let clear_code_hex_bytes = clear_code_bytes.iter().map(|b| format!("{:02x}", b)).collect::>().join(" "); + tracing::debug!("clear_code_hex_bytes={}", clear_code_hex_bytes); + + let raw_bindings = vec![ + // We need to scan for the clear code that gets emitted by the prompt prefix + // shell injection code so that we can make sure that the message of the day + // won't get clobbered immediately. + (clear_code_bytes, Code::ClearScreen), + ]; + let mut codes = Trie::new(); + for (raw_bytes, code) in raw_bindings.into_iter() { + codes.insert(raw_bytes.into_iter(), code); + } + + Ok(Matcher { + codes, + codes_cursor: TrieCursor::Start, + }) + } + + pub fn transition(&mut self, byte: u8) -> Option { + let old_cursor = self.codes_cursor; + self.codes_cursor = self.codes.advance(self.codes_cursor, byte); + tracing::debug!("TRANSITION({:?}, {:02x}) -> {:?}", old_cursor, byte, self.codes_cursor); // TODO: delete + match self.codes_cursor { + TrieCursor::NoMatch => { + self.codes_cursor = TrieCursor::Start; + None + }, + TrieCursor::Match { is_partial, .. } if !is_partial => { + let code = self.codes.get(self.codes_cursor).map(|c| *c); + tracing::debug!("TRANSITION MATCH: code={:?}", code); + self.codes_cursor = TrieCursor::Start; + code + } + _ => None, + } + } +} diff --git a/libshpool/src/daemon/keybindings.rs b/libshpool/src/daemon/keybindings.rs index e61504a9..b106802f 100644 --- a/libshpool/src/daemon/keybindings.rs +++ b/libshpool/src/daemon/keybindings.rs @@ -48,11 +48,13 @@ //! be singletons besides 'Ctrl' or of the form 'Ctrl-x' where //! x is some non-'Ctrl' key. -use std::{collections::HashMap, fmt, hash}; +use std::{collections::HashMap, fmt}; use anyhow::{anyhow, Context}; use serde_derive::Deserialize; +use super::trie::{Trie, TrieCursor, TrieTab}; + // // Keybindings table // @@ -96,6 +98,20 @@ pub enum BindingResult { #[derive(Eq, PartialEq, Copy, Clone, Hash)] struct ChordAtom(u8); +impl TrieTab for Vec> { + fn new() -> Self { + vec![None; u8::MAX as usize] + } + + fn get(&self, index: ChordAtom) -> Option<&usize> { + self[index.0 as usize].as_ref() + } + + fn set(&mut self, index: ChordAtom, elem: usize) { + self[index.0 as usize] = Some(elem) + } +} + impl Bindings { /// new builds a bindings matching engine, parsing the given binding->action /// mapping and compiling it into the pair of tries that we use to perform @@ -396,164 +412,6 @@ impl Lexer { } } -// -// Trie (used in both the parser and the execution engine) -// - -#[derive(Debug)] -struct Trie { - // The nodes which form the tree. The first node is the root - // node, afterwards the order is undefined. - nodes: Vec>, -} - -#[derive(Eq, PartialEq, Copy, Clone, Debug)] -enum TrieCursor { - /// A cursor to use to start a char-wise match - Start, - /// Represents a state in the middle or end of a match - Match { idx: usize, is_partial: bool }, - /// A terminal state indicating a failure to match - NoMatch, -} - -#[derive(Debug)] -struct TrieNode { - // We need to store a phantom symbol here so we can have the - // Sym type parameter available for the TrieTab trait constraint - // in the impl block. Apologies for the type tetris. - phantom: std::marker::PhantomData, - value: Option, - tab: TT, -} - -impl Trie -where - TT: TrieTab, - Sym: Copy, -{ - fn new() -> Self { - Trie { nodes: vec![TrieNode::new(None)] } - } - - fn insert>(&mut self, seq: Seq, value: V) { - let mut current_node = 0; - for sym in seq { - current_node = if let Some(next_node) = self.nodes[current_node].tab.get(sym) { - *next_node - } else { - let idx = self.nodes.len(); - self.nodes.push(TrieNode::new(None)); - self.nodes[current_node].tab.set(sym, idx); - idx - }; - } - self.nodes[current_node].value = Some(value); - } - - #[allow(dead_code)] - fn contains>(&self, seq: Seq) -> bool { - let mut match_state = TrieCursor::Start; - for sym in seq { - match_state = self.advance(match_state, sym); - if let TrieCursor::NoMatch = match_state { - return false; - } - } - if let TrieCursor::Start = match_state { - return self.nodes[0].value.is_some(); - } - - if let TrieCursor::Match { is_partial, .. } = match_state { !is_partial } else { false } - } - - fn advance(&self, cursor: TrieCursor, sym: Sym) -> TrieCursor { - let node = match cursor { - TrieCursor::Start => &self.nodes[0], - TrieCursor::Match { idx, .. } => &self.nodes[idx], - TrieCursor::NoMatch => return TrieCursor::NoMatch, - }; - - if let Some(idx) = node.tab.get(sym) { - TrieCursor::Match { idx: *idx, is_partial: self.nodes[*idx].value.is_none() } - } else { - TrieCursor::NoMatch - } - } - - fn get(&self, cursor: TrieCursor) -> Option<&V> { - if let TrieCursor::Match { idx, .. } = cursor { - self.nodes[idx].value.as_ref() - } else { - None - } - } -} - -impl TrieNode -where - TT: TrieTab, -{ - fn new(value: Option) -> Self { - TrieNode { phantom: std::marker::PhantomData, value, tab: TT::new() } - } -} - -/// The backing table the trie uses to associate symbols with state -/// indexes. This is basically std::ops::IndexMut plus a new function. -/// We can't just make this a sub-trait of IndexMut because u8 does -/// not implement IndexMut for vectors. -trait TrieTab { - fn new() -> Self; - fn get(&self, index: Idx) -> Option<&usize>; - fn set(&mut self, index: Idx, elem: usize); -} - -impl TrieTab for HashMap -where - Sym: hash::Hash + Eq + PartialEq, -{ - fn new() -> Self { - HashMap::new() - } - - fn get(&self, index: Sym) -> Option<&usize> { - self.get(&index) - } - - fn set(&mut self, index: Sym, elem: usize) { - self.insert(index, elem); - } -} - -impl TrieTab for Vec> { - fn new() -> Self { - vec![None; u8::MAX as usize] - } - - fn get(&self, index: u8) -> Option<&usize> { - self[index as usize].as_ref() - } - - fn set(&mut self, index: u8, elem: usize) { - self[index as usize] = Some(elem) - } -} - -impl TrieTab for Vec> { - fn new() -> Self { - vec![None; u8::MAX as usize] - } - - fn get(&self, index: ChordAtom) -> Option<&usize> { - self[index.0 as usize].as_ref() - } - - fn set(&mut self, index: ChordAtom, elem: usize) { - self[index.0 as usize] = Some(elem) - } -} - // // Data Tables // diff --git a/libshpool/src/daemon/mod.rs b/libshpool/src/daemon/mod.rs index 00591da4..f551c3b3 100644 --- a/libshpool/src/daemon/mod.rs +++ b/libshpool/src/daemon/mod.rs @@ -28,6 +28,9 @@ mod shell; mod signals; mod systemd; mod ttl_reaper; +mod show_motd; +mod trie; +mod control_codes; #[instrument(skip_all)] pub fn run( @@ -39,7 +42,7 @@ pub fn run( info!("\n\n======================== STARTING DAEMON ============================\n\n"); let config = config::read_config(&config_file)?; - let server = server::Server::new(config, hooks, runtime_dir); + let server = server::Server::new(config, hooks, runtime_dir)?; let (cleanup_socket, listener) = match systemd::activation_socket() { Ok(l) => { diff --git a/libshpool/src/daemon/server.rs b/libshpool/src/daemon/server.rs index af621b2a..965b743a 100644 --- a/libshpool/src/daemon/server.rs +++ b/libshpool/src/daemon/server.rs @@ -14,7 +14,7 @@ use std::{ collections::HashMap, - env, fs, io, net, + env, fs, io, io::Write, net, ops::Add, os, os::unix::{ @@ -36,7 +36,7 @@ use tracing::{error, info, instrument, span, trace, warn, Level}; use super::{ super::{config, consts, protocol, test_hooks, tty, user}, - etc_environment, hooks, prompt, shell, ttl_reaper, + etc_environment, hooks, prompt, shell, ttl_reaper, show_motd }; use crate::daemon::exit_notify::ExitNotifier; @@ -56,6 +56,7 @@ pub struct Server { runtime_dir: PathBuf, register_new_reapable_session: crossbeam_channel::Sender<(String, Instant)>, hooks: Box, + motd_shower: Arc, } impl Server { @@ -64,7 +65,7 @@ impl Server { config: config::Config, hooks: Box, runtime_dir: PathBuf, - ) -> Arc { + ) -> anyhow::Result> { let shells = Arc::new(Mutex::new(HashMap::new())); // buffered so that we are unlikely to block when setting up a // new session @@ -76,13 +77,16 @@ impl Server { } }); - Arc::new(Server { + let motd_shower = Arc::new(show_motd::Shower::new( + config.motd.clone().unwrap_or_default())?); + Ok(Arc::new(Server { config, shells, runtime_dir, register_new_reapable_session: new_sess_tx, hooks, - }) + motd_shower, + })) } #[instrument(skip_all)] @@ -241,11 +245,15 @@ impl Server { } if matches!(status, protocol::AttachStatus::Created { .. }) { + use config::MotdDisplayMode; + info!("creating new subshell"); if let Err(err) = self.hooks.on_new_session(&header.name) { warn!("new_session hook: {:?}", err); } - let session = self.spawn_subshell(conn_id, stream, &header)?; + let motd = self.config.motd.clone().unwrap_or_default(); + let session = self.spawn_subshell( + conn_id, stream, &header, matches!(motd, MotdDisplayMode::Dump))?; shells.insert(header.name.clone(), Box::new(session)); // fallthrough to bidi streaming @@ -280,11 +288,39 @@ impl Server { } }; - let reply_status = write_reply(client_stream, protocol::AttachReplyHeader { status }); + let reply_status = write_reply(client_stream, protocol::AttachReplyHeader { + status: status.clone(), + }); if let Err(e) = reply_status { error!("error writing reply status: {:?}", e); } + /* TODO: delete + // TODO: need to figure out how to block until the prompt injection is processed. + // Not clear how to do it without relying on gross sleeps. + let motd_at = self.config.motd.clone().unwrap_or_default(); + match (motd_at, status) { + (config::MotdDisplayAt::NewSession, protocol::AttachStatus::Created { .. }) | + (config::MotdDisplayAt::Attach, _) => { + info!("displaying motd"); + if let Err(e) = self.motd_shower.show(client_stream) { + warn!("Error showing motd: {}", e); + } + } + (_, _) => { + info!("not displaying motd"); + }, + } + */ + /* TODO: delete + let motd_at = self.config.motd.clone().unwrap_or_default(); + let show_motd = match (motd_at, status) { + (config::MotdDisplayAt::NewSession, protocol::AttachStatus::Created { .. }) | + (config::MotdDisplayAt::Attach, _) => true, + (_, _) => false, + }; + */ + info!("starting bidi stream loop"); match inner.bidi_stream(conn_id, header.local_tty_size.clone(), child_exit_notifier) { Ok(done) => { @@ -523,6 +559,7 @@ impl Server { conn_id: usize, client_stream: UnixStream, header: &protocol::AttachHeader, + dump_motd_on_new_session: bool, ) -> anyhow::Result { let user_info = user::info()?; let shell = if let Some(s) = &self.config.shell { @@ -563,7 +600,12 @@ impl Server { // to avoid breakage and vars the user has asked us to inject. .env_clear(); - self.inject_env(&mut cmd, &user_info, header).context("setting up shell env")?; + let term_db = if let Some(term) = self.inject_env(&mut cmd, &user_info, header).context("setting up shell env")? { + terminfo::Database::from_name(term).context("resolving terminfo")? + } else { + warn!("no $TERM, using default terminfo"); + terminfo::Database::from_env().context("resolving default terminfo")? + }; let shell_basename = if header.cmd.is_none() { // spawn the shell as a login shell by setting @@ -626,6 +668,7 @@ impl Server { }); // inject the prompt prefix, if any + info!("injecting prompt prefix"); let prompt_prefix = self.config.prompt_prefix.clone().unwrap_or(String::from("")); if let Some(shell_basename) = shell_basename { if !prompt_prefix.is_empty() { @@ -634,6 +677,12 @@ impl Server { { warn!("issue injecting prefix: {:?}", err); } + } else { + // issue a clear even if we don't have a prompt to inject for consistency + // and to simplify motd handling + let script = "clear\n"; + let mut pty_master = fork.is_parent().context("expected parent")?; + pty_master.write_all(script.as_bytes()).context("running initial clear")?; } } @@ -655,6 +704,9 @@ impl Server { client_stream: Some(client_stream), config: self.config.clone(), reader_join_h: None, + term_db, + motd_shower: Arc::clone(&self.motd_shower), + needs_initial_motd_dump: dump_motd_on_new_session, }; let child_pid = session_inner.pty_master.child_pid().ok_or(anyhow!("no child pid"))?; session_inner.reader_join_h = Some(session_inner.spawn_reader(shell::ReaderArgs { @@ -691,13 +743,14 @@ impl Server { }) } + /// Set up the environment for the shell, returning the right TERM value. #[instrument(skip_all)] fn inject_env( &self, cmd: &mut process::Command, user_info: &user::Info, header: &protocol::AttachHeader, - ) -> anyhow::Result<()> { + ) -> anyhow::Result> { cmd.env("HOME", &user_info.home_dir) .env( "PATH", @@ -753,7 +806,7 @@ impl Server { } } info!("injecting TERM into shell {:?}", term); - if let Some(t) = term { + if let Some(t) = &term { cmd.env("TERM", t); } @@ -780,7 +833,7 @@ impl Server { } } - Ok(()) + Ok(term) } fn ssh_auth_sock_symlink(&self, session_name: PathBuf) -> PathBuf { diff --git a/libshpool/src/daemon/shell.rs b/libshpool/src/daemon/shell.rs index 7441f0f4..07a0d231 100644 --- a/libshpool/src/daemon/shell.rs +++ b/libshpool/src/daemon/shell.rs @@ -18,6 +18,7 @@ use std::{ net, ops::Add, os::unix::net::UnixStream, + sync, sync::{ atomic::{AtomicBool, Ordering}, Arc, Mutex, @@ -32,7 +33,7 @@ use tracing::{debug, error, info, instrument, span, trace, warn, Level}; use crate::{ consts, - daemon::{config, exit_notify::ExitNotifier, keybindings}, + daemon::{config, exit_notify::ExitNotifier, keybindings, show_motd, control_codes}, protocol, test_hooks, tty, }; @@ -101,6 +102,9 @@ pub struct SessionInner { pub pty_master: shpool_pty::fork::Fork, pub client_stream: Option, pub config: config::Config, + pub term_db: terminfo::Database, + pub motd_shower: Arc, + pub needs_initial_motd_dump: bool, /// The join handle for the always-on background reader thread. /// Only wrapped in an option so we can spawn the thread after @@ -189,11 +193,20 @@ impl SessionInner { ) -> anyhow::Result>> { use nix::poll; + let term_db = self.term_db.clone(); + let mut control_code_matcher = control_codes::Matcher::new(&self.term_db) + .context("building control code matcher")?; + + let motd_shower = Arc::clone(&self.motd_shower); + let mut needs_initial_motd_dump = self.needs_initial_motd_dump; + debug!("needs_initial_motd_dump={}", needs_initial_motd_dump); // TODO: delete + let mut pty_master = self.pty_master.is_parent()?; let name = self.name.clone(); let mut closure = move || { let _s = span!(Level::INFO, "reader", s = name, cid = args.conn_id).entered(); + let mut output_spool = if matches!(args.session_restore_mode, config::SessionRestoreMode::Simple) { None @@ -342,7 +355,7 @@ impl SessionInner { .context("sending size change ack")?; } Err(err) => { - info!("size change: bailing due to: {:?}", err); + warn!("size change: bailing due to: {:?}", err); return Ok(()); } } @@ -445,6 +458,7 @@ impl SessionInner { continue; } trace!("read pty master len={} '{}'", len, String::from_utf8_lossy(&buf[..len])); + trace!("raw: {}", buf[..len].iter().map(|b| format!("{:02x}", b)).collect::>().join(" ")); // TODO: delete if !matches!(args.session_restore_mode, config::SessionRestoreMode::Simple) { if let Some(s) = output_spool.as_mut() { @@ -452,14 +466,60 @@ impl SessionInner { } } + // scan for control codes we need to handle let mut reset_client_conn = false; + let mut snip_buf_to = None; + if needs_initial_motd_dump { + for (i, byte) in buf[..len].iter().enumerate() { + match control_code_matcher.transition(*byte) { + Some(control_codes::Code::ClearScreen) if needs_initial_motd_dump => { + debug!("handling ClearScreen code"); + debug!("handling ClearScreen code 1"); // TODO: delete + if let ClientConnectionMsg::New(conn) = &client_conn { + debug!("handling ClearScreen code 2"); // TODO: delete + let mut s = conn.sink.lock().unwrap(); + + // write the clear code ahead of time so we don't + // immediately clobber ourselves + let write_to = i + 1; + let chunk = protocol::Chunk { kind: protocol::ChunkKind::Data, buf: &buf[..write_to] }; + let write_result = chunk.write_to(&mut *s).and_then(|_| s.flush()); + if let Err(err) = write_result { + info!("client_stream write err (1), assuming hangup: {:?}", err); + reset_client_conn = true; + } else { + test_hooks::emit("daemon-wrote-s2c-chunk"); + } + snip_buf_to = Some(write_to); + + debug!("handling ClearScreen code 3"); // TODO: delete + if let Err(e) = motd_shower.dump(&mut *s, &term_db) { + debug!("handling ClearScreen code 4"); // TODO: delete + warn!("Error handling clear: {}", e); + } + debug!("handling ClearScreen code 5"); // TODO: delete + } + needs_initial_motd_dump = false; + } + _ => {}, + } + } + } + if let Some(snip_to) = snip_buf_to { + if snip_to < buf.len() { + buf = Vec::from(&buf[snip_to..]); + } else { + buf.clear(); + } + } + if let ClientConnectionMsg::New(conn) = &client_conn { let chunk = protocol::Chunk { kind: protocol::ChunkKind::Data, buf: &buf[..len] }; let mut s = conn.sink.lock().unwrap(); let write_result = chunk.write_to(&mut *s).and_then(|_| s.flush()); if let Err(err) = write_result { - info!("client_stream write err, assuming hangup: {:?}", err); + info!("client_stream write err (2), assuming hangup: {:?}", err); reset_client_conn = true; } else { test_hooks::emit("daemon-wrote-s2c-chunk"); @@ -505,6 +565,20 @@ impl SessionInner { client_stream.try_clone().context("wrapping stream in bufwriter")?, ))); + /* TODO: delete + // TODO: need to figure out how to block until the prompt injection is processed. + // Not clear how to do it without relying on gross sleeps. + if display_motd { + info!("displaying motd"); + let mut client_stream = client_stream_m.lock().unwrap(); + if let Err(e) = self.motd_shower.show(&mut *client_stream) { + warn!("Error showing motd: {}", e); + } + } else { + info!("not displaying motd"); + } + */ + { let reader_ctl = self.reader_ctl.lock().unwrap(); reader_ctl @@ -698,6 +772,7 @@ impl SessionInner { if !partial_keybinding.is_empty() && i < partial_keybinding.len() => { + debug!("NO_MATCH(1): {:02x}", *byte); // TODO: delete // it turned out the partial keybinding match was not // a real match, so flush it to the output stream debug!( @@ -716,10 +791,15 @@ impl SessionInner { partial_keybinding.clear() } NoMatch => { + debug!("NO_MATCH(2): {:02x}", *byte); // TODO: delete partial_keybinding.clear(); } - Partial => partial_keybinding.push(*byte), + Partial => { + debug!("PARTIAL: {:02x}", *byte); // TODO: delete + partial_keybinding.push(*byte); + } Match(action) => { + debug!("MATCH: {:02x}", *byte); // TODO: delete info!("{:?} keybinding action fired", action); let keybinding_len = partial_keybinding.len() + 1; if keybinding_len < i { @@ -737,7 +817,7 @@ impl SessionInner { use keybindings::Action::*; match action { Detach => self.action_detach()?, - NoOp => {} + NoOp => {}, } } } diff --git a/libshpool/src/daemon/show_motd.rs b/libshpool/src/daemon/show_motd.rs new file mode 100644 index 00000000..513ef026 --- /dev/null +++ b/libshpool/src/daemon/show_motd.rs @@ -0,0 +1,84 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::io; + +use anyhow::{anyhow, Context}; + +use super::super::{config, protocol}; + +/// Showers know how to show the message of the day. +#[derive(Debug, Clone)] +pub struct Shower { + motd_resolver: motd::Resolver, + mode: config::MotdDisplayMode, +} + +impl Shower { + /// Make a new Shower. + pub fn new(mode: config::MotdDisplayMode) -> anyhow::Result { + Ok(Shower { + motd_resolver: motd::Resolver::new(motd::PamMotdResolutionStrategy::Auto) + .context("creating motd resolver")?, + mode, + }) + } + + pub fn dump(&self, mut stream: W, term_db: &terminfo::Database) -> anyhow::Result<()> { + assert!(matches!(self.mode, config::MotdDisplayMode::Dump)); + + let motd_value = self.motd_resolver.value(motd::ArgResolutionStrategy::Auto) + .context("resolving motd")?; + let raw_motd_value = Self::convert_to_raw(term_db, &motd_value)?; + + let chunk = protocol::Chunk { + kind: protocol::ChunkKind::Data, + buf: raw_motd_value.as_slice(), + }; + + chunk.write_to(&mut stream).context("dumping motd") + } + + /// Convert the given motd into a byte buffer suitable to be written to the + /// terminal. The only real transformation we perform is injecting carrage + /// returns after newlines. + fn convert_to_raw(term_db: &terminfo::Database, motd: &str) -> anyhow::Result> { + let carrage_return_code = term_db.get::() + .ok_or(anyhow!("no carrage return code"))?; + let carrage_return_bytes = carrage_return_code.expand().to_vec() + .context("expanding carrage return code")?; + + let mut buf: Vec = vec![]; + + let lines = motd.split("\n"); + for line in lines { + buf.extend(line.as_bytes()); + buf.extend("\n".as_bytes()); + buf.extend(&carrage_return_bytes); + } + + Ok(buf) + } + + /* + pub fn pager(&self, _stream: W) -> anyhow::Result<()> { + // plan: + // write it to a tmp file + // fork a background process in a pty + // shuffle the bytes + // return Ok(()) once the pager exits + unimplemented!() + } + */ +} diff --git a/libshpool/src/daemon/trie.rs b/libshpool/src/daemon/trie.rs new file mode 100644 index 00000000..faccf8f0 --- /dev/null +++ b/libshpool/src/daemon/trie.rs @@ -0,0 +1,146 @@ +use std::{collections::HashMap, hash}; + +#[derive(Debug)] +pub struct Trie { + // The nodes which form the tree. The first node is the root + // node, afterwards the order is undefined. + nodes: Vec>, +} + +#[derive(Eq, PartialEq, Copy, Clone, Debug)] +pub enum TrieCursor { + /// A cursor to use to start a char-wise match + Start, + /// Represents a state in the middle or end of a match + Match { idx: usize, is_partial: bool }, + /// A terminal state indicating a failure to match + NoMatch, +} + +#[derive(Debug)] +pub struct TrieNode { + // We need to store a phantom symbol here so we can have the + // Sym type parameter available for the TrieTab trait constraint + // in the impl block. Apologies for the type tetris. + phantom: std::marker::PhantomData, + value: Option, + tab: TT, +} + +impl Trie +where + TT: TrieTab, + Sym: Copy, +{ + pub fn new() -> Self { + Trie { nodes: vec![TrieNode::new(None)] } + } + + /// Insert a seq, value pair into the trie + pub fn insert>(&mut self, seq: Seq, value: V) { + let mut current_node = 0; + for sym in seq { + current_node = if let Some(next_node) = self.nodes[current_node].tab.get(sym) { + *next_node + } else { + let idx = self.nodes.len(); + self.nodes.push(TrieNode::new(None)); + self.nodes[current_node].tab.set(sym, idx); + idx + }; + } + self.nodes[current_node].value = Some(value); + } + + /// Check if the given sequence exists in the trie, used by tests. + #[allow(dead_code)] + pub fn contains>(&self, seq: Seq) -> bool { + let mut match_state = TrieCursor::Start; + for sym in seq { + match_state = self.advance(match_state, sym); + if let TrieCursor::NoMatch = match_state { + return false; + } + } + if let TrieCursor::Start = match_state { + return self.nodes[0].value.is_some(); + } + + if let TrieCursor::Match { is_partial, .. } = match_state { !is_partial } else { false } + } + + /// Process a single token of input, returning the current state. + /// To start a new match, use TrieCursor::Start. + pub fn advance(&self, cursor: TrieCursor, sym: Sym) -> TrieCursor { + let node = match cursor { + TrieCursor::Start => &self.nodes[0], + TrieCursor::Match { idx, .. } => &self.nodes[idx], + TrieCursor::NoMatch => return TrieCursor::NoMatch, + }; + + if let Some(idx) = node.tab.get(sym) { + TrieCursor::Match { idx: *idx, is_partial: self.nodes[*idx].value.is_none() } + } else { + TrieCursor::NoMatch + } + } + + /// Get the value for a match cursor. + pub fn get(&self, cursor: TrieCursor) -> Option<&V> { + if let TrieCursor::Match { idx, .. } = cursor { + self.nodes[idx].value.as_ref() + } else { + None + } + } +} + +impl TrieNode +where + TT: TrieTab, +{ + fn new(value: Option) -> Self { + TrieNode { phantom: std::marker::PhantomData, value, tab: TT::new() } + } +} + +/// The backing table the trie uses to associate symbols with state +/// indexes. This is basically std::ops::IndexMut plus a new function. +/// We can't just make this a sub-trait of IndexMut because u8 does +/// not implement IndexMut for vectors. +pub trait TrieTab { + fn new() -> Self; + fn get(&self, index: Idx) -> Option<&usize>; + fn set(&mut self, index: Idx, elem: usize); +} + +impl TrieTab for HashMap +where + Sym: hash::Hash + Eq + PartialEq, +{ + fn new() -> Self { + HashMap::new() + } + + fn get(&self, index: Sym) -> Option<&usize> { + self.get(&index) + } + + fn set(&mut self, index: Sym, elem: usize) { + self.insert(index, elem); + } +} + +impl TrieTab for Vec> { + fn new() -> Self { + vec![None; u8::MAX as usize] + } + + fn get(&self, index: u8) -> Option<&usize> { + self[index as usize].as_ref() + } + + fn set(&mut self, index: u8, elem: usize) { + self[index as usize] = Some(elem) + } +} diff --git a/libshpool/src/protocol.rs b/libshpool/src/protocol.rs index da8ec5a3..18da60dd 100644 --- a/libshpool/src/protocol.rs +++ b/libshpool/src/protocol.rs @@ -218,7 +218,7 @@ impl fmt::Display for SessionStatus { } /// AttachStatus indicates what happened during an attach attempt. -#[derive(PartialEq, Eq, Serialize, Deserialize, Debug)] +#[derive(PartialEq, Eq, Serialize, Deserialize, Debug, Clone)] pub enum AttachStatus { /// Attached indicates that there was an existing shell session with /// the given name, and `shpool attach` successfully connected to it. @@ -417,6 +417,7 @@ impl Client { chunk.kind, chunk.buf.len() ); + trace!("RAW: {}", chunk.buf.iter().map(|b| format!("{:02x}", b)).collect::>().join(" ")); // TODO: delete } match chunk.kind { diff --git a/shpool/Cargo.toml b/shpool/Cargo.toml index 7f5c7ecc..f2a43533 100644 --- a/shpool/Cargo.toml +++ b/shpool/Cargo.toml @@ -19,6 +19,7 @@ rust-version = "1.74" clap = { version = "4", features = ["derive"] } # cli parsing anyhow = "1" # dynamic, unstructured errors libshpool = { version = "0.5.0", path = "../libshpool" } +motd = "0.2.0" # getting the message-of-the-day [dev-dependencies] lazy_static = "1" # globals diff --git a/shpool/src/main.rs b/shpool/src/main.rs index d2852d97..fbd306d6 100644 --- a/shpool/src/main.rs +++ b/shpool/src/main.rs @@ -20,6 +20,8 @@ use clap::Parser; const VERSION: &str = env!("CARGO_PKG_VERSION"); fn main() -> anyhow::Result<()> { + motd::handle_reexec(); + let args = libshpool::Args::parse(); if args.version() {