Skip to content

Commit

Permalink
added test suggestion
Browse files Browse the repository at this point in the history
Co-authored-by: Adam Li <[email protected]>
  • Loading branch information
tsbinns and adam2392 authored Sep 9, 2023
1 parent 402cfaa commit ebe0a69
Showing 1 changed file with 38 additions and 31 deletions.
69 changes: 38 additions & 31 deletions mne_connectivity/spectral/tests/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,46 +1315,53 @@ def test_multivar_save_load(tmp_path):
assert a == b


def test_spectral_connectivity_indices_maintained(tmp_path):
@pytest.mark.parametrize("method", ["coh", "plv", "pli", "wpli", "ciplv", "mic", "mim"])
@pytest.mark.parametrize("indices", [None,
(np.array([0, 1]), np.array([2, 3])),
(np.array([[0, 1]]), np.array([[2, 3]]))
])
def test_spectral_connectivity_indices_roundtrip_io(tmp_path, method, indices):
"""Test that indices values and type is maintained after saving.
If `indices` is None, `indices` in the returned connectivity object should
be None, otherwise, `indices` should be a tuple. The type of `indices` and
its values should be retained after saving and reloading.
"""
rng = np.random.RandomState(0)
n_epochs, n_chs, n_times, sfreq, f = 5, 4, 2000, 1000., 20.
n_epochs, n_chs, n_times, sfreq, f = 5, 4, 200, 100.0, 20.0
data = rng.randn(n_epochs, n_chs, n_times)
sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000)
data[:, :, 500:1500] += sig
info = create_info(n_chs, sfreq, 'eeg')
info = create_info(n_chs, sfreq, "eeg")
tmin = -1
epochs = EpochsArray(data, info, tmin=tmin)
freqs = np.arange(10, 31)
tmp_file = os.path.join(tmp_path, 'foo_mvc.nc')
tmp_file = os.path.join(tmp_path, "foo_mvc.nc")

bivar_indices = (np.array([0, 1]), np.array([2, 3]))
multivar_indices = (np.array([[0, 1]]), np.array([[2, 3]]))
indices = [None, bivar_indices, None, multivar_indices]
methods = ['coh', 'coh', 'mic', 'mic']

for this_indices, this_method in zip(indices, methods):
con_epochs = spectral_connectivity_epochs(
epochs, method=this_method, indices=this_indices, sfreq=sfreq,
fmin=10, fmax=30)
con_time = spectral_connectivity_time(
epochs, freqs, method=this_method, indices=this_indices,
sfreq=sfreq)

for con in [con_epochs, con_time]:
con.save(tmp_file)
read_con = read_connectivity(tmp_file)
if this_indices is not None:
# check indices of same type (tuples)
assert (isinstance(con.indices, tuple) and
isinstance(read_con.indices, tuple))
# check indices have same values
assert np.all(np.array(con.indices) ==
np.array(read_con.indices))
else:
assert con.indices is None and read_con.indices is None
# mutlivariate methods and bivariate methods require the right indices shape
if method in ["mic", "mim"]:
if indices is not None and indices[0].ndim == 1:
pytest.skip()
else:
if indices is not None and indices[0].ndim == 2:
pytest.skip()

# actually test the pair of method and indices defined to check the output indices
con_epochs = spectral_connectivity_epochs(
epochs, method=method, indices=indices, sfreq=sfreq, fmin=10, fmax=30
)
con_time = spectral_connectivity_time(
epochs, freqs, method=method, indices=indices, sfreq=sfreq
)

for con in [con_epochs, con_time]:
con.save(tmp_file)
read_con = read_connectivity(tmp_file)

if indices is not None:
# check indices of same type (tuples)
assert isinstance(con.indices, tuple) and isinstance(
read_con.indices, tuple
)
# check indices have same values
assert np.all(np.array(con.indices) == np.array(read_con.indices))
else:
assert con.indices is None and read_con.indices is None

0 comments on commit ebe0a69

Please sign in to comment.