Skip to content

[DRAFT] Protocol redesign #57

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: develop-eeg
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarks/MOABB/Commands_run_experiment.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python run_experiments.py --hparams=hparams/MotorImagery/BNCI2014001/EEGNet.yaml --data_folder=./eeg_data --output_folder=./results/test_run --nsbj=9 --nsess=2 --seed=12346 --nruns=1 --train_mode=leave-one-session-out
111 changes: 77 additions & 34 deletions benchmarks/MOABB/dataio/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,39 +24,82 @@ def to_tensor(epoch):
cached_create_filter = cache(mne.filter.create_filter)


@takes("epoch", "info", "target_sfreq", "fmin", "fmax")
@provides("epoch", "sfreq", "target_sfreq", "fmin", "fmax")
def bandpass_resample(epoch, info, target_sfreq, fmin, fmax):
"""Bandpass filter and resample an epoch."""

bandpass = cached_create_filter(
None,
info["sfreq"],
l_freq=fmin,
h_freq=fmax,
method="fir",
fir_design="firwin",
verbose=False,
)

# Check that filter length is reasonable
filter_length = len(bandpass)
len_x = epoch.shape[-1]
if filter_length > len_x:
# TODO: These long filters result in massive performance degradation... Do we
# want to throw an error instead? This usually happens when fmin is used
logging.warning(
"filter_length (%i) is longer than the signal (%i), "
"distortion is likely. Reduce filter length or filter a longer signal.",
filter_length,
len_x,
def bandpass_resample():
@takes("epoch", "info", "target_sfreq", "fmin", "fmax")
@provides("epoch", "sfreq", "target_sfreq", "fmin", "fmax")
def _bandpass_resample(epoch, info, target_sfreq, fmin, fmax):
"""Bandpass filter and resample an epoch."""
Comment on lines +27 to +31
Copy link
Collaborator

@Drew-Wagner Drew-Wagner Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more what I had in mind as it makes it easy to configure from the hparams file. See how it is a function which creates a dynamic item with the desired parameters?

DONT commit the suggestion directly, as it is just to illustrate

Suggested change
def bandpass_resample():
@takes("epoch", "info", "target_sfreq", "fmin", "fmax")
@provides("epoch", "sfreq", "target_sfreq", "fmin", "fmax")
def _bandpass_resample(epoch, info, target_sfreq, fmin, fmax):
"""Bandpass filter and resample an epoch."""
def bandpass_resample(target_sfreq, fmin, fmax):
@takes("epoch", "info")
@provides("epoch")
def _bandpass_resample(epoch, info):
"""Bandpass filter and resample an epoch."""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will allow us to do something like this in the yaml file:

dynamic_items:
- !apply:bandpass_resample
  fmin: 0.5
  fmax: 22
  target_sfreq: 128


bandpass = cached_create_filter(
None,
info["sfreq"],
l_freq=fmin,
h_freq=fmax,
method="fir",
fir_design="firwin",
verbose=False,
)

# Check that filter length is reasonable
filter_length = len(bandpass)
len_x = epoch.shape[-1]
if filter_length > len_x:
# TODO: These long filters result in massive performance degradation... Do we
# want to throw an error instead? This usually happens when fmin is used
logging.warning(
"filter_length (%i) is longer than the signal (%i), "
"distortion is likely. Reduce filter length or filter a longer signal.",
filter_length,
len_x,
)

yield mne.filter.resample(
epoch,
up=target_sfreq,
down=info["sfreq"],
method="polyphase",
window=bandpass,
)
yield target_sfreq

return _bandpass_resample


'''def bandpass_resample(target_sfreq, fmin, fmax):
"""Create a dynamic item that bandpass filters and resamples an epoch."""

@takes("epoch", "info")
@provides("epoch")
def _bandpass_resample(epoch, info):
bandpass = cached_create_filter(
None,
info["sfreq"],
l_freq=fmin,
h_freq=fmax,
method="fir",
fir_design="firwin",
verbose=False,
)
breakpoint()
# Check that filter length is reasonable
filter_length = len(bandpass)
len_x = epoch.shape[-1]
if filter_length > len_x:
logging.warning(
"filter_length (%i) is longer than the signal (%i), "
"distortion is likely. Reduce filter length or filter a longer signal.",
filter_length,
len_x,
)

filtered = mne.filter.resample(
epoch,
up=target_sfreq,
down=info["sfreq"],
method="polyphase",
window=bandpass,
)
yield filtered

yield mne.filter.resample(
epoch,
up=target_sfreq,
down=info["sfreq"],
method="polyphase",
window=bandpass,
)
yield target_sfreq
return _bandpass_resample
'''
33 changes: 32 additions & 1 deletion benchmarks/MOABB/hparams/MotorImagery/BNCI2014001/EEGNet.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
seed: 1234
__set_torchseed: !apply:torch.manual_seed [!ref <seed>]

#OVERRIDES


# DIRECTORIES
data_folder: !PLACEHOLDER #'/path/to/dataset'. The dataset will be automatically downloaded in this folder
cached_data_folder: !PLACEHOLDER #'path/to/pickled/dataset'
Expand Down Expand Up @@ -39,6 +42,34 @@ C: 22
test_with: 'last' # 'last' or 'best'
test_key: "acc" # Possible opts: "loss", "f1", "auc", "acc"

# DATASET

# Get target subject
#target_subject: !ref <dataset>.subject_list[!ref <target_subject_idx>]#

# Create the subjects list
#subjects: [!ref <target_subject>]

# Create dataset using EpochedEEGDataset
#dataset_class: !new:dataio.datasets.EpochedEEGDataset

#json_path: !apply:os.path.join [!ref <cached_data_folder>, "index.json"]
#save_path: !ref <data_folder>
#dynamic_items:
# - !name:dataio.preprocessing.to_tensor
Comment on lines +58 to +59
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should include all preprocessing steps.

#output_keys: ["label", "subject", "session", "epoch"]
#preload: True

#from_moabb_datset: !apply: !ref <dataset_class>.from_moabb
# - !ref <dataset>
# - !ref <json_path>
# - !ref <subjects>
# - !ref <save_path>
# - !ref <dynamic_items>
# - !ref <output_keys>
# - !ref <preload>
# - !ref <tmin>
# - !ref <tmax>
# METRICS
f1: !name:sklearn.metrics.f1_score
average: 'macro'
Expand All @@ -52,7 +83,7 @@ metrics:
n_train_examples: 100 # it will be replaced in the train script
# checkpoints to average
avg_models: 10 # @orion_step1: --avg_models~"uniform(1, 15,discrete=True)"
number_of_epochs: 862 # @orion_step1: --number_of_epochs~"uniform(250, 1000, discrete=True)"
number_of_epochs: 800 # @orion_step1: --number_of_epochs~"uniform(250, 1000, discrete=True)"
lr: 0.0001 # @orion_step1: --lr~"choices([0.01, 0.005, 0.001, 0.0005, 0.0001])"
# Learning rate scheduling (cyclic learning rate is used here)
max_lr: !ref <lr> # Upper bound of the cycle (max value of the lr)
Expand Down
1 change: 1 addition & 0 deletions benchmarks/MOABB/models/EEGNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def forward(self, x):
x : torch.Tensor (batch, time, EEG channel, channel)
Input to convolve. 4d tensors are expected.
"""
x = x.transpose(1, 2)
x = self.conv_module(x)
x = self.dense_module(x)
return x
Loading
Loading