From c21c5615b9d83e524c148df757ada75ef5e7bc90 Mon Sep 17 00:00:00 2001 From: irving ou Date: Thu, 19 Sep 2024 13:57:19 -0400 Subject: [PATCH 01/15] feat(wasm): Add wasm support, update cryptovec --- cryptovec/Cargo.toml | 3 + cryptovec/src/cryptovec.rs | 471 ++++++++++++++++++++++++++++++ cryptovec/src/lib.rs | 428 +-------------------------- cryptovec/src/platform/mod.rs | 41 +++ cryptovec/src/platform/unix.rs | 27 ++ cryptovec/src/platform/wasm.rs | 15 + cryptovec/src/platform/windows.rs | 23 ++ 7 files changed, 585 insertions(+), 423 deletions(-) create mode 100644 cryptovec/src/cryptovec.rs create mode 100644 cryptovec/src/platform/mod.rs create mode 100644 cryptovec/src/platform/unix.rs create mode 100644 cryptovec/src/platform/wasm.rs create mode 100644 cryptovec/src/platform/windows.rs diff --git a/cryptovec/Cargo.toml b/cryptovec/Cargo.toml index 92450230..c210ba87 100644 --- a/cryptovec/Cargo.toml +++ b/cryptovec/Cargo.toml @@ -12,3 +12,6 @@ version = "0.7.0" [dependencies] libc = "0.2" winapi = {version = "0.3", features = ["basetsd", "minwindef", "memoryapi"]} + +[dev-dependencies] +wasm-bindgen-test = "0.3" \ No newline at end of file diff --git a/cryptovec/src/cryptovec.rs b/cryptovec/src/cryptovec.rs new file mode 100644 index 00000000..0e7b63f5 --- /dev/null +++ b/cryptovec/src/cryptovec.rs @@ -0,0 +1,471 @@ +use crate::platform; +use std::ops::{Deref, DerefMut, Index, IndexMut, Range, RangeFrom, RangeFull, RangeTo}; + +/// A buffer which zeroes its memory on `.clear()`, `.resize()`, and +/// reallocations, to avoid copying secrets around. +#[derive(Debug)] +pub struct CryptoVec { + p: *mut u8, // `pub(crate)` allows access from platform modules + size: usize, + capacity: usize, +} + +impl Unpin for CryptoVec {} +unsafe impl Send for CryptoVec {} +unsafe impl Sync for CryptoVec {} + +// Common traits implementations +impl AsRef<[u8]> for CryptoVec { + fn as_ref(&self) -> &[u8] { + self.deref() + } +} + +impl AsMut<[u8]> for CryptoVec { + fn as_mut(&mut self) -> &mut [u8] { + self.deref_mut() + } +} + +impl Deref for CryptoVec { + type Target = [u8]; + fn deref(&self) -> &[u8] { + unsafe { std::slice::from_raw_parts(self.p, self.size) } + } +} + +impl DerefMut for CryptoVec { + fn deref_mut(&mut self) -> &mut [u8] { + unsafe { std::slice::from_raw_parts_mut(self.p, self.size) } + } +} + +impl From for CryptoVec { + fn from(e: String) -> Self { + CryptoVec::from(e.into_bytes()) + } +} + +impl From> for CryptoVec { + fn from(e: Vec) -> Self { + let mut c = CryptoVec::new_zeroed(e.len()); + c.clone_from_slice(&e[..]); + c + } +} + +// Indexing implementations +impl Index> for CryptoVec { + type Output = [u8]; + fn index(&self, index: RangeFrom) -> &[u8] { + self.deref().index(index) + } +} +impl Index> for CryptoVec { + type Output = [u8]; + fn index(&self, index: RangeTo) -> &[u8] { + self.deref().index(index) + } +} +impl Index> for CryptoVec { + type Output = [u8]; + fn index(&self, index: Range) -> &[u8] { + self.deref().index(index) + } +} +impl Index for CryptoVec { + type Output = [u8]; + fn index(&self, _: RangeFull) -> &[u8] { + self.deref() + } +} + +impl IndexMut for CryptoVec { + fn index_mut(&mut self, _: RangeFull) -> &mut [u8] { + self.deref_mut() + } +} +impl IndexMut> for CryptoVec { + fn index_mut(&mut self, index: RangeFrom) -> &mut [u8] { + self.deref_mut().index_mut(index) + } +} +impl IndexMut> for CryptoVec { + fn index_mut(&mut self, index: RangeTo) -> &mut [u8] { + self.deref_mut().index_mut(index) + } +} +impl IndexMut> for CryptoVec { + fn index_mut(&mut self, index: Range) -> &mut [u8] { + self.deref_mut().index_mut(index) + } +} + +impl Index for CryptoVec { + type Output = u8; + fn index(&self, index: usize) -> &u8 { + self.deref().index(index) + } +} + +// IO-related implementation +impl std::io::Write for CryptoVec { + fn write(&mut self, buf: &[u8]) -> Result { + self.extend(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> Result<(), std::io::Error> { + Ok(()) + } +} + +// Default implementation +impl Default for CryptoVec { + fn default() -> Self { + CryptoVec { + p: std::ptr::NonNull::dangling().as_ptr(), + size: 0, + capacity: 0, + } + } +} + +// Memory management methods +impl CryptoVec { + /// Creates a new `CryptoVec`. + pub fn new() -> CryptoVec { + CryptoVec::default() + } + + /// Creates a new `CryptoVec` with `n` zeros. + pub fn new_zeroed(size: usize) -> CryptoVec { + unsafe { + let capacity = size.next_power_of_two(); + let layout = std::alloc::Layout::from_size_align_unchecked(capacity, 1); + let p = std::alloc::alloc_zeroed(layout); + platform::mlock(p, capacity); + CryptoVec { p, capacity, size } + } + } + + /// Resize this CryptoVec, appending zeros at the end. This may + /// perform at most one reallocation, overwriting the previous + /// version with zeros. + pub fn resize(&mut self, size: usize) { + if size <= self.capacity && size > self.size { + // If this is an expansion, just resize. + self.size = size + } else if size <= self.size { + unsafe { + platform::memset(self.p.add(size), 0, self.size - size); + } + self.size = size; + } else { + // realloc ! and erase the previous memory. + unsafe { + let next_capacity = size.next_power_of_two(); + let old_ptr = self.p; + let next_layout = std::alloc::Layout::from_size_align_unchecked(next_capacity, 1); + let new_ptr = std::alloc::alloc_zeroed(next_layout); + if new_ptr.is_null() { + #[allow(clippy::panic)] + { + panic!("Realloc failed, pointer = {:?} {:?}", self, size) + } + } + + self.p = new_ptr; + platform::mlock(self.p, next_capacity); + + if self.capacity > 0 { + std::ptr::copy_nonoverlapping(old_ptr, self.p, self.size); + for i in 0..self.size { + std::ptr::write_volatile(old_ptr.add(i), 0) + } + platform::munlock(old_ptr, self.capacity); + let layout = std::alloc::Layout::from_size_align_unchecked(self.capacity, 1); + std::alloc::dealloc(old_ptr, layout); + } + + self.capacity = next_capacity; + self.size = size; + } + } + } + + pub fn clear(&mut self) { + self.resize(0); + } + + pub fn push(&mut self, s: u8) { + let size = self.size; + self.resize(size + 1); + unsafe { *self.p.add(size) = s } + } + + pub fn extend(&mut self, s: &[u8]) { + let size = self.size; + self.resize(size + s.len()); + unsafe { + std::ptr::copy_nonoverlapping(s.as_ptr(), self.p.add(size), s.len()); + } + } + + pub fn push_u32_be(&mut self, s: u32) { + let s = s.to_be(); + let x: [u8; 4] = s.to_ne_bytes(); + self.extend(&x) + } + + pub fn read_u32_be(&self, i: usize) -> u32 { + assert!(i + 4 <= self.size); + let mut x: u32 = 0; + unsafe { + std::ptr::copy_nonoverlapping(self.p.add(i) as *const u32, &mut x as *mut u32, 1); + } + u32::from_be(x) + } + + pub fn write_all_from( + &self, + offset: usize, + mut w: W, + ) -> Result { + assert!(offset < self.size); + // if we're past this point, self.p cannot be null. + unsafe { + let s = std::slice::from_raw_parts(self.p.add(offset), self.size - offset); + w.write(s) + } + } + + pub fn resize_mut(&mut self, n: usize) -> &mut [u8] { + let size = self.size; + self.resize(size + n); + unsafe { std::slice::from_raw_parts_mut(self.p.add(size), n) } + } + + pub fn from_slice(s: &[u8]) -> CryptoVec { + let mut v = CryptoVec::new(); + v.resize(s.len()); + unsafe { + std::ptr::copy_nonoverlapping(s.as_ptr(), v.p, s.len()); + } + v + } +} + +impl Clone for CryptoVec { + fn clone(&self) -> Self { + let mut v = Self::new(); + v.extend(self); + v + } +} + +// Drop implementation +impl Drop for CryptoVec { + fn drop(&mut self) { + if self.capacity > 0 { + unsafe { + for i in 0..self.size { + std::ptr::write_volatile(self.p.add(i), 0); + } + platform::munlock(self.p, self.capacity); + let layout = std::alloc::Layout::from_size_align_unchecked(self.capacity, 1); + std::alloc::dealloc(self.p, layout); + } + } + } +} + +#[cfg(test)] +mod test { + + use wasm_bindgen_test::wasm_bindgen_test; + + use super::CryptoVec; + + #[cfg(target_arch = "wasm32")] + wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); + + #[test] + #[wasm_bindgen_test] + fn test_new() { + let crypto_vec = CryptoVec::new(); + assert_eq!(crypto_vec.size, 0); + assert_eq!(crypto_vec.capacity, 0); + } + + #[test] + #[wasm_bindgen_test] + fn test_resize_expand() { + let mut crypto_vec = CryptoVec::new_zeroed(5); + crypto_vec.resize(10); + assert_eq!(crypto_vec.size, 10); + assert!(crypto_vec.capacity >= 10); + assert!(crypto_vec.iter().skip(5).all(|&x| x == 0)); // Ensure newly added elements are zeroed + } + + #[test] + #[wasm_bindgen_test] + fn test_resize_shrink() { + let mut crypto_vec = CryptoVec::new_zeroed(10); + crypto_vec.resize(5); + assert_eq!(crypto_vec.size, 5); + // Ensure shrinking keeps the previous elements intact + assert_eq!(crypto_vec.len(), 5); + } + + #[test] + #[wasm_bindgen_test] + fn test_push() { + let mut crypto_vec = CryptoVec::new(); + crypto_vec.push(1); + crypto_vec.push(2); + assert_eq!(crypto_vec.size, 2); + assert_eq!(crypto_vec[0], 1); + assert_eq!(crypto_vec[1], 2); + } + + #[test] + #[wasm_bindgen_test] + fn test_write_trait() { + use std::io::Write; + + let mut crypto_vec = CryptoVec::new(); + let bytes_written = crypto_vec.write(&[1, 2, 3]).unwrap(); + assert_eq!(bytes_written, 3); + assert_eq!(crypto_vec.size, 3); + assert_eq!(crypto_vec.as_ref(), &[1, 2, 3]); + } + + #[test] + #[wasm_bindgen_test] + fn test_as_ref_as_mut() { + let mut crypto_vec = CryptoVec::new_zeroed(5); + let slice_ref: &[u8] = crypto_vec.as_ref(); + assert_eq!(slice_ref.len(), 5); + let slice_mut: &mut [u8] = crypto_vec.as_mut(); + slice_mut[0] = 1; + assert_eq!(crypto_vec[0], 1); + } + + #[test] + #[wasm_bindgen_test] + fn test_from_string() { + let input = String::from("hello"); + let crypto_vec: CryptoVec = input.into(); + assert_eq!(crypto_vec.as_ref(), b"hello"); + } + + #[test] + #[wasm_bindgen_test] + fn test_from_vec() { + let input = vec![1, 2, 3, 4]; + let crypto_vec: CryptoVec = input.into(); + assert_eq!(crypto_vec.as_ref(), &[1, 2, 3, 4]); + } + + #[test] + #[wasm_bindgen_test] + fn test_index() { + let crypto_vec = CryptoVec::from(vec![1, 2, 3, 4, 5]); + assert_eq!(crypto_vec[0], 1); + assert_eq!(crypto_vec[4], 5); + assert_eq!(&crypto_vec[1..3], &[2, 3]); + } + + #[test] + #[wasm_bindgen_test] + fn test_drop() { + let mut crypto_vec = CryptoVec::new_zeroed(10); + // Ensure vector is filled with non-zero data + crypto_vec.extend(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + drop(crypto_vec); + + // Check that memory zeroing was done during the drop + // This part is more difficult to test directly since it involves + // private memory management. However, with Rust's unsafe features, + // it may be checked using tools like Valgrind or manual inspection. + } + + #[test] + #[wasm_bindgen_test] + fn test_new_zeroed() { + let crypto_vec = CryptoVec::new_zeroed(10); + assert_eq!(crypto_vec.size, 10); + assert!(crypto_vec.capacity >= 10); + assert!(crypto_vec.iter().all(|&x| x == 0)); // Ensure all bytes are zeroed + } + + #[test] + #[wasm_bindgen_test] + fn test_push_u32_be() { + let mut crypto_vec = CryptoVec::new(); + let value = 43554u32; + crypto_vec.push_u32_be(value); + assert_eq!(crypto_vec.len(), 4); // u32 is 4 bytes long + assert_eq!(crypto_vec.read_u32_be(0), value); + } + + #[test] + #[wasm_bindgen_test] + + fn test_read_u32_be() { + let mut crypto_vec = CryptoVec::new(); + let value = 99485710u32; + crypto_vec.push_u32_be(value); + assert_eq!(crypto_vec.read_u32_be(0), value); + } + + #[test] + #[wasm_bindgen_test] + fn test_clear() { + let mut crypto_vec = CryptoVec::new(); + crypto_vec.extend(b"blabla"); + crypto_vec.clear(); + assert!(crypto_vec.is_empty()); + } + + #[test] + #[wasm_bindgen_test] + fn test_extend() { + let mut crypto_vec = CryptoVec::new(); + crypto_vec.extend(b"test"); + assert_eq!(crypto_vec.as_ref(), b"test"); + } + + #[test] + #[wasm_bindgen_test] + fn test_write_all_from() { + let mut crypto_vec = CryptoVec::new(); + crypto_vec.extend(b"blabla"); + + let mut output: Vec = Vec::new(); + let written_size = crypto_vec.write_all_from(0, &mut output).unwrap(); + assert_eq!(written_size, 6); // "blabla" has 6 bytes + assert_eq!(output, b"blabla"); + } + + #[test] + #[wasm_bindgen_test] + fn test_resize_mut() { + let mut crypto_vec = CryptoVec::new(); + crypto_vec.resize_mut(4).clone_from_slice(b"test"); + assert_eq!(crypto_vec.as_ref(), b"test"); + } + + #[cfg(target_pointer_width = "64")] + #[test] + fn test_large_resize_panics() { + let result = std::panic::catch_unwind(|| { + let mut vec = CryptoVec::new(); + vec.push(42); // Write something into the vector + + vec.resize(1_000_000_000_000); // Intentionally large resize + }); + assert!(result.is_err()); // Expecting a panic on large allocation + } +} diff --git a/cryptovec/src/lib.rs b/cryptovec/src/lib.rs index 8ecd1f0d..056b4f55 100644 --- a/cryptovec/src/lib.rs +++ b/cryptovec/src/lib.rs @@ -4,428 +4,10 @@ clippy::indexing_slicing, clippy::panic )] -// Copyright 2016 Pierre-Étienne Meunier -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -use std::ops::{Deref, DerefMut, Index, IndexMut, Range, RangeFrom, RangeFull, RangeTo}; -use libc::c_void; -#[cfg(not(windows))] -use libc::size_t; +// Re-export CryptoVec from the cryptovec module +mod cryptovec; +pub use cryptovec::CryptoVec; -/// A buffer which zeroes its memory on `.clear()`, `.resize()` and -/// reallocations, to avoid copying secrets around. -#[derive(Debug)] -pub struct CryptoVec { - p: *mut u8, - size: usize, - capacity: usize, -} - -impl Unpin for CryptoVec {} - -unsafe impl Send for CryptoVec {} -unsafe impl Sync for CryptoVec {} - -impl AsRef<[u8]> for CryptoVec { - fn as_ref(&self) -> &[u8] { - self.deref() - } -} -impl AsMut<[u8]> for CryptoVec { - fn as_mut(&mut self) -> &mut [u8] { - self.deref_mut() - } -} -impl Deref for CryptoVec { - type Target = [u8]; - fn deref(&self) -> &[u8] { - unsafe { std::slice::from_raw_parts(self.p, self.size) } - } -} -impl DerefMut for CryptoVec { - fn deref_mut(&mut self) -> &mut [u8] { - unsafe { std::slice::from_raw_parts_mut(self.p, self.size) } - } -} - -impl From for CryptoVec { - fn from(e: String) -> Self { - CryptoVec::from(e.into_bytes()) - } -} - -impl From> for CryptoVec { - fn from(e: Vec) -> Self { - let mut c = CryptoVec::new_zeroed(e.len()); - c.clone_from_slice(&e[..]); - c - } -} - -impl Index> for CryptoVec { - type Output = [u8]; - fn index(&self, index: RangeFrom) -> &[u8] { - self.deref().index(index) - } -} -impl Index> for CryptoVec { - type Output = [u8]; - fn index(&self, index: RangeTo) -> &[u8] { - self.deref().index(index) - } -} -impl Index> for CryptoVec { - type Output = [u8]; - fn index(&self, index: Range) -> &[u8] { - self.deref().index(index) - } -} -impl Index for CryptoVec { - type Output = [u8]; - fn index(&self, _: RangeFull) -> &[u8] { - self.deref() - } -} -impl IndexMut for CryptoVec { - fn index_mut(&mut self, _: RangeFull) -> &mut [u8] { - self.deref_mut() - } -} - -impl IndexMut> for CryptoVec { - fn index_mut(&mut self, index: RangeFrom) -> &mut [u8] { - self.deref_mut().index_mut(index) - } -} -impl IndexMut> for CryptoVec { - fn index_mut(&mut self, index: RangeTo) -> &mut [u8] { - self.deref_mut().index_mut(index) - } -} -impl IndexMut> for CryptoVec { - fn index_mut(&mut self, index: Range) -> &mut [u8] { - self.deref_mut().index_mut(index) - } -} - -impl Index for CryptoVec { - type Output = u8; - fn index(&self, index: usize) -> &u8 { - self.deref().index(index) - } -} - -impl std::io::Write for CryptoVec { - fn write(&mut self, buf: &[u8]) -> Result { - self.extend(buf); - Ok(buf.len()) - } - fn flush(&mut self) -> Result<(), std::io::Error> { - Ok(()) - } -} - -impl Default for CryptoVec { - fn default() -> Self { - CryptoVec { - p: std::ptr::NonNull::dangling().as_ptr(), - size: 0, - capacity: 0, - } - } -} - -#[cfg(not(windows))] -unsafe fn mlock(ptr: *const u8, len: usize) { - libc::mlock(ptr as *const c_void, len as size_t); -} -#[cfg(not(windows))] -unsafe fn munlock(ptr: *const u8, len: usize) { - libc::munlock(ptr as *const c_void, len as size_t); -} - -#[cfg(windows)] -use winapi::shared::basetsd::SIZE_T; -#[cfg(windows)] -use winapi::shared::minwindef::LPVOID; -#[cfg(windows)] -use winapi::um::memoryapi::{VirtualLock, VirtualUnlock}; -#[cfg(windows)] -unsafe fn mlock(ptr: *const u8, len: usize) { - VirtualLock(ptr as LPVOID, len as SIZE_T); -} -#[cfg(windows)] -unsafe fn munlock(ptr: *const u8, len: usize) { - VirtualUnlock(ptr as LPVOID, len as SIZE_T); -} - -impl Clone for CryptoVec { - fn clone(&self) -> Self { - let mut v = Self::new(); - v.extend(self); - v - } -} - -impl CryptoVec { - /// Creates a new `CryptoVec`. - pub fn new() -> CryptoVec { - CryptoVec::default() - } - - /// Creates a new `CryptoVec` with `n` zeros. - pub fn new_zeroed(size: usize) -> CryptoVec { - unsafe { - let capacity = size.next_power_of_two(); - let layout = std::alloc::Layout::from_size_align_unchecked(capacity, 1); - let p = std::alloc::alloc_zeroed(layout); - mlock(p, capacity); - CryptoVec { p, capacity, size } - } - } - - /// Creates a new `CryptoVec` with capacity `capacity`. - pub fn with_capacity(capacity: usize) -> CryptoVec { - unsafe { - let capacity = capacity.next_power_of_two(); - let layout = std::alloc::Layout::from_size_align_unchecked(capacity, 1); - let p = std::alloc::alloc_zeroed(layout); - mlock(p, capacity); - CryptoVec { - p, - capacity, - size: 0, - } - } - } - - /// Length of this `CryptoVec`. - /// - /// ``` - /// assert_eq!(russh_cryptovec::CryptoVec::new().len(), 0) - /// ``` - pub fn len(&self) -> usize { - self.size - } - - /// Returns `true` if and only if this CryptoVec is empty. - /// - /// ``` - /// assert!(russh_cryptovec::CryptoVec::new().is_empty()) - /// ``` - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Resize this CryptoVec, appending zeros at the end. This may - /// perform at most one reallocation, overwriting the previous - /// version with zeros. - pub fn resize(&mut self, size: usize) { - if size <= self.capacity && size > self.size { - // If this is an expansion, just resize. - self.size = size - } else if size <= self.size { - // If this is a truncation, resize and erase the extra memory. - unsafe { - libc::memset(self.p.add(size) as *mut c_void, 0, self.size - size); - } - self.size = size; - } else { - // realloc ! and erase the previous memory. - unsafe { - let next_capacity = size.next_power_of_two(); - let old_ptr = self.p; - let next_layout = std::alloc::Layout::from_size_align_unchecked(next_capacity, 1); - self.p = std::alloc::alloc_zeroed(next_layout); - mlock(self.p, next_capacity); - - if self.capacity > 0 { - std::ptr::copy_nonoverlapping(old_ptr, self.p, self.size); - for i in 0..self.size { - std::ptr::write_volatile(old_ptr.add(i), 0) - } - munlock(old_ptr, self.capacity); - let layout = std::alloc::Layout::from_size_align_unchecked(self.capacity, 1); - std::alloc::dealloc(old_ptr, layout); - } - - if self.p.is_null() { - #[allow(clippy::panic)] - { - panic!("Realloc failed, pointer = {:?} {:?}", self, size) - } - } else { - self.capacity = next_capacity; - self.size = size; - } - } - } - } - - /// Clear this CryptoVec (retaining the memory). - /// - /// ``` - /// let mut v = russh_cryptovec::CryptoVec::new(); - /// v.extend(b"blabla"); - /// v.clear(); - /// assert!(v.is_empty()) - /// ``` - pub fn clear(&mut self) { - self.resize(0); - } - - /// Append a new byte at the end of this CryptoVec. - pub fn push(&mut self, s: u8) { - let size = self.size; - self.resize(size + 1); - unsafe { *self.p.add(size) = s } - } - - /// Append a new u32, big endian-encoded, at the end of this CryptoVec. - /// - /// ``` - /// let mut v = russh_cryptovec::CryptoVec::new(); - /// let n = 43554; - /// v.push_u32_be(n); - /// assert_eq!(n, v.read_u32_be(0)) - /// ``` - pub fn push_u32_be(&mut self, s: u32) { - let s = s.to_be(); - let x: [u8; 4] = s.to_ne_bytes(); - self.extend(&x) - } - - /// Read a big endian-encoded u32 from this CryptoVec, with the - /// first byte at position `i`. - /// - /// ``` - /// let mut v = russh_cryptovec::CryptoVec::new(); - /// let n = 99485710; - /// v.push_u32_be(n); - /// assert_eq!(n, v.read_u32_be(0)) - /// ``` - pub fn read_u32_be(&self, i: usize) -> u32 { - assert!(i + 4 <= self.size); - let mut x: u32 = 0; - unsafe { - libc::memcpy( - (&mut x) as *mut u32 as *mut c_void, - self.p.add(i) as *const c_void, - 4, - ); - } - u32::from_be(x) - } - - /// Read `n_bytes` from `r`, and append them at the end of this - /// `CryptoVec`. Returns the number of bytes read (and appended). - pub fn read( - &mut self, - n_bytes: usize, - mut r: R, - ) -> Result { - let cur_size = self.size; - self.resize(cur_size + n_bytes); - let s = unsafe { std::slice::from_raw_parts_mut(self.p.add(cur_size), n_bytes) }; - // Resize the buffer to its appropriate size. - match r.read(s) { - Ok(n) => { - self.resize(cur_size + n); - Ok(n) - } - Err(e) => { - self.resize(cur_size); - Err(e) - } - } - } - - /// Write all this CryptoVec to the provided `Write`. Returns the - /// number of bytes actually written. - /// - /// ``` - /// let mut v = russh_cryptovec::CryptoVec::new(); - /// v.extend(b"blabla"); - /// let mut s = std::io::stdout(); - /// v.write_all_from(0, &mut s).unwrap(); - /// ``` - pub fn write_all_from( - &self, - offset: usize, - mut w: W, - ) -> Result { - assert!(offset < self.size); - // if we're past this point, self.p cannot be null. - unsafe { - let s = std::slice::from_raw_parts(self.p.add(offset), self.size - offset); - w.write(s) - } - } - - /// Resize this CryptoVec, returning a mutable borrow to the extra bytes. - /// - /// ``` - /// let mut v = russh_cryptovec::CryptoVec::new(); - /// v.resize_mut(4).clone_from_slice(b"test"); - /// ``` - pub fn resize_mut(&mut self, n: usize) -> &mut [u8] { - let size = self.size; - self.resize(size + n); - unsafe { std::slice::from_raw_parts_mut(self.p.add(size), n) } - } - - /// Append a slice at the end of this CryptoVec. - /// - /// ``` - /// let mut v = russh_cryptovec::CryptoVec::new(); - /// v.extend(b"test"); - /// ``` - pub fn extend(&mut self, s: &[u8]) { - let size = self.size; - self.resize(size + s.len()); - unsafe { - std::ptr::copy_nonoverlapping(s.as_ptr(), self.p.add(size), s.len()); - } - } - - /// Create a `CryptoVec` from a slice - /// - /// ``` - /// russh_cryptovec::CryptoVec::from_slice(b"test"); - /// ``` - pub fn from_slice(s: &[u8]) -> CryptoVec { - let mut v = CryptoVec::new(); - v.resize(s.len()); - unsafe { - std::ptr::copy_nonoverlapping(s.as_ptr(), v.p, s.len()); - } - v - } -} - -impl Drop for CryptoVec { - fn drop(&mut self) { - if self.capacity > 0 { - unsafe { - for i in 0..self.size { - std::ptr::write_volatile(self.p.add(i), 0) - } - munlock(self.p, self.capacity); - let layout = std::alloc::Layout::from_size_align_unchecked(self.capacity, 1); - std::alloc::dealloc(self.p, layout); - } - } - } -} +// Platform-specific modules +mod platform; \ No newline at end of file diff --git a/cryptovec/src/platform/mod.rs b/cryptovec/src/platform/mod.rs new file mode 100644 index 00000000..ab23326c --- /dev/null +++ b/cryptovec/src/platform/mod.rs @@ -0,0 +1,41 @@ +#[cfg(windows)] +mod windows; + +#[cfg(not(windows))] +#[cfg(not(target_arch = "wasm32"))] +mod unix; + +#[cfg(target_arch = "wasm32")] +mod wasm; + +// Re-export functions based on the platform +#[cfg(windows)] +pub use windows::{munlock,mlock,memset}; + +#[cfg(not(windows))] +#[cfg(not(target_arch = "wasm32"))] +pub use unix::{munlock, mlock, memset}; + +#[cfg(target_arch = "wasm32")] +pub use wasm::{munlock, mlock, memset}; + +#[cfg(test)] +mod tests { + use wasm_bindgen_test::wasm_bindgen_test; + + use super::*; + + #[wasm_bindgen_test] + fn test_memset() { + let mut buf = vec![0u8; 10]; + memset(buf.as_mut_ptr(), 0xff, buf.len()); + assert_eq!(buf, vec![0xff; 10]); + } + + #[wasm_bindgen_test] + fn test_memset_partial() { + let mut buf = vec![0u8; 10]; + memset(buf.as_mut_ptr(), 0xff, 5); + assert_eq!(buf, [0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0]); + } +} \ No newline at end of file diff --git a/cryptovec/src/platform/unix.rs b/cryptovec/src/platform/unix.rs new file mode 100644 index 00000000..45707c0f --- /dev/null +++ b/cryptovec/src/platform/unix.rs @@ -0,0 +1,27 @@ +use crate::CryptoVec; +use libc::{mlock, munlock, c_void}; +use std::alloc; + + +/// Unlock memory on drop for Unix-based systems. +pub fn munlock(ptr: *const u8, len: usize) { + unsafe { + if munlock(ptr as *const c_void, len) != 0 { + panic!("Failed to unlock memory."); + } + } +} + +pub fn mlock (ptr: *const u8, len: usize) { + unsafe { + if mlock(ptr as *const c_void, len) != 0 { + panic!("Failed to lock memory."); + } + } +} + +pub fn memset(ptr: *mut u8, value: i32, size: usize) { + unsafe { + libc::memset(ptr as *mut c_void, value, size); + } +} \ No newline at end of file diff --git a/cryptovec/src/platform/wasm.rs b/cryptovec/src/platform/wasm.rs new file mode 100644 index 00000000..fd41c3ec --- /dev/null +++ b/cryptovec/src/platform/wasm.rs @@ -0,0 +1,15 @@ +// WASM does not support synchronization primitives +pub fn munlock(_ptr: *const u8, _len: usize) { + // No-op +} + +pub fn mlock(_ptr: *const u8, _len: usize) -> i32 { + 0 +} + +pub fn memset(ptr: *mut u8, value: i32, size: usize) { + let byte_value = value as u8; // Extract the least significant byte directly + unsafe { + std::ptr::write_bytes(ptr, byte_value, size); + } +} diff --git a/cryptovec/src/platform/windows.rs b/cryptovec/src/platform/windows.rs new file mode 100644 index 00000000..930dbbf5 --- /dev/null +++ b/cryptovec/src/platform/windows.rs @@ -0,0 +1,23 @@ +use winapi::shared::{basetsd::SIZE_T, minwindef::LPVOID}; +use winapi::um::memoryapi::{VirtualLock, VirtualUnlock}; + +use libc::c_void; + +/// Unlock memory on drop for Windows. +pub fn munlock(ptr: *const u8, len: usize) { + unsafe { + VirtualUnlock(ptr as LPVOID, len as SIZE_T); + } +} + +pub fn mlock(ptr: *const u8, len: usize) { + unsafe { + VirtualLock(ptr as LPVOID, len as SIZE_T); + } +} + +pub fn memset(ptr: *mut u8, value: i32, size: usize) { + unsafe { + libc::memset(ptr as *mut c_void, value, size); + } +} \ No newline at end of file From ddfa6be9746205af397ea8a3964739c66deed465 Mon Sep 17 00:00:00 2001 From: irving ou Date: Thu, 19 Sep 2024 14:10:22 -0400 Subject: [PATCH 02/15] fix test and comments --- cryptovec/src/cryptovec.rs | 190 +++++++++++++++++++++--------- cryptovec/src/lib.rs | 15 +++ cryptovec/src/platform/mod.rs | 8 +- cryptovec/src/platform/unix.rs | 6 + cryptovec/src/platform/wasm.rs | 10 ++ cryptovec/src/platform/windows.rs | 6 + 6 files changed, 176 insertions(+), 59 deletions(-) diff --git a/cryptovec/src/cryptovec.rs b/cryptovec/src/cryptovec.rs index 0e7b63f5..ec1c3446 100644 --- a/cryptovec/src/cryptovec.rs +++ b/cryptovec/src/cryptovec.rs @@ -1,4 +1,4 @@ -use crate::platform; +use crate::platform::{self, memcpy, memset, mlock, munlock}; use std::ops::{Deref, DerefMut, Index, IndexMut, Range, RangeFrom, RangeFull, RangeTo}; /// A buffer which zeroes its memory on `.clear()`, `.resize()`, and @@ -131,7 +131,6 @@ impl Default for CryptoVec { } } -// Memory management methods impl CryptoVec { /// Creates a new `CryptoVec`. pub fn new() -> CryptoVec { @@ -144,11 +143,44 @@ impl CryptoVec { let capacity = size.next_power_of_two(); let layout = std::alloc::Layout::from_size_align_unchecked(capacity, 1); let p = std::alloc::alloc_zeroed(layout); - platform::mlock(p, capacity); + mlock(p, capacity); CryptoVec { p, capacity, size } } } + /// Creates a new `CryptoVec` with capacity `capacity`. + pub fn with_capacity(capacity: usize) -> CryptoVec { + unsafe { + let capacity = capacity.next_power_of_two(); + let layout = std::alloc::Layout::from_size_align_unchecked(capacity, 1); + let p = std::alloc::alloc_zeroed(layout); + mlock(p, capacity); + CryptoVec { + p, + capacity, + size: 0, + } + } + } + + /// Length of this `CryptoVec`. + /// + /// ``` + /// assert_eq!(russh_cryptovec::CryptoVec::new().len(), 0) + /// ``` + pub fn len(&self) -> usize { + self.size + } + + /// Returns `true` if and only if this CryptoVec is empty. + /// + /// ``` + /// assert!(russh_cryptovec::CryptoVec::new().is_empty()) + /// ``` + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + /// Resize this CryptoVec, appending zeros at the end. This may /// perform at most one reallocation, overwriting the previous /// version with zeros. @@ -157,8 +189,9 @@ impl CryptoVec { // If this is an expansion, just resize. self.size = size } else if size <= self.size { + // If this is a truncation, resize and erase the extra memory. unsafe { - platform::memset(self.p.add(size), 0, self.size - size); + memset(self.p.add(size), 0, self.size - size); } self.size = size; } else { @@ -167,66 +200,115 @@ impl CryptoVec { let next_capacity = size.next_power_of_two(); let old_ptr = self.p; let next_layout = std::alloc::Layout::from_size_align_unchecked(next_capacity, 1); - let new_ptr = std::alloc::alloc_zeroed(next_layout); - if new_ptr.is_null() { - #[allow(clippy::panic)] - { - panic!("Realloc failed, pointer = {:?} {:?}", self, size) - } - } - - self.p = new_ptr; - platform::mlock(self.p, next_capacity); + self.p = std::alloc::alloc_zeroed(next_layout); + mlock(self.p, next_capacity); if self.capacity > 0 { std::ptr::copy_nonoverlapping(old_ptr, self.p, self.size); for i in 0..self.size { std::ptr::write_volatile(old_ptr.add(i), 0) } - platform::munlock(old_ptr, self.capacity); + munlock(old_ptr, self.capacity); let layout = std::alloc::Layout::from_size_align_unchecked(self.capacity, 1); std::alloc::dealloc(old_ptr, layout); } - self.capacity = next_capacity; - self.size = size; + if self.p.is_null() { + #[allow(clippy::panic)] + { + panic!("Realloc failed, pointer = {:?} {:?}", self, size) + } + } else { + self.capacity = next_capacity; + self.size = size; + } } } } + /// Clear this CryptoVec (retaining the memory). + /// + /// ``` + /// let mut v = russh_cryptovec::CryptoVec::new(); + /// v.extend(b"blabla"); + /// v.clear(); + /// assert!(v.is_empty()) + /// ``` pub fn clear(&mut self) { self.resize(0); } + /// Append a new byte at the end of this CryptoVec. pub fn push(&mut self, s: u8) { let size = self.size; self.resize(size + 1); unsafe { *self.p.add(size) = s } } - pub fn extend(&mut self, s: &[u8]) { - let size = self.size; - self.resize(size + s.len()); - unsafe { - std::ptr::copy_nonoverlapping(s.as_ptr(), self.p.add(size), s.len()); - } - } - + /// Append a new u32, big endian-encoded, at the end of this CryptoVec. + /// + /// ``` + /// let mut v = russh_cryptovec::CryptoVec::new(); + /// let n = 43554; + /// v.push_u32_be(n); + /// assert_eq!(n, v.read_u32_be(0)) + /// ``` pub fn push_u32_be(&mut self, s: u32) { let s = s.to_be(); let x: [u8; 4] = s.to_ne_bytes(); self.extend(&x) } + /// Read a big endian-encoded u32 from this CryptoVec, with the + /// first byte at position `i`. + /// + /// ``` + /// let mut v = russh_cryptovec::CryptoVec::new(); + /// let n = 99485710; + /// v.push_u32_be(n); + /// assert_eq!(n, v.read_u32_be(0)) + /// ``` pub fn read_u32_be(&self, i: usize) -> u32 { assert!(i + 4 <= self.size); let mut x: u32 = 0; unsafe { - std::ptr::copy_nonoverlapping(self.p.add(i) as *const u32, &mut x as *mut u32, 1); + memcpy((&mut x) as *mut u32, self.p.add(i), 4); } u32::from_be(x) } + /// Read `n_bytes` from `r`, and append them at the end of this + /// `CryptoVec`. Returns the number of bytes read (and appended). + pub fn read( + &mut self, + n_bytes: usize, + mut r: R, + ) -> Result { + let cur_size = self.size; + self.resize(cur_size + n_bytes); + let s = unsafe { std::slice::from_raw_parts_mut(self.p.add(cur_size), n_bytes) }; + // Resize the buffer to its appropriate size. + match r.read(s) { + Ok(n) => { + self.resize(cur_size + n); + Ok(n) + } + Err(e) => { + self.resize(cur_size); + Err(e) + } + } + } + + /// Write all this CryptoVec to the provided `Write`. Returns the + /// number of bytes actually written. + /// + /// ``` + /// let mut v = russh_cryptovec::CryptoVec::new(); + /// v.extend(b"blabla"); + /// let mut s = std::io::stdout(); + /// v.write_all_from(0, &mut s).unwrap(); + /// ``` pub fn write_all_from( &self, offset: usize, @@ -240,12 +322,37 @@ impl CryptoVec { } } + /// Resize this CryptoVec, returning a mutable borrow to the extra bytes. + /// + /// ``` + /// let mut v = russh_cryptovec::CryptoVec::new(); + /// v.resize_mut(4).clone_from_slice(b"test"); + /// ``` pub fn resize_mut(&mut self, n: usize) -> &mut [u8] { let size = self.size; self.resize(size + n); unsafe { std::slice::from_raw_parts_mut(self.p.add(size), n) } } + /// Append a slice at the end of this CryptoVec. + /// + /// ``` + /// let mut v = russh_cryptovec::CryptoVec::new(); + /// v.extend(b"test"); + /// ``` + pub fn extend(&mut self, s: &[u8]) { + let size = self.size; + self.resize(size + s.len()); + unsafe { + std::ptr::copy_nonoverlapping(s.as_ptr(), self.p.add(size), s.len()); + } + } + + /// Create a `CryptoVec` from a slice + /// + /// ``` + /// russh_cryptovec::CryptoVec::from_slice(b"test"); + /// ``` pub fn from_slice(s: &[u8]) -> CryptoVec { let mut v = CryptoVec::new(); v.resize(s.len()); @@ -280,17 +387,17 @@ impl Drop for CryptoVec { } } +// DocTests cannot be run on with wasm_bindgen_test #[cfg(test)] +#[cfg(target_arch = "wasm32")] mod test { use wasm_bindgen_test::wasm_bindgen_test; use super::CryptoVec; - #[cfg(target_arch = "wasm32")] wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); - #[test] #[wasm_bindgen_test] fn test_new() { let crypto_vec = CryptoVec::new(); @@ -298,7 +405,6 @@ mod test { assert_eq!(crypto_vec.capacity, 0); } - #[test] #[wasm_bindgen_test] fn test_resize_expand() { let mut crypto_vec = CryptoVec::new_zeroed(5); @@ -308,7 +414,6 @@ mod test { assert!(crypto_vec.iter().skip(5).all(|&x| x == 0)); // Ensure newly added elements are zeroed } - #[test] #[wasm_bindgen_test] fn test_resize_shrink() { let mut crypto_vec = CryptoVec::new_zeroed(10); @@ -318,7 +423,6 @@ mod test { assert_eq!(crypto_vec.len(), 5); } - #[test] #[wasm_bindgen_test] fn test_push() { let mut crypto_vec = CryptoVec::new(); @@ -329,7 +433,6 @@ mod test { assert_eq!(crypto_vec[1], 2); } - #[test] #[wasm_bindgen_test] fn test_write_trait() { use std::io::Write; @@ -341,7 +444,6 @@ mod test { assert_eq!(crypto_vec.as_ref(), &[1, 2, 3]); } - #[test] #[wasm_bindgen_test] fn test_as_ref_as_mut() { let mut crypto_vec = CryptoVec::new_zeroed(5); @@ -352,7 +454,6 @@ mod test { assert_eq!(crypto_vec[0], 1); } - #[test] #[wasm_bindgen_test] fn test_from_string() { let input = String::from("hello"); @@ -360,7 +461,6 @@ mod test { assert_eq!(crypto_vec.as_ref(), b"hello"); } - #[test] #[wasm_bindgen_test] fn test_from_vec() { let input = vec![1, 2, 3, 4]; @@ -368,7 +468,6 @@ mod test { assert_eq!(crypto_vec.as_ref(), &[1, 2, 3, 4]); } - #[test] #[wasm_bindgen_test] fn test_index() { let crypto_vec = CryptoVec::from(vec![1, 2, 3, 4, 5]); @@ -377,7 +476,6 @@ mod test { assert_eq!(&crypto_vec[1..3], &[2, 3]); } - #[test] #[wasm_bindgen_test] fn test_drop() { let mut crypto_vec = CryptoVec::new_zeroed(10); @@ -391,7 +489,6 @@ mod test { // it may be checked using tools like Valgrind or manual inspection. } - #[test] #[wasm_bindgen_test] fn test_new_zeroed() { let crypto_vec = CryptoVec::new_zeroed(10); @@ -400,7 +497,6 @@ mod test { assert!(crypto_vec.iter().all(|&x| x == 0)); // Ensure all bytes are zeroed } - #[test] #[wasm_bindgen_test] fn test_push_u32_be() { let mut crypto_vec = CryptoVec::new(); @@ -410,7 +506,6 @@ mod test { assert_eq!(crypto_vec.read_u32_be(0), value); } - #[test] #[wasm_bindgen_test] fn test_read_u32_be() { @@ -420,7 +515,6 @@ mod test { assert_eq!(crypto_vec.read_u32_be(0), value); } - #[test] #[wasm_bindgen_test] fn test_clear() { let mut crypto_vec = CryptoVec::new(); @@ -429,7 +523,6 @@ mod test { assert!(crypto_vec.is_empty()); } - #[test] #[wasm_bindgen_test] fn test_extend() { let mut crypto_vec = CryptoVec::new(); @@ -437,7 +530,6 @@ mod test { assert_eq!(crypto_vec.as_ref(), b"test"); } - #[test] #[wasm_bindgen_test] fn test_write_all_from() { let mut crypto_vec = CryptoVec::new(); @@ -449,7 +541,6 @@ mod test { assert_eq!(output, b"blabla"); } - #[test] #[wasm_bindgen_test] fn test_resize_mut() { let mut crypto_vec = CryptoVec::new(); @@ -457,15 +548,4 @@ mod test { assert_eq!(crypto_vec.as_ref(), b"test"); } - #[cfg(target_pointer_width = "64")] - #[test] - fn test_large_resize_panics() { - let result = std::panic::catch_unwind(|| { - let mut vec = CryptoVec::new(); - vec.push(42); // Write something into the vector - - vec.resize(1_000_000_000_000); // Intentionally large resize - }); - assert!(result.is_err()); // Expecting a panic on large allocation - } } diff --git a/cryptovec/src/lib.rs b/cryptovec/src/lib.rs index 056b4f55..9403987e 100644 --- a/cryptovec/src/lib.rs +++ b/cryptovec/src/lib.rs @@ -5,6 +5,21 @@ clippy::panic )] +// Copyright 2016 Pierre-Étienne Meunier +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + // Re-export CryptoVec from the cryptovec module mod cryptovec; pub use cryptovec::CryptoVec; diff --git a/cryptovec/src/platform/mod.rs b/cryptovec/src/platform/mod.rs index ab23326c..63b715a7 100644 --- a/cryptovec/src/platform/mod.rs +++ b/cryptovec/src/platform/mod.rs @@ -10,14 +10,14 @@ mod wasm; // Re-export functions based on the platform #[cfg(windows)] -pub use windows::{munlock,mlock,memset}; +pub use windows::{memcpy, memset, mlock, munlock}; #[cfg(not(windows))] #[cfg(not(target_arch = "wasm32"))] -pub use unix::{munlock, mlock, memset}; +pub use unix::{memcpy, memset, mlock, munlock}; #[cfg(target_arch = "wasm32")] -pub use wasm::{munlock, mlock, memset}; +pub use wasm::{memcpy, memset, mlock, munlock}; #[cfg(test)] mod tests { @@ -38,4 +38,4 @@ mod tests { memset(buf.as_mut_ptr(), 0xff, 5); assert_eq!(buf, [0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0]); } -} \ No newline at end of file +} diff --git a/cryptovec/src/platform/unix.rs b/cryptovec/src/platform/unix.rs index 45707c0f..1c5ab44a 100644 --- a/cryptovec/src/platform/unix.rs +++ b/cryptovec/src/platform/unix.rs @@ -24,4 +24,10 @@ pub fn memset(ptr: *mut u8, value: i32, size: usize) { unsafe { libc::memset(ptr as *mut c_void, value, size); } +} + +pub fn memcpy(dest: *mut u32, src: *const u8, size: usize) { + unsafe { + libc::memcpy(dest as *mut c_void, src as *const c_void, size); + } } \ No newline at end of file diff --git a/cryptovec/src/platform/wasm.rs b/cryptovec/src/platform/wasm.rs index fd41c3ec..d1d192e5 100644 --- a/cryptovec/src/platform/wasm.rs +++ b/cryptovec/src/platform/wasm.rs @@ -13,3 +13,13 @@ pub fn memset(ptr: *mut u8, value: i32, size: usize) { std::ptr::write_bytes(ptr, byte_value, size); } } + +pub fn memcpy(dest: *mut u32, src: *const u8, size: usize) { + unsafe { + // Convert dest to *mut u8 for byte-wise copying + let dest_bytes = dest as *mut u8; + + // Use std::ptr::copy_nonoverlapping to copy the data + std::ptr::copy_nonoverlapping(src, dest_bytes, size); + } +} diff --git a/cryptovec/src/platform/windows.rs b/cryptovec/src/platform/windows.rs index 930dbbf5..ae03018c 100644 --- a/cryptovec/src/platform/windows.rs +++ b/cryptovec/src/platform/windows.rs @@ -20,4 +20,10 @@ pub fn memset(ptr: *mut u8, value: i32, size: usize) { unsafe { libc::memset(ptr as *mut c_void, value, size); } +} + +pub fn memcpy(dest: *mut u32, src: *const u8, size: usize) { + unsafe { + libc::memcpy(dest as *mut c_void, src as *const c_void, size); + } } \ No newline at end of file From f1f8d579fab57cb15f4d9d3ac3826246b01375bb Mon Sep 17 00:00:00 2001 From: irving ou Date: Thu, 19 Sep 2024 13:57:19 -0400 Subject: [PATCH 03/15] feat(wasm): Add wasm support, update cryptovec --- cryptovec/Cargo.toml | 3 + cryptovec/src/cryptovec.rs | 471 ++++++++++++++++++++++++++++++ cryptovec/src/lib.rs | 449 +--------------------------- cryptovec/src/platform/mod.rs | 41 +++ cryptovec/src/platform/unix.rs | 27 ++ cryptovec/src/platform/wasm.rs | 15 + cryptovec/src/platform/windows.rs | 23 ++ 7 files changed, 585 insertions(+), 444 deletions(-) create mode 100644 cryptovec/src/cryptovec.rs create mode 100644 cryptovec/src/platform/mod.rs create mode 100644 cryptovec/src/platform/unix.rs create mode 100644 cryptovec/src/platform/wasm.rs create mode 100644 cryptovec/src/platform/windows.rs diff --git a/cryptovec/Cargo.toml b/cryptovec/Cargo.toml index e04cd6f5..fafe1c8d 100644 --- a/cryptovec/Cargo.toml +++ b/cryptovec/Cargo.toml @@ -15,3 +15,6 @@ libc = "0.2" [target.'cfg(target_os = "windows")'.dependencies] winapi = {version = "0.3", features = ["basetsd", "minwindef", "memoryapi"]} + +[dev-dependencies] +wasm-bindgen-test = "0.3" \ No newline at end of file diff --git a/cryptovec/src/cryptovec.rs b/cryptovec/src/cryptovec.rs new file mode 100644 index 00000000..0e7b63f5 --- /dev/null +++ b/cryptovec/src/cryptovec.rs @@ -0,0 +1,471 @@ +use crate::platform; +use std::ops::{Deref, DerefMut, Index, IndexMut, Range, RangeFrom, RangeFull, RangeTo}; + +/// A buffer which zeroes its memory on `.clear()`, `.resize()`, and +/// reallocations, to avoid copying secrets around. +#[derive(Debug)] +pub struct CryptoVec { + p: *mut u8, // `pub(crate)` allows access from platform modules + size: usize, + capacity: usize, +} + +impl Unpin for CryptoVec {} +unsafe impl Send for CryptoVec {} +unsafe impl Sync for CryptoVec {} + +// Common traits implementations +impl AsRef<[u8]> for CryptoVec { + fn as_ref(&self) -> &[u8] { + self.deref() + } +} + +impl AsMut<[u8]> for CryptoVec { + fn as_mut(&mut self) -> &mut [u8] { + self.deref_mut() + } +} + +impl Deref for CryptoVec { + type Target = [u8]; + fn deref(&self) -> &[u8] { + unsafe { std::slice::from_raw_parts(self.p, self.size) } + } +} + +impl DerefMut for CryptoVec { + fn deref_mut(&mut self) -> &mut [u8] { + unsafe { std::slice::from_raw_parts_mut(self.p, self.size) } + } +} + +impl From for CryptoVec { + fn from(e: String) -> Self { + CryptoVec::from(e.into_bytes()) + } +} + +impl From> for CryptoVec { + fn from(e: Vec) -> Self { + let mut c = CryptoVec::new_zeroed(e.len()); + c.clone_from_slice(&e[..]); + c + } +} + +// Indexing implementations +impl Index> for CryptoVec { + type Output = [u8]; + fn index(&self, index: RangeFrom) -> &[u8] { + self.deref().index(index) + } +} +impl Index> for CryptoVec { + type Output = [u8]; + fn index(&self, index: RangeTo) -> &[u8] { + self.deref().index(index) + } +} +impl Index> for CryptoVec { + type Output = [u8]; + fn index(&self, index: Range) -> &[u8] { + self.deref().index(index) + } +} +impl Index for CryptoVec { + type Output = [u8]; + fn index(&self, _: RangeFull) -> &[u8] { + self.deref() + } +} + +impl IndexMut for CryptoVec { + fn index_mut(&mut self, _: RangeFull) -> &mut [u8] { + self.deref_mut() + } +} +impl IndexMut> for CryptoVec { + fn index_mut(&mut self, index: RangeFrom) -> &mut [u8] { + self.deref_mut().index_mut(index) + } +} +impl IndexMut> for CryptoVec { + fn index_mut(&mut self, index: RangeTo) -> &mut [u8] { + self.deref_mut().index_mut(index) + } +} +impl IndexMut> for CryptoVec { + fn index_mut(&mut self, index: Range) -> &mut [u8] { + self.deref_mut().index_mut(index) + } +} + +impl Index for CryptoVec { + type Output = u8; + fn index(&self, index: usize) -> &u8 { + self.deref().index(index) + } +} + +// IO-related implementation +impl std::io::Write for CryptoVec { + fn write(&mut self, buf: &[u8]) -> Result { + self.extend(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> Result<(), std::io::Error> { + Ok(()) + } +} + +// Default implementation +impl Default for CryptoVec { + fn default() -> Self { + CryptoVec { + p: std::ptr::NonNull::dangling().as_ptr(), + size: 0, + capacity: 0, + } + } +} + +// Memory management methods +impl CryptoVec { + /// Creates a new `CryptoVec`. + pub fn new() -> CryptoVec { + CryptoVec::default() + } + + /// Creates a new `CryptoVec` with `n` zeros. + pub fn new_zeroed(size: usize) -> CryptoVec { + unsafe { + let capacity = size.next_power_of_two(); + let layout = std::alloc::Layout::from_size_align_unchecked(capacity, 1); + let p = std::alloc::alloc_zeroed(layout); + platform::mlock(p, capacity); + CryptoVec { p, capacity, size } + } + } + + /// Resize this CryptoVec, appending zeros at the end. This may + /// perform at most one reallocation, overwriting the previous + /// version with zeros. + pub fn resize(&mut self, size: usize) { + if size <= self.capacity && size > self.size { + // If this is an expansion, just resize. + self.size = size + } else if size <= self.size { + unsafe { + platform::memset(self.p.add(size), 0, self.size - size); + } + self.size = size; + } else { + // realloc ! and erase the previous memory. + unsafe { + let next_capacity = size.next_power_of_two(); + let old_ptr = self.p; + let next_layout = std::alloc::Layout::from_size_align_unchecked(next_capacity, 1); + let new_ptr = std::alloc::alloc_zeroed(next_layout); + if new_ptr.is_null() { + #[allow(clippy::panic)] + { + panic!("Realloc failed, pointer = {:?} {:?}", self, size) + } + } + + self.p = new_ptr; + platform::mlock(self.p, next_capacity); + + if self.capacity > 0 { + std::ptr::copy_nonoverlapping(old_ptr, self.p, self.size); + for i in 0..self.size { + std::ptr::write_volatile(old_ptr.add(i), 0) + } + platform::munlock(old_ptr, self.capacity); + let layout = std::alloc::Layout::from_size_align_unchecked(self.capacity, 1); + std::alloc::dealloc(old_ptr, layout); + } + + self.capacity = next_capacity; + self.size = size; + } + } + } + + pub fn clear(&mut self) { + self.resize(0); + } + + pub fn push(&mut self, s: u8) { + let size = self.size; + self.resize(size + 1); + unsafe { *self.p.add(size) = s } + } + + pub fn extend(&mut self, s: &[u8]) { + let size = self.size; + self.resize(size + s.len()); + unsafe { + std::ptr::copy_nonoverlapping(s.as_ptr(), self.p.add(size), s.len()); + } + } + + pub fn push_u32_be(&mut self, s: u32) { + let s = s.to_be(); + let x: [u8; 4] = s.to_ne_bytes(); + self.extend(&x) + } + + pub fn read_u32_be(&self, i: usize) -> u32 { + assert!(i + 4 <= self.size); + let mut x: u32 = 0; + unsafe { + std::ptr::copy_nonoverlapping(self.p.add(i) as *const u32, &mut x as *mut u32, 1); + } + u32::from_be(x) + } + + pub fn write_all_from( + &self, + offset: usize, + mut w: W, + ) -> Result { + assert!(offset < self.size); + // if we're past this point, self.p cannot be null. + unsafe { + let s = std::slice::from_raw_parts(self.p.add(offset), self.size - offset); + w.write(s) + } + } + + pub fn resize_mut(&mut self, n: usize) -> &mut [u8] { + let size = self.size; + self.resize(size + n); + unsafe { std::slice::from_raw_parts_mut(self.p.add(size), n) } + } + + pub fn from_slice(s: &[u8]) -> CryptoVec { + let mut v = CryptoVec::new(); + v.resize(s.len()); + unsafe { + std::ptr::copy_nonoverlapping(s.as_ptr(), v.p, s.len()); + } + v + } +} + +impl Clone for CryptoVec { + fn clone(&self) -> Self { + let mut v = Self::new(); + v.extend(self); + v + } +} + +// Drop implementation +impl Drop for CryptoVec { + fn drop(&mut self) { + if self.capacity > 0 { + unsafe { + for i in 0..self.size { + std::ptr::write_volatile(self.p.add(i), 0); + } + platform::munlock(self.p, self.capacity); + let layout = std::alloc::Layout::from_size_align_unchecked(self.capacity, 1); + std::alloc::dealloc(self.p, layout); + } + } + } +} + +#[cfg(test)] +mod test { + + use wasm_bindgen_test::wasm_bindgen_test; + + use super::CryptoVec; + + #[cfg(target_arch = "wasm32")] + wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); + + #[test] + #[wasm_bindgen_test] + fn test_new() { + let crypto_vec = CryptoVec::new(); + assert_eq!(crypto_vec.size, 0); + assert_eq!(crypto_vec.capacity, 0); + } + + #[test] + #[wasm_bindgen_test] + fn test_resize_expand() { + let mut crypto_vec = CryptoVec::new_zeroed(5); + crypto_vec.resize(10); + assert_eq!(crypto_vec.size, 10); + assert!(crypto_vec.capacity >= 10); + assert!(crypto_vec.iter().skip(5).all(|&x| x == 0)); // Ensure newly added elements are zeroed + } + + #[test] + #[wasm_bindgen_test] + fn test_resize_shrink() { + let mut crypto_vec = CryptoVec::new_zeroed(10); + crypto_vec.resize(5); + assert_eq!(crypto_vec.size, 5); + // Ensure shrinking keeps the previous elements intact + assert_eq!(crypto_vec.len(), 5); + } + + #[test] + #[wasm_bindgen_test] + fn test_push() { + let mut crypto_vec = CryptoVec::new(); + crypto_vec.push(1); + crypto_vec.push(2); + assert_eq!(crypto_vec.size, 2); + assert_eq!(crypto_vec[0], 1); + assert_eq!(crypto_vec[1], 2); + } + + #[test] + #[wasm_bindgen_test] + fn test_write_trait() { + use std::io::Write; + + let mut crypto_vec = CryptoVec::new(); + let bytes_written = crypto_vec.write(&[1, 2, 3]).unwrap(); + assert_eq!(bytes_written, 3); + assert_eq!(crypto_vec.size, 3); + assert_eq!(crypto_vec.as_ref(), &[1, 2, 3]); + } + + #[test] + #[wasm_bindgen_test] + fn test_as_ref_as_mut() { + let mut crypto_vec = CryptoVec::new_zeroed(5); + let slice_ref: &[u8] = crypto_vec.as_ref(); + assert_eq!(slice_ref.len(), 5); + let slice_mut: &mut [u8] = crypto_vec.as_mut(); + slice_mut[0] = 1; + assert_eq!(crypto_vec[0], 1); + } + + #[test] + #[wasm_bindgen_test] + fn test_from_string() { + let input = String::from("hello"); + let crypto_vec: CryptoVec = input.into(); + assert_eq!(crypto_vec.as_ref(), b"hello"); + } + + #[test] + #[wasm_bindgen_test] + fn test_from_vec() { + let input = vec![1, 2, 3, 4]; + let crypto_vec: CryptoVec = input.into(); + assert_eq!(crypto_vec.as_ref(), &[1, 2, 3, 4]); + } + + #[test] + #[wasm_bindgen_test] + fn test_index() { + let crypto_vec = CryptoVec::from(vec![1, 2, 3, 4, 5]); + assert_eq!(crypto_vec[0], 1); + assert_eq!(crypto_vec[4], 5); + assert_eq!(&crypto_vec[1..3], &[2, 3]); + } + + #[test] + #[wasm_bindgen_test] + fn test_drop() { + let mut crypto_vec = CryptoVec::new_zeroed(10); + // Ensure vector is filled with non-zero data + crypto_vec.extend(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + drop(crypto_vec); + + // Check that memory zeroing was done during the drop + // This part is more difficult to test directly since it involves + // private memory management. However, with Rust's unsafe features, + // it may be checked using tools like Valgrind or manual inspection. + } + + #[test] + #[wasm_bindgen_test] + fn test_new_zeroed() { + let crypto_vec = CryptoVec::new_zeroed(10); + assert_eq!(crypto_vec.size, 10); + assert!(crypto_vec.capacity >= 10); + assert!(crypto_vec.iter().all(|&x| x == 0)); // Ensure all bytes are zeroed + } + + #[test] + #[wasm_bindgen_test] + fn test_push_u32_be() { + let mut crypto_vec = CryptoVec::new(); + let value = 43554u32; + crypto_vec.push_u32_be(value); + assert_eq!(crypto_vec.len(), 4); // u32 is 4 bytes long + assert_eq!(crypto_vec.read_u32_be(0), value); + } + + #[test] + #[wasm_bindgen_test] + + fn test_read_u32_be() { + let mut crypto_vec = CryptoVec::new(); + let value = 99485710u32; + crypto_vec.push_u32_be(value); + assert_eq!(crypto_vec.read_u32_be(0), value); + } + + #[test] + #[wasm_bindgen_test] + fn test_clear() { + let mut crypto_vec = CryptoVec::new(); + crypto_vec.extend(b"blabla"); + crypto_vec.clear(); + assert!(crypto_vec.is_empty()); + } + + #[test] + #[wasm_bindgen_test] + fn test_extend() { + let mut crypto_vec = CryptoVec::new(); + crypto_vec.extend(b"test"); + assert_eq!(crypto_vec.as_ref(), b"test"); + } + + #[test] + #[wasm_bindgen_test] + fn test_write_all_from() { + let mut crypto_vec = CryptoVec::new(); + crypto_vec.extend(b"blabla"); + + let mut output: Vec = Vec::new(); + let written_size = crypto_vec.write_all_from(0, &mut output).unwrap(); + assert_eq!(written_size, 6); // "blabla" has 6 bytes + assert_eq!(output, b"blabla"); + } + + #[test] + #[wasm_bindgen_test] + fn test_resize_mut() { + let mut crypto_vec = CryptoVec::new(); + crypto_vec.resize_mut(4).clone_from_slice(b"test"); + assert_eq!(crypto_vec.as_ref(), b"test"); + } + + #[cfg(target_pointer_width = "64")] + #[test] + fn test_large_resize_panics() { + let result = std::panic::catch_unwind(|| { + let mut vec = CryptoVec::new(); + vec.push(42); // Write something into the vector + + vec.resize(1_000_000_000_000); // Intentionally large resize + }); + assert!(result.is_err()); // Expecting a panic on large allocation + } +} diff --git a/cryptovec/src/lib.rs b/cryptovec/src/lib.rs index 256b7fc2..056b4f55 100644 --- a/cryptovec/src/lib.rs +++ b/cryptovec/src/lib.rs @@ -4,449 +4,10 @@ clippy::indexing_slicing, clippy::panic )] -// Copyright 2016 Pierre-Étienne Meunier -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -use std::ops::{Deref, DerefMut, Index, IndexMut, Range, RangeFrom, RangeFull, RangeTo}; -use libc::c_void; -#[cfg(not(windows))] -use libc::size_t; +// Re-export CryptoVec from the cryptovec module +mod cryptovec; +pub use cryptovec::CryptoVec; -/// A buffer which zeroes its memory on `.clear()`, `.resize()` and -/// reallocations, to avoid copying secrets around. -#[derive(Debug)] -pub struct CryptoVec { - p: *mut u8, - size: usize, - capacity: usize, -} - -impl Unpin for CryptoVec {} - -unsafe impl Send for CryptoVec {} -unsafe impl Sync for CryptoVec {} - -impl AsRef<[u8]> for CryptoVec { - fn as_ref(&self) -> &[u8] { - self.deref() - } -} -impl AsMut<[u8]> for CryptoVec { - fn as_mut(&mut self) -> &mut [u8] { - self.deref_mut() - } -} -impl Deref for CryptoVec { - type Target = [u8]; - fn deref(&self) -> &[u8] { - unsafe { std::slice::from_raw_parts(self.p, self.size) } - } -} -impl DerefMut for CryptoVec { - fn deref_mut(&mut self) -> &mut [u8] { - unsafe { std::slice::from_raw_parts_mut(self.p, self.size) } - } -} - -impl From for CryptoVec { - fn from(e: String) -> Self { - CryptoVec::from(e.into_bytes()) - } -} - -impl From> for CryptoVec { - fn from(e: Vec) -> Self { - let mut c = CryptoVec::new_zeroed(e.len()); - c.clone_from_slice(&e[..]); - c - } -} - -impl Index> for CryptoVec { - type Output = [u8]; - fn index(&self, index: RangeFrom) -> &[u8] { - self.deref().index(index) - } -} -impl Index> for CryptoVec { - type Output = [u8]; - fn index(&self, index: RangeTo) -> &[u8] { - self.deref().index(index) - } -} -impl Index> for CryptoVec { - type Output = [u8]; - fn index(&self, index: Range) -> &[u8] { - self.deref().index(index) - } -} -impl Index for CryptoVec { - type Output = [u8]; - fn index(&self, _: RangeFull) -> &[u8] { - self.deref() - } -} -impl IndexMut for CryptoVec { - fn index_mut(&mut self, _: RangeFull) -> &mut [u8] { - self.deref_mut() - } -} - -impl IndexMut> for CryptoVec { - fn index_mut(&mut self, index: RangeFrom) -> &mut [u8] { - self.deref_mut().index_mut(index) - } -} -impl IndexMut> for CryptoVec { - fn index_mut(&mut self, index: RangeTo) -> &mut [u8] { - self.deref_mut().index_mut(index) - } -} -impl IndexMut> for CryptoVec { - fn index_mut(&mut self, index: Range) -> &mut [u8] { - self.deref_mut().index_mut(index) - } -} - -impl Index for CryptoVec { - type Output = u8; - fn index(&self, index: usize) -> &u8 { - self.deref().index(index) - } -} - -impl std::io::Write for CryptoVec { - fn write(&mut self, buf: &[u8]) -> Result { - self.extend(buf); - Ok(buf.len()) - } - fn flush(&mut self) -> Result<(), std::io::Error> { - Ok(()) - } -} - -impl Default for CryptoVec { - fn default() -> Self { - CryptoVec { - p: std::ptr::NonNull::dangling().as_ptr(), - size: 0, - capacity: 0, - } - } -} - -#[cfg(not(windows))] -unsafe fn mlock(ptr: *const u8, len: usize) { - libc::mlock(ptr as *const c_void, len as size_t); -} -#[cfg(not(windows))] -unsafe fn munlock(ptr: *const u8, len: usize) { - libc::munlock(ptr as *const c_void, len as size_t); -} - -#[cfg(windows)] -use winapi::shared::basetsd::SIZE_T; -#[cfg(windows)] -use winapi::shared::minwindef::LPVOID; -#[cfg(windows)] -use winapi::um::memoryapi::{VirtualLock, VirtualUnlock}; -#[cfg(windows)] -unsafe fn mlock(ptr: *const u8, len: usize) { - VirtualLock(ptr as LPVOID, len as SIZE_T); -} -#[cfg(windows)] -unsafe fn munlock(ptr: *const u8, len: usize) { - VirtualUnlock(ptr as LPVOID, len as SIZE_T); -} - -impl Clone for CryptoVec { - fn clone(&self) -> Self { - let mut v = Self::new(); - v.extend(self); - v - } -} - -impl CryptoVec { - /// Creates a new `CryptoVec`. - pub fn new() -> CryptoVec { - CryptoVec::default() - } - - /// Creates a new `CryptoVec` with `n` zeros. - pub fn new_zeroed(size: usize) -> CryptoVec { - unsafe { - let capacity = size.next_power_of_two(); - let layout = std::alloc::Layout::from_size_align_unchecked(capacity, 1); - let p = std::alloc::alloc_zeroed(layout); - mlock(p, capacity); - CryptoVec { p, capacity, size } - } - } - - /// Creates a new `CryptoVec` with capacity `capacity`. - pub fn with_capacity(capacity: usize) -> CryptoVec { - unsafe { - let capacity = capacity.next_power_of_two(); - let layout = std::alloc::Layout::from_size_align_unchecked(capacity, 1); - let p = std::alloc::alloc_zeroed(layout); - mlock(p, capacity); - CryptoVec { - p, - capacity, - size: 0, - } - } - } - - /// Length of this `CryptoVec`. - /// - /// ``` - /// assert_eq!(russh_cryptovec::CryptoVec::new().len(), 0) - /// ``` - pub fn len(&self) -> usize { - self.size - } - - /// Returns `true` if and only if this CryptoVec is empty. - /// - /// ``` - /// assert!(russh_cryptovec::CryptoVec::new().is_empty()) - /// ``` - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Resize this CryptoVec, appending zeros at the end. This may - /// perform at most one reallocation, overwriting the previous - /// version with zeros. - pub fn resize(&mut self, size: usize) { - if size <= self.capacity && size > self.size { - // If this is an expansion, just resize. - self.size = size - } else if size <= self.size { - // If this is a truncation, resize and erase the extra memory. - unsafe { - libc::memset(self.p.add(size) as *mut c_void, 0, self.size - size); - } - self.size = size; - } else { - // realloc ! and erase the previous memory. - unsafe { - let next_capacity = size.next_power_of_two(); - let old_ptr = self.p; - let next_layout = std::alloc::Layout::from_size_align_unchecked(next_capacity, 1); - let new_ptr = std::alloc::alloc_zeroed(next_layout); - if new_ptr.is_null() { - #[allow(clippy::panic)] - { - panic!("Realloc failed, pointer = {:?} {:?}", self, size) - } - } - - self.p = new_ptr; - mlock(self.p, next_capacity); - - if self.capacity > 0 { - std::ptr::copy_nonoverlapping(old_ptr, self.p, self.size); - for i in 0..self.size { - std::ptr::write_volatile(old_ptr.add(i), 0) - } - munlock(old_ptr, self.capacity); - let layout = std::alloc::Layout::from_size_align_unchecked(self.capacity, 1); - std::alloc::dealloc(old_ptr, layout); - } - - self.capacity = next_capacity; - self.size = size; - } - } - } - - /// Clear this CryptoVec (retaining the memory). - /// - /// ``` - /// let mut v = russh_cryptovec::CryptoVec::new(); - /// v.extend(b"blabla"); - /// v.clear(); - /// assert!(v.is_empty()) - /// ``` - pub fn clear(&mut self) { - self.resize(0); - } - - /// Append a new byte at the end of this CryptoVec. - pub fn push(&mut self, s: u8) { - let size = self.size; - self.resize(size + 1); - unsafe { *self.p.add(size) = s } - } - - /// Append a new u32, big endian-encoded, at the end of this CryptoVec. - /// - /// ``` - /// let mut v = russh_cryptovec::CryptoVec::new(); - /// let n = 43554; - /// v.push_u32_be(n); - /// assert_eq!(n, v.read_u32_be(0)) - /// ``` - pub fn push_u32_be(&mut self, s: u32) { - let s = s.to_be(); - let x: [u8; 4] = s.to_ne_bytes(); - self.extend(&x) - } - - /// Read a big endian-encoded u32 from this CryptoVec, with the - /// first byte at position `i`. - /// - /// ``` - /// let mut v = russh_cryptovec::CryptoVec::new(); - /// let n = 99485710; - /// v.push_u32_be(n); - /// assert_eq!(n, v.read_u32_be(0)) - /// ``` - pub fn read_u32_be(&self, i: usize) -> u32 { - assert!(i + 4 <= self.size); - let mut x: u32 = 0; - unsafe { - libc::memcpy( - (&mut x) as *mut u32 as *mut c_void, - self.p.add(i) as *const c_void, - 4, - ); - } - u32::from_be(x) - } - - /// Read `n_bytes` from `r`, and append them at the end of this - /// `CryptoVec`. Returns the number of bytes read (and appended). - pub fn read( - &mut self, - n_bytes: usize, - mut r: R, - ) -> Result { - let cur_size = self.size; - self.resize(cur_size + n_bytes); - let s = unsafe { std::slice::from_raw_parts_mut(self.p.add(cur_size), n_bytes) }; - // Resize the buffer to its appropriate size. - match r.read(s) { - Ok(n) => { - self.resize(cur_size + n); - Ok(n) - } - Err(e) => { - self.resize(cur_size); - Err(e) - } - } - } - - /// Write all this CryptoVec to the provided `Write`. Returns the - /// number of bytes actually written. - /// - /// ``` - /// let mut v = russh_cryptovec::CryptoVec::new(); - /// v.extend(b"blabla"); - /// let mut s = std::io::stdout(); - /// v.write_all_from(0, &mut s).unwrap(); - /// ``` - pub fn write_all_from( - &self, - offset: usize, - mut w: W, - ) -> Result { - assert!(offset < self.size); - // if we're past this point, self.p cannot be null. - unsafe { - let s = std::slice::from_raw_parts(self.p.add(offset), self.size - offset); - w.write(s) - } - } - - /// Resize this CryptoVec, returning a mutable borrow to the extra bytes. - /// - /// ``` - /// let mut v = russh_cryptovec::CryptoVec::new(); - /// v.resize_mut(4).clone_from_slice(b"test"); - /// ``` - pub fn resize_mut(&mut self, n: usize) -> &mut [u8] { - let size = self.size; - self.resize(size + n); - unsafe { std::slice::from_raw_parts_mut(self.p.add(size), n) } - } - - /// Append a slice at the end of this CryptoVec. - /// - /// ``` - /// let mut v = russh_cryptovec::CryptoVec::new(); - /// v.extend(b"test"); - /// ``` - pub fn extend(&mut self, s: &[u8]) { - let size = self.size; - self.resize(size + s.len()); - unsafe { - std::ptr::copy_nonoverlapping(s.as_ptr(), self.p.add(size), s.len()); - } - } - - /// Create a `CryptoVec` from a slice - /// - /// ``` - /// russh_cryptovec::CryptoVec::from_slice(b"test"); - /// ``` - pub fn from_slice(s: &[u8]) -> CryptoVec { - let mut v = CryptoVec::new(); - v.resize(s.len()); - unsafe { - std::ptr::copy_nonoverlapping(s.as_ptr(), v.p, s.len()); - } - v - } -} - -impl Drop for CryptoVec { - fn drop(&mut self) { - if self.capacity > 0 { - unsafe { - for i in 0..self.size { - std::ptr::write_volatile(self.p.add(i), 0) - } - munlock(self.p, self.capacity); - let layout = std::alloc::Layout::from_size_align_unchecked(self.capacity, 1); - std::alloc::dealloc(self.p, layout); - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - // If `resize` is called with a size that is too large to be allocated, it - // should panic, and not segfault or fail silently. - #[test] - fn large_resize_panics() { - let result = std::panic::catch_unwind(|| { - let mut vec = CryptoVec::new(); - // Write something into the vector, so that there is something to - // copy when reallocating, to test all code paths. - vec.push(42); - - vec.resize(1_000_000_000_000) - }); - assert!(result.is_err()); - } -} +// Platform-specific modules +mod platform; \ No newline at end of file diff --git a/cryptovec/src/platform/mod.rs b/cryptovec/src/platform/mod.rs new file mode 100644 index 00000000..ab23326c --- /dev/null +++ b/cryptovec/src/platform/mod.rs @@ -0,0 +1,41 @@ +#[cfg(windows)] +mod windows; + +#[cfg(not(windows))] +#[cfg(not(target_arch = "wasm32"))] +mod unix; + +#[cfg(target_arch = "wasm32")] +mod wasm; + +// Re-export functions based on the platform +#[cfg(windows)] +pub use windows::{munlock,mlock,memset}; + +#[cfg(not(windows))] +#[cfg(not(target_arch = "wasm32"))] +pub use unix::{munlock, mlock, memset}; + +#[cfg(target_arch = "wasm32")] +pub use wasm::{munlock, mlock, memset}; + +#[cfg(test)] +mod tests { + use wasm_bindgen_test::wasm_bindgen_test; + + use super::*; + + #[wasm_bindgen_test] + fn test_memset() { + let mut buf = vec![0u8; 10]; + memset(buf.as_mut_ptr(), 0xff, buf.len()); + assert_eq!(buf, vec![0xff; 10]); + } + + #[wasm_bindgen_test] + fn test_memset_partial() { + let mut buf = vec![0u8; 10]; + memset(buf.as_mut_ptr(), 0xff, 5); + assert_eq!(buf, [0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0]); + } +} \ No newline at end of file diff --git a/cryptovec/src/platform/unix.rs b/cryptovec/src/platform/unix.rs new file mode 100644 index 00000000..45707c0f --- /dev/null +++ b/cryptovec/src/platform/unix.rs @@ -0,0 +1,27 @@ +use crate::CryptoVec; +use libc::{mlock, munlock, c_void}; +use std::alloc; + + +/// Unlock memory on drop for Unix-based systems. +pub fn munlock(ptr: *const u8, len: usize) { + unsafe { + if munlock(ptr as *const c_void, len) != 0 { + panic!("Failed to unlock memory."); + } + } +} + +pub fn mlock (ptr: *const u8, len: usize) { + unsafe { + if mlock(ptr as *const c_void, len) != 0 { + panic!("Failed to lock memory."); + } + } +} + +pub fn memset(ptr: *mut u8, value: i32, size: usize) { + unsafe { + libc::memset(ptr as *mut c_void, value, size); + } +} \ No newline at end of file diff --git a/cryptovec/src/platform/wasm.rs b/cryptovec/src/platform/wasm.rs new file mode 100644 index 00000000..fd41c3ec --- /dev/null +++ b/cryptovec/src/platform/wasm.rs @@ -0,0 +1,15 @@ +// WASM does not support synchronization primitives +pub fn munlock(_ptr: *const u8, _len: usize) { + // No-op +} + +pub fn mlock(_ptr: *const u8, _len: usize) -> i32 { + 0 +} + +pub fn memset(ptr: *mut u8, value: i32, size: usize) { + let byte_value = value as u8; // Extract the least significant byte directly + unsafe { + std::ptr::write_bytes(ptr, byte_value, size); + } +} diff --git a/cryptovec/src/platform/windows.rs b/cryptovec/src/platform/windows.rs new file mode 100644 index 00000000..930dbbf5 --- /dev/null +++ b/cryptovec/src/platform/windows.rs @@ -0,0 +1,23 @@ +use winapi::shared::{basetsd::SIZE_T, minwindef::LPVOID}; +use winapi::um::memoryapi::{VirtualLock, VirtualUnlock}; + +use libc::c_void; + +/// Unlock memory on drop for Windows. +pub fn munlock(ptr: *const u8, len: usize) { + unsafe { + VirtualUnlock(ptr as LPVOID, len as SIZE_T); + } +} + +pub fn mlock(ptr: *const u8, len: usize) { + unsafe { + VirtualLock(ptr as LPVOID, len as SIZE_T); + } +} + +pub fn memset(ptr: *mut u8, value: i32, size: usize) { + unsafe { + libc::memset(ptr as *mut c_void, value, size); + } +} \ No newline at end of file From eeddc3e1336348f0d4d83b9d56ca319f8df71880 Mon Sep 17 00:00:00 2001 From: irving ou Date: Thu, 19 Sep 2024 14:10:22 -0400 Subject: [PATCH 04/15] fix test and comments --- cryptovec/src/cryptovec.rs | 190 +++++++++++++++++++++--------- cryptovec/src/lib.rs | 15 +++ cryptovec/src/platform/mod.rs | 8 +- cryptovec/src/platform/unix.rs | 6 + cryptovec/src/platform/wasm.rs | 10 ++ cryptovec/src/platform/windows.rs | 6 + 6 files changed, 176 insertions(+), 59 deletions(-) diff --git a/cryptovec/src/cryptovec.rs b/cryptovec/src/cryptovec.rs index 0e7b63f5..ec1c3446 100644 --- a/cryptovec/src/cryptovec.rs +++ b/cryptovec/src/cryptovec.rs @@ -1,4 +1,4 @@ -use crate::platform; +use crate::platform::{self, memcpy, memset, mlock, munlock}; use std::ops::{Deref, DerefMut, Index, IndexMut, Range, RangeFrom, RangeFull, RangeTo}; /// A buffer which zeroes its memory on `.clear()`, `.resize()`, and @@ -131,7 +131,6 @@ impl Default for CryptoVec { } } -// Memory management methods impl CryptoVec { /// Creates a new `CryptoVec`. pub fn new() -> CryptoVec { @@ -144,11 +143,44 @@ impl CryptoVec { let capacity = size.next_power_of_two(); let layout = std::alloc::Layout::from_size_align_unchecked(capacity, 1); let p = std::alloc::alloc_zeroed(layout); - platform::mlock(p, capacity); + mlock(p, capacity); CryptoVec { p, capacity, size } } } + /// Creates a new `CryptoVec` with capacity `capacity`. + pub fn with_capacity(capacity: usize) -> CryptoVec { + unsafe { + let capacity = capacity.next_power_of_two(); + let layout = std::alloc::Layout::from_size_align_unchecked(capacity, 1); + let p = std::alloc::alloc_zeroed(layout); + mlock(p, capacity); + CryptoVec { + p, + capacity, + size: 0, + } + } + } + + /// Length of this `CryptoVec`. + /// + /// ``` + /// assert_eq!(russh_cryptovec::CryptoVec::new().len(), 0) + /// ``` + pub fn len(&self) -> usize { + self.size + } + + /// Returns `true` if and only if this CryptoVec is empty. + /// + /// ``` + /// assert!(russh_cryptovec::CryptoVec::new().is_empty()) + /// ``` + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + /// Resize this CryptoVec, appending zeros at the end. This may /// perform at most one reallocation, overwriting the previous /// version with zeros. @@ -157,8 +189,9 @@ impl CryptoVec { // If this is an expansion, just resize. self.size = size } else if size <= self.size { + // If this is a truncation, resize and erase the extra memory. unsafe { - platform::memset(self.p.add(size), 0, self.size - size); + memset(self.p.add(size), 0, self.size - size); } self.size = size; } else { @@ -167,66 +200,115 @@ impl CryptoVec { let next_capacity = size.next_power_of_two(); let old_ptr = self.p; let next_layout = std::alloc::Layout::from_size_align_unchecked(next_capacity, 1); - let new_ptr = std::alloc::alloc_zeroed(next_layout); - if new_ptr.is_null() { - #[allow(clippy::panic)] - { - panic!("Realloc failed, pointer = {:?} {:?}", self, size) - } - } - - self.p = new_ptr; - platform::mlock(self.p, next_capacity); + self.p = std::alloc::alloc_zeroed(next_layout); + mlock(self.p, next_capacity); if self.capacity > 0 { std::ptr::copy_nonoverlapping(old_ptr, self.p, self.size); for i in 0..self.size { std::ptr::write_volatile(old_ptr.add(i), 0) } - platform::munlock(old_ptr, self.capacity); + munlock(old_ptr, self.capacity); let layout = std::alloc::Layout::from_size_align_unchecked(self.capacity, 1); std::alloc::dealloc(old_ptr, layout); } - self.capacity = next_capacity; - self.size = size; + if self.p.is_null() { + #[allow(clippy::panic)] + { + panic!("Realloc failed, pointer = {:?} {:?}", self, size) + } + } else { + self.capacity = next_capacity; + self.size = size; + } } } } + /// Clear this CryptoVec (retaining the memory). + /// + /// ``` + /// let mut v = russh_cryptovec::CryptoVec::new(); + /// v.extend(b"blabla"); + /// v.clear(); + /// assert!(v.is_empty()) + /// ``` pub fn clear(&mut self) { self.resize(0); } + /// Append a new byte at the end of this CryptoVec. pub fn push(&mut self, s: u8) { let size = self.size; self.resize(size + 1); unsafe { *self.p.add(size) = s } } - pub fn extend(&mut self, s: &[u8]) { - let size = self.size; - self.resize(size + s.len()); - unsafe { - std::ptr::copy_nonoverlapping(s.as_ptr(), self.p.add(size), s.len()); - } - } - + /// Append a new u32, big endian-encoded, at the end of this CryptoVec. + /// + /// ``` + /// let mut v = russh_cryptovec::CryptoVec::new(); + /// let n = 43554; + /// v.push_u32_be(n); + /// assert_eq!(n, v.read_u32_be(0)) + /// ``` pub fn push_u32_be(&mut self, s: u32) { let s = s.to_be(); let x: [u8; 4] = s.to_ne_bytes(); self.extend(&x) } + /// Read a big endian-encoded u32 from this CryptoVec, with the + /// first byte at position `i`. + /// + /// ``` + /// let mut v = russh_cryptovec::CryptoVec::new(); + /// let n = 99485710; + /// v.push_u32_be(n); + /// assert_eq!(n, v.read_u32_be(0)) + /// ``` pub fn read_u32_be(&self, i: usize) -> u32 { assert!(i + 4 <= self.size); let mut x: u32 = 0; unsafe { - std::ptr::copy_nonoverlapping(self.p.add(i) as *const u32, &mut x as *mut u32, 1); + memcpy((&mut x) as *mut u32, self.p.add(i), 4); } u32::from_be(x) } + /// Read `n_bytes` from `r`, and append them at the end of this + /// `CryptoVec`. Returns the number of bytes read (and appended). + pub fn read( + &mut self, + n_bytes: usize, + mut r: R, + ) -> Result { + let cur_size = self.size; + self.resize(cur_size + n_bytes); + let s = unsafe { std::slice::from_raw_parts_mut(self.p.add(cur_size), n_bytes) }; + // Resize the buffer to its appropriate size. + match r.read(s) { + Ok(n) => { + self.resize(cur_size + n); + Ok(n) + } + Err(e) => { + self.resize(cur_size); + Err(e) + } + } + } + + /// Write all this CryptoVec to the provided `Write`. Returns the + /// number of bytes actually written. + /// + /// ``` + /// let mut v = russh_cryptovec::CryptoVec::new(); + /// v.extend(b"blabla"); + /// let mut s = std::io::stdout(); + /// v.write_all_from(0, &mut s).unwrap(); + /// ``` pub fn write_all_from( &self, offset: usize, @@ -240,12 +322,37 @@ impl CryptoVec { } } + /// Resize this CryptoVec, returning a mutable borrow to the extra bytes. + /// + /// ``` + /// let mut v = russh_cryptovec::CryptoVec::new(); + /// v.resize_mut(4).clone_from_slice(b"test"); + /// ``` pub fn resize_mut(&mut self, n: usize) -> &mut [u8] { let size = self.size; self.resize(size + n); unsafe { std::slice::from_raw_parts_mut(self.p.add(size), n) } } + /// Append a slice at the end of this CryptoVec. + /// + /// ``` + /// let mut v = russh_cryptovec::CryptoVec::new(); + /// v.extend(b"test"); + /// ``` + pub fn extend(&mut self, s: &[u8]) { + let size = self.size; + self.resize(size + s.len()); + unsafe { + std::ptr::copy_nonoverlapping(s.as_ptr(), self.p.add(size), s.len()); + } + } + + /// Create a `CryptoVec` from a slice + /// + /// ``` + /// russh_cryptovec::CryptoVec::from_slice(b"test"); + /// ``` pub fn from_slice(s: &[u8]) -> CryptoVec { let mut v = CryptoVec::new(); v.resize(s.len()); @@ -280,17 +387,17 @@ impl Drop for CryptoVec { } } +// DocTests cannot be run on with wasm_bindgen_test #[cfg(test)] +#[cfg(target_arch = "wasm32")] mod test { use wasm_bindgen_test::wasm_bindgen_test; use super::CryptoVec; - #[cfg(target_arch = "wasm32")] wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); - #[test] #[wasm_bindgen_test] fn test_new() { let crypto_vec = CryptoVec::new(); @@ -298,7 +405,6 @@ mod test { assert_eq!(crypto_vec.capacity, 0); } - #[test] #[wasm_bindgen_test] fn test_resize_expand() { let mut crypto_vec = CryptoVec::new_zeroed(5); @@ -308,7 +414,6 @@ mod test { assert!(crypto_vec.iter().skip(5).all(|&x| x == 0)); // Ensure newly added elements are zeroed } - #[test] #[wasm_bindgen_test] fn test_resize_shrink() { let mut crypto_vec = CryptoVec::new_zeroed(10); @@ -318,7 +423,6 @@ mod test { assert_eq!(crypto_vec.len(), 5); } - #[test] #[wasm_bindgen_test] fn test_push() { let mut crypto_vec = CryptoVec::new(); @@ -329,7 +433,6 @@ mod test { assert_eq!(crypto_vec[1], 2); } - #[test] #[wasm_bindgen_test] fn test_write_trait() { use std::io::Write; @@ -341,7 +444,6 @@ mod test { assert_eq!(crypto_vec.as_ref(), &[1, 2, 3]); } - #[test] #[wasm_bindgen_test] fn test_as_ref_as_mut() { let mut crypto_vec = CryptoVec::new_zeroed(5); @@ -352,7 +454,6 @@ mod test { assert_eq!(crypto_vec[0], 1); } - #[test] #[wasm_bindgen_test] fn test_from_string() { let input = String::from("hello"); @@ -360,7 +461,6 @@ mod test { assert_eq!(crypto_vec.as_ref(), b"hello"); } - #[test] #[wasm_bindgen_test] fn test_from_vec() { let input = vec![1, 2, 3, 4]; @@ -368,7 +468,6 @@ mod test { assert_eq!(crypto_vec.as_ref(), &[1, 2, 3, 4]); } - #[test] #[wasm_bindgen_test] fn test_index() { let crypto_vec = CryptoVec::from(vec![1, 2, 3, 4, 5]); @@ -377,7 +476,6 @@ mod test { assert_eq!(&crypto_vec[1..3], &[2, 3]); } - #[test] #[wasm_bindgen_test] fn test_drop() { let mut crypto_vec = CryptoVec::new_zeroed(10); @@ -391,7 +489,6 @@ mod test { // it may be checked using tools like Valgrind or manual inspection. } - #[test] #[wasm_bindgen_test] fn test_new_zeroed() { let crypto_vec = CryptoVec::new_zeroed(10); @@ -400,7 +497,6 @@ mod test { assert!(crypto_vec.iter().all(|&x| x == 0)); // Ensure all bytes are zeroed } - #[test] #[wasm_bindgen_test] fn test_push_u32_be() { let mut crypto_vec = CryptoVec::new(); @@ -410,7 +506,6 @@ mod test { assert_eq!(crypto_vec.read_u32_be(0), value); } - #[test] #[wasm_bindgen_test] fn test_read_u32_be() { @@ -420,7 +515,6 @@ mod test { assert_eq!(crypto_vec.read_u32_be(0), value); } - #[test] #[wasm_bindgen_test] fn test_clear() { let mut crypto_vec = CryptoVec::new(); @@ -429,7 +523,6 @@ mod test { assert!(crypto_vec.is_empty()); } - #[test] #[wasm_bindgen_test] fn test_extend() { let mut crypto_vec = CryptoVec::new(); @@ -437,7 +530,6 @@ mod test { assert_eq!(crypto_vec.as_ref(), b"test"); } - #[test] #[wasm_bindgen_test] fn test_write_all_from() { let mut crypto_vec = CryptoVec::new(); @@ -449,7 +541,6 @@ mod test { assert_eq!(output, b"blabla"); } - #[test] #[wasm_bindgen_test] fn test_resize_mut() { let mut crypto_vec = CryptoVec::new(); @@ -457,15 +548,4 @@ mod test { assert_eq!(crypto_vec.as_ref(), b"test"); } - #[cfg(target_pointer_width = "64")] - #[test] - fn test_large_resize_panics() { - let result = std::panic::catch_unwind(|| { - let mut vec = CryptoVec::new(); - vec.push(42); // Write something into the vector - - vec.resize(1_000_000_000_000); // Intentionally large resize - }); - assert!(result.is_err()); // Expecting a panic on large allocation - } } diff --git a/cryptovec/src/lib.rs b/cryptovec/src/lib.rs index 056b4f55..9403987e 100644 --- a/cryptovec/src/lib.rs +++ b/cryptovec/src/lib.rs @@ -5,6 +5,21 @@ clippy::panic )] +// Copyright 2016 Pierre-Étienne Meunier +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + // Re-export CryptoVec from the cryptovec module mod cryptovec; pub use cryptovec::CryptoVec; diff --git a/cryptovec/src/platform/mod.rs b/cryptovec/src/platform/mod.rs index ab23326c..63b715a7 100644 --- a/cryptovec/src/platform/mod.rs +++ b/cryptovec/src/platform/mod.rs @@ -10,14 +10,14 @@ mod wasm; // Re-export functions based on the platform #[cfg(windows)] -pub use windows::{munlock,mlock,memset}; +pub use windows::{memcpy, memset, mlock, munlock}; #[cfg(not(windows))] #[cfg(not(target_arch = "wasm32"))] -pub use unix::{munlock, mlock, memset}; +pub use unix::{memcpy, memset, mlock, munlock}; #[cfg(target_arch = "wasm32")] -pub use wasm::{munlock, mlock, memset}; +pub use wasm::{memcpy, memset, mlock, munlock}; #[cfg(test)] mod tests { @@ -38,4 +38,4 @@ mod tests { memset(buf.as_mut_ptr(), 0xff, 5); assert_eq!(buf, [0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0]); } -} \ No newline at end of file +} diff --git a/cryptovec/src/platform/unix.rs b/cryptovec/src/platform/unix.rs index 45707c0f..1c5ab44a 100644 --- a/cryptovec/src/platform/unix.rs +++ b/cryptovec/src/platform/unix.rs @@ -24,4 +24,10 @@ pub fn memset(ptr: *mut u8, value: i32, size: usize) { unsafe { libc::memset(ptr as *mut c_void, value, size); } +} + +pub fn memcpy(dest: *mut u32, src: *const u8, size: usize) { + unsafe { + libc::memcpy(dest as *mut c_void, src as *const c_void, size); + } } \ No newline at end of file diff --git a/cryptovec/src/platform/wasm.rs b/cryptovec/src/platform/wasm.rs index fd41c3ec..d1d192e5 100644 --- a/cryptovec/src/platform/wasm.rs +++ b/cryptovec/src/platform/wasm.rs @@ -13,3 +13,13 @@ pub fn memset(ptr: *mut u8, value: i32, size: usize) { std::ptr::write_bytes(ptr, byte_value, size); } } + +pub fn memcpy(dest: *mut u32, src: *const u8, size: usize) { + unsafe { + // Convert dest to *mut u8 for byte-wise copying + let dest_bytes = dest as *mut u8; + + // Use std::ptr::copy_nonoverlapping to copy the data + std::ptr::copy_nonoverlapping(src, dest_bytes, size); + } +} diff --git a/cryptovec/src/platform/windows.rs b/cryptovec/src/platform/windows.rs index 930dbbf5..ae03018c 100644 --- a/cryptovec/src/platform/windows.rs +++ b/cryptovec/src/platform/windows.rs @@ -20,4 +20,10 @@ pub fn memset(ptr: *mut u8, value: i32, size: usize) { unsafe { libc::memset(ptr as *mut c_void, value, size); } +} + +pub fn memcpy(dest: *mut u32, src: *const u8, size: usize) { + unsafe { + libc::memcpy(dest as *mut c_void, src as *const c_void, size); + } } \ No newline at end of file From 9413a6b4e49f5a31ec482add46f9e00adccb0417 Mon Sep 17 00:00:00 2001 From: irving ou Date: Thu, 19 Sep 2024 14:52:53 -0400 Subject: [PATCH 05/15] fmt --- cryptovec/src/cryptovec.rs | 1 - cryptovec/src/lib.rs | 2 +- cryptovec/src/platform/unix.rs | 7 +++---- cryptovec/src/platform/wasm.rs | 2 +- cryptovec/src/platform/windows.rs | 2 +- 5 files changed, 6 insertions(+), 8 deletions(-) diff --git a/cryptovec/src/cryptovec.rs b/cryptovec/src/cryptovec.rs index ec1c3446..21709673 100644 --- a/cryptovec/src/cryptovec.rs +++ b/cryptovec/src/cryptovec.rs @@ -547,5 +547,4 @@ mod test { crypto_vec.resize_mut(4).clone_from_slice(b"test"); assert_eq!(crypto_vec.as_ref(), b"test"); } - } diff --git a/cryptovec/src/lib.rs b/cryptovec/src/lib.rs index 9403987e..e2b3d54d 100644 --- a/cryptovec/src/lib.rs +++ b/cryptovec/src/lib.rs @@ -25,4 +25,4 @@ mod cryptovec; pub use cryptovec::CryptoVec; // Platform-specific modules -mod platform; \ No newline at end of file +mod platform; diff --git a/cryptovec/src/platform/unix.rs b/cryptovec/src/platform/unix.rs index 1c5ab44a..c857a85c 100644 --- a/cryptovec/src/platform/unix.rs +++ b/cryptovec/src/platform/unix.rs @@ -1,8 +1,7 @@ use crate::CryptoVec; -use libc::{mlock, munlock, c_void}; +use libc::{c_void, mlock, munlock}; use std::alloc; - /// Unlock memory on drop for Unix-based systems. pub fn munlock(ptr: *const u8, len: usize) { unsafe { @@ -12,7 +11,7 @@ pub fn munlock(ptr: *const u8, len: usize) { } } -pub fn mlock (ptr: *const u8, len: usize) { +pub fn mlock(ptr: *const u8, len: usize) { unsafe { if mlock(ptr as *const c_void, len) != 0 { panic!("Failed to lock memory."); @@ -30,4 +29,4 @@ pub fn memcpy(dest: *mut u32, src: *const u8, size: usize) { unsafe { libc::memcpy(dest as *mut c_void, src as *const c_void, size); } -} \ No newline at end of file +} diff --git a/cryptovec/src/platform/wasm.rs b/cryptovec/src/platform/wasm.rs index d1d192e5..77f1cc8e 100644 --- a/cryptovec/src/platform/wasm.rs +++ b/cryptovec/src/platform/wasm.rs @@ -10,7 +10,7 @@ pub fn mlock(_ptr: *const u8, _len: usize) -> i32 { pub fn memset(ptr: *mut u8, value: i32, size: usize) { let byte_value = value as u8; // Extract the least significant byte directly unsafe { - std::ptr::write_bytes(ptr, byte_value, size); + std::ptr::write_bytes(ptr, byte_value, size); } } diff --git a/cryptovec/src/platform/windows.rs b/cryptovec/src/platform/windows.rs index ae03018c..968a3e5d 100644 --- a/cryptovec/src/platform/windows.rs +++ b/cryptovec/src/platform/windows.rs @@ -26,4 +26,4 @@ pub fn memcpy(dest: *mut u32, src: *const u8, size: usize) { unsafe { libc::memcpy(dest as *mut c_void, src as *const c_void, size); } -} \ No newline at end of file +} From 02c0b1fa45ec8ce5126f658678857ef4c4ba03a6 Mon Sep 17 00:00:00 2001 From: irving ou Date: Fri, 20 Sep 2024 09:24:49 -0400 Subject: [PATCH 06/15] fix build on linux --- cryptovec/src/platform/unix.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/cryptovec/src/platform/unix.rs b/cryptovec/src/platform/unix.rs index c857a85c..e82feca3 100644 --- a/cryptovec/src/platform/unix.rs +++ b/cryptovec/src/platform/unix.rs @@ -1,20 +1,18 @@ -use crate::CryptoVec; -use libc::{c_void, mlock, munlock}; -use std::alloc; +use libc::c_void; /// Unlock memory on drop for Unix-based systems. pub fn munlock(ptr: *const u8, len: usize) { unsafe { - if munlock(ptr as *const c_void, len) != 0 { - panic!("Failed to unlock memory."); + if libc::munlock(ptr as *const c_void, len) != 0 { + panic!("Failed to unlock memory"); } } } pub fn mlock(ptr: *const u8, len: usize) { unsafe { - if mlock(ptr as *const c_void, len) != 0 { - panic!("Failed to lock memory."); + if libc::mlock(ptr as *const c_void, len) != 0 { + panic!("Failed to lock memory"); } } } From 60af4fcb71b433e503b9f2a3502bb39a3ddedb3c Mon Sep 17 00:00:00 2001 From: irving ou Date: Fri, 20 Sep 2024 09:35:58 -0400 Subject: [PATCH 07/15] Squashed commit of the following: commit 686cd892ae1d27b21db7a28defcd69cc97086fa1 Author: irving ou Date: Fri Sep 20 09:29:24 2024 -0400 update cargo.toml commit 25c0847d32c9bf8a0bff60cc529418196eadf114 Author: irving ou Date: Thu Sep 19 14:54:43 2024 -0400 Update time.rs commit 0d7ad5cd63aa3b68d2b1a75ceee3e21404edf477 Author: irving ou Date: Thu Sep 19 14:51:45 2024 -0400 fmt commit c9667a6958a25780772fa1c765b558cd30822ea4 Author: irving ou Date: Thu Sep 19 14:50:56 2024 -0400 allow dead code commit e1e9e8ca871d8cadb582fe693121d9ed7a195ade Author: irving ou Date: Thu Sep 19 14:45:38 2024 -0400 Create runtime.rs commit 77bea6214a7573a0a53376d2847267246067c4b5 Author: irving ou Date: Thu Sep 19 14:45:03 2024 -0400 Add comments commit 2cf485ce8ca75237d922f6ccb43722c02544c3c9 Author: irving ou Date: Thu Sep 19 14:44:03 2024 -0400 feat(wasm): Add util creates for runtime,future and time primitives for wasm --- Cargo.toml | 2 +- russh-util/Cargo.toml | 26 +++++++ russh-util/src/future.rs | 158 ++++++++++++++++++++++++++++++++++++++ russh-util/src/lib.rs | 5 ++ russh-util/src/runtime.rs | 16 ++++ russh-util/src/time.rs | 26 +++++++ 6 files changed, 232 insertions(+), 1 deletion(-) create mode 100644 russh-util/Cargo.toml create mode 100644 russh-util/src/future.rs create mode 100644 russh-util/src/lib.rs create mode 100644 russh-util/src/runtime.rs create mode 100644 russh-util/src/time.rs diff --git a/Cargo.toml b/Cargo.toml index 8ba84e34..2b0a41ee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["russh-keys", "russh", "russh-config", "cryptovec", "pageant"] +members = ["russh-keys", "russh", "russh-config", "cryptovec", "pageant", "russh-util"] [patch.crates-io] russh = { path = "russh" } diff --git a/russh-util/Cargo.toml b/russh-util/Cargo.toml new file mode 100644 index 00000000..e44901c2 --- /dev/null +++ b/russh-util/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "russh-util" +version = "0.1.0" +edition = "2021" +rust-version = "1.65" + +[dependencies] +chrono = "0.4.38" + +[dev-dependencies] +futures-executor = "0.3.13" +static_assertions = "1.1.0" + + +[target.'cfg(target_arch = "wasm32")'.dependencies] +wasm-bindgen = "0.2" +wasm-bindgen-futures = "0.4.43" + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +tokio = { version = "1.17", features = [ + "io-util", + "macros", + "sync", + "rt-multi-thread", + "rt", +] } diff --git a/russh-util/src/future.rs b/russh-util/src/future.rs new file mode 100644 index 00000000..dab5f8d3 --- /dev/null +++ b/russh-util/src/future.rs @@ -0,0 +1,158 @@ +/// This file is a copy of the `ReusableBoxFuture` type from the `tokio-util` crate. +use std::alloc::Layout; +use std::fmt; +use std::future::{self, Future}; +use std::mem::{self, ManuallyDrop}; +use std::pin::Pin; +use std::ptr; +use std::task::{Context, Poll}; + +/// A reusable `Pin + Send + 'a>>`. +/// +/// This type lets you replace the future stored in the box without +/// reallocating when the size and alignment permits this. +pub struct ReusableBoxFuture<'a, T> { + boxed: Pin + Send + 'a>>, +} + +impl<'a, T> ReusableBoxFuture<'a, T> { + /// Create a new `ReusableBoxFuture` containing the provided future. + pub fn new(future: F) -> Self + where + F: Future + Send + 'a, + { + Self { + boxed: Box::pin(future), + } + } + + /// Replace the future currently stored in this box. + /// + /// This reallocates if and only if the layout of the provided future is + /// different from the layout of the currently stored future. + pub fn set(&mut self, future: F) + where + F: Future + Send + 'a, + { + if let Err(future) = self.try_set(future) { + *self = Self::new(future); + } + } + + /// Replace the future currently stored in this box. + /// + /// This function never reallocates, but returns an error if the provided + /// future has a different size or alignment from the currently stored + /// future. + pub fn try_set(&mut self, future: F) -> Result<(), F> + where + F: Future + Send + 'a, + { + // If we try to inline the contents of this function, the type checker complains because + // the bound `T: 'a` is not satisfied in the call to `pending()`. But by putting it in an + // inner function that doesn't have `T` as a generic parameter, we implicitly get the bound + // `F::Output: 'a` transitively through `F: 'a`, allowing us to call `pending()`. + #[inline(always)] + fn real_try_set<'a, F>( + this: &mut ReusableBoxFuture<'a, F::Output>, + future: F, + ) -> Result<(), F> + where + F: Future + Send + 'a, + { + // future::Pending is a ZST so this never allocates. + let boxed = mem::replace(&mut this.boxed, Box::pin(future::pending())); + reuse_pin_box(boxed, future, |boxed| this.boxed = Pin::from(boxed)) + } + + real_try_set(self, future) + } + + /// Get a pinned reference to the underlying future. + pub fn get_pin(&mut self) -> Pin<&mut (dyn Future + Send)> { + self.boxed.as_mut() + } + + /// Poll the future stored inside this box. + pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll { + self.get_pin().poll(cx) + } +} + +impl Future for ReusableBoxFuture<'_, T> { + type Output = T; + + /// Poll the future stored inside this box. + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::into_inner(self).get_pin().poll(cx) + } +} + +// The only method called on self.boxed is poll, which takes &mut self, so this +// struct being Sync does not permit any invalid access to the Future, even if +// the future is not Sync. +unsafe impl Sync for ReusableBoxFuture<'_, T> {} + +impl fmt::Debug for ReusableBoxFuture<'_, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ReusableBoxFuture").finish() + } +} + +fn reuse_pin_box(boxed: Pin>, new_value: U, callback: F) -> Result +where + F: FnOnce(Box) -> O, +{ + let layout = Layout::for_value::(&*boxed); + if layout != Layout::new::() { + return Err(new_value); + } + + // SAFETY: We don't ever construct a non-pinned reference to the old `T` from now on, and we + // always drop the `T`. + let raw: *mut T = Box::into_raw(unsafe { Pin::into_inner_unchecked(boxed) }); + + // When dropping the old value panics, we still want to call `callback` — so move the rest of + // the code into a guard type. + let guard = CallOnDrop::new(|| { + let raw: *mut U = raw.cast::(); + unsafe { raw.write(new_value) }; + + // SAFETY: + // - `T` and `U` have the same layout. + // - `raw` comes from a `Box` that uses the same allocator as this one. + // - `raw` points to a valid instance of `U` (we just wrote it in). + let boxed = unsafe { Box::from_raw(raw) }; + + callback(boxed) + }); + + // Drop the old value. + unsafe { ptr::drop_in_place(raw) }; + + // Run the rest of the code. + Ok(guard.call()) +} + +struct CallOnDrop O> { + f: ManuallyDrop, +} + +impl O> CallOnDrop { + fn new(f: F) -> Self { + let f = ManuallyDrop::new(f); + Self { f } + } + fn call(self) -> O { + let mut this = ManuallyDrop::new(self); + let f = unsafe { ManuallyDrop::take(&mut this.f) }; + f() + } +} + +impl O> Drop for CallOnDrop { + fn drop(&mut self) { + let f = unsafe { ManuallyDrop::take(&mut self.f) }; + f(); + } +} diff --git a/russh-util/src/lib.rs b/russh-util/src/lib.rs new file mode 100644 index 00000000..a92f8aee --- /dev/null +++ b/russh-util/src/lib.rs @@ -0,0 +1,5 @@ +#![allow(dead_code)] // To be removed when full wasm support is added. + +pub mod future; +pub mod runtime; +pub mod time; diff --git a/russh-util/src/runtime.rs b/russh-util/src/runtime.rs new file mode 100644 index 00000000..341e1c42 --- /dev/null +++ b/russh-util/src/runtime.rs @@ -0,0 +1,16 @@ +use std::future::Future; + +pub fn spawn(future: F) +where + F: Future + 'static + Send, +{ + #[cfg(target_arch = "wasm32")] + { + wasm_bindgen_futures::spawn_local(future); + } + + #[cfg(not(target_arch = "wasm32"))] + { + tokio::spawn(future); + } +} diff --git a/russh-util/src/time.rs b/russh-util/src/time.rs new file mode 100644 index 00000000..b26d0982 --- /dev/null +++ b/russh-util/src/time.rs @@ -0,0 +1,26 @@ +#[cfg(not(target_arch = "wasm32"))] +pub use std::time::Instant; + +#[cfg(target_arch = "wasm32")] +pub use wasm::Instant; + +mod wasm { + #[derive(Debug, Clone, Copy)] + pub struct Instant { + inner: chrono::DateTime, + } + + impl Instant { + pub fn now() -> Self { + Instant { + inner: chrono::Utc::now(), + } + } + + pub fn duration_since(&self, earlier: Instant) -> std::time::Duration { + (self.inner - earlier.inner) + .to_std() + .expect("Duration is negative") + } + } +} From 6da4e568c501560552a4d1e53cb1cfc96bf748ca Mon Sep 17 00:00:00 2001 From: irving ou Date: Fri, 20 Sep 2024 13:58:30 -0400 Subject: [PATCH 08/15] feat(wasm): basic integration for wasm --- Cargo.toml | 1 + russh-keys/Cargo.toml | 19 ++- russh-keys/src/known_hosts.rs | 245 ++++++++++++++++++++++++++++++++++ russh-keys/src/lib.rs | 243 ++------------------------------- russh-util/Cargo.toml | 10 +- russh-util/src/future.rs | 158 ---------------------- russh-util/src/lib.rs | 3 - russh-util/src/runtime.rs | 106 ++++++++++++++- russh-util/src/time.rs | 1 + russh/Cargo.toml | 18 +-- russh/src/auth.rs | 8 +- russh/src/client/encrypted.rs | 2 +- russh/src/client/kex.rs | 7 +- russh/src/client/mod.rs | 10 +- russh/src/kex/mod.rs | 1 + russh/src/lib.rs | 8 +- russh/src/msg.rs | 18 ++- russh/src/negotiation.rs | 37 +++++ russh/src/session.rs | 10 +- 19 files changed, 462 insertions(+), 443 deletions(-) create mode 100644 russh-keys/src/known_hosts.rs delete mode 100644 russh-util/src/future.rs diff --git a/Cargo.toml b/Cargo.toml index 2b0a41ee..9dc05561 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,6 @@ [workspace] members = ["russh-keys", "russh", "russh-config", "cryptovec", "pageant", "russh-util"] +resolver = "2" [patch.crates-io] russh = { path = "russh" } diff --git a/russh-keys/Cargo.toml b/russh-keys/Cargo.toml index 80fbb610..aca59e5f 100644 --- a/russh-keys/Cargo.toml +++ b/russh-keys/Cargo.toml @@ -22,7 +22,6 @@ byteorder = { workspace = true } data-encoding = "2.3" digest = { workspace = true } der = "0.7" -home = "0.5" ecdsa = "0.16" ed25519-dalek = { version = "2.0", features = ["rand_core", "pkcs8"] } elliptic-curve = "0.13" @@ -52,6 +51,16 @@ spki = "0.7" ssh-encoding = { workspace = true } ssh-key = { workspace = true } thiserror = { workspace = true } +typenum = "1.17" +yasna = { version = "0.5.0", features = ["bit-vec", "num-bigint"], optional = true } +zeroize = "1.7" +getrandom = { version = "0.2.15", features = ["js"] } + +[features] +vendored-openssl = ["openssl", "openssl/vendored"] +legacy-ed25519-pkcs8-parser = ["yasna"] + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] tokio = { workspace = true, features = [ "io-util", "rt-multi-thread", @@ -59,13 +68,8 @@ tokio = { workspace = true, features = [ "net", ] } tokio-stream = { workspace = true } -typenum = "1.17" -yasna = { version = "0.5.0", features = ["bit-vec", "num-bigint"], optional = true } -zeroize = "1.7" +home = "0.5" -[features] -vendored-openssl = ["openssl", "openssl/vendored"] -legacy-ed25519-pkcs8-parser = ["yasna"] [target.'cfg(windows)'.dependencies] pageant = { version = "0.0.1-beta.3", path = "../pageant" } @@ -77,3 +81,4 @@ tokio = { workspace = true, features = ["test-util", "macros", "process"] } [package.metadata.docs.rs] features = ["openssl"] + diff --git a/russh-keys/src/known_hosts.rs b/russh-keys/src/known_hosts.rs new file mode 100644 index 00000000..ee9d01f7 --- /dev/null +++ b/russh-keys/src/known_hosts.rs @@ -0,0 +1,245 @@ +use std::borrow::Cow; +use std::fs::{File, OpenOptions}; +use std::io::{BufRead, BufReader, Read, Seek, SeekFrom, Write}; +use std::path::{Path, PathBuf}; + +use crate::{key, Error, PublicKeyBase64}; +use data_encoding::BASE64_MIME; +use hmac::{Hmac, Mac}; +use log::debug; +use sha1::Sha1; + +/// Check whether the host is known, from its standard location. +pub fn check_known_hosts(host: &str, port: u16, pubkey: &key::PublicKey) -> Result { + check_known_hosts_path(host, port, pubkey, known_hosts_path()?) +} + +/// Check that a server key matches the one recorded in file `path`. +pub fn check_known_hosts_path>( + host: &str, + port: u16, + pubkey: &key::PublicKey, + path: P, +) -> Result { + let check = known_host_keys_path(host, port, path)? + .into_iter() + .map( + |(line, recorded)| match (pubkey.name() == recorded.name(), *pubkey == recorded) { + (true, true) => Ok(true), + (true, false) => Err(Error::KeyChanged { line }), + _ => Ok(false), + }, + ) + // If any Err was returned, we stop here + .collect::, Error>>()? + .into_iter() + // Now we check the results for a match + .any(|x| x); + + Ok(check) +} + +#[cfg(target_os = "windows")] +fn known_hosts_path() -> Result { + if let Some(home_dir) = home::home_dir() { + Ok(home_dir.join("ssh").join("known_hosts")) + } else { + Err(Error::NoHomeDir) + } +} + +#[cfg(not(target_os = "windows"))] +fn known_hosts_path() -> Result { + if let Some(home_dir) = home::home_dir() { + Ok(home_dir.join(".ssh").join("known_hosts")) + } else { + Err(Error::NoHomeDir) + } +} + +/// Get the server key that matches the one recorded in the user's known_hosts file. +pub fn known_host_keys(host: &str, port: u16) -> Result, Error> { + known_host_keys_path(host, port, known_hosts_path()?) +} + +/// Get the server key that matches the one recorded in `path`. +pub fn known_host_keys_path>( + host: &str, + port: u16, + path: P, +) -> Result, Error> { + use crate::parse_public_key_base64; + + let mut f = if let Ok(f) = File::open(path) { + BufReader::new(f) + } else { + return Ok(vec![]); + }; + let mut buffer = String::new(); + + let host_port = if port == 22 { + Cow::Borrowed(host) + } else { + Cow::Owned(format!("[{}]:{}", host, port)) + }; + debug!("host_port = {:?}", host_port); + let mut line = 1; + let mut matches = vec![]; + while f.read_line(&mut buffer)? > 0 { + { + if buffer.as_bytes().first() == Some(&b'#') { + buffer.clear(); + continue; + } + debug!("line = {:?}", buffer); + let mut s = buffer.split(' '); + let hosts = s.next(); + let _ = s.next(); + let key = s.next(); + if let (Some(h), Some(k)) = (hosts, key) { + debug!("{:?} {:?}", h, k); + if match_hostname(&host_port, h) { + matches.push((line, parse_public_key_base64(k)?)); + } + } + } + buffer.clear(); + line += 1; + } + Ok(matches) +} + +fn match_hostname(host: &str, pattern: &str) -> bool { + for entry in pattern.split(',') { + if entry.starts_with("|1|") { + let mut parts = entry.split('|').skip(2); + let Some(Ok(salt)) = parts.next().map(|p| BASE64_MIME.decode(p.as_bytes())) else { + continue; + }; + let Some(Ok(hash)) = parts.next().map(|p| BASE64_MIME.decode(p.as_bytes())) else { + continue; + }; + if let Ok(hmac) = Hmac::::new_from_slice(&salt) { + if hmac.chain_update(host).verify_slice(&hash).is_ok() { + return true; + } + } + } else if host == entry { + return true; + } + } + false +} + +/// Record a host's public key into the user's known_hosts file. +pub fn learn_known_hosts(host: &str, port: u16, pubkey: &key::PublicKey) -> Result<(), Error> { + learn_known_hosts_path(host, port, pubkey, known_hosts_path()?) +} + +/// Record a host's public key into a nonstandard location. +pub fn learn_known_hosts_path>( + host: &str, + port: u16, + pubkey: &key::PublicKey, + path: P, +) -> Result<(), Error> { + if let Some(parent) = path.as_ref().parent() { + std::fs::create_dir_all(parent)? + } + let mut file = OpenOptions::new() + .read(true) + .append(true) + .create(true) + .open(path)?; + + // Test whether the known_hosts file ends with a \n + let mut buf = [0; 1]; + let mut ends_in_newline = false; + if file.seek(SeekFrom::End(-1)).is_ok() { + file.read_exact(&mut buf)?; + ends_in_newline = buf[0] == b'\n'; + } + + // Write the key. + file.seek(SeekFrom::End(0))?; + let mut file = std::io::BufWriter::new(file); + if !ends_in_newline { + file.write_all(b"\n")?; + } + if port != 22 { + write!(file, "[{}]:{} ", host, port)? + } else { + write!(file, "{} ", host)? + } + write_public_key_base64(&mut file, pubkey)?; + file.write_all(b"\n")?; + Ok(()) +} + +/// Write a public key onto the provided `Write`, encoded in base-64. +pub fn write_public_key_base64( + mut w: W, + publickey: &key::PublicKey, +) -> Result<(), Error> { + let pk = publickey.public_key_base64(); + writeln!(w, "{} {}", publickey.name(), pk)?; + Ok(()) +} + +#[cfg(test)] +mod test { + use crate::parse_public_key_base64; + use std::fs::File; + + use super::*; + + #[test] + fn test_check_known_hosts() { + env_logger::try_init().unwrap_or(()); + let dir = tempdir::TempDir::new("russh").unwrap(); + let path = dir.path().join("known_hosts"); + { + let mut f = File::create(&path).unwrap(); + f.write_all(b"[localhost]:13265 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJdD7y3aLq454yWBdwLWbieU1ebz9/cu7/QEXn9OIeZJ\n").unwrap(); + f.write_all(b"#pijul.org,37.120.161.53 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIA6rWI3G2sz07DnfFlrouTcysQlj2P+jpNSOEWD9OJ3X\n").unwrap(); + f.write_all(b"pijul.org,37.120.161.53 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIA6rWI3G1sz07DnfFlrouTcysQlj2P+jpNSOEWD9OJ3X\n").unwrap(); + f.write_all(b"|1|O33ESRMWPVkMYIwJ1Uw+n877jTo=|nuuC5vEqXlEZ/8BXQR7m619W6Ak= ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILIG2T/B0l0gaqj3puu510tu9N1OkQ4znY3LYuEm5zCF\n").unwrap(); + } + + // Valid key, non-standard port. + let host = "localhost"; + let port = 13265; + let hostkey = parse_public_key_base64( + "AAAAC3NzaC1lZDI1NTE5AAAAIJdD7y3aLq454yWBdwLWbieU1ebz9/cu7/QEXn9OIeZJ", + ) + .unwrap(); + assert!(check_known_hosts_path(host, port, &hostkey, &path).unwrap()); + + // Valid key, hashed. + let host = "example.com"; + let port = 22; + let hostkey = parse_public_key_base64( + "AAAAC3NzaC1lZDI1NTE5AAAAILIG2T/B0l0gaqj3puu510tu9N1OkQ4znY3LYuEm5zCF", + ) + .unwrap(); + assert!(check_known_hosts_path(host, port, &hostkey, &path).unwrap()); + + // Valid key, several hosts, port 22 + let host = "pijul.org"; + let port = 22; + let hostkey = parse_public_key_base64( + "AAAAC3NzaC1lZDI1NTE5AAAAIA6rWI3G1sz07DnfFlrouTcysQlj2P+jpNSOEWD9OJ3X", + ) + .unwrap(); + assert!(check_known_hosts_path(host, port, &hostkey, &path).unwrap()); + + // Now with the key in a comment above, check that it's not recognized + let host = "pijul.org"; + let port = 22; + let hostkey = parse_public_key_base64( + "AAAAC3NzaC1lZDI1NTE5AAAAIA6rWI3G2sz07DnfFlrouTcysQlj2P+jpNSOEWD9OJ3X", + ) + .unwrap(); + assert!(check_known_hosts_path(host, port, &hostkey, &path).is_err()); + } +} diff --git a/russh-keys/src/lib.rs b/russh-keys/src/lib.rs index f7b82b7c..f13f3ec3 100644 --- a/russh-keys/src/lib.rs +++ b/russh-keys/src/lib.rs @@ -63,18 +63,14 @@ //! //! ``` -use std::borrow::Cow; -use std::fs::{File, OpenOptions}; -use std::io::{BufRead, BufReader, Read, Seek, SeekFrom, Write}; -use std::path::{Path, PathBuf}; +use std::fs::File; +use std::io::Read; +use std::path::Path; use aes::cipher::block_padding::UnpadError; use aes::cipher::inout::PadError; use byteorder::{BigEndian, WriteBytesExt}; use data_encoding::BASE64_MIME; -use hmac::{Hmac, Mac}; -use log::debug; -use sha1::Sha1; use ssh_key::Certificate; use thiserror::Error; @@ -95,8 +91,15 @@ mod backend; mod backend; /// OpenSSH agent protocol implementation +#[cfg(not(target_arch = "wasm32"))] pub mod agent; +#[cfg(not(target_arch = "wasm32"))] +pub mod known_hosts; + +#[cfg(not(target_arch = "wasm32"))] +pub use known_hosts::{check_known_hosts, check_known_hosts_path}; + #[derive(Debug, Error)] pub enum Error { /// The key could not be read, for an unknown reason @@ -319,16 +322,6 @@ fn write_ec_public_key(buf: &mut Vec, key: &ec::PublicKey) { buf.extend_ssh_string(&q); } -/// Write a public key onto the provided `Write`, encoded in base-64. -pub fn write_public_key_base64( - mut w: W, - publickey: &key::PublicKey, -) -> Result<(), Error> { - let pk = publickey.public_key_base64(); - writeln!(w, "{} {}", publickey.name(), pk)?; - Ok(()) -} - /// Load a secret key, deciphering it with the supplied password if necessary. pub fn load_secret_key>( secret_: P, @@ -358,171 +351,6 @@ fn is_base64_char(c: char) -> bool { || c == '=' } -/// Record a host's public key into the user's known_hosts file. -pub fn learn_known_hosts(host: &str, port: u16, pubkey: &key::PublicKey) -> Result<(), Error> { - learn_known_hosts_path(host, port, pubkey, known_hosts_path()?) -} - -/// Record a host's public key into a nonstandard location. -pub fn learn_known_hosts_path>( - host: &str, - port: u16, - pubkey: &key::PublicKey, - path: P, -) -> Result<(), Error> { - if let Some(parent) = path.as_ref().parent() { - std::fs::create_dir_all(parent)? - } - let mut file = OpenOptions::new() - .read(true) - .append(true) - .create(true) - .open(path)?; - - // Test whether the known_hosts file ends with a \n - let mut buf = [0; 1]; - let mut ends_in_newline = false; - if file.seek(SeekFrom::End(-1)).is_ok() { - file.read_exact(&mut buf)?; - ends_in_newline = buf[0] == b'\n'; - } - - // Write the key. - file.seek(SeekFrom::End(0))?; - let mut file = std::io::BufWriter::new(file); - if !ends_in_newline { - file.write_all(b"\n")?; - } - if port != 22 { - write!(file, "[{}]:{} ", host, port)? - } else { - write!(file, "{} ", host)? - } - write_public_key_base64(&mut file, pubkey)?; - file.write_all(b"\n")?; - Ok(()) -} - -/// Get the server key that matches the one recorded in the user's known_hosts file. -pub fn known_host_keys(host: &str, port: u16) -> Result, Error> { - known_host_keys_path(host, port, known_hosts_path()?) -} - -/// Get the server key that matches the one recorded in `path`. -pub fn known_host_keys_path>( - host: &str, - port: u16, - path: P, -) -> Result, Error> { - let mut f = if let Ok(f) = File::open(path) { - BufReader::new(f) - } else { - return Ok(vec![]); - }; - let mut buffer = String::new(); - - let host_port = if port == 22 { - Cow::Borrowed(host) - } else { - Cow::Owned(format!("[{}]:{}", host, port)) - }; - debug!("host_port = {:?}", host_port); - let mut line = 1; - let mut matches = vec![]; - while f.read_line(&mut buffer)? > 0 { - { - if buffer.as_bytes().first() == Some(&b'#') { - buffer.clear(); - continue; - } - debug!("line = {:?}", buffer); - let mut s = buffer.split(' '); - let hosts = s.next(); - let _ = s.next(); - let key = s.next(); - if let (Some(h), Some(k)) = (hosts, key) { - debug!("{:?} {:?}", h, k); - if match_hostname(&host_port, h) { - matches.push((line, parse_public_key_base64(k)?)); - } - } - } - buffer.clear(); - line += 1; - } - Ok(matches) -} - -fn match_hostname(host: &str, pattern: &str) -> bool { - for entry in pattern.split(',') { - if entry.starts_with("|1|") { - let mut parts = entry.split('|').skip(2); - let Some(Ok(salt)) = parts.next().map(|p| BASE64_MIME.decode(p.as_bytes())) else { - continue; - }; - let Some(Ok(hash)) = parts.next().map(|p| BASE64_MIME.decode(p.as_bytes())) else { - continue; - }; - if let Ok(hmac) = Hmac::::new_from_slice(&salt) { - if hmac.chain_update(host).verify_slice(&hash).is_ok() { - return true; - } - } - } else if host == entry { - return true; - } - } - false -} - -/// Check whether the host is known, from its standard location. -pub fn check_known_hosts(host: &str, port: u16, pubkey: &key::PublicKey) -> Result { - check_known_hosts_path(host, port, pubkey, known_hosts_path()?) -} - -/// Check that a server key matches the one recorded in file `path`. -pub fn check_known_hosts_path>( - host: &str, - port: u16, - pubkey: &key::PublicKey, - path: P, -) -> Result { - let check = known_host_keys_path(host, port, path)? - .into_iter() - .map( - |(line, recorded)| match (pubkey.name() == recorded.name(), *pubkey == recorded) { - (true, true) => Ok(true), - (true, false) => Err(Error::KeyChanged { line }), - _ => Ok(false), - }, - ) - // If any Err was returned, we stop here - .collect::, Error>>()? - .into_iter() - // Now we check the results for a match - .any(|x| x); - - Ok(check) -} - -#[cfg(target_os = "windows")] -fn known_hosts_path() -> Result { - if let Some(home_dir) = home::home_dir() { - Ok(home_dir.join("ssh").join("known_hosts")) - } else { - Err(Error::NoHomeDir) - } -} - -#[cfg(not(target_os = "windows"))] -fn known_hosts_path() -> Result { - if let Some(home_dir) = home::home_dir() { - Ok(home_dir.join(".ssh").join("known_hosts")) - } else { - Err(Error::NoHomeDir) - } -} - #[cfg(test)] mod test { use std::fs::File; @@ -530,6 +358,7 @@ mod test { #[cfg(unix)] use futures::Future; + use log::debug; use super::*; @@ -709,56 +538,6 @@ Ve0k2ddxoEsSE15H4lgNHM2iuYKzIqZJOReHRCTff6QGgMYPDqDfFfL1Hc1Ntql0pwAAAA ); } - #[test] - fn test_check_known_hosts() { - env_logger::try_init().unwrap_or(()); - let dir = tempdir::TempDir::new("russh").unwrap(); - let path = dir.path().join("known_hosts"); - { - let mut f = File::create(&path).unwrap(); - f.write_all(b"[localhost]:13265 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIJdD7y3aLq454yWBdwLWbieU1ebz9/cu7/QEXn9OIeZJ\n").unwrap(); - f.write_all(b"#pijul.org,37.120.161.53 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIA6rWI3G2sz07DnfFlrouTcysQlj2P+jpNSOEWD9OJ3X\n").unwrap(); - f.write_all(b"pijul.org,37.120.161.53 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIA6rWI3G1sz07DnfFlrouTcysQlj2P+jpNSOEWD9OJ3X\n").unwrap(); - f.write_all(b"|1|O33ESRMWPVkMYIwJ1Uw+n877jTo=|nuuC5vEqXlEZ/8BXQR7m619W6Ak= ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILIG2T/B0l0gaqj3puu510tu9N1OkQ4znY3LYuEm5zCF\n").unwrap(); - } - - // Valid key, non-standard port. - let host = "localhost"; - let port = 13265; - let hostkey = parse_public_key_base64( - "AAAAC3NzaC1lZDI1NTE5AAAAIJdD7y3aLq454yWBdwLWbieU1ebz9/cu7/QEXn9OIeZJ", - ) - .unwrap(); - assert!(check_known_hosts_path(host, port, &hostkey, &path).unwrap()); - - // Valid key, hashed. - let host = "example.com"; - let port = 22; - let hostkey = parse_public_key_base64( - "AAAAC3NzaC1lZDI1NTE5AAAAILIG2T/B0l0gaqj3puu510tu9N1OkQ4znY3LYuEm5zCF", - ) - .unwrap(); - assert!(check_known_hosts_path(host, port, &hostkey, &path).unwrap()); - - // Valid key, several hosts, port 22 - let host = "pijul.org"; - let port = 22; - let hostkey = parse_public_key_base64( - "AAAAC3NzaC1lZDI1NTE5AAAAIA6rWI3G1sz07DnfFlrouTcysQlj2P+jpNSOEWD9OJ3X", - ) - .unwrap(); - assert!(check_known_hosts_path(host, port, &hostkey, &path).unwrap()); - - // Now with the key in a comment above, check that it's not recognized - let host = "pijul.org"; - let port = 22; - let hostkey = parse_public_key_base64( - "AAAAC3NzaC1lZDI1NTE5AAAAIA6rWI3G2sz07DnfFlrouTcysQlj2P+jpNSOEWD9OJ3X", - ) - .unwrap(); - assert!(check_known_hosts_path(host, port, &hostkey, &path).is_err()); - } - #[test] fn test_parse_p256_public_key() { env_logger::try_init().unwrap_or(()); diff --git a/russh-util/Cargo.toml b/russh-util/Cargo.toml index e44901c2..a409480f 100644 --- a/russh-util/Cargo.toml +++ b/russh-util/Cargo.toml @@ -6,21 +6,15 @@ rust-version = "1.65" [dependencies] chrono = "0.4.38" +tokio = { version = "1.17", features = ["sync", "macros"] } [dev-dependencies] futures-executor = "0.3.13" static_assertions = "1.1.0" - [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen = "0.2" wasm-bindgen-futures = "0.4.43" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] -tokio = { version = "1.17", features = [ - "io-util", - "macros", - "sync", - "rt-multi-thread", - "rt", -] } +tokio = { version = "1.17", features = ["io-util", "rt-multi-thread", "rt"] } diff --git a/russh-util/src/future.rs b/russh-util/src/future.rs deleted file mode 100644 index dab5f8d3..00000000 --- a/russh-util/src/future.rs +++ /dev/null @@ -1,158 +0,0 @@ -/// This file is a copy of the `ReusableBoxFuture` type from the `tokio-util` crate. -use std::alloc::Layout; -use std::fmt; -use std::future::{self, Future}; -use std::mem::{self, ManuallyDrop}; -use std::pin::Pin; -use std::ptr; -use std::task::{Context, Poll}; - -/// A reusable `Pin + Send + 'a>>`. -/// -/// This type lets you replace the future stored in the box without -/// reallocating when the size and alignment permits this. -pub struct ReusableBoxFuture<'a, T> { - boxed: Pin + Send + 'a>>, -} - -impl<'a, T> ReusableBoxFuture<'a, T> { - /// Create a new `ReusableBoxFuture` containing the provided future. - pub fn new(future: F) -> Self - where - F: Future + Send + 'a, - { - Self { - boxed: Box::pin(future), - } - } - - /// Replace the future currently stored in this box. - /// - /// This reallocates if and only if the layout of the provided future is - /// different from the layout of the currently stored future. - pub fn set(&mut self, future: F) - where - F: Future + Send + 'a, - { - if let Err(future) = self.try_set(future) { - *self = Self::new(future); - } - } - - /// Replace the future currently stored in this box. - /// - /// This function never reallocates, but returns an error if the provided - /// future has a different size or alignment from the currently stored - /// future. - pub fn try_set(&mut self, future: F) -> Result<(), F> - where - F: Future + Send + 'a, - { - // If we try to inline the contents of this function, the type checker complains because - // the bound `T: 'a` is not satisfied in the call to `pending()`. But by putting it in an - // inner function that doesn't have `T` as a generic parameter, we implicitly get the bound - // `F::Output: 'a` transitively through `F: 'a`, allowing us to call `pending()`. - #[inline(always)] - fn real_try_set<'a, F>( - this: &mut ReusableBoxFuture<'a, F::Output>, - future: F, - ) -> Result<(), F> - where - F: Future + Send + 'a, - { - // future::Pending is a ZST so this never allocates. - let boxed = mem::replace(&mut this.boxed, Box::pin(future::pending())); - reuse_pin_box(boxed, future, |boxed| this.boxed = Pin::from(boxed)) - } - - real_try_set(self, future) - } - - /// Get a pinned reference to the underlying future. - pub fn get_pin(&mut self) -> Pin<&mut (dyn Future + Send)> { - self.boxed.as_mut() - } - - /// Poll the future stored inside this box. - pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll { - self.get_pin().poll(cx) - } -} - -impl Future for ReusableBoxFuture<'_, T> { - type Output = T; - - /// Poll the future stored inside this box. - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::into_inner(self).get_pin().poll(cx) - } -} - -// The only method called on self.boxed is poll, which takes &mut self, so this -// struct being Sync does not permit any invalid access to the Future, even if -// the future is not Sync. -unsafe impl Sync for ReusableBoxFuture<'_, T> {} - -impl fmt::Debug for ReusableBoxFuture<'_, T> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("ReusableBoxFuture").finish() - } -} - -fn reuse_pin_box(boxed: Pin>, new_value: U, callback: F) -> Result -where - F: FnOnce(Box) -> O, -{ - let layout = Layout::for_value::(&*boxed); - if layout != Layout::new::() { - return Err(new_value); - } - - // SAFETY: We don't ever construct a non-pinned reference to the old `T` from now on, and we - // always drop the `T`. - let raw: *mut T = Box::into_raw(unsafe { Pin::into_inner_unchecked(boxed) }); - - // When dropping the old value panics, we still want to call `callback` — so move the rest of - // the code into a guard type. - let guard = CallOnDrop::new(|| { - let raw: *mut U = raw.cast::(); - unsafe { raw.write(new_value) }; - - // SAFETY: - // - `T` and `U` have the same layout. - // - `raw` comes from a `Box` that uses the same allocator as this one. - // - `raw` points to a valid instance of `U` (we just wrote it in). - let boxed = unsafe { Box::from_raw(raw) }; - - callback(boxed) - }); - - // Drop the old value. - unsafe { ptr::drop_in_place(raw) }; - - // Run the rest of the code. - Ok(guard.call()) -} - -struct CallOnDrop O> { - f: ManuallyDrop, -} - -impl O> CallOnDrop { - fn new(f: F) -> Self { - let f = ManuallyDrop::new(f); - Self { f } - } - fn call(self) -> O { - let mut this = ManuallyDrop::new(self); - let f = unsafe { ManuallyDrop::take(&mut this.f) }; - f() - } -} - -impl O> Drop for CallOnDrop { - fn drop(&mut self) { - let f = unsafe { ManuallyDrop::take(&mut self.f) }; - f(); - } -} diff --git a/russh-util/src/lib.rs b/russh-util/src/lib.rs index a92f8aee..ba4302eb 100644 --- a/russh-util/src/lib.rs +++ b/russh-util/src/lib.rs @@ -1,5 +1,2 @@ -#![allow(dead_code)] // To be removed when full wasm support is added. - -pub mod future; pub mod runtime; pub mod time; diff --git a/russh-util/src/runtime.rs b/russh-util/src/runtime.rs index 341e1c42..348e0aad 100644 --- a/russh-util/src/runtime.rs +++ b/russh-util/src/runtime.rs @@ -1,16 +1,110 @@ -use std::future::Future; +#[cfg(not(target_arch = "wasm32"))] +pub use native::*; +#[cfg(target_arch = "wasm32")] +pub use wasm::*; -pub fn spawn(future: F) +#[derive(Debug)] +pub struct JoinError { + #[cfg(not(target_arch = "wasm32"))] + inner: tokio::task::JoinError, + #[cfg(target_arch = "wasm32")] + inner: tokio::sync::oneshot::error::RecvError, +} + +impl std::fmt::Display for JoinError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "JoinError: {}", self.inner) + } +} + +impl std::error::Error for JoinError {} + +pub struct JoinHandle where - F: Future + 'static + Send, + T: Send, { #[cfg(target_arch = "wasm32")] + handle: tokio::sync::oneshot::Receiver, + #[cfg(not(target_arch = "wasm32"))] + handle: tokio::task::JoinHandle, +} + +#[cfg(target_arch = "wasm32")] +pub mod wasm { + + use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + }; + + use crate::runtime::{JoinError, JoinHandle}; + + pub fn spawn(future: F) -> JoinHandle + where + F: Future + 'static + Send, + T: Send + 'static, { - wasm_bindgen_futures::spawn_local(future); + let (sender, receiver) = tokio::sync::oneshot::channel(); + wasm_bindgen_futures::spawn_local(async { + let result = future.await; + let result = sender.send(result); + if result.is_err() { + panic!("Failed to send result to receiver"); + } + }); + + JoinHandle { handle: receiver } } - #[cfg(not(target_arch = "wasm32"))] + impl Future for JoinHandle + where + T: Send, + { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match Pin::new(&mut self.handle).poll(cx) { + Poll::Ready(Ok(val)) => Poll::Ready(Ok(val)), + Poll::Ready(Err(e)) => Poll::Ready(Err(JoinError { inner: e })), + Poll::Pending => Poll::Pending, + } + } + } +} + +#[cfg(not(target_arch = "wasm32"))] +pub mod native { + + use crate::runtime::{JoinError, JoinHandle}; + + use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + }; + + pub fn spawn(future: F) -> JoinHandle + where + F: Future + 'static + Send, + T: Send + 'static, { - tokio::spawn(future); + let handle = tokio::spawn(future); + JoinHandle { handle } + } + + impl Future for JoinHandle + where + T: Send, + { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match Pin::new(&mut self.handle).poll(cx) { + Poll::Ready(Ok(val)) => Poll::Ready(Ok(val)), + Poll::Ready(Err(e)) => Poll::Ready(Err(JoinError { inner: e })), + Poll::Pending => Poll::Pending, + } + } } } diff --git a/russh-util/src/time.rs b/russh-util/src/time.rs index b26d0982..a5e1adc2 100644 --- a/russh-util/src/time.rs +++ b/russh-util/src/time.rs @@ -4,6 +4,7 @@ pub use std::time::Instant; #[cfg(target_arch = "wasm32")] pub use wasm::Instant; +#[cfg(target_arch = "wasm32")] mod wasm { #[derive(Debug, Clone, Copy)] pub struct Instant { diff --git a/russh/Cargo.toml b/russh/Cargo.toml index fecdf24d..673d07fb 100644 --- a/russh/Cargo.toml +++ b/russh/Cargo.toml @@ -38,7 +38,6 @@ hmac = { workspace = true } log = { workspace = true } num-bigint = { version = "0.4", features = ["rand"] } once_cell = "1.13" -openssl = { workspace = true, optional = true } p256 = { version = "0.13", features = ["ecdh"] } p384 = { version = "0.13", features = ["ecdh"] } p521 = { version = "0.13", features = ["ecdh"] } @@ -53,16 +52,9 @@ ssh-encoding = { workspace = true } ssh-key = { workspace = true } subtle = "2.4" thiserror = { workspace = true } -tokio = { workspace = true, features = [ - "io-util", - "rt-multi-thread", - "time", - "net", - "sync", - "macros", - "process", -] } +russh-util = { path = "../russh-util" } des = "0.8.1" +tokio = { workspace = true, features = ["io-util", "sync", "time"] } [dev-dependencies] anyhow = "1.0" @@ -77,7 +69,6 @@ tokio = { version = "1.17.0", features = [ "sync", "macros", ] } -russh-sftp = "2.0.0-beta.2" rand = "0.8.5" shell-escape = "0.1" tokio-fd = "0.3" @@ -86,3 +77,8 @@ ratatui = "0.26.0" [package.metadata.docs.rs] features = ["openssl"] + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +openssl = { workspace = true, optional = true } +russh-sftp = "2.0.0-beta.2" +tokio = { workspace = true } diff --git a/russh/src/auth.rs b/russh/src/auth.rs index 649d5ccd..e76bb9ea 100644 --- a/russh/src/auth.rs +++ b/russh/src/auth.rs @@ -18,7 +18,6 @@ use std::sync::Arc; use bitflags::bitflags; use ssh_key::Certificate; use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite}; use crate::keys::{encoding, key}; use crate::CryptoVec; @@ -59,7 +58,8 @@ pub enum AgentAuthError { Key(#[from] russh_keys::Error), } -impl Signer +#[cfg(not(target_arch = "wasm32"))] +impl Signer for russh_keys::agent::client::AgentClient { type Error = AgentAuthError; @@ -128,14 +128,17 @@ impl MethodSet { #[derive(Debug)] pub struct AuthRequest { pub methods: MethodSet, + #[cfg_attr(target_arch = "wasm32", allow(dead_code))] pub partial_success: bool, pub current: Option, + #[cfg_attr(target_arch = "wasm32", allow(dead_code))] pub rejection_count: usize, } #[doc(hidden)] #[derive(Debug)] pub enum CurrentRequest { + #[cfg_attr(target_arch = "wasm32", allow(dead_code))] PublicKey { #[allow(dead_code)] key: CryptoVec, @@ -144,6 +147,7 @@ pub enum CurrentRequest { sent_pk_ok: bool, }, KeyboardInteractive { + #[cfg_attr(target_arch = "wasm32", allow(dead_code))] submethods: String, }, } diff --git a/russh/src/client/encrypted.rs b/russh/src/client/encrypted.rs index 1a3ac7ba..c7998b06 100644 --- a/russh/src/client/encrypted.rs +++ b/russh/src/client/encrypted.rs @@ -123,7 +123,7 @@ impl Session { return Err(crate::Error::Kex.into()); } self.common.write_buffer.bytes = 0; - enc.last_rekey = std::time::Instant::now(); + enc.last_rekey = russh_util::time::Instant::now(); // Ok, NEWKEYS received, now encrypted. enc.flush_all_pending(); diff --git a/russh/src/client/kex.rs b/russh/src/client/kex.rs index 92de368a..e9b469c9 100644 --- a/russh/src/client/kex.rs +++ b/russh/src/client/kex.rs @@ -69,7 +69,12 @@ impl KexInit { write_buffer: &mut SSHBuffer, ) -> Result<(), crate::Error> { self.exchange.client_kex_init.clear(); - negotiation::write_kex(&config.preferred, &mut self.exchange.client_kex_init, None)?; + negotiation::write_kex( + &config.preferred, + &mut self.exchange.client_kex_init, + #[cfg(not(target_arch = "wasm32"))] + None, + )?; self.sent = true; cipher.write(&self.exchange.client_kex_init, write_buffer); Ok(()) diff --git a/russh/src/client/mod.rs b/russh/src/client/mod.rs index daca7ed9..7082492b 100644 --- a/russh/src/client/mod.rs +++ b/russh/src/client/mod.rs @@ -47,7 +47,6 @@ use futures::Future; use log::{debug, error, info, trace}; use ssh_key::Certificate; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; -use tokio::net::{TcpStream, ToSocketAddrs}; use tokio::pin; use tokio::sync::mpsc::{ channel, unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender, @@ -222,7 +221,7 @@ pub enum DisconnectReason + Send> { pub struct Handle { sender: Sender, receiver: UnboundedReceiver, - join: tokio::task::JoinHandle>, + join: russh_util::runtime::JoinHandle>, } impl Drop for Handle { @@ -709,12 +708,13 @@ impl Future for Handle { /// commands, etc. The future will resolve to an error if the connection fails. /// This function creates a connection to the `addr` specified using a /// [`tokio::net::TcpStream`] and then calls [`connect_stream`] under the hood. -pub async fn connect( +#[cfg(not(target_arch = "wasm32"))] +pub async fn connect( config: Arc, addrs: A, handler: H, ) -> Result, H::Error> { - let socket = TcpStream::connect(addrs) + let socket = tokio::net::TcpStream::connect(addrs) .await .map_err(crate::Error::from)?; connect_stream(config, socket, handler).await @@ -779,7 +779,7 @@ where ); session.read_ssh_id(sshid)?; let (kex_done_signal, kex_done_signal_rx) = oneshot::channel(); - let join = tokio::spawn(session.run(stream, handler, Some(kex_done_signal))); + let join = russh_util::runtime::spawn(session.run(stream, handler, Some(kex_done_signal))); if kex_done_signal_rx.await.is_err() { // kex_done_signal Sender is dropped when the session diff --git a/russh/src/kex/mod.rs b/russh/src/kex/mod.rs index c01ef42d..59a58633 100644 --- a/russh/src/kex/mod.rs +++ b/russh/src/kex/mod.rs @@ -51,6 +51,7 @@ impl Debug for dyn KexAlgorithm + Send { pub(crate) trait KexAlgorithm { fn skip_exchange(&self) -> bool; + #[cfg_attr(target_arch = "wasm32", allow(dead_code))] fn server_dh(&mut self, exchange: &mut Exchange, payload: &[u8]) -> Result<(), crate::Error>; fn client_dh( diff --git a/russh/src/lib.rs b/russh/src/lib.rs index 2605bf67..0e667ea2 100644 --- a/russh/src/lib.rs +++ b/russh/src/lib.rs @@ -150,6 +150,7 @@ mod parsing; mod session; /// Server side of this library. +#[cfg(not(target_arch = "wasm32"))] pub mod server; /// Client side of this library. @@ -293,7 +294,11 @@ pub enum Error { Decompress(#[from] flate2::DecompressError), #[error(transparent)] - Join(#[from] tokio::task::JoinError), + Join(#[from] russh_util::runtime::JoinError), + + #[cfg(not(target_arch = "wasm32"))] + #[error(transparent)] + TokioJoin(#[from] tokio::task::JoinError), #[error(transparent)] #[cfg(feature = "openssl")] @@ -518,6 +523,7 @@ pub(crate) struct ChannelParams { sender_maximum_packet_size: u32, /// Has the other side confirmed the channel? pub confirmed: bool, + #[cfg_attr(target_arch = "wasm32", allow(dead_code))] wants_reply: bool, pending_data: std::collections::VecDeque<(CryptoVec, Option, usize)>, pending_eof: bool, diff --git a/russh/src/msg.rs b/russh/src/msg.rs index f6bb67bc..f8fe1a58 100644 --- a/russh/src/msg.rs +++ b/russh/src/msg.rs @@ -13,6 +13,10 @@ // limitations under the License. // // https://tools.ietf.org/html/rfc4253#section-12 + +#[cfg(not(target_arch = "wasm32"))] +pub use server::*; + pub const DISCONNECT: u8 = 1; #[allow(dead_code)] pub const IGNORE: u8 = 2; @@ -21,7 +25,6 @@ pub const UNIMPLEMENTED: u8 = 3; #[allow(dead_code)] pub const DEBUG: u8 = 4; -pub const SERVICE_REQUEST: u8 = 5; pub const SERVICE_ACCEPT: u8 = 6; pub const EXT_INFO: u8 = 7; pub const KEXINIT: u8 = 20; @@ -36,10 +39,7 @@ pub const USERAUTH_REQUEST: u8 = 50; pub const USERAUTH_FAILURE: u8 = 51; pub const USERAUTH_SUCCESS: u8 = 52; pub const USERAUTH_BANNER: u8 = 53; -pub const USERAUTH_PK_OK: u8 = 60; -// https://tools.ietf.org/html/rfc4256#section-5 -pub const USERAUTH_INFO_REQUEST: u8 = 60; pub const USERAUTH_INFO_RESPONSE: u8 = 61; // some numbers have same meaning @@ -62,9 +62,17 @@ pub const CHANNEL_REQUEST: u8 = 98; pub const CHANNEL_SUCCESS: u8 = 99; pub const CHANNEL_FAILURE: u8 = 100; -pub const SSH_OPEN_ADMINISTRATIVELY_PROHIBITED: u8 = 1; #[allow(dead_code)] pub const SSH_OPEN_CONNECT_FAILED: u8 = 2; pub const SSH_OPEN_UNKNOWN_CHANNEL_TYPE: u8 = 3; #[allow(dead_code)] pub const SSH_OPEN_RESOURCE_SHORTAGE: u8 = 4; + +#[cfg(not(target_arch = "wasm32"))] +pub mod server { + // https://tools.ietf.org/html/rfc4256#section-5 + pub const USERAUTH_INFO_REQUEST: u8 = 60; + pub const USERAUTH_PK_OK: u8 = 60; + pub const SERVICE_REQUEST: u8 = 5; + pub const SSH_OPEN_ADMINISTRATIVELY_PROHIBITED: u8 = 1; +} diff --git a/russh/src/negotiation.rs b/russh/src/negotiation.rs index c6726c96..5931168d 100644 --- a/russh/src/negotiation.rs +++ b/russh/src/negotiation.rs @@ -23,6 +23,7 @@ use crate::kex::{EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT, EXTENSION_OPENSSH_STRIC use crate::keys::encoding::{Encoding, Reader}; use crate::keys::key; use crate::keys::key::{KeyPair, PublicKey}; +#[cfg(not(target_arch = "wasm32"))] use crate::server::Config; use crate::{cipher, compression, kex, mac, msg, AlgorithmKind, CryptoVec, Error}; @@ -378,6 +379,7 @@ impl Select for Client { } } +#[cfg(not(target_arch = "wasm32"))] pub fn write_kex( prefs: &Preferred, buf: &mut CryptoVec, @@ -432,3 +434,38 @@ pub fn write_kex( buf.extend(&[0, 0, 0, 0]); // reserved Ok(()) } + +#[cfg(target_arch = "wasm32")] +pub fn write_kex(prefs: &Preferred, buf: &mut CryptoVec) -> Result<(), Error> { + // buf.clear(); + buf.push(msg::KEXINIT); + + let mut cookie = [0; 16]; + rand::thread_rng().fill_bytes(&mut cookie); + + buf.extend(&cookie); // cookie + buf.extend_list(prefs.kex.iter().filter(|k| { + !({ + [ + crate::kex::EXTENSION_SUPPORT_AS_SERVER, + crate::kex::EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER, + ] + }) + .contains(*k) + })); // kex algo + + buf.extend_list(prefs.cipher.iter()); // cipher client to server + buf.extend_list(prefs.cipher.iter()); // cipher server to client + + buf.extend_list(prefs.mac.iter()); // mac client to server + buf.extend_list(prefs.mac.iter()); // mac server to client + buf.extend_list(prefs.compression.iter()); // compress client to server + buf.extend_list(prefs.compression.iter()); // compress server to client + + buf.write_empty_list(); // languages client to server + buf.write_empty_list(); // languagesserver to client + + buf.push(0); // doesn't follow + buf.extend(&[0, 0, 0, 0]); // reserved + Ok(()) +} diff --git a/russh/src/session.rs b/russh/src/session.rs index 0a1f633b..f901aef4 100644 --- a/russh/src/session.rs +++ b/russh/src/session.rs @@ -45,7 +45,7 @@ pub(crate) struct Encrypted { pub last_channel_id: Wrapping, pub write: CryptoVec, pub write_cursor: usize, - pub last_rekey: std::time::Instant, + pub last_rekey: russh_util::time::Instant, pub server_compression: crate::compression::Compression, pub client_compression: crate::compression::Compression, pub compress: crate::compression::Compress, @@ -59,6 +59,7 @@ pub(crate) struct CommonSession { pub config: Config, pub encrypted: Option, pub auth_method: Option, + #[cfg_attr(target_arch = "wasm32", allow(dead_code))] pub(crate) auth_attempts: usize, pub write_buffer: SSHBuffer, pub kex: Option, @@ -125,7 +126,7 @@ impl CommonSession { last_channel_id: Wrapping(1), write: CryptoVec::new(), write_cursor: 0, - last_rekey: std::time::Instant::now(), + last_rekey: russh_util::time::Instant::now(), server_compression: newkeys.names.server_compression, client_compression: newkeys.names.client_compression, compress: crate::compression::Compress::None, @@ -157,6 +158,7 @@ impl CommonSession { } /// Send a single byte message onto the channel. + #[cfg(not(target_arch = "wasm32"))] pub fn byte(&mut self, channel: ChannelId, msg: u8) { if let Some(ref mut enc) = self.encrypted { enc.byte(channel, msg) @@ -425,7 +427,7 @@ impl Encrypted { return Ok(false); } - let now = std::time::Instant::now(); + let now = russh_util::time::Instant::now(); let dur = now.duration_since(self.last_rekey); Ok(write_buffer.bytes >= limits.rekey_write_limit || dur >= limits.rekey_time_limit) } @@ -502,6 +504,7 @@ pub(crate) enum Kex { Init(KexInit), /// Algorithms have been determined, the DH algorithm should run. + #[cfg_attr(target_arch = "wasm32", allow(dead_code))] Dh(KexDh), /// The kex has run. @@ -551,6 +554,7 @@ impl KexInit { } #[derive(Debug)] +#[cfg_attr(target_arch = "wasm32", allow(dead_code))] pub(crate) struct KexDh { pub exchange: Exchange, pub names: negotiation::Names, From d13a45ad960d9f2a53907713cfb693cff276abc1 Mon Sep 17 00:00:00 2001 From: Eugene Date: Fri, 20 Sep 2024 21:50:13 +0200 Subject: [PATCH 09/15] Update rust.yml --- .github/workflows/rust.yml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 3174f403..ca5b274c 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -27,6 +27,18 @@ jobs: with: package: russh + Build-WASM: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Install target + run: rustup target add wasm32-wasip1 + + - name: Build (all features enabled) + run: cargo build --verbose --target wasm32-wasip1 -p russh + Formatting: runs-on: ubuntu-latest From 6f00a782392e8beec4a54a53e2965bb14afc4ab1 Mon Sep 17 00:00:00 2001 From: Eugene Date: Fri, 20 Sep 2024 21:50:21 +0200 Subject: [PATCH 10/15] panic lints --- README.md | 1 + cryptovec/src/platform/unix.rs | 2 ++ 2 files changed, 3 insertions(+) diff --git a/README.md b/README.md index 68655d5b..f3e433e4 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,7 @@ This is a fork of [Thrussh](https://nest.pijul.com/pijul/thrussh) by Pierre-Éti ### Panics * When the Rust allocator fails to allocate memory during a CryptoVec being resized. +* When `mlock`/`munlock` fails to protect sensitive data in memory. ### Unsafe code diff --git a/cryptovec/src/platform/unix.rs b/cryptovec/src/platform/unix.rs index e82feca3..0f7ed9e5 100644 --- a/cryptovec/src/platform/unix.rs +++ b/cryptovec/src/platform/unix.rs @@ -3,6 +3,7 @@ use libc::c_void; /// Unlock memory on drop for Unix-based systems. pub fn munlock(ptr: *const u8, len: usize) { unsafe { + #[allow(clippy::panic)] if libc::munlock(ptr as *const c_void, len) != 0 { panic!("Failed to unlock memory"); } @@ -11,6 +12,7 @@ pub fn munlock(ptr: *const u8, len: usize) { pub fn mlock(ptr: *const u8, len: usize) { unsafe { + #[allow(clippy::panic)] if libc::mlock(ptr as *const c_void, len) != 0 { panic!("Failed to lock memory"); } From 3619bf0377be2d08a6ba780012c3437ff5fa82f9 Mon Sep 17 00:00:00 2001 From: Eugene Date: Fri, 20 Sep 2024 21:50:26 +0200 Subject: [PATCH 11/15] lint --- russh-keys/src/lib.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/russh-keys/src/lib.rs b/russh-keys/src/lib.rs index f13f3ec3..edf52849 100644 --- a/russh-keys/src/lib.rs +++ b/russh-keys/src/lib.rs @@ -353,9 +353,6 @@ fn is_base64_char(c: char) -> bool { #[cfg(test)] mod test { - use std::fs::File; - use std::io::Write; - #[cfg(unix)] use futures::Future; use log::debug; From e17ee82e321e08b811609c638b88601486781edc Mon Sep 17 00:00:00 2001 From: Eugene Date: Fri, 20 Sep 2024 21:50:35 +0200 Subject: [PATCH 12/15] simplify spawn runtime --- russh-util/src/runtime.rs | 134 +++++++++++++++----------------------- 1 file changed, 53 insertions(+), 81 deletions(-) diff --git a/russh-util/src/runtime.rs b/russh-util/src/runtime.rs index 348e0aad..5c5230ca 100644 --- a/russh-util/src/runtime.rs +++ b/russh-util/src/runtime.rs @@ -1,110 +1,82 @@ -#[cfg(not(target_arch = "wasm32"))] -pub use native::*; -#[cfg(target_arch = "wasm32")] -pub use wasm::*; +use std::io::ErrorKind; +use std::ops::Deref; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; #[derive(Debug)] -pub struct JoinError { - #[cfg(not(target_arch = "wasm32"))] - inner: tokio::task::JoinError, - #[cfg(target_arch = "wasm32")] - inner: tokio::sync::oneshot::error::RecvError, +pub struct JoinError(Box); + +impl Deref for JoinError { + type Target = dyn std::error::Error + Send + Sync; + + fn deref(&self) -> &Self::Target { + &*self.0 + } } impl std::fmt::Display for JoinError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "JoinError: {}", self.inner) + self.0.fmt(f) } } -impl std::error::Error for JoinError {} +impl Default for JoinError { + fn default() -> Self { + Self(Box::new(std::io::Error::new( + ErrorKind::Other, + "aborted".to_string(), + ))) + } +} pub struct JoinHandle where T: Send, { - #[cfg(target_arch = "wasm32")] handle: tokio::sync::oneshot::Receiver, - #[cfg(not(target_arch = "wasm32"))] - handle: tokio::task::JoinHandle, } #[cfg(target_arch = "wasm32")] -pub mod wasm { - - use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, +macro_rules! spawn_impl { + ($fn:expr) => { + wasm_bindgen_futures::spawn_local($fn) }; - - use crate::runtime::{JoinError, JoinHandle}; - - pub fn spawn(future: F) -> JoinHandle - where - F: Future + 'static + Send, - T: Send + 'static, - { - let (sender, receiver) = tokio::sync::oneshot::channel(); - wasm_bindgen_futures::spawn_local(async { - let result = future.await; - let result = sender.send(result); - if result.is_err() { - panic!("Failed to send result to receiver"); - } - }); - - JoinHandle { handle: receiver } - } - - impl Future for JoinHandle - where - T: Send, - { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match Pin::new(&mut self.handle).poll(cx) { - Poll::Ready(Ok(val)) => Poll::Ready(Ok(val)), - Poll::Ready(Err(e)) => Poll::Ready(Err(JoinError { inner: e })), - Poll::Pending => Poll::Pending, - } - } - } } #[cfg(not(target_arch = "wasm32"))] -pub mod native { - - use crate::runtime::{JoinError, JoinHandle}; - - use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, +macro_rules! spawn_impl { + ($fn:expr) => { + tokio::spawn($fn) }; +} - pub fn spawn(future: F) -> JoinHandle - where - F: Future + 'static + Send, - T: Send + 'static, - { - let handle = tokio::spawn(future); - JoinHandle { handle } - } +pub fn spawn(future: F) -> JoinHandle +where + F: Future + 'static + Send, + T: Send + 'static, +{ + let (sender, receiver) = tokio::sync::oneshot::channel(); + spawn_impl!(async { + let result = future.await; + let _ = sender.send(result); + }); + JoinHandle { handle: receiver } +} - impl Future for JoinHandle - where - T: Send, - { - type Output = Result; +impl Future for JoinHandle +where + T: Send, +{ + type Output = Result; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match Pin::new(&mut self.handle).poll(cx) { - Poll::Ready(Ok(val)) => Poll::Ready(Ok(val)), - Poll::Ready(Err(e)) => Poll::Ready(Err(JoinError { inner: e })), - Poll::Pending => Poll::Pending, - } + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match Pin::new(&mut self.handle).poll(cx) { + Poll::Ready(Ok(val)) => Poll::Ready(Ok(val)), + Poll::Ready(Err(_)) => Poll::Ready(Err(JoinError::default())), + Poll::Pending => Poll::Pending, } } } From dd4e1eb520a151e80a059b2cabbd6ecb96a78e7a Mon Sep 17 00:00:00 2001 From: Eugene Date: Fri, 20 Sep 2024 21:53:56 +0200 Subject: [PATCH 13/15] un-gate agent impl --- russh-keys/Cargo.toml | 6 +++++- russh-keys/src/agent/server.rs | 4 ++-- russh-keys/src/lib.rs | 1 - russh/src/auth.rs | 4 ++-- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/russh-keys/Cargo.toml b/russh-keys/Cargo.toml index aca59e5f..9e027afc 100644 --- a/russh-keys/Cargo.toml +++ b/russh-keys/Cargo.toml @@ -43,6 +43,7 @@ rand = { workspace = true } rand_core = { version = "0.6.4", features = ["std"] } rsa = "0.9" russh-cryptovec = { version = "0.7.0", path = "../cryptovec" } +russh-util = { version = "0.1.0", path = "../russh-util" } sec1 = { version = "0.7", features = ["pkcs8"] } serde = { version = "1.0", features = ["derive"] } sha1 = { workspace = true } @@ -55,6 +56,10 @@ typenum = "1.17" yasna = { version = "0.5.0", features = ["bit-vec", "num-bigint"], optional = true } zeroize = "1.7" getrandom = { version = "0.2.15", features = ["js"] } +tokio = { workspace = true, features = [ + "io-util", + "time", +] } [features] vendored-openssl = ["openssl", "openssl/vendored"] @@ -81,4 +86,3 @@ tokio = { workspace = true, features = ["test-util", "macros", "process"] } [package.metadata.docs.rs] features = ["openssl"] - diff --git a/russh-keys/src/agent/server.rs b/russh-keys/src/agent/server.rs index be89509a..e52306a5 100644 --- a/russh-keys/src/agent/server.rs +++ b/russh-keys/src/agent/server.rs @@ -66,7 +66,7 @@ where while let Some(Ok(stream)) = listener.next().await { let mut buf = CryptoVec::new(); buf.resize(4); - tokio::spawn( + russh_util::runtime::spawn( (Connection { lock: lock.clone(), keys: keys.clone(), @@ -276,7 +276,7 @@ impl Signer +impl Signer for russh_keys::agent::client::AgentClient { type Error = AgentAuthError; From 6811f1d27b2fe5a0add70fb46a08be8a70af7c9c Mon Sep 17 00:00:00 2001 From: Eugene Date: Fri, 20 Sep 2024 22:03:29 +0200 Subject: [PATCH 14/15] cleanup --- russh-util/src/runtime.rs | 25 ++++------------------- russh/src/client/kex.rs | 1 - russh/src/lib.rs | 4 ---- russh/src/msg.rs | 2 +- russh/src/negotiation.rs | 42 ++++++--------------------------------- russh/src/server/mod.rs | 6 +++--- 6 files changed, 14 insertions(+), 66 deletions(-) diff --git a/russh-util/src/runtime.rs b/russh-util/src/runtime.rs index 5c5230ca..d42183d7 100644 --- a/russh-util/src/runtime.rs +++ b/russh-util/src/runtime.rs @@ -1,5 +1,3 @@ -use std::io::ErrorKind; -use std::ops::Deref; use std::{ future::Future, pin::Pin, @@ -7,30 +5,15 @@ use std::{ }; #[derive(Debug)] -pub struct JoinError(Box); - -impl Deref for JoinError { - type Target = dyn std::error::Error + Send + Sync; - - fn deref(&self) -> &Self::Target { - &*self.0 - } -} +pub struct JoinError; impl std::fmt::Display for JoinError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.0.fmt(f) + write!(f, "JoinError") } } -impl Default for JoinError { - fn default() -> Self { - Self(Box::new(std::io::Error::new( - ErrorKind::Other, - "aborted".to_string(), - ))) - } -} +impl std::error::Error for JoinError {} pub struct JoinHandle where @@ -75,7 +58,7 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match Pin::new(&mut self.handle).poll(cx) { Poll::Ready(Ok(val)) => Poll::Ready(Ok(val)), - Poll::Ready(Err(_)) => Poll::Ready(Err(JoinError::default())), + Poll::Ready(Err(_)) => Poll::Ready(Err(JoinError)), Poll::Pending => Poll::Pending, } } diff --git a/russh/src/client/kex.rs b/russh/src/client/kex.rs index e9b469c9..8f73b7ed 100644 --- a/russh/src/client/kex.rs +++ b/russh/src/client/kex.rs @@ -72,7 +72,6 @@ impl KexInit { negotiation::write_kex( &config.preferred, &mut self.exchange.client_kex_init, - #[cfg(not(target_arch = "wasm32"))] None, )?; self.sent = true; diff --git a/russh/src/lib.rs b/russh/src/lib.rs index 0e667ea2..fed29b36 100644 --- a/russh/src/lib.rs +++ b/russh/src/lib.rs @@ -296,10 +296,6 @@ pub enum Error { #[error(transparent)] Join(#[from] russh_util::runtime::JoinError), - #[cfg(not(target_arch = "wasm32"))] - #[error(transparent)] - TokioJoin(#[from] tokio::task::JoinError), - #[error(transparent)] #[cfg(feature = "openssl")] Openssl(#[from] openssl::error::ErrorStack), diff --git a/russh/src/msg.rs b/russh/src/msg.rs index f8fe1a58..52311a82 100644 --- a/russh/src/msg.rs +++ b/russh/src/msg.rs @@ -69,7 +69,7 @@ pub const SSH_OPEN_UNKNOWN_CHANNEL_TYPE: u8 = 3; pub const SSH_OPEN_RESOURCE_SHORTAGE: u8 = 4; #[cfg(not(target_arch = "wasm32"))] -pub mod server { +mod server { // https://tools.ietf.org/html/rfc4256#section-5 pub const USERAUTH_INFO_REQUEST: u8 = 60; pub const USERAUTH_PK_OK: u8 = 60; diff --git a/russh/src/negotiation.rs b/russh/src/negotiation.rs index 5931168d..2bec9f4a 100644 --- a/russh/src/negotiation.rs +++ b/russh/src/negotiation.rs @@ -27,6 +27,12 @@ use crate::keys::key::{KeyPair, PublicKey}; use crate::server::Config; use crate::{cipher, compression, kex, mac, msg, AlgorithmKind, CryptoVec, Error}; +#[cfg(target_arch = "wasm32")] +/// WASM-only stub +pub struct Config { + keys: Vec, +} + #[derive(Debug, Clone)] pub struct Names { pub kex: kex::Name, @@ -379,7 +385,6 @@ impl Select for Client { } } -#[cfg(not(target_arch = "wasm32"))] pub fn write_kex( prefs: &Preferred, buf: &mut CryptoVec, @@ -434,38 +439,3 @@ pub fn write_kex( buf.extend(&[0, 0, 0, 0]); // reserved Ok(()) } - -#[cfg(target_arch = "wasm32")] -pub fn write_kex(prefs: &Preferred, buf: &mut CryptoVec) -> Result<(), Error> { - // buf.clear(); - buf.push(msg::KEXINIT); - - let mut cookie = [0; 16]; - rand::thread_rng().fill_bytes(&mut cookie); - - buf.extend(&cookie); // cookie - buf.extend_list(prefs.kex.iter().filter(|k| { - !({ - [ - crate::kex::EXTENSION_SUPPORT_AS_SERVER, - crate::kex::EXTENSION_OPENSSH_STRICT_KEX_AS_SERVER, - ] - }) - .contains(*k) - })); // kex algo - - buf.extend_list(prefs.cipher.iter()); // cipher client to server - buf.extend_list(prefs.cipher.iter()); // cipher server to client - - buf.extend_list(prefs.mac.iter()); // mac client to server - buf.extend_list(prefs.mac.iter()); // mac server to client - buf.extend_list(prefs.compression.iter()); // compress client to server - buf.extend_list(prefs.compression.iter()); // compress server to client - - buf.write_empty_list(); // languages client to server - buf.write_empty_list(); // languagesserver to client - - buf.push(0); // doesn't follow - buf.extend(&[0, 0, 0, 0]); // reserved - Ok(()) -} diff --git a/russh/src/server/mod.rs b/russh/src/server/mod.rs index a5e21508..e345d86a 100644 --- a/russh/src/server/mod.rs +++ b/russh/src/server/mod.rs @@ -38,10 +38,10 @@ use std::task::{Context, Poll}; use async_trait::async_trait; use futures::future::Future; use log::{debug, error}; +use russh_util::runtime::JoinHandle; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::{TcpListener, ToSocketAddrs}; use tokio::pin; -use tokio::task::JoinHandle; use crate::cipher::{clear, CipherPair, OpeningKey}; use crate::keys::key; @@ -577,7 +577,7 @@ pub trait Server { let config = config.clone(); let handler = self.new_client(socket.peer_addr().ok()); let error_tx = error_tx.clone(); - tokio::spawn(async move { + russh_util::runtime::spawn(async move { let session = match run_stream(config, socket, handler).await { Ok(s) => s, Err(e) => { @@ -698,7 +698,7 @@ where channels: HashMap::new(), open_global_requests: VecDeque::new(), }; - let join = tokio::spawn(session.run(stream, handler)); + let join = russh_util::runtime::spawn(session.run(stream, handler)); Ok(RunningSession { handle, join }) } From d48619c25332993775ed44f79558fc18ddbfdec9 Mon Sep 17 00:00:00 2001 From: Eugene Date: Fri, 20 Sep 2024 22:08:20 +0200 Subject: [PATCH 15/15] fmt --- cryptovec/src/cryptovec.rs | 3 ++- cryptovec/src/platform/mod.rs | 6 ++---- cryptovec/src/platform/windows.rs | 6 +++--- russh-keys/src/known_hosts.rs | 5 +++-- russh-util/src/runtime.rs | 8 +++----- russh/src/client/kex.rs | 6 +----- 6 files changed, 14 insertions(+), 20 deletions(-) diff --git a/cryptovec/src/cryptovec.rs b/cryptovec/src/cryptovec.rs index 21709673..1821ddab 100644 --- a/cryptovec/src/cryptovec.rs +++ b/cryptovec/src/cryptovec.rs @@ -1,6 +1,7 @@ -use crate::platform::{self, memcpy, memset, mlock, munlock}; use std::ops::{Deref, DerefMut, Index, IndexMut, Range, RangeFrom, RangeFull, RangeTo}; +use crate::platform::{self, memcpy, memset, mlock, munlock}; + /// A buffer which zeroes its memory on `.clear()`, `.resize()`, and /// reallocations, to avoid copying secrets around. #[derive(Debug)] diff --git a/cryptovec/src/platform/mod.rs b/cryptovec/src/platform/mod.rs index 63b715a7..78fbbffe 100644 --- a/cryptovec/src/platform/mod.rs +++ b/cryptovec/src/platform/mod.rs @@ -9,15 +9,13 @@ mod unix; mod wasm; // Re-export functions based on the platform -#[cfg(windows)] -pub use windows::{memcpy, memset, mlock, munlock}; - #[cfg(not(windows))] #[cfg(not(target_arch = "wasm32"))] pub use unix::{memcpy, memset, mlock, munlock}; - #[cfg(target_arch = "wasm32")] pub use wasm::{memcpy, memset, mlock, munlock}; +#[cfg(windows)] +pub use windows::{memcpy, memset, mlock, munlock}; #[cfg(test)] mod tests { diff --git a/cryptovec/src/platform/windows.rs b/cryptovec/src/platform/windows.rs index 968a3e5d..72e0b8f2 100644 --- a/cryptovec/src/platform/windows.rs +++ b/cryptovec/src/platform/windows.rs @@ -1,7 +1,7 @@ -use winapi::shared::{basetsd::SIZE_T, minwindef::LPVOID}; -use winapi::um::memoryapi::{VirtualLock, VirtualUnlock}; - use libc::c_void; +use winapi::shared::basetsd::SIZE_T; +use winapi::shared::minwindef::LPVOID; +use winapi::um::memoryapi::{VirtualLock, VirtualUnlock}; /// Unlock memory on drop for Windows. pub fn munlock(ptr: *const u8, len: usize) { diff --git a/russh-keys/src/known_hosts.rs b/russh-keys/src/known_hosts.rs index ee9d01f7..2355e970 100644 --- a/russh-keys/src/known_hosts.rs +++ b/russh-keys/src/known_hosts.rs @@ -3,12 +3,13 @@ use std::fs::{File, OpenOptions}; use std::io::{BufRead, BufReader, Read, Seek, SeekFrom, Write}; use std::path::{Path, PathBuf}; -use crate::{key, Error, PublicKeyBase64}; use data_encoding::BASE64_MIME; use hmac::{Hmac, Mac}; use log::debug; use sha1::Sha1; +use crate::{key, Error, PublicKeyBase64}; + /// Check whether the host is known, from its standard location. pub fn check_known_hosts(host: &str, port: u16, pubkey: &key::PublicKey) -> Result { check_known_hosts_path(host, port, pubkey, known_hosts_path()?) @@ -188,10 +189,10 @@ pub fn write_public_key_base64( #[cfg(test)] mod test { - use crate::parse_public_key_base64; use std::fs::File; use super::*; + use crate::parse_public_key_base64; #[test] fn test_check_known_hosts() { diff --git a/russh-util/src/runtime.rs b/russh-util/src/runtime.rs index d42183d7..ad6d280a 100644 --- a/russh-util/src/runtime.rs +++ b/russh-util/src/runtime.rs @@ -1,8 +1,6 @@ -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; #[derive(Debug)] pub struct JoinError; diff --git a/russh/src/client/kex.rs b/russh/src/client/kex.rs index 8f73b7ed..92de368a 100644 --- a/russh/src/client/kex.rs +++ b/russh/src/client/kex.rs @@ -69,11 +69,7 @@ impl KexInit { write_buffer: &mut SSHBuffer, ) -> Result<(), crate::Error> { self.exchange.client_kex_init.clear(); - negotiation::write_kex( - &config.preferred, - &mut self.exchange.client_kex_init, - None, - )?; + negotiation::write_kex(&config.preferred, &mut self.exchange.client_kex_init, None)?; self.sent = true; cipher.write(&self.exchange.client_kex_init, write_buffer); Ok(())