Skip to content

Commit

Permalink
Merge pull request #58 from qsib-cbie/user/jtrueb/fftconvolve
Browse files Browse the repository at this point in the history
Convolution and Correlation via FFT
  • Loading branch information
trueb2 authored Nov 17, 2024
2 parents a199276 + cc4c9e6 commit 924f7ae
Show file tree
Hide file tree
Showing 6 changed files with 278 additions and 2 deletions.
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"editor.formatOnSave": true,
"rust-analyzer.cargo.features": [
"std"
"std",
"plot"
],
"rust-analyzer.diagnostics.disabled": [
"mismatched-arg-count"
Expand Down
9 changes: 8 additions & 1 deletion sci-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,23 @@ all-features = true

[features]
default = ['alloc']

# Allow allocating vecs, matrices, etc.
alloc = ['nalgebra/alloc', 'nalgebra/libm', 'kalmanfilt/alloc']

# Enable FFT and standard library features
std = ['nalgebra/std', 'nalgebra/macros', 'rustfft', 'alloc']

# Enable debug plotting through python system calls
plot = ['std']

[dependencies]
num-traits = { version = "0.2.15", default-features = false }
itertools = { version = "0.13.0", default-features = false }
nalgebra = { version = "0.33.2", default-features = false }
ndarray = { version = "0.16.1", default-features = false }
lstsq = { version = "0.6.0", default-features = false }
rustfft = { version = "6.1.0", features = ["neon"], optional = true }
rustfft = { version = "6.2.0", optional = true }
kalmanfilt = { version = "0.3.0", default-features = false }
gaussfilt = { version = "0.1.3", default-features = false }

Expand Down
4 changes: 4 additions & 0 deletions sci-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@ pub mod stats;

/// Special math functions
pub mod special;

/// Debug plotting
#[cfg(feature = "plot")]
pub mod plot;
58 changes: 58 additions & 0 deletions sci-rs/src/plot.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use std::io::Write;

/// Debug utility function that will run a python script to plot the data.
///
/// This function generates a Python script to create plots of the input data and their autocorrelations.
/// It then executes the script using the system's Python interpreter.
///
/// Note: This function will open a new window to display the plots and will block execution until the window is closed.
/// It also suppresses stdout and stderr from the Python process.
pub fn python_plot(xs: Vec<&[f32]>) {
let script = format!(
r#"
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from scipy.signal import correlate
xs = {:?}
fig = plt.figure(figsize=(12, 12))
gs = gridspec.GridSpec(len(xs), 2)
for i, x in enumerate(xs):
ax = plt.subplot(gs[i, 0])
ax.plot(x, label = f"C{{i}}")
ax.legend()
ax.set_xlabel("Samples")
ax = plt.subplot(gs[i, 1])
autocorr = correlate(x, x, mode='full')
normcorr = autocorr / autocorr.max()
offsets = range(-len(x) + 1, len(x))
ax.plot(offsets, normcorr, label = f"Autocorrelation of C{{i}}")
ax.legend()
ax.set_xlabel("Lag")
plt.show()
"#,
xs
);
// Run the script with python
let script = script.as_bytes();
let mut python = match std::process::Command::new("python")
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::null()) // noisy
.stderr(std::process::Stdio::null()) // noisy
.spawn()
{
Ok(p) => p,
Err(_) => return, // Return early if python fails to start
};

if let Some(mut stdin) = python.stdin.take() {
if stdin.write_all(script).is_err() {
return; // Return early if writing fails
}
} else {
return; // Return early if we can't get stdin
}

// Wait for the python process to finish, ignoring any errors
let _ = python.wait(); // Ignore errors as this may be called from CI
}
202 changes: 202 additions & 0 deletions sci-rs/src/signal/convolve.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
use nalgebra::Complex;
use num_traits::{Float, FromPrimitive, Signed, Zero};
use rustfft::{FftNum, FftPlanner};

/// Convolution mode determines behavior near edges and output size
pub enum ConvolveMode {
/// Full convolution, output size is `in1.len() + in2.len() - 1`
Full,
/// Valid convolution, output size is `max(in1.len(), in2.len()) - min(in1.len(), in2.len()) + 1`
Valid,
/// Same convolution, output size is `in1.len()`
Same,
}

