Skip to content

[DRAFT] Protocol redesign #57

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: develop-eeg
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -157,5 +157,8 @@ dmypy.json
# Log folders
**/log/

# Datasets folders
**/eeg_data/

# Mac OS
.DS_Store
3 changes: 3 additions & 0 deletions benchmarks/MOABB/Commands_run_experiment.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
python run_experiments.py --hparams=hparams/MotorImagery/BNCI2014001/EEGNet.yaml --data_folder=./eeg_data --output_folder=./results/test_run --nsbj=9 --nsess=2 --seed=12346 --nruns=1 --train_mode=leave-one-session-out

python run_sweep.py --hparams hparams/MotorImagery/BNCI2014001/EEGNet.yaml --sweep_type optuna --n_trials 2 --data_folder eeg_data --output_folder results/htop/test_run5 --cached_data_folder eeg_data/cache --nsbj 9 --nsess 2 --seed 1234 --nruns 1 --eval_metric acc --eval_set test --data_iterator_name leave-one-session-out --device cuda
72 changes: 40 additions & 32 deletions benchmarks/MOABB/dataio/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from speechbrain.utils.data_pipeline import provides, takes


mne.set_log_level("ERROR")


@takes("epoch")
@provides("epoch")
def to_tensor(epoch):
Expand All @@ -24,39 +27,44 @@ def to_tensor(epoch):
cached_create_filter = cache(mne.filter.create_filter)


@takes("epoch", "info", "target_sfreq", "fmin", "fmax")
@provides("epoch", "sfreq", "target_sfreq", "fmin", "fmax")
def bandpass_resample(epoch, info, target_sfreq, fmin, fmax):
"""Bandpass filter and resample an epoch."""

