Skip to content

Commit

Permalink
partially handle ITC; comments
Browse files Browse the repository at this point in the history
  • Loading branch information
drammock committed Apr 14, 2023
1 parent 897a1e4 commit 67d35de
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions mne/time_frequency/spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(self, inst, method, freqs, tmin, tmax, picks,
del freqs
# make sure we get taper weights if needed.
default = signature(tfr_array_multitaper).parameters['output'].default
# Default for tfr_array_multitaper
self._is_complex_mt = (method == 'multitaper' and
method_kw.get('output', default) == 'complex')
if self._is_complex_mt:
Expand Down Expand Up @@ -335,7 +334,7 @@ def __init__(self, inst, method, freqs, *, tmin=None, tmax=None,
rba = 'NaN' if reject_by_annotation else None
data = self.inst.get_data(self._picks, start, stop + 1,
reject_by_annotation=rba)
# compute the spectra
# compute the TFR
self._compute_tfr(data, n_jobs, method_kw, verbose)
# check for correct shape and bad values
self._check_values()
Expand All @@ -345,16 +344,22 @@ def __init__(self, inst, method, freqs, *, tmin=None, tmax=None,

def _compute_tfr(self, data, n_jobs, method_kw, verbose):
# kwargs are already incorporated (self._tfr_func is a partial)
result = self._tfr_func([data], self.sfreq, decim=self._decim,
n_jobs=n_jobs, verbose=verbose)
result = self._tfr_func(
data[np.newaxis], # prepend a singleton "epochs" axis
self.sfreq,
decim=self._decim,
n_jobs=n_jobs,
verbose=verbose)
# assign ._data (handling unaggregated multitaper output)
if self._is_complex_mt:
if method_kw.get('return_itc', False):
self._data, self._itc, freqs = result # TODO actually return ITC
elif self._is_complex_mt:
self._data, self._mt_weights = result
else:
self._data = result
# remove fake "epoch" dimension
self._data = self._data[0]
# TODO triaging `result` is needed to split out ITC (for epochs)

#
# output shape of TFR funcs is:
# - morlet, average → ( n_ch, n_freq, n_time)
Expand Down

0 comments on commit 67d35de

Please sign in to comment.