Skip to content

Commit

Permalink
Add comprehensive tests for preprocessing and motion integration in P…
Browse files Browse the repository at this point in the history
…yMoBI
  • Loading branch information
snesmaeili committed Nov 20, 2024
1 parent ca3a5cd commit aac77ab
Show file tree
Hide file tree
Showing 8 changed files with 1,060 additions and 0 deletions.
79 changes: 79 additions & 0 deletions examples/complete_pipeline_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# examples/complete_pipeline_example.py

import mne
import numpy as np
from pathlib import Path


def run_complete_pipeline_example():
"""
Complete example showing the full PyMoBI pipeline functionality.
Uses sample data and demonstrates all major features.
"""
# Create configuration
config = PyMoBIConfig(
study_folder=Path("example_study"),
filename_prefix="sub-",
resample_freq=250.0,
channels_to_remove=[],
eog_channels=['EOG_l', 'EOG_r'],
ref_channel='FCz',

# Channel detection parameters
chancorr_crit=0.8,
chan_max_broken_time=0.3,
chan_detect_num_iter=20,

# AMICA parameters
filter_lowCutoffFreqAMICA=1.75,
num_models=1,
max_threads=8,
amica_autoreject=True,
amica_n_rej=10,

# ICLabel settings
iclabel_classifier='lite',
iclabel_classes=[1],
iclabel_threshold=-1,

# Final filtering
final_filter_lower_edge=0.2,

# Processing control
save_intermediate=True
)

# Load sample data
raw = load_sample_data()

# Create data container
data = PyMoBIData(raw, subject_id=1)

# Create and run pipeline
pipeline = create_default_pipeline(config)
processed_data = pipeline.run(data)

# Generate visualizations
visualizer = SignalVisualizer(config)
visualizer.plot_data_overview(processed_data)

# Generate processing report
report = ProcessingReport(config)
report.generate_report(processed_data)

return processed_data

def load_sample_data():
"""Load sample EEG data."""
sample_data_folder = mne.datasets.sample.data_path()
raw_fname = sample_data_folder / 'MEG' / 'sample' / 'sample_audvis_raw.fif'
raw = mne.io.read_raw_fif(raw_fname, preload=True)

# Keep only EEG channels
raw.pick_types(meg=False, eeg=True, eog=True)

return raw

if __name__ == '__main__':
# Run complete pipeline example
processed_data = run_complete_pipeline_example()
155 changes: 155 additions & 0 deletions tests/benchmarks/test_performance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# tests/benchmarks/test_performance.py

import pytest
import time
import mne
import numpy as np
from pathlib import Path
from pymobi import PyMoBIConfig, PyMoBIData, create_default_pipeline

def generate_test_data(duration: float, n_channels: int, sfreq: float) -> mne.io.Raw:
"""Generate synthetic EEG data for testing."""
n_samples = int(duration * sfreq)
data = np.random.randn(n_channels, n_samples)
ch_names = [f'EEG{i:03d}' for i in range(n_channels)]
ch_types = ['eeg'] * n_channels

info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
return mne.io.RawArray(data, info)

@pytest.mark.benchmark
class TestPipelinePerformance:
"""Test suite for pipeline performance benchmarks."""

@pytest.fixture
def config(self):
"""Create test configuration."""
return PyMoBIConfig(
study_folder=Path("test_data"),
resample_freq=250.0,
channels_to_remove=[],
ref_channel='FCz',
asr_cutoff=20,
use_asr=True,
num_models=1,
max_threads=4
)

@pytest.mark.parametrize("duration", [60, 300, 600])
def test_processing_speed(self, duration, config, benchmark):
"""Benchmark processing speed for different data lengths."""
# Generate test data
raw = generate_test_data(
duration=duration,
n_channels=64,
sfreq=1000.0
)

data = PyMoBIData(raw, subject_id=1)
pipeline = create_default_pipeline(config)

# Run benchmark
result = benchmark(pipeline.run, data)

# Verify result
assert isinstance(result, PyMoBIData)
assert result.mne_raw.info['sfreq'] == config.resample_freq

@pytest.mark.parametrize("n_channels", [32, 64, 128])
def test_channel_scaling(self, n_channels, config, benchmark):
"""Benchmark performance scaling with number of channels."""
# Generate test data
raw = generate_test_data(
duration=60,
n_channels=n_channels,
sfreq=1000.0
)

data = PyMoBIData(raw, subject_id=1)
pipeline = create_default_pipeline(config)

# Run benchmark
result = benchmark(pipeline.run, data)

# Verify result
assert isinstance(result, PyMoBIData)
assert len(result.mne_raw.ch_names) == n_channels

def test_memory_usage(self, config):
"""Test memory usage during processing."""
import psutil
import os

process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss

# Generate large test data
raw = generate_test_data(
duration=300,
n_channels=128,
sfreq=1000.0
)

data = PyMoBIData(raw, subject_id=1)
pipeline = create_default_pipeline(config)

# Process data
result = pipeline.run(data)

final_memory = process.memory_info().rss
memory_increase = (final_memory - initial_memory) / 1024 / 1024 # MB

# Check memory usage
assert memory_increase < 2000 # Less than 2GB increase

@pytest.mark.parametrize("n_threads", [1, 2, 4, 8])
def test_parallel_scaling(self, n_threads, config, benchmark):
"""Test processing speed with different numbers of threads."""
config.max_threads = n_threads

# Generate test data
raw = generate_test_data(
duration=60,
n_channels=64,
sfreq=1000.0
)

data = PyMoBIData(raw, subject_id=1)
pipeline = create_default_pipeline(config)

# Run benchmark
result = benchmark(pipeline.run, data)

# Verify result
assert isinstance(result, PyMoBIData)

def test_continuous_processing(self, config):
"""Test continuous processing of streaming data."""
chunk_duration = 1.0 # 1 second chunks
total_duration = 60.0 # 60 seconds total

processing_times = []

for i in range(int(total_duration / chunk_duration)):
# Generate chunk
raw = generate_test_data(
duration=chunk_duration,
n_channels=64,
sfreq=1000.0
)

data = PyMoBIData(raw, subject_id=1)
pipeline = create_default_pipeline(config)

# Process chunk and measure time
start_time = time.time()
result = pipeline.run(data)
processing_times.append(time.time() - start_time)

# Calculate statistics
mean_time = np.mean(processing_times)
std_time = np.std(processing_times)

# Check processing speed consistency
assert mean_time < chunk_duration # Processing faster than real-time
assert std_time < 0.1 # Consistent processing time
Loading

0 comments on commit aac77ab

Please sign in to comment.