/// Performs FFT-based convolution on two slices of floating point values.
///
/// According to Python docs, this is generally much faster than direct convolution
/// for large arrays (n > ~500), but can be slower when only a few output values are needed.
/// We only implement the FFT version in Rust for now.
///
/// # Arguments
/// - `in1`: First input signal
/// - `in2`: Second input signal
/// - `mode`: Convolution mode (currently only Full is supported)
///
/// # Returns
/// A Vec containing the discrete linear convolution of `in1` with `in2`.
/// For Full mode, the output length will be `in1.len() + in2.len() - 1`.
pub fn fftconvolve<F: Float + FftNum>(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec<F> {
// Determine the size of the FFT (next power of 2 for zero-padding)
let n1 = in1.len();
let n2 = in2.len();
let n = n1 + n2 - 1;
let fft_size = n.next_power_of_two();

// Prepare input buffers as Complex<F> with zero-padding to fft_size
let mut padded_in1 = vec![Complex::zero(); fft_size];
let mut padded_in2 = vec![Complex::zero(); fft_size];

// Copy input data into zero-padded buffers
padded_in1.iter_mut().zip(in1.iter()).for_each(|(p, &v)| {
*p = Complex::new(v, F::zero());
});
padded_in2.iter_mut().zip(in2.iter()).for_each(|(p, &v)| {
*p = Complex::new(v, F::zero());
});

// Perform the FFT
let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(fft_size);
fft.process(&mut padded_in1);
fft.process(&mut padded_in2);

// Multiply element-wise in the frequency domain
let mut result_freq: Vec<Complex<F>> = padded_in1
.iter()
.zip(&padded_in2)
.map(|(a, b)| a * b)
.collect();

// Perform the inverse FFT
let ifft = planner.plan_fft_inverse(fft_size);
ifft.process(&mut result_freq);

// Take only the real part, normalize, and truncate to the original output size (n)
let fft_size = F::from(fft_size).unwrap();
let full_convolution = result_freq
.iter()
.take(n)
.map(|x| x.re / fft_size)
.collect();

// Extract the appropriate slice based on the mode
match mode {
ConvolveMode::Full => full_convolution,
ConvolveMode::Valid => {
if n1 >= n2 {
full_convolution[(n2 - 1)..(n1)].to_vec()
} else {
Vec::new()
}
}
ConvolveMode::Same => {
let start = (n2 - 1) / 2;
let end = start + n1;
full_convolution[start..end].to_vec()
}
}
}

/// Compute the convolution of two signals using FFT.
///
/// # Arguments
/// * `in1` - First input array
/// * `in2` - Second input array
///
/// # Returns
/// A Vec containing the convolution of `in1` with `in2`.
/// With Full mode, the output length will be `in1.len() + in2.len() - 1`.
pub fn convolve<F: Float + FftNum>(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec<F> {
fftconvolve(in1, in2, mode)
}

/// Compute the cross-correlation of two signals using FFT.
///
/// Cross-correlation is similar to convolution but with flipping one of the signals.
/// This function uses FFT to compute the correlation efficiently.
///
/// # Arguments
/// * `in1` - First input array
/// * `in2` - Second input array
///
/// # Returns
/// A Vec containing the cross-correlation of `in1` with `in2`.
/// With Full mode, the output length will be `in1.len() + in2.len() - 1`.
pub fn correlate<F: Float + FftNum>(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec<F> {
// For correlation, we need to reverse in2
let mut in2_rev = in2.to_vec();
in2_rev.reverse();
fftconvolve(in1, &in2_rev, mode)
}

#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;

#[test]
fn test_convolve() {
let in1 = vec![1.0, 2.0, 3.0];
let in2 = vec![4.0, 5.0, 6.0];
let result = convolve(&in1, &in2, ConvolveMode::Full);
let expected = vec![4.0, 13.0, 28.0, 27.0, 18.0];

for (a, b) in result.iter().zip(expected.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-10);
}
}

#[test]
fn test_correlate() {
let in1 = vec![1.0, 2.0, 3.0];
let in2 = vec![4.0, 5.0, 6.0];
let result = correlate(&in1, &in2, ConvolveMode::Full);
let expected = vec![6.0, 17.0, 32.0, 23.0, 12.0];
for (a, b) in result.iter().zip(expected.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-10);
}
}

#[test]
fn test_convolve_valid() {
let in1 = vec![1.0, 2.0, 3.0, 4.0];
let in2 = vec![1.0, 2.0];
let result = convolve(&in1, &in2, ConvolveMode::Valid);
let expected = vec![4.0, 7.0, 10.0];
for (a, b) in result.iter().zip(expected.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-10);
}
}

#[test]
fn test_convolve_same() {
let in1 = vec![1.0, 2.0, 3.0, 4.0];
let in2 = vec![1.0, 2.0, 1.0];
let result = convolve(&in1, &in2, ConvolveMode::Same);
let expected = vec![4.0, 8.0, 12.0, 11.0];
for (a, b) in result.iter().zip(expected.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-10);
}
}

#[test]
fn test_scipy_example() {
use rand::distributions::{Distribution, Standard};
use rand::thread_rng;

// Generate 1000 random samples from standard normal distribution
let mut rng = thread_rng();
let sig: Vec<f64> = Standard.sample_iter(&mut rng).take(1000).collect();

// Compute autocorrelation using correlate directly
let autocorr = correlate(&sig, &sig, ConvolveMode::Full);

// Basic sanity checks
assert_eq!(autocorr.len(), 1999); // Full convolution length should be 2N-1
assert!(autocorr.iter().all(|&x| !x.is_nan())); // No NaN values

// Maximum correlation should be near the middle since it's autocorrelation
let max_idx = autocorr
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0;
assert!((max_idx as i32 - 999).abs() <= 1); // Should be near index 999

let sig: Vec<f32> = sig.iter().map(|x| *x as f32).collect();
let autocorr: Vec<f32> = autocorr.iter().map(|x| *x as f32).collect();
crate::plot::python_plot(vec![&sig, &autocorr]);
}
}
4 changes: 4 additions & 0 deletions sci-rs/src/signal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ pub mod filter;
/// Signal Generation
pub mod wave;

/// Convolution
#[cfg(feature = "std")]
pub mod convolve;

/// Signal Resampling
#[cfg(feature = "std")]
pub mod resample;

0 comments on commit 924f7ae

Please sign in to comment.