bandpass = cached_create_filter(
None,
info["sfreq"],
l_freq=fmin,
h_freq=fmax,
method="fir",
fir_design="firwin",
verbose=False,
def bandpass_resample(target_sfreq, fmin, fmax):
@takes(
"epoch", "info",
)
@provides("epoch")
def _bandpass_resample(epoch, info):
"""Bandpass filter and resample an epoch."""

# Check that filter length is reasonable
filter_length = len(bandpass)
len_x = epoch.shape[-1]
if filter_length > len_x:
# TODO: These long filters result in massive performance degradation... Do we
# want to throw an error instead? This usually happens when fmin is used
logging.warning(
"filter_length (%i) is longer than the signal (%i), "
"distortion is likely. Reduce filter length or filter a longer signal.",
filter_length,
len_x,
bandpass = cached_create_filter(
None,
info["sfreq"],
l_freq=fmin,
h_freq=fmax,
method="fir",
fir_design="firwin",
verbose=False,
)

yield mne.filter.resample(
epoch,
up=target_sfreq,
down=info["sfreq"],
method="polyphase",
window=bandpass,
)
yield target_sfreq
# Check that filter length is reasonable
filter_length = len(bandpass)
len_x = epoch.shape[-1]
if filter_length > len_x:
# TODO: These long filters result in massive performance degradation... Do we
# want to throw an error instead? This usually happens when fmin is used
logging.warning(
"filter_length (%i) is longer than the signal (%i), "
"distortion is likely. Reduce filter length or filter a longer signal.",
filter_length,
len_x,
)

yield mne.filter.resample(
epoch,
up=target_sfreq,
down=info["sfreq"],
method="polyphase",
window=bandpass,
)
yield target_sfreq

return _bandpass_resample
82 changes: 79 additions & 3 deletions benchmarks/MOABB/hparams/MotorImagery/BNCI2014001/EEGNet.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
seed: 1234
__set_torchseed: !apply:torch.manual_seed [!ref <seed>]

#OVERRIDES
num_workers: 4

# DIRECTORIES
data_folder: !PLACEHOLDER #'/path/to/dataset'. The dataset will be automatically downloaded in this folder
cached_data_folder: !PLACEHOLDER #'path/to/pickled/dataset'
output_folder: !PLACEHOLDER #'path/to/results'

# DATASET HPARS
# Defining the MOABB dataset.
dataset: !new:moabb.datasets.BNCI2014001
dataset: !new:moabb.datasets.BNCI2014_001
save_prepared_dataset: True # set to True if you want to save the prepared dataset as a pkl file to load and use afterwards
data_iterator_name: !PLACEHOLDER
target_subject_idx: !PLACEHOLDER
Expand All @@ -17,7 +20,7 @@ events_to_load: null # all events will be loaded
original_sample_rate: 250 # Original sampling rate provided by dataset authors
sample_rate: 125 # Target sampling rate (Hz)
# band-pass filtering cut-off frequencies
fmin: 0.13 # @orion_step1: --fmin~"uniform(0.1, 5, precision=2)"
fmin: 1.0 # @orion_step1: --fmin~"uniform(0.1, 5, precision=2)" # note undefined when under .5
fmax: 46.0 # @orion_step1: --fmax~"uniform(20.0, 50.0, precision=3)"
n_classes: 4
# tmin, tmax respect to stimulus onset that define the interval attribute of the dataset class
Expand All @@ -39,6 +42,53 @@ C: 22
test_with: 'last' # 'last' or 'best'
test_key: "acc" # Possible opts: "loss", "f1", "auc", "acc"

# DATASET
# ─── Subject extraction helpers ──────────────────────────────────────────────
# 1) Grab the whole subject_list from the BNCI2014001 object
subject_list: !apply:getattr # → dataset.subject_list
- !ref <dataset> # first arg = the object
- subject_list # second arg = attribute name

# 2) Pick the single subject we want with operator.getitem(list, idx)
target_subject: !apply:operator.getitem
- !ref <subject_list> # the list
- !ref <target_subject_idx> # the integer index supplied on CLI

# Get target subject
#target_subject: # TODD

# Create the subjects list
subjects: [!ref <target_subject>]

# Create dataset using EpochedEEGDataset
#dataset_class: !new:dataio.datasets.EpochedEEGDataset

json_path: !apply:os.path.join [!ref <cached_data_folder>, "index.json"]
save_path: !ref <data_folder>
# dynamic items list
bandpass_resample: !apply:dataio.preprocessing.bandpass_resample
target_sfreq: !ref <sample_rate>
fmin: !ref <fmin>
fmax: !ref <fmax>

dynamic_items:
- !ref <bandpass_resample>
- !name:dataio.preprocessing.to_tensor
output_keys: ["label", "subject", "session", "epoch"]
preload: True

EEG_dataset: !apply:dataio.datasets.EpochedEEGDataset.from_moabb
dataset: !ref <dataset>
json_path: !ref <json_path>
subjects: !ref <subjects>
save_path: !ref <save_path>
dynamic_items: !ref <dynamic_items>
output_keys: !ref <output_keys>
preload: !ref <preload>
tmin: !ref <tmin>
tmax: !ref <tmax>


# METRICS
f1: !name:sklearn.metrics.f1_score
average: 'macro'
Expand All @@ -52,7 +102,7 @@ metrics:
n_train_examples: 100 # it will be replaced in the train script
# checkpoints to average
avg_models: 10 # @orion_step1: --avg_models~"uniform(1, 15,discrete=True)"
number_of_epochs: 862 # @orion_step1: --number_of_epochs~"uniform(250, 1000, discrete=True)"
number_of_epochs: 10 # @orion_step1: --number_of_epochs~"uniform(250, 1000, discrete=True)"
lr: 0.0001 # @orion_step1: --lr~"choices([0.01, 0.005, 0.001, 0.0005, 0.0001])"
# Learning rate scheduling (cyclic learning rate is used here)
max_lr: !ref <lr> # Upper bound of the cycle (max value of the lr)
Expand Down Expand Up @@ -165,3 +215,29 @@ model: !new:models.EEGNet.EEGNet
dense_max_norm: !ref <dense_max_norm>
dropout: !ref <dropout>
dense_n_neurons: !ref <n_classes>

# Search Space

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the search space concept!

I wonder if there's a way to reference the search space in the rest of the yaml file -- like:

dropout: !ref <search_space.dropout>

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might make it slightly clearer with how it works rather than having to use overrides. You might still have to reload the file each time, depending on how the reference worked exactly.

# the search space is defined as a dictionary of parameter names and a dictionary of possible values
# the values can be sampled from a uniform distribution, a discrete uniform distribution or a choice of values
search_space:
fmin:
type: uniform
min: 0.1
max: 5.0
precision: 2
dropout:
type: uniform
min: 0.0
max: 0.5
precision: 3
cnn_temporal_kernels:
type: discrete_uniform
min: 4
max: 64
batch_size_exponent:
type: discrete_uniform
min: 4
max: 6
lr:
type: choice
values: [0.01, 0.005, 0.001, 0.0005, 0.0001]
1 change: 1 addition & 0 deletions benchmarks/MOABB/models/EEGNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def forward(self, x):
x : torch.Tensor (batch, time, EEG channel, channel)
Input to convolve. 4d tensors are expected.
"""
x = x.transpose(1, 2)
x = self.conv_module(x)
x = self.dense_module(x)
return x
Loading
Loading