diff --git a/.vscode/settings.json b/.vscode/settings.json index df89c94..dc9e82b 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,7 +1,8 @@ { "editor.formatOnSave": true, "rust-analyzer.cargo.features": [ - "std" + "std", + "plot" ], "rust-analyzer.diagnostics.disabled": [ "mismatched-arg-count" diff --git a/sci-rs/Cargo.toml b/sci-rs/Cargo.toml index 08aa74b..244a884 100644 --- a/sci-rs/Cargo.toml +++ b/sci-rs/Cargo.toml @@ -17,16 +17,23 @@ all-features = true [features] default = ['alloc'] + +# Allow allocating vecs, matrices, etc. alloc = ['nalgebra/alloc', 'nalgebra/libm', 'kalmanfilt/alloc'] + +# Enable FFT and standard library features std = ['nalgebra/std', 'nalgebra/macros', 'rustfft', 'alloc'] +# Enable debug plotting through python system calls +plot = ['std'] + [dependencies] num-traits = { version = "0.2.15", default-features = false } itertools = { version = "0.13.0", default-features = false } nalgebra = { version = "0.33.2", default-features = false } ndarray = { version = "0.16.1", default-features = false } lstsq = { version = "0.6.0", default-features = false } -rustfft = { version = "6.1.0", features = ["neon"], optional = true } +rustfft = { version = "6.2.0", optional = true } kalmanfilt = { version = "0.3.0", default-features = false } gaussfilt = { version = "0.1.3", default-features = false } diff --git a/sci-rs/src/lib.rs b/sci-rs/src/lib.rs index a257494..7e47de0 100644 --- a/sci-rs/src/lib.rs +++ b/sci-rs/src/lib.rs @@ -37,3 +37,7 @@ pub mod stats; /// Special math functions pub mod special; + +/// Debug plotting +#[cfg(feature = "plot")] +pub mod plot; diff --git a/sci-rs/src/plot.rs b/sci-rs/src/plot.rs new file mode 100644 index 0000000..02ff7e7 --- /dev/null +++ b/sci-rs/src/plot.rs @@ -0,0 +1,58 @@ +use std::io::Write; + +/// Debug utility function that will run a python script to plot the data. +/// +/// This function generates a Python script to create plots of the input data and their autocorrelations. +/// It then executes the script using the system's Python interpreter. +/// +/// Note: This function will open a new window to display the plots and will block execution until the window is closed. +/// It also suppresses stdout and stderr from the Python process. +pub fn python_plot(xs: Vec<&[f32]>) { + let script = format!( + r#" +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +from scipy.signal import correlate + +xs = {:?} +fig = plt.figure(figsize=(12, 12)) +gs = gridspec.GridSpec(len(xs), 2) +for i, x in enumerate(xs): + ax = plt.subplot(gs[i, 0]) + ax.plot(x, label = f"C{{i}}") + ax.legend() + ax.set_xlabel("Samples") + ax = plt.subplot(gs[i, 1]) + autocorr = correlate(x, x, mode='full') + normcorr = autocorr / autocorr.max() + offsets = range(-len(x) + 1, len(x)) + ax.plot(offsets, normcorr, label = f"Autocorrelation of C{{i}}") + ax.legend() + ax.set_xlabel("Lag") +plt.show() +"#, + xs + ); + // Run the script with python + let script = script.as_bytes(); + let mut python = match std::process::Command::new("python") + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::null()) // noisy + .stderr(std::process::Stdio::null()) // noisy + .spawn() + { + Ok(p) => p, + Err(_) => return, // Return early if python fails to start + }; + + if let Some(mut stdin) = python.stdin.take() { + if stdin.write_all(script).is_err() { + return; // Return early if writing fails + } + } else { + return; // Return early if we can't get stdin + } + + // Wait for the python process to finish, ignoring any errors + let _ = python.wait(); // Ignore errors as this may be called from CI +} diff --git a/sci-rs/src/signal/convolve.rs b/sci-rs/src/signal/convolve.rs new file mode 100644 index 0000000..ab925e8 --- /dev/null +++ b/sci-rs/src/signal/convolve.rs @@ -0,0 +1,202 @@ +use nalgebra::Complex; +use num_traits::{Float, FromPrimitive, Signed, Zero}; +use rustfft::{FftNum, FftPlanner}; + +/// Convolution mode determines behavior near edges and output size +pub enum ConvolveMode { + /// Full convolution, output size is `in1.len() + in2.len() - 1` + Full, + /// Valid convolution, output size is `max(in1.len(), in2.len()) - min(in1.len(), in2.len()) + 1` + Valid, + /// Same convolution, output size is `in1.len()` + Same, +} + +/// Performs FFT-based convolution on two slices of floating point values. +/// +/// According to Python docs, this is generally much faster than direct convolution +/// for large arrays (n > ~500), but can be slower when only a few output values are needed. +/// We only implement the FFT version in Rust for now. +/// +/// # Arguments +/// - `in1`: First input signal +/// - `in2`: Second input signal +/// - `mode`: Convolution mode (currently only Full is supported) +/// +/// # Returns +/// A Vec containing the discrete linear convolution of `in1` with `in2`. +/// For Full mode, the output length will be `in1.len() + in2.len() - 1`. +pub fn fftconvolve(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec { + // Determine the size of the FFT (next power of 2 for zero-padding) + let n1 = in1.len(); + let n2 = in2.len(); + let n = n1 + n2 - 1; + let fft_size = n.next_power_of_two(); + + // Prepare input buffers as Complex with zero-padding to fft_size + let mut padded_in1 = vec![Complex::zero(); fft_size]; + let mut padded_in2 = vec![Complex::zero(); fft_size]; + + // Copy input data into zero-padded buffers + padded_in1.iter_mut().zip(in1.iter()).for_each(|(p, &v)| { + *p = Complex::new(v, F::zero()); + }); + padded_in2.iter_mut().zip(in2.iter()).for_each(|(p, &v)| { + *p = Complex::new(v, F::zero()); + }); + + // Perform the FFT + let mut planner = FftPlanner::new(); + let fft = planner.plan_fft_forward(fft_size); + fft.process(&mut padded_in1); + fft.process(&mut padded_in2); + + // Multiply element-wise in the frequency domain + let mut result_freq: Vec> = padded_in1 + .iter() + .zip(&padded_in2) + .map(|(a, b)| a * b) + .collect(); + + // Perform the inverse FFT + let ifft = planner.plan_fft_inverse(fft_size); + ifft.process(&mut result_freq); + + // Take only the real part, normalize, and truncate to the original output size (n) + let fft_size = F::from(fft_size).unwrap(); + let full_convolution = result_freq + .iter() + .take(n) + .map(|x| x.re / fft_size) + .collect(); + + // Extract the appropriate slice based on the mode + match mode { + ConvolveMode::Full => full_convolution, + ConvolveMode::Valid => { + if n1 >= n2 { + full_convolution[(n2 - 1)..(n1)].to_vec() + } else { + Vec::new() + } + } + ConvolveMode::Same => { + let start = (n2 - 1) / 2; + let end = start + n1; + full_convolution[start..end].to_vec() + } + } +} + +/// Compute the convolution of two signals using FFT. +/// +/// # Arguments +/// * `in1` - First input array +/// * `in2` - Second input array +/// +/// # Returns +/// A Vec containing the convolution of `in1` with `in2`. +/// With Full mode, the output length will be `in1.len() + in2.len() - 1`. +pub fn convolve(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec { + fftconvolve(in1, in2, mode) +} + +/// Compute the cross-correlation of two signals using FFT. +/// +/// Cross-correlation is similar to convolution but with flipping one of the signals. +/// This function uses FFT to compute the correlation efficiently. +/// +/// # Arguments +/// * `in1` - First input array +/// * `in2` - Second input array +/// +/// # Returns +/// A Vec containing the cross-correlation of `in1` with `in2`. +/// With Full mode, the output length will be `in1.len() + in2.len() - 1`. +pub fn correlate(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec { + // For correlation, we need to reverse in2 + let mut in2_rev = in2.to_vec(); + in2_rev.reverse(); + fftconvolve(in1, &in2_rev, mode) +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_relative_eq; + + #[test] + fn test_convolve() { + let in1 = vec![1.0, 2.0, 3.0]; + let in2 = vec![4.0, 5.0, 6.0]; + let result = convolve(&in1, &in2, ConvolveMode::Full); + let expected = vec![4.0, 13.0, 28.0, 27.0, 18.0]; + + for (a, b) in result.iter().zip(expected.iter()) { + assert_relative_eq!(a, b, epsilon = 1e-10); + } + } + + #[test] + fn test_correlate() { + let in1 = vec![1.0, 2.0, 3.0]; + let in2 = vec![4.0, 5.0, 6.0]; + let result = correlate(&in1, &in2, ConvolveMode::Full); + let expected = vec![6.0, 17.0, 32.0, 23.0, 12.0]; + for (a, b) in result.iter().zip(expected.iter()) { + assert_relative_eq!(a, b, epsilon = 1e-10); + } + } + + #[test] + fn test_convolve_valid() { + let in1 = vec![1.0, 2.0, 3.0, 4.0]; + let in2 = vec![1.0, 2.0]; + let result = convolve(&in1, &in2, ConvolveMode::Valid); + let expected = vec![4.0, 7.0, 10.0]; + for (a, b) in result.iter().zip(expected.iter()) { + assert_relative_eq!(a, b, epsilon = 1e-10); + } + } + + #[test] + fn test_convolve_same() { + let in1 = vec![1.0, 2.0, 3.0, 4.0]; + let in2 = vec![1.0, 2.0, 1.0]; + let result = convolve(&in1, &in2, ConvolveMode::Same); + let expected = vec![4.0, 8.0, 12.0, 11.0]; + for (a, b) in result.iter().zip(expected.iter()) { + assert_relative_eq!(a, b, epsilon = 1e-10); + } + } + + #[test] + fn test_scipy_example() { + use rand::distributions::{Distribution, Standard}; + use rand::thread_rng; + + // Generate 1000 random samples from standard normal distribution + let mut rng = thread_rng(); + let sig: Vec = Standard.sample_iter(&mut rng).take(1000).collect(); + + // Compute autocorrelation using correlate directly + let autocorr = correlate(&sig, &sig, ConvolveMode::Full); + + // Basic sanity checks + assert_eq!(autocorr.len(), 1999); // Full convolution length should be 2N-1 + assert!(autocorr.iter().all(|&x| !x.is_nan())); // No NaN values + + // Maximum correlation should be near the middle since it's autocorrelation + let max_idx = autocorr + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) + .unwrap() + .0; + assert!((max_idx as i32 - 999).abs() <= 1); // Should be near index 999 + + let sig: Vec = sig.iter().map(|x| *x as f32).collect(); + let autocorr: Vec = autocorr.iter().map(|x| *x as f32).collect(); + crate::plot::python_plot(vec![&sig, &autocorr]); + } +} diff --git a/sci-rs/src/signal/mod.rs b/sci-rs/src/signal/mod.rs index 2e64722..8aede03 100644 --- a/sci-rs/src/signal/mod.rs +++ b/sci-rs/src/signal/mod.rs @@ -4,6 +4,10 @@ pub mod filter; /// Signal Generation pub mod wave; +/// Convolution +#[cfg(feature = "std")] +pub mod convolve; + /// Signal Resampling #[cfg(feature = "std")] pub mod resample;