-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #58 from qsib-cbie/user/jtrueb/fftconvolve
Convolution and Correlation via FFT
- Loading branch information
Showing
6 changed files
with
278 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
use std::io::Write; | ||
|
||
/// Debug utility function that will run a python script to plot the data. | ||
/// | ||
/// This function generates a Python script to create plots of the input data and their autocorrelations. | ||
/// It then executes the script using the system's Python interpreter. | ||
/// | ||
/// Note: This function will open a new window to display the plots and will block execution until the window is closed. | ||
/// It also suppresses stdout and stderr from the Python process. | ||
pub fn python_plot(xs: Vec<&[f32]>) { | ||
let script = format!( | ||
r#" | ||
import matplotlib.pyplot as plt | ||
import matplotlib.gridspec as gridspec | ||
from scipy.signal import correlate | ||
xs = {:?} | ||
fig = plt.figure(figsize=(12, 12)) | ||
gs = gridspec.GridSpec(len(xs), 2) | ||
for i, x in enumerate(xs): | ||
ax = plt.subplot(gs[i, 0]) | ||
ax.plot(x, label = f"C{{i}}") | ||
ax.legend() | ||
ax.set_xlabel("Samples") | ||
ax = plt.subplot(gs[i, 1]) | ||
autocorr = correlate(x, x, mode='full') | ||
normcorr = autocorr / autocorr.max() | ||
offsets = range(-len(x) + 1, len(x)) | ||
ax.plot(offsets, normcorr, label = f"Autocorrelation of C{{i}}") | ||
ax.legend() | ||
ax.set_xlabel("Lag") | ||
plt.show() | ||
"#, | ||
xs | ||
); | ||
// Run the script with python | ||
let script = script.as_bytes(); | ||
let mut python = match std::process::Command::new("python") | ||
.stdin(std::process::Stdio::piped()) | ||
.stdout(std::process::Stdio::null()) // noisy | ||
.stderr(std::process::Stdio::null()) // noisy | ||
.spawn() | ||
{ | ||
Ok(p) => p, | ||
Err(_) => return, // Return early if python fails to start | ||
}; | ||
|
||
if let Some(mut stdin) = python.stdin.take() { | ||
if stdin.write_all(script).is_err() { | ||
return; // Return early if writing fails | ||
} | ||
} else { | ||
return; // Return early if we can't get stdin | ||
} | ||
|
||
// Wait for the python process to finish, ignoring any errors | ||
let _ = python.wait(); // Ignore errors as this may be called from CI | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
use nalgebra::Complex; | ||
use num_traits::{Float, FromPrimitive, Signed, Zero}; | ||
use rustfft::{FftNum, FftPlanner}; | ||
|
||
/// Convolution mode determines behavior near edges and output size | ||
pub enum ConvolveMode { | ||
/// Full convolution, output size is `in1.len() + in2.len() - 1` | ||
Full, | ||
/// Valid convolution, output size is `max(in1.len(), in2.len()) - min(in1.len(), in2.len()) + 1` | ||
Valid, | ||
/// Same convolution, output size is `in1.len()` | ||
Same, | ||
} | ||
|
||
/// Performs FFT-based convolution on two slices of floating point values. | ||
/// | ||
/// According to Python docs, this is generally much faster than direct convolution | ||
/// for large arrays (n > ~500), but can be slower when only a few output values are needed. | ||
/// We only implement the FFT version in Rust for now. | ||
/// | ||
/// # Arguments | ||
/// - `in1`: First input signal | ||
/// - `in2`: Second input signal | ||
/// - `mode`: Convolution mode (currently only Full is supported) | ||
/// | ||
/// # Returns | ||
/// A Vec containing the discrete linear convolution of `in1` with `in2`. | ||
/// For Full mode, the output length will be `in1.len() + in2.len() - 1`. | ||
pub fn fftconvolve<F: Float + FftNum>(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec<F> { | ||
// Determine the size of the FFT (next power of 2 for zero-padding) | ||
let n1 = in1.len(); | ||
let n2 = in2.len(); | ||
let n = n1 + n2 - 1; | ||
let fft_size = n.next_power_of_two(); | ||
|
||
// Prepare input buffers as Complex<F> with zero-padding to fft_size | ||
let mut padded_in1 = vec![Complex::zero(); fft_size]; | ||
let mut padded_in2 = vec![Complex::zero(); fft_size]; | ||
|
||
// Copy input data into zero-padded buffers | ||
padded_in1.iter_mut().zip(in1.iter()).for_each(|(p, &v)| { | ||
*p = Complex::new(v, F::zero()); | ||
}); | ||
padded_in2.iter_mut().zip(in2.iter()).for_each(|(p, &v)| { | ||
*p = Complex::new(v, F::zero()); | ||
}); | ||
|
||
// Perform the FFT | ||
let mut planner = FftPlanner::new(); | ||
let fft = planner.plan_fft_forward(fft_size); | ||
fft.process(&mut padded_in1); | ||
fft.process(&mut padded_in2); | ||
|
||
// Multiply element-wise in the frequency domain | ||
let mut result_freq: Vec<Complex<F>> = padded_in1 | ||
.iter() | ||
.zip(&padded_in2) | ||
.map(|(a, b)| a * b) | ||
.collect(); | ||
|
||
// Perform the inverse FFT | ||
let ifft = planner.plan_fft_inverse(fft_size); | ||
ifft.process(&mut result_freq); | ||
|
||
// Take only the real part, normalize, and truncate to the original output size (n) | ||
let fft_size = F::from(fft_size).unwrap(); | ||
let full_convolution = result_freq | ||
.iter() | ||
.take(n) | ||
.map(|x| x.re / fft_size) | ||
.collect(); | ||
|
||
// Extract the appropriate slice based on the mode | ||
match mode { | ||
ConvolveMode::Full => full_convolution, | ||
ConvolveMode::Valid => { | ||
if n1 >= n2 { | ||
full_convolution[(n2 - 1)..(n1)].to_vec() | ||
} else { | ||
Vec::new() | ||
} | ||
} | ||
ConvolveMode::Same => { | ||
let start = (n2 - 1) / 2; | ||
let end = start + n1; | ||
full_convolution[start..end].to_vec() | ||
} | ||
} | ||
} | ||
|
||
/// Compute the convolution of two signals using FFT. | ||
/// | ||
/// # Arguments | ||
/// * `in1` - First input array | ||
/// * `in2` - Second input array | ||
/// | ||
/// # Returns | ||
/// A Vec containing the convolution of `in1` with `in2`. | ||
/// With Full mode, the output length will be `in1.len() + in2.len() - 1`. | ||
pub fn convolve<F: Float + FftNum>(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec<F> { | ||
fftconvolve(in1, in2, mode) | ||
} | ||
|
||
/// Compute the cross-correlation of two signals using FFT. | ||
/// | ||
/// Cross-correlation is similar to convolution but with flipping one of the signals. | ||
/// This function uses FFT to compute the correlation efficiently. | ||
/// | ||
/// # Arguments | ||
/// * `in1` - First input array | ||
/// * `in2` - Second input array | ||
/// | ||
/// # Returns | ||
/// A Vec containing the cross-correlation of `in1` with `in2`. | ||
/// With Full mode, the output length will be `in1.len() + in2.len() - 1`. | ||
pub fn correlate<F: Float + FftNum>(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec<F> { | ||
// For correlation, we need to reverse in2 | ||
let mut in2_rev = in2.to_vec(); | ||
in2_rev.reverse(); | ||
fftconvolve(in1, &in2_rev, mode) | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use approx::assert_relative_eq; | ||
|
||
#[test] | ||
fn test_convolve() { | ||
let in1 = vec![1.0, 2.0, 3.0]; | ||
let in2 = vec![4.0, 5.0, 6.0]; | ||
let result = convolve(&in1, &in2, ConvolveMode::Full); | ||
let expected = vec![4.0, 13.0, 28.0, 27.0, 18.0]; | ||
|
||
for (a, b) in result.iter().zip(expected.iter()) { | ||
assert_relative_eq!(a, b, epsilon = 1e-10); | ||
} | ||
} | ||
|
||
#[test] | ||
fn test_correlate() { | ||
let in1 = vec![1.0, 2.0, 3.0]; | ||
let in2 = vec![4.0, 5.0, 6.0]; | ||
let result = correlate(&in1, &in2, ConvolveMode::Full); | ||
let expected = vec![6.0, 17.0, 32.0, 23.0, 12.0]; | ||
for (a, b) in result.iter().zip(expected.iter()) { | ||
assert_relative_eq!(a, b, epsilon = 1e-10); | ||
} | ||
} | ||
|
||
#[test] | ||
fn test_convolve_valid() { | ||
let in1 = vec![1.0, 2.0, 3.0, 4.0]; | ||
let in2 = vec![1.0, 2.0]; | ||
let result = convolve(&in1, &in2, ConvolveMode::Valid); | ||
let expected = vec![4.0, 7.0, 10.0]; | ||
for (a, b) in result.iter().zip(expected.iter()) { | ||
assert_relative_eq!(a, b, epsilon = 1e-10); | ||
} | ||
} | ||
|
||
#[test] | ||
fn test_convolve_same() { | ||
let in1 = vec![1.0, 2.0, 3.0, 4.0]; | ||
let in2 = vec![1.0, 2.0, 1.0]; | ||
let result = convolve(&in1, &in2, ConvolveMode::Same); | ||
let expected = vec![4.0, 8.0, 12.0, 11.0]; | ||
for (a, b) in result.iter().zip(expected.iter()) { | ||
assert_relative_eq!(a, b, epsilon = 1e-10); | ||
} | ||
} | ||
|
||
#[test] | ||
fn test_scipy_example() { | ||
use rand::distributions::{Distribution, Standard}; | ||
use rand::thread_rng; | ||
|
||
// Generate 1000 random samples from standard normal distribution | ||
let mut rng = thread_rng(); | ||
let sig: Vec<f64> = Standard.sample_iter(&mut rng).take(1000).collect(); | ||
|
||
// Compute autocorrelation using correlate directly | ||
let autocorr = correlate(&sig, &sig, ConvolveMode::Full); | ||
|
||
// Basic sanity checks | ||
assert_eq!(autocorr.len(), 1999); // Full convolution length should be 2N-1 | ||
assert!(autocorr.iter().all(|&x| !x.is_nan())); // No NaN values | ||
|
||
// Maximum correlation should be near the middle since it's autocorrelation | ||
let max_idx = autocorr | ||
.iter() | ||
.enumerate() | ||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) | ||
.unwrap() | ||
.0; | ||
assert!((max_idx as i32 - 999).abs() <= 1); // Should be near index 999 | ||
|
||
let sig: Vec<f32> = sig.iter().map(|x| *x as f32).collect(); | ||
let autocorr: Vec<f32> = autocorr.iter().map(|x| *x as f32).collect(); | ||
crate::plot::python_plot(vec![&sig, &autocorr]); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters