Skip to content

Commit

Permalink
updated spectral tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Sep 1, 2023
1 parent fe0fe68 commit b7fcf12
Showing 1 changed file with 46 additions and 1 deletion.
47 changes: 46 additions & 1 deletion mne_connectivity/spectral/tests/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,7 +1286,7 @@ def test_save(tmp_path):
def test_multivar_save_load(tmp_path):
"""Test saving and loading results of multivariate connectivity."""
rng = np.random.RandomState(0)
n_epochs, n_chs, n_times, sfreq, f = 10, 4, 2000, 1000., 20.
n_epochs, n_chs, n_times, sfreq, f = 5, 4, 2000, 1000., 20.
data = rng.randn(n_epochs, n_chs, n_times)
sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000)
data[:, :, 500:1500] += sig
Expand All @@ -1313,3 +1313,48 @@ def test_multivar_save_load(tmp_path):
a = repr(this_con).split('~')[0]
b = repr(read_con).split('~')[0]
assert a == b


def test_spectral_connectivity_indices_maintained(tmp_path):
"""Test that indices values and type is maintained after saving.
If `indices` is None, `indices` in the returned connectivity object should
be None, otherwise, `indices` should be a tuple. The type of `indices` and
its values should be retained after saving and reloading.
"""
rng = np.random.RandomState(0)
n_epochs, n_chs, n_times, sfreq, f = 5, 4, 2000, 1000., 20.
data = rng.randn(n_epochs, n_chs, n_times)
sig = np.sin(2 * np.pi * f * np.arange(1000) / sfreq) * np.hanning(1000)
data[:, :, 500:1500] += sig
info = create_info(n_chs, sfreq, 'eeg')
tmin = -1
epochs = EpochsArray(data, info, tmin=tmin)
freqs = np.arange(10, 31)
tmp_file = os.path.join(tmp_path, 'foo_mvc.nc')

bivar_indices = (np.array([0, 1]), np.array([2, 3]))
multivar_indices = (np.array([[0, 1]]), np.array([[2, 3]]))
indices = [None, bivar_indices, None, multivar_indices]
methods = ['coh', 'coh', 'mic', 'mic']

for this_indices, this_method in zip(indices, methods):
con_epochs = spectral_connectivity_epochs(
epochs, method=this_method, indices=this_indices, sfreq=sfreq,
fmin=10, fmax=30)
con_time = spectral_connectivity_time(
epochs, freqs, method=this_method, indices=this_indices,
sfreq=sfreq)

for con in [con_epochs, con_time]:
con.save(tmp_file)
read_con = read_connectivity(tmp_file)
if this_indices is not None:
# check indices of same type (tuples)
assert (isinstance(con.indices, tuple) and
isinstance(read_con.indices, tuple))
# check indices have same values
assert np.all(np.array(con.indices) ==
np.array(read_con.indices))
else:
assert con.indices is None and read_con.indices is None

0 comments on commit b7fcf12

Please sign in to comment.