diff --git a/.github/workflows/docker_tests.py b/.github/workflows/docker_tests.py index c2c0be8e7..03c5b2e9a 100644 --- a/.github/workflows/docker_tests.py +++ b/.github/workflows/docker_tests.py @@ -16,8 +16,7 @@ print("Validating image") want = np.ones(size) dists = np.sqrt( - np.linspace(-1, 1, size[0])[:, np.newaxis] ** 2 - + np.linspace(-1, 1, size[1]) ** 2 + np.linspace(-1, 1, size[0])[:, np.newaxis] ** 2 + np.linspace(-1, 1, size[1]) ** 2 ) want = (dists > 0.5).astype(float) corr = np.corrcoef(want.ravel(), data.ravel())[0, 1] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ac727ce29..1eee77252 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,6 +5,4 @@ repos: hooks: - id: ruff args: ["--fix"] - # - id: ruff-format -ci: - skip: [ruff] + - id: ruff-format diff --git a/doc/conf.py b/doc/conf.py index 41baf6592..6322e9c20 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # project-template documentation build configuration file, created by # sphinx-quickstart on Mon Jan 18 14:44:12 2016. @@ -12,18 +11,18 @@ # All configuration values have a default; values that are commented out # serve to show the default. -from datetime import datetime, timezone -import sys import os +import sys import warnings -import sphinx.util.logging -from sphinx_gallery.sorting import FileNameSortKey +from datetime import datetime, timezone -sys.path.append("../") import mne +import sphinx.util.logging from mne.fixes import _compare_version -import mne_nirs from mne.tests.test_docstring_parameters import error_ignores +from sphinx_gallery.sorting import FileNameSortKey + +import mne_nirs sphinx_logger = sphinx.util.logging.getLogger("mne") @@ -42,12 +41,12 @@ # We need to triage which date type we use so that incremental builds work # (Sphinx looks at variable changes and rewrites all files if some change) -copyright = ( +copyright = ( # noqa: A001 f'2012–{td.year}, MNE Developers. Last updated \n' # noqa: E501 '' # noqa: E501 ) if os.getenv("MNE_FULL_DATE", "false").lower() != "true": - copyright = f"2012–{td.year}, {project} Developers. Last updated locally." + copyright = f"2012–{td.year}, {project} Developers. Last updated locally." # noqa: A001, E501 # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -55,7 +54,9 @@ # # The full version, including alpha/beta/rc tags. release = mne_nirs.__version__ -sphinx_logger.info(f"Building documentation for {project} {release} ({mne_nirs.__file__})") +sphinx_logger.info( + f"Building documentation for {project} {release} ({mne_nirs.__file__})" +) # The short X.Y version. version = ".".join(release.split(".")[:2]) @@ -68,15 +69,15 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.viewcode', - 'sphinx_copybutton', - 'sphinx_gallery.gen_gallery', - 'numpydoc', - 'sphinxcontrib.bibtex', + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.viewcode", + "sphinx_copybutton", + "sphinx_gallery.gen_gallery", + "numpydoc", + "sphinxcontrib.bibtex", ] # Add any paths that contain templates here, relative to this directory. @@ -84,7 +85,7 @@ # generate autosummary even if no references. autosummary_generate = True -autodoc_default_options = {'inherited-members': None} +autodoc_default_options = {"inherited-members": None} # The suffix of source filenames. source_suffix = ".rst" @@ -101,7 +102,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build', '_templates'] +exclude_patterns = ["_build", "_templates"] # A list of ignored prefixes for module index sorting. modindex_common_prefix = ["mne_nirs."] @@ -110,7 +111,7 @@ # default_role = "py:obj" # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # NumPyDoc configuration ----------------------------------------------------- @@ -118,19 +119,40 @@ numpydoc_class_members_toctree = False numpydoc_attributes_as_param_list = True numpydoc_validate = True -numpydoc_validation_checks = {'all'} | set(error_ignores) +numpydoc_validation_checks = {"all"} | set(error_ignores) numpydoc_validation_exclude = { # set of regex # dict subclasses - r'\.clear', r'\.get$', r'\.copy$', r'\.fromkeys', r'\.items', r'\.keys', - r'\.pop', r'\.popitem', r'\.setdefault', r'\.update', r'\.values', + r"\.clear", + r"\.get$", + r"\.copy$", + r"\.fromkeys", + r"\.items", + r"\.keys", + r"\.pop", + r"\.popitem", + r"\.setdefault", + r"\.update", + r"\.values", # list subclasses - r'\.append', r'\.count', r'\.extend', r'\.index', r'\.insert', r'\.remove', - r'\.sort', + r"\.append", + r"\.count", + r"\.extend", + r"\.index", + r"\.insert", + r"\.remove", + r"\.sort", # we currently don't document these properly (probably okay) - r'\.__getitem__', r'\.__contains__', r'\.__hash__', r'\.__mul__', - r'\.__sub__', r'\.__add__', r'\.__iter__', r'\.__div__', r'\.__neg__', + r"\.__getitem__", + r"\.__contains__", + r"\.__hash__", + r"\.__mul__", + r"\.__sub__", + r"\.__add__", + r"\.__iter__", + r"\.__div__", + r"\.__neg__", # copied from sklearn - r'mne\.utils\.deprecated', + r"mne\.utils\.deprecated", } numpydoc_xref_param_type = True numpydoc_xref_aliases = { @@ -202,9 +224,9 @@ } # sphinxcontrib-bibtex -bibtex_bibfiles = ['./references.bib', './references-nirs.bib'] -bibtex_style = 'unsrt' -bibtex_footbibliography_header = '' +bibtex_bibfiles = ["./references.bib", "./references-nirs.bib"] +bibtex_style = "unsrt" +bibtex_footbibliography_header = "" nitpick_ignore_regex = [ # Type hints for undocumented types @@ -216,13 +238,13 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'pydata_sphinx_theme' +html_theme = "pydata_sphinx_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # html_theme_options = {} -switcher_version_match = 'dev' if release.endswith('dev0') else version +switcher_version_match = "dev" if release.endswith("dev0") else version html_context = { "default_mode": "auto", # next 3 are for the "edit this page" button @@ -248,11 +270,11 @@ "footer_start": ["copyright"], "analytics": dict(google_analytics_id="UA-188272121-1"), "switcher": { - "json_url": 'https://mne.tools/mne-nirs/dev/_static/versions.json', + "json_url": "https://mne.tools/mne-nirs/dev/_static/versions.json", "version_match": switcher_version_match, }, - 'pygment_light_style': 'default', - 'pygment_dark_style': 'github-dark', + "pygment_light_style": "default", + "pygment_dark_style": "github-dark", } # The name of an image file (relative to this directory) to place at the top @@ -268,7 +290,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # If true, links to the reST sources are added to the pages. html_show_sourcelink = False @@ -277,7 +299,7 @@ html_show_sphinx = False # Output file base name for HTML help builder. -htmlhelp_basename = 'mnenirsdoc' +htmlhelp_basename = "mnenirsdoc" # -- Options for LaTeX output --------------------------------------------- @@ -291,8 +313,13 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - ('index', 'project-template.tex', u'project-template Documentation', - u'Robert Luke', 'manual'), + ( + "index", + "project-template.tex", + "project-template Documentation", + "Robert Luke", + "manual", + ), ] @@ -301,8 +328,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - ('index', 'project-template', u'project-template Documentation', - [u'Robert Luke'], 1) + ("index", "project-template", "project-template Documentation", ["Robert Luke"], 1) ] @@ -312,9 +338,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'project-template', u'project-template Documentation', - u'Robert Luke', 'project-template', 'One line description of project.', - 'Miscellaneous'), + ( + "index", + "project-template", + "project-template Documentation", + "Robert Luke", + "project-template", + "One line description of project.", + "Miscellaneous", + ), ] # Example configuration for intersphinx: refer to the Python standard library. @@ -332,14 +364,14 @@ "statsmodels": ("https://www.statsmodels.org/stable", None), } -scrapers = ('matplotlib',) +scrapers = ("matplotlib",) try: mne.viz.set_3d_backend(mne.viz.get_3d_backend()) except Exception: report_scraper = None else: backend = mne.viz.get_3d_backend() - if backend in ('notebook', 'pyvistaqt'): + if backend in ("notebook", "pyvistaqt"): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) import pyvista @@ -348,15 +380,16 @@ scrapers += ( mne.gui._GUIScraper(), mne.viz._brain._BrainScraper(), - 'pyvista', + "pyvista", ) report_scraper = mne.report._ReportScraper() scrapers += (report_scraper,) del backend try: import mne_qt_browser - _min_ver = _compare_version(mne_qt_browser.__version__, '>=', '0.2') - if mne.viz.get_browser_backend() == 'qt' and _min_ver: + + _min_ver = _compare_version(mne_qt_browser.__version__, ">=", "0.2") + if mne.viz.get_browser_backend() == "qt" and _min_ver: scrapers += (mne.viz._scraper._MNEQtBrowserScraper(),) except ImportError: pass @@ -367,33 +400,32 @@ # instead of in the root." # we will store dev docs in a `dev` subdirectory and all other docs in a # directory "v" + version_str. E.g., "v0.3" -if 'dev' in version: - filepath_prefix = 'dev' +if "dev" in version: + filepath_prefix = "dev" else: - filepath_prefix = 'stable' + filepath_prefix = "stable" # sphinx-gallery configuration sphinx_gallery_conf = { - 'doc_module': 'mne_nirs', - 'backreferences_dir': os.path.join('generated'), - 'image_scrapers': scrapers, - 'reference_url': { - 'mne_nirs': None}, - 'download_all_examples': False, - 'show_memory': sys.platform.startswith("linux"), - 'within_subsection_order': FileNameSortKey, - 'junit': os.path.join('..', 'test-results', 'sphinx-gallery', 'junit.xml'), - 'binder': { - # Required keys - 'org': 'mne-tools', - 'repo': 'mne-nirs', - 'branch': 'gh-pages', # noqa: E501 Can be any branch, tag, or commit hash. Use a branch that hosts your docs. - 'binderhub_url': 'https://mybinder.org', # noqa: E501 Any URL of a binderhub deployment. Must be full URL (e.g. https://mybinder.org). - 'filepath_prefix': filepath_prefix, # noqa: E501 A prefix to prepend to any filepaths in Binder links. - 'dependencies': [ - '../requirements.txt', - '../requirements_doc.txt', + "doc_module": "mne_nirs", + "backreferences_dir": os.path.join("generated"), + "image_scrapers": scrapers, + "reference_url": {"mne_nirs": None}, + "download_all_examples": False, + "show_memory": sys.platform.startswith("linux"), + "within_subsection_order": FileNameSortKey, + "junit": os.path.join("..", "test-results", "sphinx-gallery", "junit.xml"), + "binder": { + # Required keys + "org": "mne-tools", + "repo": "mne-nirs", + "branch": "gh-pages", # noqa: E501 Can be any branch, tag, or commit hash. Use a branch that hosts your docs. + "binderhub_url": "https://mybinder.org", # noqa: E501 Any URL of a binderhub deployment. Must be full URL (e.g. https://mybinder.org). + "filepath_prefix": filepath_prefix, # noqa: E501 A prefix to prepend to any filepaths in Binder links. + "dependencies": [ + "../requirements.txt", + "../requirements_doc.txt", ], }, - 'plot_gallery': 'True', # Avoid annoying str/bool default warning + "plot_gallery": "True", # Avoid annoying str/bool default warning } diff --git a/examples/general/plot_01_data_io.py b/examples/general/plot_01_data_io.py index ed375b6c2..68eac3b94 100644 --- a/examples/general/plot_01_data_io.py +++ b/examples/general/plot_01_data_io.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- r""" .. _tut-importing-fnirs-data: @@ -166,9 +165,10 @@ # %% import os.path as op + +import mne import numpy as np import pandas as pd -import mne # sphinx_gallery_thumbnail_number = 2 @@ -184,7 +184,7 @@ # %% # Next, we will load the example CSV file. -data = pd.read_csv('fnirs.csv') +data = pd.read_csv("fnirs.csv") # %% @@ -196,15 +196,43 @@ # detector numbers and type is either ``hbo``, ``hbr`` or the # wavelength. -ch_names = ['S1_D1 hbo', 'S1_D1 hbr', 'S2_D1 hbo', 'S2_D1 hbr', - 'S3_D1 hbo', 'S3_D1 hbr', 'S4_D1 hbo', 'S4_D1 hbr', - 'S5_D2 hbo', 'S5_D2 hbr', 'S6_D2 hbo', 'S6_D2 hbr', - 'S7_D2 hbo', 'S7_D2 hbr', 'S8_D2 hbo', 'S8_D2 hbr'] -ch_types = ['hbo', 'hbr', 'hbo', 'hbr', - 'hbo', 'hbr', 'hbo', 'hbr', - 'hbo', 'hbr', 'hbo', 'hbr', - 'hbo', 'hbr', 'hbo', 'hbr'] -sfreq = 10. # in Hz +ch_names = [ + "S1_D1 hbo", + "S1_D1 hbr", + "S2_D1 hbo", + "S2_D1 hbr", + "S3_D1 hbo", + "S3_D1 hbr", + "S4_D1 hbo", + "S4_D1 hbr", + "S5_D2 hbo", + "S5_D2 hbr", + "S6_D2 hbo", + "S6_D2 hbr", + "S7_D2 hbo", + "S7_D2 hbr", + "S8_D2 hbo", + "S8_D2 hbr", +] +ch_types = [ + "hbo", + "hbr", + "hbo", + "hbr", + "hbo", + "hbr", + "hbo", + "hbr", + "hbo", + "hbr", + "hbo", + "hbr", + "hbo", + "hbr", + "hbo", + "hbr", +] +sfreq = 10.0 # in Hz # %% @@ -244,7 +272,7 @@ # fNIRS with :func:`mne.channels.read_custom_montage` by setting # ``coord_frame`` to ``'mri'``. -montage = mne.channels.make_standard_montage('artinis-octamon') +montage = mne.channels.make_standard_montage("artinis-octamon") raw.set_montage(montage) # View the position of optodes in 2D to confirm the positions are correct. @@ -257,11 +285,12 @@ # The ficiduals are marked in blue, green and red. # See :ref:`tut-source-alignment` for more details. -subjects_dir = op.join(mne.datasets.sample.data_path(), 'subjects') +subjects_dir = op.join(mne.datasets.sample.data_path(), "subjects") mne.datasets.fetch_fsaverage(subjects_dir=subjects_dir) -brain = mne.viz.Brain('fsaverage', subjects_dir=subjects_dir, - alpha=0.5, cortex='low_contrast') +brain = mne.viz.Brain( + "fsaverage", subjects_dir=subjects_dir, alpha=0.5, cortex="low_contrast" +) brain.add_head() -brain.add_sensors(raw.info, trans='fsaverage') +brain.add_sensors(raw.info, trans="fsaverage") brain.show_view(azimuth=90, elevation=90, distance=500) diff --git a/examples/general/plot_05_datasets.py b/examples/general/plot_05_datasets.py index 2884d98c7..4606fc975 100644 --- a/examples/general/plot_05_datasets.py +++ b/examples/general/plot_05_datasets.py @@ -25,9 +25,9 @@ # License: BSD (3-clause) -import mne_nirs import mne_bids.stats +import mne_nirs # %% # ******************* diff --git a/examples/general/plot_06_gowerlabs.py b/examples/general/plot_06_gowerlabs.py index 6198f2441..90de42bfe 100644 --- a/examples/general/plot_06_gowerlabs.py +++ b/examples/general/plot_06_gowerlabs.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- r""" .. _tut-gowerlabs-data: @@ -6,7 +5,7 @@ Read Gowerlabs LUMO data ======================== -`LUMO `__ is a modular, wearable, +`LUMO `__ is a modular, wearable, high-density diffuse optical tomography (HD-DOT) system produced by `Gowerlabs `__. This tutorial demonstrates how to load data from LUMO, and how to utilise 3D digitisation @@ -31,12 +30,11 @@ # License: BSD (3-clause) import os.path as op + import mne from mne.datasets.testing import data_path - from mne.viz import set_3d_view - # %% # Import Gowerlabs Example File # ----------------------------- @@ -51,7 +49,7 @@ import mne_nirs.io testing_path = data_path(download=True) -fname = op.join(testing_path, 'SNIRF', 'GowerLabs', 'lumomat-1-1-0.snirf') +fname = op.join(testing_path, "SNIRF", "GowerLabs", "lumomat-1-1-0.snirf") # %% # We can view the path to the data by calling the variable `fname`. @@ -81,7 +79,7 @@ # # We observe valid data in each channel, and note that the file includes a # number of event annotations. -# Annotations are a flexible tool to represent events in your experiment. +# Annotations are a flexible tool to represent events in your experiment. # They can also be used to annotate other useful information such as bad # segments of data, participant movements, etc. We can inspect the # annotations to ensure they match what we expect from our experiment. @@ -93,8 +91,8 @@ # The implementation of annotations varies between manufacturers. Rather # than recording the onset and duration of a stimulus condition, LUMO records # discrete event markers which have a nominal one second duration. Each -# marker can consist of an arbitrary character or string. In this sample, -# there were six `A` annotations, one `Cat` annotation, and two `Dog` +# marker can consist of an arbitrary character or string. In this sample, +# there were six `A` annotations, one `Cat` annotation, and two `Dog` # annotations. We can view the specific data for each annotation by converting # the annotations to a dataframe. @@ -123,11 +121,17 @@ # Below we see that there are three LUMO tiles, each with three sources # and four detectors. -subjects_dir = op.join(mne.datasets.sample.data_path(), 'subjects') +subjects_dir = op.join(mne.datasets.sample.data_path(), "subjects") mne.datasets.fetch_fsaverage(subjects_dir=subjects_dir) -brain = mne.viz.Brain('fsaverage', subjects_dir=subjects_dir, alpha=0.0, cortex='low_contrast', background="w") -brain.add_sensors(raw.info, trans='fsaverage', fnirs=["sources", "detectors"]) +brain = mne.viz.Brain( + "fsaverage", + subjects_dir=subjects_dir, + alpha=0.0, + cortex="low_contrast", + background="w", +) +brain.add_sensors(raw.info, trans="fsaverage", fnirs=["sources", "detectors"]) brain.show_view(azimuth=130, elevation=80, distance=700) @@ -142,10 +146,16 @@ # to coregistration. You can also use the MNE-Python # coregistration GUI :func:`mne:mne.gui.coregistration`. -plot_kwargs = dict(subjects_dir=subjects_dir, - surfaces="brain", dig=True, eeg=[], - fnirs=['sources', 'detectors'], show_axes=True, - coord_frame='head', mri_fiducials=True) +plot_kwargs = dict( + subjects_dir=subjects_dir, + surfaces="brain", + dig=True, + eeg=[], + fnirs=["sources", "detectors"], + show_axes=True, + coord_frame="head", + mri_fiducials=True, +) fig = mne.viz.plot_alignment(trans="fsaverage", subject="fsaverage", **plot_kwargs) set_3d_view(figure=fig, azimuth=90, elevation=0, distance=1) @@ -168,7 +178,9 @@ # including :ref:`mne:tut-auto-coreg`. # -fig = mne.viz.plot_alignment(raw.info, trans="fsaverage", subject="fsaverage", **plot_kwargs) +fig = mne.viz.plot_alignment( + raw.info, trans="fsaverage", subject="fsaverage", **plot_kwargs +) set_3d_view(figure=fig, azimuth=90, elevation=0, distance=1) # %% @@ -179,10 +191,14 @@ # a rotation and translation of the optode frame to minimise the # distance between the fiducials. -coreg = mne.coreg.Coregistration(raw.info, "fsaverage", subjects_dir, fiducials="estimated") -coreg.fit_fiducials(lpa_weight=1., nasion_weight=1., rpa_weight=1.) +coreg = mne.coreg.Coregistration( + raw.info, "fsaverage", subjects_dir, fiducials="estimated" +) +coreg.fit_fiducials(lpa_weight=1.0, nasion_weight=1.0, rpa_weight=1.0) -fig = mne.viz.plot_alignment(raw.info, trans=coreg.trans, subject="fsaverage", **plot_kwargs) +fig = mne.viz.plot_alignment( + raw.info, trans=coreg.trans, subject="fsaverage", **plot_kwargs +) set_3d_view(figure=fig, azimuth=90, elevation=0, distance=1) @@ -194,8 +210,10 @@ # an individualised MRI scan. You can also :func:`mne:mne.scale_mri` to scale # the generic MRI head. -brain = mne.viz.Brain('fsaverage', subjects_dir=subjects_dir, background='w', cortex='0.5', alpha=0.3) -brain.add_sensors(raw.info, trans=coreg.trans, fnirs=['sources', 'detectors']) +brain = mne.viz.Brain( + "fsaverage", subjects_dir=subjects_dir, background="w", cortex="0.5", alpha=0.3 +) +brain.add_sensors(raw.info, trans=coreg.trans, fnirs=["sources", "detectors"]) brain.show_view(azimuth=90, elevation=90, distance=500) @@ -223,8 +241,10 @@ raw_w_coreg = mne.io.read_raw_snirf("raw_coregistered_to_fsaverage.snirf") # Now you can simply use `trans = "fsaverage"`. -brain = mne.viz.Brain('fsaverage', subjects_dir=subjects_dir, background='w', cortex='0.5', alpha=0.3) -brain.add_sensors(raw_w_coreg.info, trans="fsaverage", fnirs=['sources', 'detectors']) +brain = mne.viz.Brain( + "fsaverage", subjects_dir=subjects_dir, background="w", cortex="0.5", alpha=0.3 +) +brain.add_sensors(raw_w_coreg.info, trans="fsaverage", fnirs=["sources", "detectors"]) # %% diff --git a/examples/general/plot_10_hrf_simulation.py b/examples/general/plot_10_hrf_simulation.py index 5ce343cf7..6f03bd424 100644 --- a/examples/general/plot_10_hrf_simulation.py +++ b/examples/general/plot_10_hrf_simulation.py @@ -9,11 +9,6 @@ experiment and analyse the simulated signal. We investigate the effect additive noise and measurement length has on response amplitude estimates. - -.. contents:: Page contents - :local: - :depth: 2 - """ # sphinx_gallery_thumbnail_number = 3 @@ -21,13 +16,15 @@ # # License: BSD (3-clause) -import mne -import mne_nirs import matplotlib.pylab as plt +import mne import numpy as np +from nilearn.plotting import plot_design_matrix + +import mne_nirs from mne_nirs.experimental_design import make_first_level_design_matrix from mne_nirs.statistics import run_glm -from nilearn.plotting import plot_design_matrix + np.random.seed(1) @@ -41,11 +38,12 @@ # The amplitude of the simulated signal is 4 uMol and the sample rate is 3 Hz. # The simulated signal is plotted below. -sfreq = 3. -amp = 4. +sfreq = 3.0 +amp = 4.0 raw = mne_nirs.simulation.simulate_nirs_raw( - sfreq=sfreq, sig_dur=60 * 5, amplitude=amp, isi_min=15., isi_max=45.) + sfreq=sfreq, sig_dur=60 * 5, amplitude=amp, isi_min=15.0, isi_max=45.0 +) raw.plot(duration=300, show_scrollbars=False) @@ -57,9 +55,9 @@ # data. We use the nilearn plotting function to visualise the design matrix. # For more details on this procedure see :ref:`tut-fnirs-hrf`. -design_matrix = make_first_level_design_matrix(raw, stim_dur=5.0, - drift_order=1, - drift_model='polynomial') +design_matrix = make_first_level_design_matrix( + raw, stim_dur=5.0, drift_order=1, drift_model="polynomial" +) fig, ax1 = plt.subplots(figsize=(10, 6), constrained_layout=True) fig = plot_design_matrix(design_matrix, ax=ax1) @@ -79,13 +77,18 @@ def print_results(glm_est, truth): - """Function to print the results of GLM estimate""" - print("Estimate:", glm_est.theta()[0][0], - " MSE:", glm_est.MSE()[0], - " Error (uM):", 1e6*(glm_est.theta()[0][0] - truth * 1e-6)) + """Print the results of GLM estimate.""" + print( + "Estimate:", + glm_est.theta()[0][0], + " MSE:", + glm_est.MSE()[0], + " Error (uM):", + 1e6 * (glm_est.theta()[0][0] - truth * 1e-6), + ) -print_results(glm_est, amp) +print_results(glm_est, amp) # %% @@ -98,7 +101,8 @@ def print_results(glm_est, truth): # and plot the noisy data and the GLM fitted model. # We print the response estimate and see that is close, but not exactly correct, # we observe the mean square error is similar to the added noise. -# Note that the clean data plot is so similar to the GLM estimate that it is hard to see unless zoomed in. +# Note that the clean data plot is so similar to the GLM estimate that it is hard to +# see unless zoomed in. # First take a copy of noise free data for comparison raw_noise_free = raw.copy() @@ -127,15 +131,17 @@ def print_results(glm_est, truth): # However, the error is greater for the colored than white noise. raw = raw_noise_free.copy() -cov = mne.Covariance(np.ones(1) * 1e-11, raw.ch_names, - raw.info['bads'], raw.info['projs'], nfree=0) -raw = mne.simulation.add_noise(raw, cov, - iir_filter=[1., -0.58853134, -0.29575669, - -0.52246482, 0.38735476, - 0.02428681]) -design_matrix = make_first_level_design_matrix(raw, stim_dur=5.0, - drift_order=1, - drift_model='polynomial') +cov = mne.Covariance( + np.ones(1) * 1e-11, raw.ch_names, raw.info["bads"], raw.info["projs"], nfree=0 +) +raw = mne.simulation.add_noise( + raw, + cov, + iir_filter=[1.0, -0.58853134, -0.29575669, -0.52246482, 0.38735476, 0.02428681], +) +design_matrix = make_first_level_design_matrix( + raw, stim_dur=5.0, drift_order=1, drift_model="polynomial" +) glm_est = run_glm(raw, design_matrix) fig, ax = plt.subplots(constrained_layout=True) @@ -158,17 +164,20 @@ def print_results(glm_est, truth): # approximately 0.6 uM for 5 minutes of data to 0.25 uM for 30 minutes of data. raw = mne_nirs.simulation.simulate_nirs_raw( - sfreq=sfreq, sig_dur=60 * 30, amplitude=amp, isi_min=15., isi_max=45.) -cov = mne.Covariance(np.ones(1) * 1e-11, raw.ch_names, - raw.info['bads'], raw.info['projs'], nfree=0) -raw = mne.simulation.add_noise(raw, cov, - iir_filter=[1., -0.58853134, -0.29575669, - -0.52246482, 0.38735476, - 0.02428681]) - -design_matrix = make_first_level_design_matrix(raw, stim_dur=5.0, - drift_order=1, - drift_model='polynomial') + sfreq=sfreq, sig_dur=60 * 30, amplitude=amp, isi_min=15.0, isi_max=45.0 +) +cov = mne.Covariance( + np.ones(1) * 1e-11, raw.ch_names, raw.info["bads"], raw.info["projs"], nfree=0 +) +raw = mne.simulation.add_noise( + raw, + cov, + iir_filter=[1.0, -0.58853134, -0.29575669, -0.52246482, 0.38735476, 0.02428681], +) + +design_matrix = make_first_level_design_matrix( + raw, stim_dur=5.0, drift_order=1, drift_model="polynomial" +) glm_est = run_glm(raw, design_matrix) fig, ax = plt.subplots(constrained_layout=True) @@ -191,7 +200,7 @@ def print_results(glm_est, truth): # properties was extracted from the data and if this # improved the response estimate. -glm_est = run_glm(raw, design_matrix, noise_model='ar5') +glm_est = run_glm(raw, design_matrix, noise_model="ar5") fig, ax = plt.subplots(figsize=(15, 6), constrained_layout=True) # actual values from model above diff --git a/examples/general/plot_11_hrf_measured.py b/examples/general/plot_11_hrf_measured.py index ca6b1a3ab..9cf2da2c6 100644 --- a/examples/general/plot_11_hrf_measured.py +++ b/examples/general/plot_11_hrf_measured.py @@ -32,20 +32,16 @@ # License: BSD (3-clause) import os -import numpy as np -import matplotlib.pyplot as plt +import matplotlib.pyplot as plt import mne -import mne_nirs +import numpy as np +from nilearn.plotting import plot_design_matrix +import mne_nirs +from mne_nirs.channels import get_long_channels, get_short_channels, picks_pair_to_idx from mne_nirs.experimental_design import make_first_level_design_matrix from mne_nirs.statistics import run_glm -from mne_nirs.channels import (get_long_channels, - get_short_channels, - picks_pair_to_idx) - -from nilearn.plotting import plot_design_matrix - # %% # Import raw NIRS data @@ -69,7 +65,7 @@ # stimulus interval. fnirs_data_folder = mne.datasets.fnirs_motor.data_path() -fnirs_raw_dir = os.path.join(fnirs_data_folder, 'Participant-1') +fnirs_raw_dir = os.path.join(fnirs_data_folder, "Participant-1") raw_intensity = mne.io.read_raw_nirx(fnirs_raw_dir).load_data() raw_intensity.resample(0.7) @@ -84,10 +80,10 @@ # # Because of limitations with ``nilearn``, we use ``'_'`` to separate conditions # rather than the standard ``'/'``. -raw_intensity.annotations.rename({'1.0': 'Control', - '2.0': 'Tapping_Left', - '3.0': 'Tapping_Right'}) -raw_intensity.annotations.delete(raw_intensity.annotations.description == '15.0') +raw_intensity.annotations.rename( + {"1.0": "Control", "2.0": "Tapping_Left", "3.0": "Tapping_Right"} +) +raw_intensity.annotations.delete(raw_intensity.annotations.description == "15.0") raw_intensity.annotations.set_durations(5) @@ -128,7 +124,7 @@ # events is also randomised. events, event_dict = mne.events_from_annotations(raw_haemo, verbose=False) -mne.viz.plot_events(events, event_id=event_dict, sfreq=raw_haemo.info['sfreq']) +mne.viz.plot_events(events, event_id=event_dict, sfreq=raw_haemo.info["sfreq"]) # %% @@ -141,7 +137,7 @@ fig, ax = plt.subplots(figsize=(15, 6), constrained_layout=True) ax.plot(raw_haemo.times, s) ax.legend(["Control", "Left", "Right"], loc="upper right") -ax.set_xlabel("Time (s)"); +ax.set_xlabel("Time (s)") # %% @@ -165,13 +161,16 @@ # parameter value. See the nilearn documentation for recommendations on setting # these values. In short, they suggest "The cutoff period (1/high_pass) should be # set as the longest period between two trials of the same condition multiplied by 2. -# For instance, if the longest period is 32s, the high_pass frequency shall be 1/64 Hz ~ 0.016 Hz". +# For instance, if the longest period is 32s, the high_pass frequency shall be +# 1/64 Hz ~ 0.016 Hz". -design_matrix = make_first_level_design_matrix(raw_haemo, - drift_model='cosine', - high_pass=0.005, # Must be specified per experiment - hrf_model='spm', - stim_dur=5.0) +design_matrix = make_first_level_design_matrix( + raw_haemo, + drift_model="cosine", + high_pass=0.005, # Must be specified per experiment + hrf_model="spm", + stim_dur=5.0, +) # %% @@ -182,11 +181,13 @@ # related to each experimental condition # uncontaminated by systemic effects. -design_matrix["ShortHbO"] = np.mean(short_chs.copy().pick( - picks="hbo").get_data(), axis=0) +design_matrix["ShortHbO"] = np.mean( + short_chs.copy().pick(picks="hbo").get_data(), axis=0 +) -design_matrix["ShortHbR"] = np.mean(short_chs.copy().pick( - picks="hbr").get_data(), axis=0) +design_matrix["ShortHbR"] = np.mean( + short_chs.copy().pick(picks="hbr").get_data(), axis=0 +) # %% @@ -225,7 +226,7 @@ fig, ax = plt.subplots(constrained_layout=True) s = mne_nirs.experimental_design.create_boxcar(raw_intensity, stim_dur=5.0) ax.plot(raw_intensity.times, s[:, 1]) -ax.plot(design_matrix['Tapping_Left']) +ax.plot(design_matrix["Tapping_Left"]) ax.legend(["Stimulus", "Expected Response"]) ax.set(xlim=(180, 300), xlabel="Time (s)", ylabel="Amplitude") @@ -264,7 +265,7 @@ # Note: as we wish to retain both channels for further the analysis below, # we operate on a copy to demonstrate this channel picking functionality. -glm_est.copy().pick('S1_D1 hbr') +glm_est.copy().pick("S1_D1 hbr") # %% # @@ -282,7 +283,7 @@ # For example, to determine the MSE for channel `S1` `D1` for the hbr type # you would call: -glm_est.copy().pick('S1_D1 hbr').MSE() +glm_est.copy().pick("S1_D1 hbr").MSE() # %% @@ -320,7 +321,7 @@ # negative of HbO as expected. glm_est = run_glm(raw_haemo, design_matrix) -glm_est.plot_topo(conditions=['Tapping_Left', 'Tapping_Right']) +glm_est.plot_topo(conditions=["Tapping_Left", "Tapping_Right"]) # %% @@ -345,15 +346,21 @@ # apparent that the data does not indicate that activity spreads across # the center of the head. -fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 6), gridspec_kw=dict(width_ratios=[0.92, 1])) +fig, axes = plt.subplots( + nrows=1, ncols=2, figsize=(10, 6), gridspec_kw=dict(width_ratios=[0.92, 1]) +) glm_hbo = glm_est.copy().pick(picks="hbo") -conditions = ['Tapping_Right'] +conditions = ["Tapping_Right"] glm_hbo.plot_topo(axes=axes[0], colorbar=False, conditions=conditions) -glm_hbo.copy().pick(picks=range(10)).plot_topo(conditions=conditions, axes=axes[1], colorbar=False, vlim=(-16, 16)) -glm_hbo.copy().pick(picks=range(10, 20)).plot_topo(conditions=conditions, axes=axes[1], colorbar=False, vlim=(-16, 16)) +glm_hbo.copy().pick(picks=range(10)).plot_topo( + conditions=conditions, axes=axes[1], colorbar=False, vlim=(-16, 16) +) +glm_hbo.copy().pick(picks=range(10, 20)).plot_topo( + conditions=conditions, axes=axes[1], colorbar=False, vlim=(-16, 16) +) axes[0].set_title("Smoothed across hemispheres") axes[1].set_title("Hemispheres plotted independently") @@ -364,7 +371,9 @@ # Another way to view the data is to project the GLM estimates to the nearest # cortical surface -glm_est.copy().surface_projection(condition="Tapping_Right", view="dorsal", chroma="hbo") +glm_est.copy().surface_projection( + condition="Tapping_Right", view="dorsal", chroma="hbo" +) # %% @@ -388,15 +397,15 @@ # The fOLD toolbox can be used to assist in the design of ROIs. # And consideration should be paid to ensure optimal size ROIs are selected. -left = [[1, 1], [1, 2], [1, 3], [2, 1], [2, 3], - [2, 4], [3, 2], [3, 3], [4, 3], [4, 4]] -right = [[5, 5], [5, 6], [5, 7], [6, 5], [6, 7], - [6, 8], [7, 6], [7, 7], [8, 7], [8, 8]] +left = [[1, 1], [1, 2], [1, 3], [2, 1], [2, 3], [2, 4], [3, 2], [3, 3], [4, 3], [4, 4]] +right = [[5, 5], [5, 6], [5, 7], [6, 5], [6, 7], [6, 8], [7, 6], [7, 7], [8, 7], [8, 8]] -groups = dict(Left_ROI=picks_pair_to_idx(raw_haemo, left), - Right_ROI=picks_pair_to_idx(raw_haemo, right)) +groups = dict( + Left_ROI=picks_pair_to_idx(raw_haemo, left), + Right_ROI=picks_pair_to_idx(raw_haemo, right), +) -conditions = ['Control', 'Tapping_Left', 'Tapping_Right'] +conditions = ["Control", "Tapping_Left", "Tapping_Right"] df = glm_est.to_dataframe_region_of_interest(groups, conditions) @@ -421,9 +430,10 @@ # from tapping on the right hand. contrast_matrix = np.eye(design_matrix.shape[1]) -basic_conts = dict([(column, contrast_matrix[i]) - for i, column in enumerate(design_matrix.columns)]) -contrast_LvR = basic_conts['Tapping_Left'] - basic_conts['Tapping_Right'] +basic_conts = dict( + [(column, contrast_matrix[i]) for i, column in enumerate(design_matrix.columns)] +) +contrast_LvR = basic_conts["Tapping_Left"] - basic_conts["Tapping_Right"] contrast = glm_est.compute_contrast(contrast_LvR) contrast.plot_topo() @@ -457,9 +467,9 @@ # motor cortex, so we dont expect 100% of channels to detect responses to # the tapping, but we do expect 5% or less for the false positive rate. -(df - .query('Condition in ["Control", "Tapping_Left", "Tapping_Right"]') - .drop(['df', 'mse', 'p_value', 't'], axis=1) - .groupby(['Condition', 'Chroma', 'ch_name']) - .agg(['mean']) - ) +( + df.query('Condition in ["Control", "Tapping_Left", "Tapping_Right"]') + .drop(["df", "mse", "p_value", "t"], axis=1) + .groupby(["Condition", "Chroma", "ch_name"]) + .agg(["mean"]) +) diff --git a/examples/general/plot_12_group_glm.py b/examples/general/plot_12_group_glm.py index c31cc8189..0cb35f20d 100644 --- a/examples/general/plot_12_group_glm.py +++ b/examples/general/plot_12_group_glm.py @@ -61,11 +61,7 @@ Participants tapped their thumb to their fingers for 5s. Conditions were presented in a random order with a randomised inter stimulus interval. - -.. contents:: Page contents - :local: - :depth: 2 -""" +""" # noqa: E501 # sphinx_gallery_thumbnail_number = 2 # Authors: Robert Luke @@ -74,34 +70,31 @@ # Import common libraries +import matplotlib as mpl + +# Import Plotting Library +import matplotlib.pyplot as plt import numpy as np import pandas as pd +import seaborn as sns -# Import MNE processing -from mne.preprocessing.nirs import optical_density, beer_lambert_law +# Import StatsModels +import statsmodels.formula.api as smf -# Import MNE-NIRS processing -from mne_nirs.statistics import run_glm -from mne_nirs.experimental_design import make_first_level_design_matrix -from mne_nirs.statistics import statsmodels_to_results -from mne_nirs.channels import get_short_channels, get_long_channels -from mne_nirs.channels import picks_pair_to_idx -from mne_nirs.visualisation import plot_glm_group_topo -from mne_nirs.datasets import fnirs_motor_group -from mne_nirs.visualisation import plot_glm_surface_projection -from mne_nirs.io.fold import fold_channel_specificity +# Import MNE processing +from mne.preprocessing.nirs import beer_lambert_law, optical_density # Import MNE-BIDS processing -from mne_bids import BIDSPath, read_raw_bids, get_entity_vals +from mne_bids import BIDSPath, get_entity_vals, read_raw_bids -# Import StatsModels -import statsmodels.formula.api as smf - -# Import Plotting Library -import matplotlib.pyplot as plt -import matplotlib as mpl -import seaborn as sns +from mne_nirs.channels import get_long_channels, get_short_channels, picks_pair_to_idx +from mne_nirs.datasets import fnirs_motor_group +from mne_nirs.experimental_design import make_first_level_design_matrix +from mne_nirs.io.fold import fold_channel_specificity +# Import MNE-NIRS processing +from mne_nirs.statistics import run_glm, statsmodels_to_results +from mne_nirs.visualisation import plot_glm_group_topo, plot_glm_surface_projection # %% # Set up directories @@ -123,15 +116,16 @@ # We inform the software that we are analysing nirs data that is saved in # the snirf format. -dataset = BIDSPath(root=root, task="tapping", - datatype="nirs", suffix="nirs", extension=".snirf") +dataset = BIDSPath( + root=root, task="tapping", datatype="nirs", suffix="nirs", extension=".snirf" +) print(dataset.directory) # %% # For example we can automatically query the subjects, tasks, and sessions. -subjects = get_entity_vals(root, 'subject') +subjects = get_entity_vals(root, "subject") print(subjects) @@ -172,13 +166,13 @@ def individual_analysis(bids_path, ID): - raw_intensity = read_raw_bids(bids_path=bids_path, verbose=False) - # Delete annotation labeled 15, as these just signify the start and end of experiment. - raw_intensity.annotations.delete(raw_intensity.annotations.description == '15.0') + # Delete annotation labeled 15, as these just signify the experiment start and end. + raw_intensity.annotations.delete(raw_intensity.annotations.description == "15.0") # sanitize event names raw_intensity.annotations.description[:] = [ - d.replace('/', '_') for d in raw_intensity.annotations.description] + d.replace("/", "_") for d in raw_intensity.annotations.description + ] # Convert signal to haemoglobin and resample raw_od = optical_density(raw_intensity) @@ -193,8 +187,12 @@ def individual_analysis(bids_path, ID): design_matrix = make_first_level_design_matrix(raw_haemo, stim_dur=5.0) # Append short channels mean to design matrix - design_matrix["ShortHbO"] = np.mean(sht_chans.copy().pick(picks="hbo").get_data(), axis=0) - design_matrix["ShortHbR"] = np.mean(sht_chans.copy().pick(picks="hbr").get_data(), axis=0) + design_matrix["ShortHbO"] = np.mean( + sht_chans.copy().pick(picks="hbo").get_data(), axis=0 + ) + design_matrix["ShortHbR"] = np.mean( + sht_chans.copy().pick(picks="hbr").get_data(), axis=0 + ) # Run GLM glm_est = run_glm(raw_haemo, design_matrix) @@ -205,22 +203,24 @@ def individual_analysis(bids_path, ID): right = [[8, 7], [5, 7], [7, 7], [5, 6], [6, 7], [5, 5]] # Then generate the correct indices for each pair groups = dict( - Left_Hemisphere=picks_pair_to_idx(raw_haemo, left, on_missing='ignore'), - Right_Hemisphere=picks_pair_to_idx(raw_haemo, right, on_missing='ignore')) + Left_Hemisphere=picks_pair_to_idx(raw_haemo, left, on_missing="ignore"), + Right_Hemisphere=picks_pair_to_idx(raw_haemo, right, on_missing="ignore"), + ) # Extract channel metrics cha = glm_est.to_dataframe() # Compute region of interest results from channel data - roi = glm_est.to_dataframe_region_of_interest(groups, - design_matrix.columns, - demographic_info=True) + roi = glm_est.to_dataframe_region_of_interest( + groups, design_matrix.columns, demographic_info=True + ) # Define left vs right tapping contrast contrast_matrix = np.eye(design_matrix.shape[1]) - basic_conts = dict([(column, contrast_matrix[i]) - for i, column in enumerate(design_matrix.columns)]) - contrast_LvR = basic_conts['Tapping_Left'] - basic_conts['Tapping_Right'] + basic_conts = dict( + [(column, contrast_matrix[i]) for i, column in enumerate(design_matrix.columns)] + ) + contrast_LvR = basic_conts["Tapping_Left"] - basic_conts["Tapping_Right"] # Compute defined contrast contrast = glm_est.compute_contrast(contrast_LvR) @@ -230,9 +230,9 @@ def individual_analysis(bids_path, ID): roi["ID"] = cha["ID"] = con["ID"] = ID # Convert to uM for nicer plotting below. - cha["theta"] = [t * 1.e6 for t in cha["theta"]] - roi["theta"] = [t * 1.e6 for t in roi["theta"]] - con["effect"] = [t * 1.e6 for t in con["effect"]] + cha["theta"] = [t * 1.0e6 for t in cha["theta"]] + roi["theta"] = [t * 1.0e6 for t in roi["theta"]] + con["effect"] = [t * 1.0e6 for t in con["effect"]] return raw_haemo, roi, cha, con @@ -251,7 +251,6 @@ def individual_analysis(bids_path, ID): df_con = pd.DataFrame() # To store channel level contrast results for sub in subjects: # Loop from first to fifth subject - # Create path to file based on experiment info bids_path = dataset.update(subject=sub) @@ -278,7 +277,18 @@ def individual_analysis(bids_path, ID): grp_results = df_roi.query("Condition in ['Control', 'Tapping_Left', 'Tapping_Right']") grp_results = grp_results.query("Chroma in ['hbo']") -sns.catplot(x="Condition", y="theta", col="ID", hue="ROI", data=grp_results, col_wrap=5, errorbar=None, palette="muted", height=4, s=10) +sns.catplot( + x="Condition", + y="theta", + col="ID", + hue="ROI", + data=grp_results, + col_wrap=5, + errorbar=None, + palette="muted", + height=4, + s=10, +) # %% @@ -316,8 +326,9 @@ def individual_analysis(bids_path, ID): grp_results = df_roi.query("Condition in ['Control','Tapping_Left', 'Tapping_Right']") -roi_model = smf.mixedlm("theta ~ -1 + ROI:Condition:Chroma", - grp_results, groups=grp_results["ID"]).fit(method='nm') +roi_model = smf.mixedlm( + "theta ~ -1 + ROI:Condition:Chroma", grp_results, groups=grp_results["ID"] +).fit(method="nm") roi_model.summary() @@ -348,8 +359,9 @@ def individual_analysis(bids_path, ID): grp_results = grp_results.query("Chroma in ['hbo']") grp_results = grp_results.query("ROI in ['Right_Hemisphere']") -roi_model = smf.mixedlm("theta ~ Condition + Sex", - grp_results, groups=grp_results["ID"]).fit(method='nm') +roi_model = smf.mixedlm( + "theta ~ Condition + Sex", grp_results, groups=grp_results["ID"] +).fit(method="nm") roi_model.summary() # %% @@ -367,12 +379,22 @@ def individual_analysis(bids_path, ID): # Regenerate the results from the original group model above grp_results = df_roi.query("Condition in ['Control','Tapping_Left', 'Tapping_Right']") -roi_model = smf.mixedlm("theta ~ -1 + ROI:Condition:Chroma", - grp_results, groups=grp_results["ID"]).fit(method='nm') +roi_model = smf.mixedlm( + "theta ~ -1 + ROI:Condition:Chroma", grp_results, groups=grp_results["ID"] +).fit(method="nm") df = statsmodels_to_results(roi_model) -sns.catplot(x="Condition", y="Coef.", hue="ROI", data=df.query("Chroma == 'hbo'"), errorbar=None, palette="muted", height=4, s=10) +sns.catplot( + x="Condition", + y="Coef.", + hue="ROI", + data=df.query("Chroma == 'hbo'"), + errorbar=None, + palette="muted", + height=4, + s=10, +) # %% @@ -386,47 +408,66 @@ def individual_analysis(bids_path, ID): # than region of interest as above). # Then we pass these results to the topomap function. -fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 10), - gridspec_kw=dict(width_ratios=[1, 1])) +fig, axes = plt.subplots( + nrows=2, ncols=2, figsize=(10, 10), gridspec_kw=dict(width_ratios=[1, 1]) +) # Cut down the dataframe just to the conditions we are interested in ch_summary = df_cha.query("Condition in ['Tapping_Left', 'Tapping_Right']") ch_summary = ch_summary.query("Chroma in ['hbo']") # Run group level model and convert to dataframe -ch_model = smf.mixedlm("theta ~ -1 + ch_name:Chroma:Condition", - ch_summary, groups=ch_summary["ID"]).fit(method='nm') +ch_model = smf.mixedlm( + "theta ~ -1 + ch_name:Chroma:Condition", ch_summary, groups=ch_summary["ID"] +).fit(method="nm") ch_model_df = statsmodels_to_results(ch_model) # Plot the two conditions -plot_glm_group_topo(raw_haemo.copy().pick(picks="hbo"), - ch_model_df.query("Condition in ['Tapping_Left']"), - colorbar=False, axes=axes[0, 0], - vlim=(0, 20), cmap=mpl.cm.Oranges) - -plot_glm_group_topo(raw_haemo.copy().pick(picks="hbo"), - ch_model_df.query("Condition in ['Tapping_Right']"), - colorbar=True, axes=axes[0, 1], - vlim=(0, 20), cmap=mpl.cm.Oranges) +plot_glm_group_topo( + raw_haemo.copy().pick(picks="hbo"), + ch_model_df.query("Condition in ['Tapping_Left']"), + colorbar=False, + axes=axes[0, 0], + vlim=(0, 20), + cmap=mpl.cm.Oranges, +) + +plot_glm_group_topo( + raw_haemo.copy().pick(picks="hbo"), + ch_model_df.query("Condition in ['Tapping_Right']"), + colorbar=True, + axes=axes[0, 1], + vlim=(0, 20), + cmap=mpl.cm.Oranges, +) # Cut down the dataframe just to the conditions we are interested in ch_summary = df_cha.query("Condition in ['Tapping_Left', 'Tapping_Right']") ch_summary = ch_summary.query("Chroma in ['hbr']") # Run group level model and convert to dataframe -ch_model = smf.mixedlm("theta ~ -1 + ch_name:Chroma:Condition", - ch_summary, groups=ch_summary["ID"]).fit(method='nm') +ch_model = smf.mixedlm( + "theta ~ -1 + ch_name:Chroma:Condition", ch_summary, groups=ch_summary["ID"] +).fit(method="nm") ch_model_df = statsmodels_to_results(ch_model) # Plot the two conditions -plot_glm_group_topo(raw_haemo.copy().pick(picks="hbr"), - ch_model_df.query("Condition in ['Tapping_Left']"), - colorbar=False, axes=axes[1, 0], - vlim=(-10, 0), cmap=mpl.cm.Blues_r) -plot_glm_group_topo(raw_haemo.copy().pick(picks="hbr"), - ch_model_df.query("Condition in ['Tapping_Right']"), - colorbar=True, axes=axes[1, 1], - vlim=(-10, 0), cmap=mpl.cm.Blues_r) +plot_glm_group_topo( + raw_haemo.copy().pick(picks="hbr"), + ch_model_df.query("Condition in ['Tapping_Left']"), + colorbar=False, + axes=axes[1, 0], + vlim=(-10, 0), + cmap=mpl.cm.Blues_r, +) +plot_glm_group_topo( + raw_haemo.copy().pick(picks="hbr"), + ch_model_df.query("Condition in ['Tapping_Right']"), + colorbar=True, + axes=axes[1, 1], + vlim=(-10, 0), + cmap=mpl.cm.Blues_r, +) # %% @@ -441,14 +482,16 @@ def individual_analysis(bids_path, ID): con_summary = df_con.query("Chroma in ['hbo']") # Run group level model and convert to dataframe -con_model = smf.mixedlm("effect ~ -1 + ch_name:Chroma", - con_summary, groups=con_summary["ID"]).fit(method='nm') -con_model_df = statsmodels_to_results(con_model, - order=raw_haemo.copy().pick( - picks="hbo").ch_names) +con_model = smf.mixedlm( + "effect ~ -1 + ch_name:Chroma", con_summary, groups=con_summary["ID"] +).fit(method="nm") +con_model_df = statsmodels_to_results( + con_model, order=raw_haemo.copy().pick(picks="hbo").ch_names +) -plot_glm_group_topo(raw_haemo.copy().pick(picks="hbo"), - con_model_df, colorbar=True, axes=axes) +plot_glm_group_topo( + raw_haemo.copy().pick(picks="hbo"), con_model_df, colorbar=True, axes=axes +) # %% @@ -457,8 +500,12 @@ def individual_analysis(bids_path, ID): # And set all channels that dont have a significant response to zero. # -plot_glm_group_topo(raw_haemo.copy().pick(picks="hbo").pick(picks=range(10)), - con_model_df, colorbar=True, threshold=True) +plot_glm_group_topo( + raw_haemo.copy().pick(picks="hbo").pick(picks=range(10)), + con_model_df, + colorbar=True, + threshold=True, +) # %% @@ -479,28 +526,40 @@ def individual_analysis(bids_path, ID): # Generate brain figure from data -clim = dict(kind='value', pos_lims=(0, 8, 11)) -brain = plot_glm_surface_projection(raw_haemo.copy().pick("hbo"), - con_model_df, clim=clim, view='dorsal', - colorbar=True, size=(800, 700)) -brain.add_text(0.05, 0.95, "Left-Right", 'title', font_size=16, color='k') +clim = dict(kind="value", pos_lims=(0, 8, 11)) +brain = plot_glm_surface_projection( + raw_haemo.copy().pick("hbo"), + con_model_df, + clim=clim, + view="dorsal", + colorbar=True, + size=(800, 700), +) +brain.add_text(0.05, 0.95, "Left-Right", "title", font_size=16, color="k") # Run model code as above -clim = dict(kind='value', pos_lims=(0, 11.5, 17)) -for idx, cond in enumerate(['Tapping_Left', 'Tapping_Right']): - +clim = dict(kind="value", pos_lims=(0, 11.5, 17)) +for idx, cond in enumerate(["Tapping_Left", "Tapping_Right"]): # Run same model as explained in the sections above ch_summary = df_cha.query("Condition in [@cond]") ch_summary = ch_summary.query("Chroma in ['hbo']") - ch_model = smf.mixedlm("theta ~ -1 + ch_name", ch_summary, - groups=ch_summary["ID"]).fit(method='nm') - model_df = statsmodels_to_results(ch_model, order=raw_haemo.copy().pick("hbo").ch_names) + ch_model = smf.mixedlm( + "theta ~ -1 + ch_name", ch_summary, groups=ch_summary["ID"] + ).fit(method="nm") + model_df = statsmodels_to_results( + ch_model, order=raw_haemo.copy().pick("hbo").ch_names + ) # Generate brain figure from data - brain = plot_glm_surface_projection(raw_haemo.copy().pick("hbo"), - model_df, clim=clim, view='dorsal', - colorbar=True, size=(800, 700)) - brain.add_text(0.05, 0.95, cond, 'title', font_size=16, color='k') + brain = plot_glm_surface_projection( + raw_haemo.copy().pick("hbo"), + model_df, + clim=clim, + view="dorsal", + colorbar=True, + size=(800, 700), + ) + brain.add_text(0.05, 0.95, cond, "title", font_size=16, color="k") # %% @@ -514,16 +573,17 @@ def individual_analysis(bids_path, ID): ch_summary = ch_summary.query("Chroma in ['hbo']") # Run group level model and convert to dataframe -ch_model = smf.mixedlm("theta ~ -1 + ch_name:Chroma:Condition", - ch_summary, groups=ch_summary["ID"]).fit(method='nm') +ch_model = smf.mixedlm( + "theta ~ -1 + ch_name:Chroma:Condition", ch_summary, groups=ch_summary["ID"] +).fit(method="nm") # Here we can use the order argument to ensure the channel name order -ch_model_df = statsmodels_to_results(ch_model, - order=raw_haemo.copy().pick( - picks="hbo").ch_names) +ch_model_df = statsmodels_to_results( + ch_model, order=raw_haemo.copy().pick(picks="hbo").ch_names +) # And make the table prettier ch_model_df.reset_index(drop=True, inplace=True) -ch_model_df = ch_model_df.set_index(['ch_name', 'Condition']) +ch_model_df = ch_model_df.set_index(["ch_name", "Condition"]) ch_model_df @@ -540,7 +600,9 @@ def individual_analysis(bids_path, ID): # The tool is very intuitive and easy to use. # Be sure to cite the authors if you use their tool or data: # -# Morais, Guilherme Augusto Zimeo, Joana Bisol Balardin, and João Ricardo Sato. "fNIRS optodes’ location decider (fOLD): a toolbox for probe arrangement guided by brain regions-of-interest." Scientific reports 8.1 (2018): 1-11. +# Morais, Guilherme Augusto Zimeo, Joana Bisol Balardin, and João Ricardo Sato. +# "fNIRS optodes’ location decider (fOLD): a toolbox for probe arrangement guided by +# brain regions-of-interest." Scientific reports 8.1 (2018): 1-11. # # It can be useful to understand what brain structures # the measured response may have resulted from. Here we illustrate @@ -555,7 +617,7 @@ def individual_analysis(bids_path, ID): # that they provide. See the Notes section of # :func:`mne_nirs.io.fold_channel_specificity` for more information. -largest_response_channel = ch_model_df.loc[ch_model_df['Coef.'].idxmax()] +largest_response_channel = ch_model_df.loc[ch_model_df["Coef."].idxmax()] largest_response_channel diff --git a/examples/general/plot_13_fir_glm.py b/examples/general/plot_13_fir_glm.py index 8ad38e8ff..4d5a812cd 100644 --- a/examples/general/plot_13_fir_glm.py +++ b/examples/general/plot_13_fir_glm.py @@ -50,12 +50,7 @@ Simply modify the ``read_raw_`` function to match your data type. See :ref:`data importing tutorial ` to learn how to use your data with MNE-Python. - -.. contents:: Page contents - :local: - :depth: 2 - -""" +""" # noqa: E501 # sphinx_gallery_thumbnail_number = 1 # Authors: Robert Luke @@ -64,28 +59,26 @@ # Import common libraries +# Import Plotting Library +import matplotlib.pyplot as plt import numpy as np import pandas as pd -# Import MNE processing -from mne.preprocessing.nirs import optical_density, beer_lambert_law +# Import StatsModels +import statsmodels.formula.api as smf -# Import MNE-NIRS processing -from mne_nirs.statistics import run_glm -from mne_nirs.experimental_design import make_first_level_design_matrix -from mne_nirs.statistics import statsmodels_to_results -from mne_nirs.datasets import fnirs_motor_group -from mne_nirs.channels import get_short_channels, get_long_channels +# Import MNE processing +from mne.preprocessing.nirs import beer_lambert_law, optical_density # Import MNE-BIDS processing from mne_bids import BIDSPath, read_raw_bids -# Import StatsModels -import statsmodels.formula.api as smf - -# Import Plotting Library -import matplotlib.pyplot as plt +from mne_nirs.channels import get_long_channels, get_short_channels +from mne_nirs.datasets import fnirs_motor_group +from mne_nirs.experimental_design import make_first_level_design_matrix +# Import MNE-NIRS processing +from mne_nirs.statistics import run_glm, statsmodels_to_results # %% # Define FIR analysis @@ -98,14 +91,15 @@ # Due to the chosen sample rate of 0.5 Hz, these delays # correspond to 0, 2, 4... seconds from the onset of the stimulus. -def analysis(fname, ID): +def analysis(fname, ID): raw_intensity = read_raw_bids(bids_path=fname, verbose=False) - # Delete annotation labeled 15, as these just signify the start and end of experiment. - raw_intensity.annotations.delete(raw_intensity.annotations.description == '15.0') + # Delete annotation labeled 15, as these just signify the experiment start and end. + raw_intensity.annotations.delete(raw_intensity.annotations.description == "15.0") # sanitize event names raw_intensity.annotations.description[:] = [ - d.replace('/', '_') for d in raw_intensity.annotations.description] + d.replace("/", "_") for d in raw_intensity.annotations.description + ] # Convert signal to haemoglobin and just keep hbo raw_od = optical_density(raw_intensity) @@ -117,13 +111,15 @@ def analysis(fname, ID): raw_haemo = get_long_channels(raw_haemo) # Create a design matrix - design_matrix = make_first_level_design_matrix(raw_haemo, - hrf_model='fir', - stim_dur=1.0, - fir_delays=range(10), - drift_model='cosine', - high_pass=0.01, - oversampling=1) + design_matrix = make_first_level_design_matrix( + raw_haemo, + hrf_model="fir", + stim_dur=1.0, + fir_delays=range(10), + drift_model="cosine", + high_pass=0.01, + oversampling=1, + ) # Add short channels as regressor in GLM for chan in range(len(short_chans.ch_names)): design_matrix[f"short_{chan}"] = short_chans.get_data(chan).T @@ -139,7 +135,7 @@ def analysis(fname, ID): df_ind = glm_est.to_dataframe_region_of_interest(rois, conditions) df_ind["ID"] = ID - df_ind["theta"] = [t * 1.e6 for t in df_ind["theta"]] + df_ind["theta"] = [t * 1.0e6 for t in df_ind["theta"]] return df_ind, raw_haemo, design_matrix @@ -155,12 +151,17 @@ def analysis(fname, ID): df = pd.DataFrame() for sub in range(1, 6): # Loop from first to fifth subject - ID = '%02d' % sub # Tidy the subject name + ID = "%02d" % sub # Tidy the subject name # Create path to file based on experiment info - bids_path = BIDSPath(subject=ID, task="tapping", - root=fnirs_motor_group.data_path(), - datatype="nirs", suffix="nirs", extension=".snirf") + bids_path = BIDSPath( + subject=ID, + task="tapping", + root=fnirs_motor_group.data_path(), + datatype="nirs", + suffix="nirs", + extension=".snirf", + ) df_individual, raw, dm = analysis(bids_path, ID) @@ -183,9 +184,9 @@ def analysis(fname, ID): df = df.query("isTapping in [True]") # Make a new column that stores the condition name for tidier model below df.loc[:, "TidyCond"] = "" -df.loc[df["isTapping"] == True, "TidyCond"] = "Tapping" +df.loc[df["isTapping"] == True, "TidyCond"] = "Tapping" # noqa: E712 # Finally, extract the FIR delay in to its own column in data frame -df.loc[:, "delay"] = [n.split('_')[-1] for n in df.Condition] +df.loc[:, "delay"] = [n.split("_")[-1] for n in df.Condition] # To simplify this example we will only look at the right hand tapping # condition so we now remove the left tapping conditions from the @@ -202,8 +203,7 @@ def analysis(fname, ID): # of FIR delay for each chromophore on the evoked response with participant # (ID) as a random variable. -lme = smf.mixedlm('theta ~ -1 + delay:TidyCond:Chroma', df, - groups=df["ID"]).fit() +lme = smf.mixedlm("theta ~ -1 + delay:TidyCond:Chroma", df, groups=df["ID"]).fit() # The model is summarised below, and is not displayed here. # You can display the model output using: lme.summary() @@ -221,7 +221,7 @@ def analysis(fname, ID): # Create a dataframe from LME model for plotting below df_sum = statsmodels_to_results(lme) df_sum["delay"] = [int(n) for n in df_sum["delay"]] -df_sum = df_sum.sort_values('delay') +df_sum = df_sum.sort_values("delay") # Print the result for the oxyhaemoglobin data in the tapping condition df_sum.query("TidyCond in ['Tapping']").query("Chroma in ['hbo']") @@ -272,9 +272,9 @@ def analysis(fname, ID): # Plot the result axes[0].plot(index_values, np.asarray(dm_cond)) -axes[1].plot(index_values,np.asarray(dm_cond_scaled_hbo)) -axes[2].plot(index_values, np.sum(dm_cond_scaled_hbo, axis=1), 'r') -axes[2].plot(index_values, np.sum(dm_cond_scaled_hbr, axis=1), 'b') +axes[1].plot(index_values, np.asarray(dm_cond_scaled_hbo)) +axes[2].plot(index_values, np.sum(dm_cond_scaled_hbo, axis=1), "r") +axes[2].plot(index_values, np.sum(dm_cond_scaled_hbr, axis=1), "b") # Format the plot for ax in range(3): @@ -310,16 +310,22 @@ def analysis(fname, ID): fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(7, 7)) # Plot the result -axes.plot(index_values, np.sum(dm_cond_scaled_hbo, axis=1), 'r') -axes.plot(index_values, np.sum(dm_cond_scaled_hbr, axis=1), 'b') -axes.fill_between(index_values, - np.asarray(np.sum(dm_cond_scaled_hbo_l95, axis=1)), - np.asarray(np.sum(dm_cond_scaled_hbo_u95, axis=1)), - facecolor='red', alpha=0.25) -axes.fill_between(index_values, - np.asarray(np.sum(dm_cond_scaled_hbr_l95, axis=1)), - np.asarray(np.sum(dm_cond_scaled_hbr_u95, axis=1)), - facecolor='blue', alpha=0.25) +axes.plot(index_values, np.sum(dm_cond_scaled_hbo, axis=1), "r") +axes.plot(index_values, np.sum(dm_cond_scaled_hbr, axis=1), "b") +axes.fill_between( + index_values, + np.asarray(np.sum(dm_cond_scaled_hbo_l95, axis=1)), + np.asarray(np.sum(dm_cond_scaled_hbo_u95, axis=1)), + facecolor="red", + alpha=0.25, +) +axes.fill_between( + index_values, + np.asarray(np.sum(dm_cond_scaled_hbr_l95, axis=1)), + np.asarray(np.sum(dm_cond_scaled_hbr_u95, axis=1)), + facecolor="blue", + alpha=0.25, +) # Format the plot axes.set_xlim(-5, 30) diff --git a/examples/general/plot_14_glm_components.py b/examples/general/plot_14_glm_components.py index 5604b1108..396552721 100644 --- a/examples/general/plot_14_glm_components.py +++ b/examples/general/plot_14_glm_components.py @@ -40,13 +40,7 @@ various choices available in your analysis. However, this is just to illustrate various points. In reality (see all other tutorials), MNE-NIRS will wrap all required Nilearn functions so you don't need to access them directly. - - -.. contents:: Page contents - :local: - :depth: 2 - -""" +""" # noqa: E501 # sphinx_gallery_thumbnail_number = 1 # Authors: Robert Luke @@ -56,21 +50,24 @@ # Import common libraries import os -import numpy as np -import mne -# Import MNE-NIRS processing -from mne_nirs.experimental_design import make_first_level_design_matrix, \ - longest_inter_annotation_interval, drift_high_pass +import matplotlib as mpl + +# Import Plotting Library +import matplotlib.pyplot as plt +import mne +import numpy as np # Import Nilearn from nilearn.glm import first_level from nilearn.plotting import plot_design_matrix -# Import Plotting Library -import matplotlib.pyplot as plt -import matplotlib as mpl - +# Import MNE-NIRS processing +from mne_nirs.experimental_design import ( + drift_high_pass, + longest_inter_annotation_interval, + make_first_level_design_matrix, +) # %% # Haemodynamic Response Function @@ -133,9 +130,9 @@ # Modifying the duration changes the regressor timecourse. Below we demonstrate # how this varies for several duration values with the Glover HRF. + # Convenient functions so we dont need to repeat code below def generate_stim(onset, amplitude, duration, hrf_model, maxtime=30): - # Generate signal with specified duration and onset frame_times = np.linspace(0, maxtime, 601) exp_condition = np.array((onset, duration, amplitude)).reshape(3, 1) @@ -150,8 +147,7 @@ def generate_stim(onset, amplitude, duration, hrf_model, maxtime=30): def plot_regressor(onset, amplitude, duration, hrf_model): - frame_times, stim, signal = generate_stim( - onset, amplitude, duration, hrf_model) + frame_times, stim, signal = generate_stim(onset, amplitude, duration, hrf_model) plt.fill(frame_times, stim, "k", alpha=0.5, label="stimulus") plt.plot(frame_times, signal.T[0], label="Regressor") plt.xlabel("Time (s)") @@ -205,7 +201,8 @@ def plot_regressor(onset, amplitude, duration, hrf_model): for n in [1, 3, 5, 10, 15, 20, 25, 30, 35]: frame_times, stim, signal = generate_stim( - onset, amplitude, n, hrf_model, maxtime=50) + onset, amplitude, n, hrf_model, maxtime=50 + ) axes.plot(frame_times, signal.T[0], label="Regressor", c=cmap(norm(n))) axes.set_xlabel("Time (s)") @@ -238,13 +235,13 @@ def plot_regressor(onset, amplitude, duration, hrf_model): # and give names to the annotations. fnirs_data_folder = mne.datasets.fnirs_motor.data_path() -fnirs_raw_dir = os.path.join(fnirs_data_folder, 'Participant-1') +fnirs_raw_dir = os.path.join(fnirs_data_folder, "Participant-1") raw_intensity = mne.io.read_raw_nirx(fnirs_raw_dir).load_data().crop(tmax=300) # raw_intensity.resample(0.7) -raw_intensity.annotations.rename({'1.0': 'Control', - '2.0': 'Tapping/Left', - '3.0': 'Tapping/Right'}) -raw_intensity.annotations.delete(raw_intensity.annotations.description == '15.0') +raw_intensity.annotations.rename( + {"1.0": "Control", "2.0": "Tapping/Left", "3.0": "Tapping/Right"} +) +raw_intensity.annotations.delete(raw_intensity.annotations.description == "15.0") raw_intensity.annotations.set_durations(5) @@ -256,17 +253,19 @@ def plot_regressor(onset, amplitude, duration, hrf_model): # axis and is specified in scan number (fMRI hangover) or sample. # There is no colorbar for this plot, as specified in Nilearn. # -# We can see that when each event occurs the model value increases before returning to baseline. -# this is the same information as was shown in the time courses above, except displayed differently -# with color representing amplitude. +# We can see that when each event occurs the model value increases before returning to +# baseline. This is the same information as was shown in the time courses above, except +# displayed differently with color representing amplitude. -design_matrix = make_first_level_design_matrix(raw_intensity, - # Ignore drift model for now, see section below - drift_model='polynomial', - drift_order=0, - # Here we specify the HRF and duration - hrf_model='glover', - stim_dur=3.0) +design_matrix = make_first_level_design_matrix( + raw_intensity, + # Ignore drift model for now, see section below + drift_model="polynomial", + drift_order=0, + # Here we specify the HRF and duration + hrf_model="glover", + stim_dur=3.0, +) fig, ax1 = plt.subplots(figsize=(10, 6), nrows=1, ncols=1) fig = plot_design_matrix(design_matrix, ax=ax1) @@ -277,13 +276,15 @@ def plot_regressor(onset, amplitude, duration, hrf_model): # As before we can explore the effect of modifying the duration, # the resulting regressor for each annotation is elongated. -design_matrix = make_first_level_design_matrix(raw_intensity, - # Ignore drift model for now, see section below - drift_model='polynomial', - drift_order=0, - # Here we specify the HRF and duration - hrf_model='glover', - stim_dur=13.0) +design_matrix = make_first_level_design_matrix( + raw_intensity, + # Ignore drift model for now, see section below + drift_model="polynomial", + drift_order=0, + # Here we specify the HRF and duration + hrf_model="glover", + stim_dur=13.0, +) fig, ax1 = plt.subplots(figsize=(10, 6), nrows=1, ncols=1) fig = plot_design_matrix(design_matrix, ax=ax1) @@ -295,13 +296,15 @@ def plot_regressor(onset, amplitude, duration, hrf_model): # may overlap (for example an event related design). # This is not an issue, the design matrix can handle overlapping responses. -design_matrix = make_first_level_design_matrix(raw_intensity, - # Ignore drift model for now, see section below - drift_model='polynomial', - drift_order=0, - # Here we specify the HRF and duration - hrf_model='glover', - stim_dur=30.0) +design_matrix = make_first_level_design_matrix( + raw_intensity, + # Ignore drift model for now, see section below + drift_model="polynomial", + drift_order=0, + # Here we specify the HRF and duration + hrf_model="glover", + stim_dur=30.0, +) fig, ax1 = plt.subplots(figsize=(10, 6), nrows=1, ncols=1) fig = plot_design_matrix(design_matrix, ax=ax1) @@ -343,9 +346,9 @@ def plot_regressor(onset, amplitude, duration, hrf_model): # You can observe that with increasing polynomial order, # higher frequency components will be regressed from the signal. -design_matrix = make_first_level_design_matrix(raw_intensity, - drift_model='polynomial', - drift_order=5) +design_matrix = make_first_level_design_matrix( + raw_intensity, drift_model="polynomial", drift_order=5 +) fig, ax1 = plt.subplots(figsize=(10, 6), nrows=1, ncols=1) fig = plot_design_matrix(design_matrix, ax=ax1) @@ -362,9 +365,9 @@ def plot_regressor(onset, amplitude, duration, hrf_model): # In the example below we demonstrate how to regress our signals up to 0.01 Hz. # We observe that the function has included 6 drift regressors in the design matrix. -design_matrix = make_first_level_design_matrix(raw_intensity, - drift_model='cosine', - high_pass=0.01) +design_matrix = make_first_level_design_matrix( + raw_intensity, drift_model="cosine", high_pass=0.01 +) fig, ax1 = plt.subplots(figsize=(10, 6), nrows=1, ncols=1) fig = plot_design_matrix(design_matrix, ax=ax1) @@ -376,9 +379,9 @@ def plot_regressor(onset, amplitude, duration, hrf_model): # higher frequency components. So we can increase the high pass cut off and # this should add more regressors. -design_matrix = make_first_level_design_matrix(raw_intensity, - drift_model='cosine', - high_pass=0.03) +design_matrix = make_first_level_design_matrix( + raw_intensity, drift_model="cosine", high_pass=0.03 +) fig, ax1 = plt.subplots(figsize=(10, 6), nrows=1, ncols=1) fig = plot_design_matrix(design_matrix, ax=ax1) @@ -395,17 +398,18 @@ def plot_regressor(onset, amplitude, duration, hrf_model): # the high pass cut off can be set on first principles. # # The Nilearn documentation states that -# "The cutoff period (1/high_pass) should be set as the longest period between two trials of the same condition multiplied by 2. -# For instance, if the longest period is 32s, the high_pass frequency shall be 1/64 Hz ~ 0.016 Hz." +# "The cutoff period (1/high_pass) should be set as the longest period between two +# trials of the same condition multiplied by 2. For instance, if the longest period is +# 32s, the high_pass frequency shall be 1/64 Hz ~ 0.016 Hz." # `(reference) `__. # -# To assist in selecting a high pass value a few convenience functions are included in MNE-NIRS. -# First we can query what the longest ISI is per annotation, but first we must be sure -# to remove annotations we aren't interested in (in this experiment the trigger +# To assist in selecting a high pass value a few convenience functions are included in +# MNE-NIRS. First we can query what the longest ISI is per annotation, but first we must +# be sure to remove annotations we aren't interested in (in this experiment the trigger # 15 is not of interest). raw_original = mne.io.read_raw_nirx(fnirs_raw_dir) -raw_original.annotations.delete(raw_original.annotations.description == '15.0') +raw_original.annotations.delete(raw_original.annotations.description == "15.0") isis, names = longest_inter_annotation_interval(raw_original) print(isis) @@ -430,8 +434,8 @@ def plot_regressor(onset, amplitude, duration, hrf_model): # sense to include them as a single condition when computing the ISI. # This would be achieved by renaming the triggers. -raw_original.annotations.rename({'2.0': 'Tapping', '3.0': 'Tapping'}) -raw_original.annotations.delete(raw_original.annotations.description == '1.0') +raw_original.annotations.rename({"2.0": "Tapping", "3.0": "Tapping"}) +raw_original.annotations.delete(raw_original.annotations.description == "1.0") isis, names = longest_inter_annotation_interval(raw_original) print(isis) print(drift_high_pass(raw_original)) diff --git a/examples/general/plot_16_waveform_group.py b/examples/general/plot_16_waveform_group.py index 96e761cad..a04790034 100644 --- a/examples/general/plot_16_waveform_group.py +++ b/examples/general/plot_16_waveform_group.py @@ -64,10 +64,6 @@ Participants tapped their thumb to their fingers for the entire 5 sec. Conditions were presented in a random order with a randomised inter- stimulus interval. - -.. contents:: Page contents - :local: - :depth: 2 """ # sphinx_gallery_thumbnail_number = 2 @@ -76,33 +72,36 @@ # License: BSD (3-clause) # Import common libraries -import pandas as pd -from itertools import compress from collections import defaultdict from copy import deepcopy +from itertools import compress from pprint import pprint -# Import MNE processing -from mne.viz import plot_compare_evokeds +# Import Plotting Library +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns + +# Import StatsModels +import statsmodels.formula.api as smf from mne import Epochs, events_from_annotations, set_log_level +from mne.preprocessing.nirs import ( + beer_lambert_law, + optical_density, + scalp_coupling_index, + temporal_derivative_distribution_repair, +) -# Import MNE-NIRS processing -from mne_nirs.channels import get_long_channels -from mne_nirs.channels import picks_pair_to_idx -from mne_nirs.datasets import fnirs_motor_group -from mne.preprocessing.nirs import beer_lambert_law, optical_density,\ - temporal_derivative_distribution_repair, scalp_coupling_index -from mne_nirs.signal_enhancement import enhance_negative_correlation +# Import MNE processing +from mne.viz import plot_compare_evokeds # Import MNE-BIDS processing from mne_bids import BIDSPath, read_raw_bids -# Import StatsModels -import statsmodels.formula.api as smf - -# Import Plotting Library -import matplotlib.pyplot as plt -import seaborn as sns +# Import MNE-NIRS processing +from mne_nirs.channels import get_long_channels, picks_pair_to_idx +from mne_nirs.datasets import fnirs_motor_group +from mne_nirs.signal_enhancement import enhance_negative_correlation # Set general parameters set_log_level("WARNING") # Don't show info, as it is repetitive for many subjects @@ -124,8 +123,8 @@ # and :ref:`artifact correction tutorial `. # As such, this example will skim over the individual-level details. -def individual_analysis(bids_path): +def individual_analysis(bids_path): # Read data with annotations in BIDS format raw_intensity = read_raw_bids(bids_path=bids_path, verbose=False) raw_intensity = get_long_channels(raw_intensity, min_dist=0.01) @@ -142,20 +141,31 @@ def individual_analysis(bids_path): # Convert to haemoglobin and filter raw_haemo = beer_lambert_law(raw_od, ppf=0.1) - raw_haemo = raw_haemo.filter(0.02, 0.3, - h_trans_bandwidth=0.1, l_trans_bandwidth=0.01, - verbose=False) + raw_haemo = raw_haemo.filter( + 0.02, 0.3, h_trans_bandwidth=0.1, l_trans_bandwidth=0.01, verbose=False + ) # Apply further data cleaning techniques and extract epochs raw_haemo = enhance_negative_correlation(raw_haemo) # Extract events but ignore those with # the word Ends (i.e. drop ExperimentEnds events) - events, event_dict = events_from_annotations(raw_haemo, verbose=False, - regexp='^(?![Ends]).*$') - epochs = Epochs(raw_haemo, events, event_id=event_dict, tmin=-5, tmax=20, - reject=dict(hbo=200e-6), reject_by_annotation=True, - proj=True, baseline=(None, 0), detrend=0, - preload=True, verbose=False) + events, event_dict = events_from_annotations( + raw_haemo, verbose=False, regexp="^(?![Ends]).*$" + ) + epochs = Epochs( + raw_haemo, + events, + event_id=event_dict, + tmin=-5, + tmax=20, + reject=dict(hbo=200e-6), + reject_by_annotation=True, + proj=True, + baseline=(None, 0), + detrend=0, + preload=True, + verbose=False, + ) return raw_haemo, epochs @@ -174,11 +184,15 @@ def individual_analysis(bids_path): all_evokeds = defaultdict(list) for sub in range(1, 6): # Loop from first to fifth subject - # Create path to file based on experiment info - bids_path = BIDSPath(subject="%02d" % sub, task="tapping", datatype="nirs", - root=fnirs_motor_group.data_path(), suffix="nirs", - extension=".snirf") + bids_path = BIDSPath( + subject="%02d" % sub, + task="tapping", + datatype="nirs", + root=fnirs_motor_group.data_path(), + suffix="nirs", + extension=".snirf", + ) # Analyse data and return both ROI and channel results raw_haemo, epochs = individual_analysis(bids_path) @@ -209,13 +223,21 @@ def individual_analysis(bids_path): fig, axes = plt.subplots(nrows=1, ncols=len(all_evokeds), figsize=(17, 5)) lims = dict(hbo=[-5, 12], hbr=[-5, 12]) -for (pick, color) in zip(['hbo', 'hbr'], ['r', 'b']): +for pick, color in zip(["hbo", "hbr"], ["r", "b"]): for idx, evoked in enumerate(all_evokeds): - plot_compare_evokeds({evoked: all_evokeds[evoked]}, combine='mean', - picks=pick, axes=axes[idx], show=False, - colors=[color], legend=False, ylim=lims, ci=0.95, - show_sensors=idx == 2) - axes[idx].set_title('{}'.format(evoked)) + plot_compare_evokeds( + {evoked: all_evokeds[evoked]}, + combine="mean", + picks=pick, + axes=axes[idx], + show=False, + colors=[color], + legend=False, + ylim=lims, + ci=0.95, + show_sensors=idx == 2, + ) + axes[idx].set_title(f"{evoked}") axes[0].legend(["Oxyhaemoglobin", "Deoxyhaemoglobin"]) # %% @@ -255,8 +277,10 @@ def individual_analysis(bids_path): right = [[8, 7], [5, 7], [7, 7], [5, 6], [6, 7], [5, 5]] # Then generate the correct indices for each pair and store in dictionary -rois = dict(Left_Hemisphere=picks_pair_to_idx(raw_haemo, left), - Right_Hemisphere=picks_pair_to_idx(raw_haemo, right)) +rois = dict( + Left_Hemisphere=picks_pair_to_idx(raw_haemo, left), + Right_Hemisphere=picks_pair_to_idx(raw_haemo, right), +) pprint(rois) @@ -269,22 +293,29 @@ def individual_analysis(bids_path): # regions of the brain for each condition. # Specify the figure size and limits per chromophore. -fig, axes = plt.subplots(nrows=len(rois), ncols=len(all_evokeds), - figsize=(15, 6)) +fig, axes = plt.subplots(nrows=len(rois), ncols=len(all_evokeds), figsize=(15, 6)) lims = dict(hbo=[-8, 16], hbr=[-8, 16]) -for (pick, color) in zip(['hbo', 'hbr'], ['r', 'b']): +for pick, color in zip(["hbo", "hbr"], ["r", "b"]): for ridx, roi in enumerate(rois): for cidx, evoked in enumerate(all_evokeds): - if pick == 'hbr': + if pick == "hbr": picks = rois[roi][1::2] # Select only the hbr channels else: picks = rois[roi][0::2] # Select only the hbo channels - plot_compare_evokeds({evoked: all_evokeds[evoked]}, combine='mean', - picks=picks, axes=axes[ridx, cidx], - show=False, colors=[color], legend=False, - ylim=lims, ci=0.95, show_sensors=cidx == 2) + plot_compare_evokeds( + {evoked: all_evokeds[evoked]}, + combine="mean", + picks=picks, + axes=axes[ridx, cidx], + show=False, + colors=[color], + legend=False, + ylim=lims, + ci=0.95, + show_sensors=cidx == 2, + ) axes[ridx, cidx].set_title("") axes[0, cidx].set_title(f"{evoked}") axes[ridx, 0].set_ylabel(f"{roi}\nChromophore (ΔμMol)") @@ -309,7 +340,7 @@ def individual_analysis(bids_path): # to a csv file for easy analysis in any statistical analysis software. # We also demonstrate two example analyses on these values below. -df = pd.DataFrame(columns=['ID', 'ROI', 'Chroma', 'Condition', 'Value']) +df = pd.DataFrame(columns=["ID", "ROI", "Chroma", "Condition", "Value"]) for idx, evoked in enumerate(all_evokeds): subj_id = 0 @@ -322,11 +353,18 @@ def individual_analysis(bids_path): # Append metadata and extracted feature to dataframe this_df = pd.DataFrame( - {'ID': subj_id, 'ROI': roi, 'Chroma': chroma, - 'Condition': evoked, 'Value': value}, index=[0]) + { + "ID": subj_id, + "ROI": roi, + "Chroma": chroma, + "Condition": evoked, + "Value": value, + }, + index=[0], + ) df = pd.concat([df, this_df], ignore_index=True) df.reset_index(inplace=True, drop=True) -df['Value'] = pd.to_numeric(df['Value']) # some Pandas have this as object +df["Value"] = pd.to_numeric(df["Value"]) # some Pandas have this as object # You can export the dataframe for analysis in your favorite stats program df.to_csv("stats-export.csv") @@ -349,12 +387,21 @@ def individual_analysis(bids_path): # For this reason, fNIRS is most appropriate for detecting changes within a # single ROI between conditions. -sns.catplot(x="Condition", y="Value", hue="ID", data=df.query("Chroma == 'hbo'"), errorbar=None, palette="muted", height=4, s=10) +sns.catplot( + x="Condition", + y="Value", + hue="ID", + data=df.query("Chroma == 'hbo'"), + errorbar=None, + palette="muted", + height=4, + s=10, +) # %% # Research question 1: Comparison of conditions -# --------------------------------------------------------------------------------------------------- +# --------------------------------------------- # # In this example question we ask: is the HbO response in the # left ROI to tapping with the right hand larger @@ -366,8 +413,7 @@ def individual_analysis(bids_path): input_data = input_data.query("Chroma in ['hbo']") input_data = input_data.query("ROI in ['Left_Hemisphere']") -roi_model = smf.mixedlm("Value ~ Condition", input_data, - groups=input_data["ID"]).fit() +roi_model = smf.mixedlm("Value ~ Condition", input_data, groups=input_data["ID"]).fit() roi_model.summary() # %% @@ -390,22 +436,28 @@ def individual_analysis(bids_path): # Encode the ROIs as ipsi- or contralateral to the hand that is tapping. df["Hemisphere"] = "Unknown" -df.loc[(df["Condition"] == "Tapping/Right") & - (df["ROI"] == "Right_Hemisphere"), "Hemisphere"] = "Ipsilateral" -df.loc[(df["Condition"] == "Tapping/Right") & - (df["ROI"] == "Left_Hemisphere"), "Hemisphere"] = "Contralateral" -df.loc[(df["Condition"] == "Tapping/Left") & - (df["ROI"] == "Left_Hemisphere"), "Hemisphere"] = "Ipsilateral" -df.loc[(df["Condition"] == "Tapping/Left") & - (df["ROI"] == "Right_Hemisphere"), "Hemisphere"] = "Contralateral" +df.loc[ + (df["Condition"] == "Tapping/Right") & (df["ROI"] == "Right_Hemisphere"), + "Hemisphere", +] = "Ipsilateral" +df.loc[ + (df["Condition"] == "Tapping/Right") & (df["ROI"] == "Left_Hemisphere"), + "Hemisphere", +] = "Contralateral" +df.loc[ + (df["Condition"] == "Tapping/Left") & (df["ROI"] == "Left_Hemisphere"), "Hemisphere" +] = "Ipsilateral" +df.loc[ + (df["Condition"] == "Tapping/Left") & (df["ROI"] == "Right_Hemisphere"), + "Hemisphere", +] = "Contralateral" # Subset the data for example model input_data = df.query("Condition in ['Tapping/Right', 'Tapping/Left']") input_data = input_data.query("Chroma in ['hbo']") assert len(input_data) -roi_model = smf.mixedlm("Value ~ Hemisphere", input_data, - groups=input_data["ID"]).fit() +roi_model = smf.mixedlm("Value ~ Hemisphere", input_data, groups=input_data["ID"]).fit() roi_model.summary() # %% diff --git a/examples/general/plot_19_snirf.py b/examples/general/plot_19_snirf.py index 14b9e0e94..90e5260d6 100644 --- a/examples/general/plot_19_snirf.py +++ b/examples/general/plot_19_snirf.py @@ -17,15 +17,10 @@ MNE Python and MNE-NIRS can be used to read and write SNIRF files respectively. In this tutorial we demonstrate how to convert your MNE data to -SNIRF and write it to disk and also how to read SNIRF files. We also demonstrate how to validate -that a SNIRF file conforms to the SNIRF specification. +SNIRF and write it to disk and also how to read SNIRF files. We also demonstrate how to +validate that a SNIRF file conforms to the SNIRF specification. You can read the SNIRF protocol at the official site https://github.com/fNIRS/snirf. - -.. contents:: Page contents - :local: - :depth: 2 - """ @@ -34,14 +29,14 @@ # License: BSD (3-clause) import os + import mne import snirf - from mne.io import read_raw_nirx, read_raw_snirf -from mne.preprocessing.nirs import optical_density, beer_lambert_law -from mne_nirs.io import write_raw_snirf +from mne.preprocessing.nirs import beer_lambert_law, optical_density from numpy.testing import assert_allclose +from mne_nirs.io import write_raw_snirf # %% # Import raw NIRS data from vendor @@ -51,7 +46,7 @@ fnirs_data_folder = mne.datasets.fnirs_motor.data_path() -fnirs_raw_dir = os.path.join(fnirs_data_folder, 'Participant-1') +fnirs_raw_dir = os.path.join(fnirs_data_folder, "Participant-1") raw_intensity = read_raw_nirx(fnirs_raw_dir).load_data() @@ -61,22 +56,22 @@ # # Now we can write this data back to disk in the SNIRF format. -write_raw_snirf(raw_intensity, 'test_raw.snirf') +write_raw_snirf(raw_intensity, "test_raw.snirf") # %% # Read back SNIRF file # -------------------- -# +# # Next we can read back the snirf file. -snirf_intensity = read_raw_snirf('test_raw.snirf') +snirf_intensity = read_raw_snirf("test_raw.snirf") # %% # Compare files # ------------- -# +# # Finally we can compare the data of the original to the SNIRF format and # ensure that the values are the same. @@ -95,7 +90,7 @@ # https://github.com/BUNPC/pysnirf2. Below we demonstrate that the files created # by MNE-NIRS are compliant with the specification. -result = snirf.validateSnirf('test_raw.snirf') +result = snirf.validateSnirf("test_raw.snirf") assert result.is_valid() result.display() @@ -107,9 +102,9 @@ # MNE-NIRS cal also be used to write optical density data to SNIRF files. raw_od = optical_density(raw_intensity) -write_raw_snirf(raw_od, 'test_raw_od.snirf') +write_raw_snirf(raw_od, "test_raw_od.snirf") -result = snirf.validateSnirf('test_raw_od.snirf') +result = snirf.validateSnirf("test_raw_od.snirf") assert result.is_valid() result.display() @@ -121,8 +116,8 @@ # And it can write valid haemoglobin data to SNIRF files. raw_hb = beer_lambert_law(raw_od) -write_raw_snirf(raw_hb, 'test_raw_hb.snirf') +write_raw_snirf(raw_hb, "test_raw_hb.snirf") -result = snirf.validateSnirf('test_raw_hb.snirf') +result = snirf.validateSnirf("test_raw_hb.snirf") assert result.is_valid() result.display() diff --git a/examples/general/plot_20_enhance.py b/examples/general/plot_20_enhance.py index 6ea43af71..985bd6817 100644 --- a/examples/general/plot_20_enhance.py +++ b/examples/general/plot_20_enhance.py @@ -8,11 +8,6 @@ techniques on functional near-infrared spectroscopy (fNIRS) data. - -.. contents:: Page contents - :local: - :depth: 2 - """ @@ -23,12 +18,11 @@ import os import matplotlib.pyplot as plt - import mne + import mne_nirs from mne_nirs.channels import picks_pair_to_idx - # %% # Import and preprocess data # -------------------------- @@ -39,18 +33,17 @@ # for more details. fnirs_data_folder = mne.datasets.fnirs_motor.data_path() -fnirs_raw_dir = os.path.join(fnirs_data_folder, 'Participant-1') +fnirs_raw_dir = os.path.join(fnirs_data_folder, "Participant-1") raw_intensity = mne.io.read_raw_nirx(fnirs_raw_dir, verbose=True).load_data() raw_od = mne.preprocessing.nirs.optical_density(raw_intensity) raw_haemo = mne.preprocessing.nirs.beer_lambert_law(raw_od, ppf=0.1) raw_haemo = mne_nirs.channels.get_long_channels(raw_haemo) -raw_haemo = raw_haemo.filter(0.05, 0.7, h_trans_bandwidth=0.2, - l_trans_bandwidth=0.02) -events, _ = mne.events_from_annotations(raw_haemo, event_id={'1.0': 1, - '2.0': 2, - '3.0': 3}) -event_dict = {'Control': 1, 'Tapping/Left': 2, 'Tapping/Right': 3} +raw_haemo = raw_haemo.filter(0.05, 0.7, h_trans_bandwidth=0.2, l_trans_bandwidth=0.02) +events, _ = mne.events_from_annotations( + raw_haemo, event_id={"1.0": 1, "2.0": 2, "3.0": 3} +) +event_dict = {"Control": 1, "Tapping/Left": 2, "Tapping/Right": 3} # %% @@ -63,16 +56,27 @@ reject_criteria = dict(hbo=100e-6) tmin, tmax = -5, 15 -epochs = mne.Epochs(raw_haemo, events, event_id=event_dict, - tmin=tmin, tmax=tmax, - reject=reject_criteria, reject_by_annotation=True, - proj=True, baseline=(None, 0), preload=True, - detrend=None, verbose=True) +epochs = mne.Epochs( + raw_haemo, + events, + event_id=event_dict, + tmin=tmin, + tmax=tmax, + reject=reject_criteria, + reject_by_annotation=True, + proj=True, + baseline=(None, 0), + preload=True, + detrend=None, + verbose=True, +) -evoked_dict = {'Tapping/HbO': epochs['Tapping'].average(picks='hbo'), - 'Tapping/HbR': epochs['Tapping'].average(picks='hbr'), - 'Control/HbO': epochs['Control'].average(picks='hbo'), - 'Control/HbR': epochs['Control'].average(picks='hbr')} +evoked_dict = { + "Tapping/HbO": epochs["Tapping"].average(picks="hbo"), + "Tapping/HbR": epochs["Tapping"].average(picks="hbr"), + "Control/HbO": epochs["Control"].average(picks="hbo"), + "Control/HbR": epochs["Control"].average(picks="hbr"), +} # Rename channels until the encoding of frequency in ch_name is fixed for condition in evoked_dict: @@ -87,16 +91,27 @@ raw_anti = mne_nirs.signal_enhancement.enhance_negative_correlation(raw_haemo) -epochs_anti = mne.Epochs(raw_anti, events, event_id=event_dict, - tmin=tmin, tmax=tmax, - reject=reject_criteria, reject_by_annotation=True, - proj=True, baseline=(None, 0), preload=True, - detrend=None, verbose=True) +epochs_anti = mne.Epochs( + raw_anti, + events, + event_id=event_dict, + tmin=tmin, + tmax=tmax, + reject=reject_criteria, + reject_by_annotation=True, + proj=True, + baseline=(None, 0), + preload=True, + detrend=None, + verbose=True, +) -evoked_dict_anti = {'Tapping/HbO': epochs_anti['Tapping'].average(picks='hbo'), - 'Tapping/HbR': epochs_anti['Tapping'].average(picks='hbr'), - 'Control/HbO': epochs_anti['Control'].average(picks='hbo'), - 'Control/HbR': epochs_anti['Control'].average(picks='hbr')} +evoked_dict_anti = { + "Tapping/HbO": epochs_anti["Tapping"].average(picks="hbo"), + "Tapping/HbR": epochs_anti["Tapping"].average(picks="hbr"), + "Control/HbO": epochs_anti["Control"].average(picks="hbo"), + "Control/HbR": epochs_anti["Control"].average(picks="hbr"), +} # Rename channels until the encoding of frequency in ch_name is fixed for condition in evoked_dict_anti: @@ -113,20 +128,29 @@ raw_haemo = mne.preprocessing.nirs.beer_lambert_law(od_corrected, ppf=0.1) raw_haemo = mne_nirs.channels.get_long_channels(raw_haemo) -raw_haemo = raw_haemo.filter(0.05, 0.7, h_trans_bandwidth=0.2, - l_trans_bandwidth=0.02) - -epochs_corr = mne.Epochs(raw_haemo, events, event_id=event_dict, - tmin=tmin, tmax=tmax, - reject=reject_criteria, reject_by_annotation=True, - proj=True, baseline=(None, 0), preload=True, - detrend=None, verbose=True) +raw_haemo = raw_haemo.filter(0.05, 0.7, h_trans_bandwidth=0.2, l_trans_bandwidth=0.02) + +epochs_corr = mne.Epochs( + raw_haemo, + events, + event_id=event_dict, + tmin=tmin, + tmax=tmax, + reject=reject_criteria, + reject_by_annotation=True, + proj=True, + baseline=(None, 0), + preload=True, + detrend=None, + verbose=True, +) evoked_dict_corr = { - 'Tapping/HbO': epochs_corr['Tapping'].average(picks='hbo'), - 'Tapping/HbR': epochs_corr['Tapping'].average(picks='hbr'), - 'Control/HbO': epochs_corr['Control'].average(picks='hbo'), - 'Control/HbR': epochs_corr['Control'].average(picks='hbr')} + "Tapping/HbO": epochs_corr["Tapping"].average(picks="hbo"), + "Tapping/HbR": epochs_corr["Tapping"].average(picks="hbr"), + "Control/HbO": epochs_corr["Control"].average(picks="hbo"), + "Control/HbR": epochs_corr["Control"].average(picks="hbr"), +} # Rename channels until the encoding of frequency in ch_name is fixed for condition in evoked_dict_corr: @@ -141,28 +165,43 @@ fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 6)) -color_dict = dict(HbO='#AA3377', HbR='b') -styles_dict = dict(Control=dict(linestyle='dashed')) - -mne.viz.plot_compare_evokeds(evoked_dict, combine="mean", ci=0.95, - axes=axes[0], colors=color_dict, - styles=styles_dict, - ylim=dict(hbo=[-10, 15])) - -mne.viz.plot_compare_evokeds(evoked_dict_anti, combine="mean", ci=0.95, - axes=axes[1], colors=color_dict, - styles=styles_dict, - ylim=dict(hbo=[-10, 15])) - -mne.viz.plot_compare_evokeds(evoked_dict_corr, combine="mean", ci=0.95, - axes=axes[2], colors=color_dict, - styles=styles_dict, - ylim=dict(hbo=[-10, 15])) - -for column, condition in enumerate(['Original Data', - 'With Enhanced Anticorrelation', - 'With Short Regression']): - axes[column].set_title('{}'.format(condition)) +color_dict = dict(HbO="#AA3377", HbR="b") +styles_dict = dict(Control=dict(linestyle="dashed")) + +mne.viz.plot_compare_evokeds( + evoked_dict, + combine="mean", + ci=0.95, + axes=axes[0], + colors=color_dict, + styles=styles_dict, + ylim=dict(hbo=[-10, 15]), +) + +mne.viz.plot_compare_evokeds( + evoked_dict_anti, + combine="mean", + ci=0.95, + axes=axes[1], + colors=color_dict, + styles=styles_dict, + ylim=dict(hbo=[-10, 15]), +) + +mne.viz.plot_compare_evokeds( + evoked_dict_corr, + combine="mean", + ci=0.95, + axes=axes[2], + colors=color_dict, + styles=styles_dict, + ylim=dict(hbo=[-10, 15]), +) + +for column, condition in enumerate( + ["Original Data", "With Enhanced Anticorrelation", "With Short Regression"] +): + axes[column].set_title(f"{condition}") # %% @@ -175,80 +214,112 @@ left = [[1, 3], [2, 3], [1, 2], [4, 3]] right = [[5, 7], [6, 7], [5, 6], [8, 7]] -groups = dict(Left_ROI=picks_pair_to_idx(raw_anti.pick(picks='hbo'), left, - on_missing='warning'), - Right_ROI=picks_pair_to_idx(raw_anti.pick(picks='hbo'), right, - on_missing='warning')) +groups = dict( + Left_ROI=picks_pair_to_idx(raw_anti.pick(picks="hbo"), left, on_missing="warning"), + Right_ROI=picks_pair_to_idx( + raw_anti.pick(picks="hbo"), right, on_missing="warning" + ), +) evoked_dict = { - 'Left/HbO': epochs['Tapping/Left'].average(picks='hbo'), - 'Left/HbR': epochs['Tapping/Left'].average(picks='hbr'), - 'Right/HbO': epochs['Tapping/Right'].average(picks='hbo'), - 'Right/HbR': epochs['Tapping/Right'].average(picks='hbr')} + "Left/HbO": epochs["Tapping/Left"].average(picks="hbo"), + "Left/HbR": epochs["Tapping/Left"].average(picks="hbr"), + "Right/HbO": epochs["Tapping/Right"].average(picks="hbo"), + "Right/HbR": epochs["Tapping/Right"].average(picks="hbr"), +} for condition in evoked_dict: evoked_dict[condition].rename_channels(lambda x: x[:-4]) evoked_dict_anti = { - 'Left/HbO': epochs_anti['Tapping/Left'].average(picks='hbo'), - 'Left/HbR': epochs_anti['Tapping/Left'].average(picks='hbr'), - 'Right/HbO': epochs_anti['Tapping/Right'].average(picks='hbo'), - 'Right/HbR': epochs_anti['Tapping/Right'].average(picks='hbr')} + "Left/HbO": epochs_anti["Tapping/Left"].average(picks="hbo"), + "Left/HbR": epochs_anti["Tapping/Left"].average(picks="hbr"), + "Right/HbO": epochs_anti["Tapping/Right"].average(picks="hbo"), + "Right/HbR": epochs_anti["Tapping/Right"].average(picks="hbr"), +} for condition in evoked_dict_anti: evoked_dict_anti[condition].rename_channels(lambda x: x[:-4]) evoked_dict_corr = { - 'Left/HbO': epochs_corr['Tapping/Left'].average(picks='hbo'), - 'Left/HbR': epochs_corr['Tapping/Left'].average(picks='hbr'), - 'Right/HbO': epochs_corr['Tapping/Right'].average(picks='hbo'), - 'Right/HbR': epochs_corr['Tapping/Right'].average(picks='hbr')} + "Left/HbO": epochs_corr["Tapping/Left"].average(picks="hbo"), + "Left/HbR": epochs_corr["Tapping/Left"].average(picks="hbr"), + "Right/HbO": epochs_corr["Tapping/Right"].average(picks="hbo"), + "Right/HbR": epochs_corr["Tapping/Right"].average(picks="hbr"), +} for condition in evoked_dict_corr: evoked_dict_corr[condition].rename_channels(lambda x: x[:-4]) -color_dict = dict(HbO='#AA3377', HbR='b') -styles_dict = dict(Left=dict(linestyle='dashed')) +color_dict = dict(HbO="#AA3377", HbR="b") +styles_dict = dict(Left=dict(linestyle="dashed")) fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(15, 16)) -mne.viz.plot_compare_evokeds(evoked_dict, combine="mean", ci=0.95, - picks=groups['Left_ROI'], - axes=axes[0, 0], colors=color_dict, - styles=styles_dict, - ylim=dict(hbo=[-10, 15])) - -mne.viz.plot_compare_evokeds(evoked_dict, combine="mean", ci=0.95, - picks=groups['Right_ROI'], - axes=axes[0, 1], colors=color_dict, - styles=styles_dict, - ylim=dict(hbo=[-10, 15])) - -mne.viz.plot_compare_evokeds(evoked_dict_anti, combine="mean", ci=0.95, - picks=groups['Left_ROI'], - axes=axes[1, 0], colors=color_dict, - styles=styles_dict, - ylim=dict(hbo=[-10, 15])) - -mne.viz.plot_compare_evokeds(evoked_dict_anti, combine="mean", ci=0.95, - picks=groups['Right_ROI'], - axes=axes[1, 1], colors=color_dict, - styles=styles_dict, - ylim=dict(hbo=[-10, 15])) - -mne.viz.plot_compare_evokeds(evoked_dict_corr, combine="mean", ci=0.95, - picks=groups['Left_ROI'], - axes=axes[2, 0], colors=color_dict, - styles=styles_dict, - ylim=dict(hbo=[-10, 15])) - -mne.viz.plot_compare_evokeds(evoked_dict_corr, combine="mean", ci=0.95, - picks=groups['Right_ROI'], - axes=axes[2, 1], colors=color_dict, - styles=styles_dict, - ylim=dict(hbo=[-10, 15])) - -for row, condition in enumerate(['Original', - 'Anticorrelation', - 'Short Regression']): - for column, hemi in enumerate(['Left', 'Right']): - axes[row, column].set_title('{}: {}'.format(condition, hemi)) - +mne.viz.plot_compare_evokeds( + evoked_dict, + combine="mean", + ci=0.95, + picks=groups["Left_ROI"], + axes=axes[0, 0], + colors=color_dict, + styles=styles_dict, + ylim=dict(hbo=[-10, 15]), +) + +mne.viz.plot_compare_evokeds( + evoked_dict, + combine="mean", + ci=0.95, + picks=groups["Right_ROI"], + axes=axes[0, 1], + colors=color_dict, + styles=styles_dict, + ylim=dict(hbo=[-10, 15]), +) + +mne.viz.plot_compare_evokeds( + evoked_dict_anti, + combine="mean", + ci=0.95, + picks=groups["Left_ROI"], + axes=axes[1, 0], + colors=color_dict, + styles=styles_dict, + ylim=dict(hbo=[-10, 15]), +) + +mne.viz.plot_compare_evokeds( + evoked_dict_anti, + combine="mean", + ci=0.95, + picks=groups["Right_ROI"], + axes=axes[1, 1], + colors=color_dict, + styles=styles_dict, + ylim=dict(hbo=[-10, 15]), +) + +mne.viz.plot_compare_evokeds( + evoked_dict_corr, + combine="mean", + ci=0.95, + picks=groups["Left_ROI"], + axes=axes[2, 0], + colors=color_dict, + styles=styles_dict, + ylim=dict(hbo=[-10, 15]), +) + +mne.viz.plot_compare_evokeds( + evoked_dict_corr, + combine="mean", + ci=0.95, + picks=groups["Right_ROI"], + axes=axes[2, 1], + colors=color_dict, + styles=styles_dict, + ylim=dict(hbo=[-10, 15]), +) + +for row, condition in enumerate(["Original", "Anticorrelation", "Short Regression"]): + for column, hemi in enumerate(["Left", "Right"]): + axes[row, column].set_title(f"{condition}: {hemi}") diff --git a/examples/general/plot_21_artifacts.py b/examples/general/plot_21_artifacts.py index 0d50bbbdd..c9a46b5c7 100644 --- a/examples/general/plot_21_artifacts.py +++ b/examples/general/plot_21_artifacts.py @@ -15,10 +15,12 @@ # License: BSD (3-clause) import os -import mne -from mne.preprocessing.nirs import (optical_density, - temporal_derivative_distribution_repair) +import mne +from mne.preprocessing.nirs import ( + optical_density, + temporal_derivative_distribution_repair, +) # %% # Import data @@ -30,12 +32,13 @@ # and plot these signals. fnirs_data_folder = mne.datasets.fnirs_motor.data_path() -fnirs_cw_amplitude_dir = os.path.join(fnirs_data_folder, 'Participant-1') +fnirs_cw_amplitude_dir = os.path.join(fnirs_data_folder, "Participant-1") raw_intensity = mne.io.read_raw_nirx(fnirs_cw_amplitude_dir, verbose=True) raw_intensity.load_data().resample(3, npad="auto") raw_od = optical_density(raw_intensity) -new_annotations = mne.Annotations([31, 187, 317], [8, 8, 8], - ["Movement", "Movement", "Movement"]) +new_annotations = mne.Annotations( + [31, 187, 317], [8, 8, 8], ["Movement", "Movement", "Movement"] +) raw_od.set_annotations(new_annotations) raw_od.plot(n_channels=15, duration=400, show_scrollbars=False) @@ -60,10 +63,10 @@ corrupted_data = raw_od.get_data() corrupted_data[:, 298:302] = corrupted_data[:, 298:302] - 0.06 corrupted_data[:, 450:750] = corrupted_data[:, 450:750] + 0.03 -corrupted_od = mne.io.RawArray(corrupted_data, raw_od.info, - first_samp=raw_od.first_samp) -new_annotations.append([95, 145, 245], [10, 10, 10], - ["Spike", "Baseline", "Baseline"]) +corrupted_od = mne.io.RawArray( + corrupted_data, raw_od.info, first_samp=raw_od.first_samp +) +new_annotations.append([95, 145, 245], [10, 10, 10], ["Spike", "Baseline", "Baseline"]) corrupted_od.set_annotations(new_annotations) corrupted_od.plot(n_channels=15, duration=400, show_scrollbars=False) @@ -93,4 +96,4 @@ # Frank A Fishburn, Ruth S Ludlum, Chandan J Vaidya, and Andrei V Medvedev. # Temporal derivative distribution repair (tddr): a motion correction method # for fNIRS. NeuroImage, -# 184:171–179, 2019. doi:10.1016/j.neuroimage.2018.09.025. \ No newline at end of file +# 184:171–179, 2019. doi:10.1016/j.neuroimage.2018.09.025. diff --git a/examples/general/plot_22_quality.py b/examples/general/plot_22_quality.py index 7c19640bc..0b98eaf82 100644 --- a/examples/general/plot_22_quality.py +++ b/examples/general/plot_22_quality.py @@ -53,12 +53,13 @@ # License: BSD (3-clause) import os -import mne -import numpy as np from itertools import compress -import matplotlib.pyplot as plt +import matplotlib.pyplot as plt +import mne +import numpy as np from mne.preprocessing.nirs import optical_density + from mne_nirs.preprocessing import peak_power, scalp_coupling_index_windowed from mne_nirs.visualisation import plot_timechannel_quality_metric @@ -72,7 +73,7 @@ # We then convert the data to optical density and plot the raw signal. fnirs_data_folder = mne.datasets.fnirs_motor.data_path() -fnirs_cw_amplitude_dir = os.path.join(fnirs_data_folder, 'Participant-1') +fnirs_cw_amplitude_dir = os.path.join(fnirs_data_folder, "Participant-1") raw_intensity = mne.io.read_raw_nirx(fnirs_cw_amplitude_dir, verbose=True) raw_intensity.load_data().resample(4.0, npad="auto") raw_od = optical_density(raw_intensity) @@ -104,7 +105,7 @@ sci = mne.preprocessing.nirs.scalp_coupling_index(raw_od) fig, ax = plt.subplots() ax.hist(sci) -ax.set(xlabel='Scalp Coupling Index', ylabel='Count', xlim=[0, 1]) +ax.set(xlabel="Scalp Coupling Index", ylabel="Count", xlim=[0, 1]) # %% # We observe that most of the channels have a good SCI of 1, but a few channels @@ -115,8 +116,8 @@ # We then print a list of the bad channels and observe their are 10 channels # (five source-detector pairs) that are marked as bad. -raw_od.info['bads'] = list(compress(raw_od.ch_names, sci < 0.7)) -print(raw_od.info['bads']) +raw_od.info["bads"] = list(compress(raw_od.ch_names, sci < 0.7)) +print(raw_od.info["bads"]) # %% # We can plot the time course of the signal again and note that the bad @@ -143,7 +144,7 @@ sci = mne.preprocessing.nirs.scalp_coupling_index(raw_od.copy().crop(10)) fig, ax = plt.subplots() ax.hist(sci) -ax.set(xlabel='Scalp Coupling Index', ylabel='Count', xlim=[0, 1]) +ax.set(xlabel="Scalp Coupling Index", ylabel="Count", xlim=[0, 1]) # %% # SCI evaluated over moving window @@ -164,9 +165,13 @@ # define a window length that is appropriate for the experiment. _, scores, times = scalp_coupling_index_windowed(raw_od, time_window=60) -plot_timechannel_quality_metric(raw_od, scores, times, threshold=0.7, - title="Scalp Coupling Index " - "Quality Evaluation") +plot_timechannel_quality_metric( + raw_od, + scores, + times, + threshold=0.7, + title="Scalp Coupling Index " "Quality Evaluation", +) # %% # ********** @@ -180,8 +185,9 @@ # of the recording. raw_od, scores, times = peak_power(raw_od, time_window=10) -plot_timechannel_quality_metric(raw_od, scores, times, threshold=0.1, - title="Peak Power Quality Evaluation") +plot_timechannel_quality_metric( + raw_od, scores, times, threshold=0.1, title="Peak Power Quality Evaluation" +) # %% @@ -205,9 +211,13 @@ # Next we plot just these channels to demonstrate that indeed an artifact # has been added. -raw_od.copy().pick(picks=[12, 13, 34, 35]).\ - plot(n_channels=55, duration=40000, show_scrollbars=False, - clipping=None, scalings={'fnirs_od': 0.2}) +raw_od.copy().pick(picks=[12, 13, 34, 35]).plot( + n_channels=55, + duration=40000, + show_scrollbars=False, + clipping=None, + scalings={"fnirs_od": 0.2}, +) # %% @@ -246,9 +256,13 @@ # channels would be generated for S5-D13 (as the artifact was only present # on S2-D4). -raw_od.copy().pick(picks=[12, 13, 34, 35]).\ - plot(n_channels=55, duration=40000, show_scrollbars=False, - clipping=None, scalings={'fnirs_od': 0.2}) +raw_od.copy().pick(picks=[12, 13, 34, 35]).plot( + n_channels=55, + duration=40000, + show_scrollbars=False, + clipping=None, + scalings={"fnirs_od": 0.2}, +) # %% # These channel and time specific annotations are used by downstream diff --git a/examples/general/plot_30_frequency.py b/examples/general/plot_30_frequency.py index d06777a26..dcfa6ed30 100644 --- a/examples/general/plot_30_frequency.py +++ b/examples/general/plot_30_frequency.py @@ -12,11 +12,6 @@ on experimental design and our model of how the brain reacts to stimuli, the actual data measured during an experiment, and the filtering that is applied to the data. - -.. contents:: Page contents - :local: - :depth: 2 - """ # Authors: Robert Luke diff --git a/examples/general/plot_40_mayer.py b/examples/general/plot_40_mayer.py index db9257c15..4b0d1ad1e 100644 --- a/examples/general/plot_40_mayer.py +++ b/examples/general/plot_40_mayer.py @@ -33,31 +33,23 @@ You should read their excellent documentation. Their work should be considered the primary resource, and this is just an example of how to apply it to fNIRS data for the purpose of extracting Mayer waves oscillation parameters. - - -.. contents:: Page contents - :local: - :depth: 2 - -""" +""" # noqa: E501 # Authors: Robert Luke # # License: BSD (3-clause) import os + +import matplotlib.pyplot as plt import mne import numpy as np -import matplotlib.pyplot as plt - -from mne.preprocessing.nirs import optical_density, beer_lambert_law +from fooof import FOOOF +from mne.preprocessing.nirs import beer_lambert_law, optical_density from mne_nirs.channels import get_long_channels from mne_nirs.preprocessing import quantify_mayer_fooof -from fooof import FOOOF - - # %% # Import and preprocess data # -------------------------- @@ -65,7 +57,7 @@ # We read in the data and convert to haemoglobin concentration. fnirs_data_folder = mne.datasets.fnirs_motor.data_path() -fnirs_raw_dir = os.path.join(fnirs_data_folder, 'Participant-1') +fnirs_raw_dir = os.path.join(fnirs_data_folder, "Participant-1") raw = mne.io.read_raw_nirx(fnirs_raw_dir, verbose=True).load_data() raw = optical_density(raw) @@ -98,6 +90,7 @@ # The shaded area illustrates the oscillation fitted by the FOOOF algorithm, # it matches well to the data. + def scale_up_spectra(spectra, freqs): """ FOOOF requires the frequency values to be higher than the fNIRS data @@ -107,9 +100,9 @@ def scale_up_spectra(spectra, freqs): freqs = freqs * 10 return spectra, freqs + # Prepare data for FOOOF -psd = raw.compute_psd( - fmin=0.001, fmax=1.0, tmin=0, tmax=None, n_overlap=300, n_fft=600) +psd = raw.compute_psd(fmin=0.001, fmax=1.0, tmin=0, tmax=None, n_overlap=300, n_fft=600) spectra, freqs = psd.get_data(return_freqs=True) spectra, freqs = scale_up_spectra(spectra, freqs) @@ -121,7 +114,7 @@ def scale_up_spectra(spectra, freqs): fm.fit(freqs, np.mean(spectra, axis=0), freq_range) fig, axs = plt.subplots(1, 1, figsize=(10, 5)) -fm.plot(plot_peaks='shade', data_kwargs={'color': 'orange'}, ax=axs) +fm.plot(plot_peaks="shade", data_kwargs={"color": "orange"}, ax=axs) # Correct for x10 scaling above plt.xticks([0, 1, 2, 3, 4, 5, 6], [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]) @@ -149,8 +142,8 @@ def scale_up_spectra(spectra, freqs): # applying this analysis to fNIRS data with MNE-NIRS. # # An example measurement illustrated what the presence of a Mayer wave -# looks like with a power spectral density. The measurement also illustrated that the Mayer wave -# is not a perfect sinusoid, as evidenced by the broad spectral content. +# looks like with a power spectral density. The measurement also illustrated that the +# Mayer wave is not a perfect sinusoid, as evidenced by the broad spectral content. # Further, the example illustrated that the Mayer wave is not always precisely locked # to 0.1 Hz, both visual inspection and FOOOF quantification indicate a 0.09 Hz # centre frequency. diff --git a/examples/general/plot_50_decoding.py b/examples/general/plot_50_decoding.py index dc45e30a8..253160731 100644 --- a/examples/general/plot_50_decoding.py +++ b/examples/general/plot_50_decoding.py @@ -30,11 +30,6 @@ Simply modify the ``read_raw_`` function to match your data type. See :ref:`data importing tutorial ` to learn how to use your data with MNE-Python. - - -.. contents:: Page contents - :local: - :depth: 2 """ # Authors: Robert Luke @@ -43,28 +38,26 @@ # Import common libraries -import os import contextlib +import os + import numpy as np +from mne import Epochs, events_from_annotations +from mne.decoding import Scaler, Vectorizer, cross_val_multiscore -# Import sklearn processing -from sklearn.pipeline import make_pipeline +# Import MNE-Python processing +from mne.preprocessing.nirs import beer_lambert_law, optical_density + +# Import MNE-BIDS processing +from mne_bids import BIDSPath, get_entity_vals, read_raw_bids from sklearn.linear_model import LogisticRegression -# Import MNE-Python processing -from mne.preprocessing.nirs import optical_density, beer_lambert_law -from mne import Epochs, events_from_annotations -from mne.decoding import (Scaler, - cross_val_multiscore, - Vectorizer) +# Import sklearn processing +from sklearn.pipeline import make_pipeline # Import MNE-NIRS processing from mne_nirs.datasets.fnirs_motor_group import data_path -# Import MNE-BIDS processing -from mne_bids import BIDSPath, read_raw_bids, get_entity_vals - - # %% # Set up directories # ------------------ @@ -82,9 +75,10 @@ # In this example we use the example dataset ``audio_or_visual_speech``. root = data_path() -dataset = BIDSPath(root=root, suffix="nirs", extension=".snirf", - task="tapping", datatype="nirs") -subjects = get_entity_vals(root, 'subject') +dataset = BIDSPath( + root=root, suffix="nirs", extension=".snirf", task="tapping", datatype="nirs" +) +subjects = get_entity_vals(root, "subject") # %% @@ -100,21 +94,31 @@ def epoch_preprocessing(bids_path): - with open(os.devnull, "w") as f, contextlib.redirect_stdout(f): raw_intensity = read_raw_bids(bids_path=bids_path).load_data() raw_od = optical_density(raw_intensity) raw_od.resample(1.5) raw_haemo = beer_lambert_law(raw_od, ppf=6) - raw_haemo = raw_haemo.filter(None, 0.6, h_trans_bandwidth=0.05, - l_trans_bandwidth=0.01, verbose=False) + raw_haemo = raw_haemo.filter( + None, 0.6, h_trans_bandwidth=0.05, l_trans_bandwidth=0.01, verbose=False + ) events, event_dict = events_from_annotations(raw_haemo, verbose=False) - epochs = Epochs(raw_haemo, events, event_id=event_dict, tmin=-5, tmax=30, - reject=dict(hbo=100e-6), reject_by_annotation=True, - proj=True, baseline=(None, 0), detrend=1, - preload=True, verbose=False) + epochs = Epochs( + raw_haemo, + events, + event_id=event_dict, + tmin=-5, + tmax=30, + reject=dict(hbo=100e-6), + reject_by_annotation=True, + proj=True, + baseline=(None, 0), + detrend=1, + preload=True, + verbose=False, + ) epochs = epochs[["Tapping/Right", "Tapping/Left"]] return raw_haemo, epochs @@ -135,11 +139,9 @@ def epoch_preprocessing(bids_path): # This approach classifies the data within, rather than across, subjects. -for chroma in ['hbo', 'hbr']: - +for chroma in ["hbo", "hbr"]: st_scores = [] for sub in subjects: - bids_path = dataset.update(subject=sub) raw_haemo, epochs = epoch_preprocessing(bids_path) @@ -148,17 +150,20 @@ def epoch_preprocessing(bids_path): X = epochs.get_data() y = epochs.events[:, 2] - clf = make_pipeline(Scaler(epochs.info), - Vectorizer(), - LogisticRegression(solver='liblinear')) + clf = make_pipeline( + Scaler(epochs.info), Vectorizer(), LogisticRegression(solver="liblinear") + ) - scores = 100 * cross_val_multiscore(clf, X, y, - cv=5, n_jobs=1, scoring='roc_auc') + scores = 100 * cross_val_multiscore( + clf, X, y, cv=5, n_jobs=1, scoring="roc_auc" + ) st_scores.append(np.mean(scores, axis=0)) - print(f"Average spatio-temporal ROC-AUC performance ({chroma}) = " - f"{np.round(np.mean(st_scores))} % ({np.round(np.std(st_scores))})") + print( + f"Average spatio-temporal ROC-AUC performance ({chroma}) = " + f"{np.round(np.mean(st_scores))} % ({np.round(np.std(st_scores))})" + ) # %% diff --git a/examples/general/plot_60_aux_data.py b/examples/general/plot_60_aux_data.py index 2918df382..c42ceb7d0 100644 --- a/examples/general/plot_60_aux_data.py +++ b/examples/general/plot_60_aux_data.py @@ -15,11 +15,6 @@ example and refer readers to the detailed description above. Instead, we focus on extracting the auxiliary data and how this can be incorporated in to your analysis. - -.. contents:: Page contents - :local: - :depth: 2 - """ # sphinx_gallery_thumbnail_number = 2 @@ -27,20 +22,16 @@ # # License: BSD (3-clause) -import numpy as np import matplotlib.pyplot as plt -import pandas as pd - import mne +import numpy as np +import pandas as pd +from nilearn.plotting import plot_design_matrix +from mne_nirs.channels import get_long_channels, get_short_channels +from mne_nirs.datasets.snirf_with_aux import data_path from mne_nirs.experimental_design import make_first_level_design_matrix -from mne_nirs.channels import (get_long_channels, - get_short_channels) from mne_nirs.io.snirf import read_snirf_aux_data -from mne_nirs.datasets.snirf_with_aux import data_path - -from nilearn.plotting import plot_design_matrix - # %% # Import raw NIRS data @@ -62,10 +53,10 @@ # Then we crop the recording to the section containing our # experimental conditions. -raw_intensity.annotations.rename({'1': 'Control', - '2': 'Tapping_Left', - '3': 'Tapping_Right'}) -raw_intensity.annotations.delete(raw_intensity.annotations.description == '15') +raw_intensity.annotations.rename( + {"1": "Control", "2": "Tapping_Left", "3": "Tapping_Right"} +) +raw_intensity.annotations.delete(raw_intensity.annotations.description == "15") raw_intensity.annotations.set_durations(5) @@ -102,11 +93,13 @@ # The model consists of various components to model different things we assume # contribute to the measured signal. -design_matrix = make_first_level_design_matrix(raw_haemo, - drift_model='cosine', - high_pass=0.005, # Must be specified per experiment - hrf_model='spm', - stim_dur=5.0) +design_matrix = make_first_level_design_matrix( + raw_haemo, + drift_model="cosine", + high_pass=0.005, # Must be specified per experiment + hrf_model="spm", + stim_dur=5.0, +) # %% @@ -117,11 +110,13 @@ # related to each experimental condition # uncontaminated by systemic effects. -design_matrix["ShortHbO"] = np.mean(short_chs.copy().pick( - picks="hbo").get_data(), axis=0) +design_matrix["ShortHbO"] = np.mean( + short_chs.copy().pick(picks="hbo").get_data(), axis=0 +) -design_matrix["ShortHbR"] = np.mean(short_chs.copy().pick( - picks="hbr").get_data(), axis=0) +design_matrix["ShortHbR"] = np.mean( + short_chs.copy().pick(picks="hbr").get_data(), axis=0 +) # %% @@ -134,7 +129,6 @@ fig = plot_design_matrix(design_matrix, ax=ax1) - # %% # Load auxiliary data # ------------------- @@ -155,7 +149,7 @@ # And you can verify the data looks reasonable by plotting # individual fields. -plt.plot(raw_haemo.times, aux_df['HR']) +plt.plot(raw_haemo.times, aux_df["HR"]) plt.xlabel("Time (s)") plt.ylabel("Heart Rate (bpm)") diff --git a/examples/general/plot_70_visualise_brain.py b/examples/general/plot_70_visualise_brain.py index 6b4a9dce8..1f7314cfb 100644 --- a/examples/general/plot_70_visualise_brain.py +++ b/examples/general/plot_70_visualise_brain.py @@ -14,12 +14,6 @@ This tutorial glosses over the processing details, see the :ref:`GLM tutorial ` for details on the preprocessing. - -.. contents:: Page contents - :local: - :depth: 2 - - """ # sphinx_gallery_thumbnail_number = 5 @@ -27,24 +21,23 @@ # # License: BSD (3-clause) +import mne import numpy as np import pandas as pd - -import mne -from mne.preprocessing.nirs import optical_density, beer_lambert_law - import statsmodels.formula.api as smf +from mne.preprocessing.nirs import beer_lambert_law, optical_density +from mne_bids import BIDSPath, get_entity_vals, read_raw_bids -from mne_bids import BIDSPath, read_raw_bids, get_entity_vals import mne_nirs - -from mne_nirs.experimental_design import make_first_level_design_matrix -from mne_nirs.statistics import run_glm, statsmodels_to_results from mne_nirs.channels import get_long_channels, get_short_channels -from mne_nirs.io.fold import fold_landmark_specificity -from mne_nirs.visualisation import plot_nirs_source_detector, plot_glm_surface_projection from mne_nirs.datasets import fnirs_motor_group - +from mne_nirs.experimental_design import make_first_level_design_matrix +from mne_nirs.io.fold import fold_landmark_specificity +from mne_nirs.statistics import run_glm, statsmodels_to_results +from mne_nirs.visualisation import ( + plot_glm_surface_projection, + plot_nirs_source_detector, +) # %% # Download example data @@ -59,13 +52,19 @@ # Download the ``audio_or_visual_speech`` dataset and load the first measurement. root = mne_nirs.datasets.audio_or_visual_speech.data_path() -dataset = BIDSPath(root=root, suffix="nirs", extension=".snirf", subject="04", - task="AudioVisualBroadVsRestricted", datatype="nirs", session="01") +dataset = BIDSPath( + root=root, + suffix="nirs", + extension=".snirf", + subject="04", + task="AudioVisualBroadVsRestricted", + datatype="nirs", + session="01", +) raw = mne.io.read_raw_snirf(dataset.fpath) -raw.annotations.rename({'1.0': 'Audio', - '2.0': 'Video', - '3.0': 'Control', - '15.0': 'Ends'}) +raw.annotations.rename( + {"1.0": "Audio", "2.0": "Video", "3.0": "Control", "15.0": "Ends"} +) # %% # Download annotation information @@ -74,10 +73,14 @@ # Download the HCP-MMP parcellation. # Download anatomical locations -subjects_dir = str(mne.datasets.sample.data_path()) + '/subjects' +subjects_dir = str(mne.datasets.sample.data_path()) + "/subjects" mne.datasets.fetch_hcp_mmp_parcellation(subjects_dir=subjects_dir, accept=True) -labels = mne.read_labels_from_annot('fsaverage', 'HCPMMP1', 'lh', subjects_dir=subjects_dir) -labels_combined = mne.read_labels_from_annot('fsaverage', 'HCPMMP1_combined', 'lh', subjects_dir=subjects_dir) +labels = mne.read_labels_from_annot( + "fsaverage", "HCPMMP1", "lh", subjects_dir=subjects_dir +) +labels_combined = mne.read_labels_from_annot( + "fsaverage", "HCPMMP1_combined", "lh", subjects_dir=subjects_dir +) # %% @@ -92,8 +95,12 @@ # In this example we can see channels over the left inferior frontal gyrus, # auditory cortex, planum temporale, and occipital lobe. -brain = mne.viz.Brain('fsaverage', subjects_dir=subjects_dir, background='w', cortex='0.5') -brain.add_sensors(raw.info, trans='fsaverage', fnirs=['channels', 'pairs', 'sources', 'detectors']) +brain = mne.viz.Brain( + "fsaverage", subjects_dir=subjects_dir, background="w", cortex="0.5" +) +brain.add_sensors( + raw.info, trans="fsaverage", fnirs=["channels", "pairs", "sources", "detectors"] +) brain.show_view(azimuth=180, elevation=80, distance=450) # %% @@ -108,13 +115,14 @@ # specify which views to use to show each channel pair: view_map = { - 'left-lat': np.r_[np.arange(1, 27), 28], - 'caudal': np.r_[27, np.arange(43, 53)], - 'right-lat': np.r_[np.arange(29, 43), 44], + "left-lat": np.r_[np.arange(1, 27), 28], + "caudal": np.r_[27, np.arange(43, 53)], + "right-lat": np.r_[np.arange(29, 43), 44], } fig_montage = mne_nirs.visualisation.plot_3d_montage( - raw.info, view_map=view_map, subjects_dir=subjects_dir) + raw.info, view_map=view_map, subjects_dir=subjects_dir +) # %% # Plot sensor channels and anatomical region of interest @@ -126,11 +134,15 @@ # In this example we highlight the primary auditory cortex in blue, # and we can see that a number of channels are placed over this structure. -brain = mne.viz.Brain('fsaverage', subjects_dir=subjects_dir, background='w', cortex='0.5') -brain.add_sensors(raw.info, trans='fsaverage', fnirs=['channels', 'pairs', 'sources', 'detectors']) +brain = mne.viz.Brain( + "fsaverage", subjects_dir=subjects_dir, background="w", cortex="0.5" +) +brain.add_sensors( + raw.info, trans="fsaverage", fnirs=["channels", "pairs", "sources", "detectors"] +) -aud_label = [label for label in labels if label.name == 'L_A1_ROI-lh'][0] -brain.add_label(aud_label, borders=False, color='blue') +aud_label = [label for label in labels if label.name == "L_A1_ROI-lh"][0] +brain.add_label(aud_label, borders=False, color="blue") brain.show_view(azimuth=180, elevation=80, distance=450) @@ -145,7 +157,9 @@ # The tool is very intuitive and easy to use. # Be sure to cite the authors if you use their tool or data: # -# Morais, Guilherme Augusto Zimeo, Joana Bisol Balardin, and João Ricardo Sato. "fNIRS optodes’ location decider (fOLD): a toolbox for probe arrangement guided by brain regions-of-interest." Scientific reports 8.1 (2018): 1-11. +# Morais, Guilherme Augusto Zimeo, Joana Bisol Balardin, and João Ricardo Sato. +# "fNIRS optodes’ location decider (fOLD): a toolbox for probe arrangement guided by +# brain regions-of-interest." Scientific reports 8.1 (2018): 1-11. # # Rather than simply eye balling the sensor and ROIs of interest, we can # quantify the specificity of each channel to the anatomical region of interest @@ -156,33 +170,43 @@ # :ref:`this tutorial `. # Return specificity of each channel to the Left IFG -specificity = fold_landmark_specificity(raw, 'L IFG (p. Triangularis)') +specificity = fold_landmark_specificity(raw, "L IFG (p. Triangularis)") # Retain only channels with specificity to left IFG of greater than 50% raw_IFG = raw.copy().pick(picks=np.where(specificity > 50)[0]) -brain = mne.viz.Brain('fsaverage', subjects_dir=subjects_dir, background='w', cortex='0.5') -brain.add_sensors(raw_IFG.info, trans='fsaverage', fnirs=['channels', 'pairs']) +brain = mne.viz.Brain( + "fsaverage", subjects_dir=subjects_dir, background="w", cortex="0.5" +) +brain.add_sensors(raw_IFG.info, trans="fsaverage", fnirs=["channels", "pairs"]) -ifg_label = [label for label in labels_combined if label.name == 'Inferior Frontal Cortex-lh'][0] -brain.add_label(ifg_label, borders=False, color='green') +ifg_label = [ + label for label in labels_combined if label.name == "Inferior Frontal Cortex-lh" +][0] +brain.add_label(ifg_label, borders=False, color="green") brain.show_view(azimuth=140, elevation=95, distance=360) # %% # -# Alternatively, we can retain all channels and visualise the specificity of each channel the ROI -# by encoding the specificty in the color of the line between each source and detector. -# In this example we see that several channels have substantial specificity to +# Alternatively, we can retain all channels and visualise the specificity of each +# channel by encoding the specificty in the color of the line between each source and +# detector. In this example we see that several channels have substantial specificity to # the region of interest. # # Note: this function currently doesn't support the new MNE brain API, so does # not allow the same behaviour as above (adding sensors, highlighting ROIs etc). # It should be updated in the near future. -fig = plot_nirs_source_detector(specificity, raw.info, surfaces='brain', - subject='fsaverage', subjects_dir=subjects_dir, trans='fsaverage') +fig = plot_nirs_source_detector( + specificity, + raw.info, + surfaces="brain", + subject="fsaverage", + subjects_dir=subjects_dir, + trans="fsaverage", +) mne.viz.set_3d_view(fig, azimuth=140, elevation=95) @@ -207,7 +231,9 @@ sht_chans = get_short_channels(raw_haemo) raw_haemo = get_long_channels(raw_haemo) design_matrix = make_first_level_design_matrix(raw_haemo, stim_dur=13.0) -design_matrix["ShortHbO"] = np.mean(sht_chans.copy().pick(picks="hbo").get_data(), axis=0) +design_matrix["ShortHbO"] = np.mean( + sht_chans.copy().pick(picks="hbo").get_data(), axis=0 +) glm_est = run_glm(raw_haemo, design_matrix) # First we create a dictionary for each region of interest. @@ -218,8 +244,12 @@ rois["Visual_weighted"] = range(len(glm_est.ch_names)) # Next we compute the specificity for each channel to the auditory and visual cortex. -spec_aud = fold_landmark_specificity(raw_haemo, '42 - Primary and Auditory Association Cortex', atlas="Brodmann") -spec_vis = fold_landmark_specificity(raw_haemo, '17 - Primary Visual Cortex (V1)', atlas="Brodmann") +spec_aud = fold_landmark_specificity( + raw_haemo, "42 - Primary and Auditory Association Cortex", atlas="Brodmann" +) +spec_vis = fold_landmark_specificity( + raw_haemo, "17 - Primary Visual Cortex (V1)", atlas="Brodmann" +) # Next we create a dictionary to store the weights for each channel in the ROI. # The weights will be the specificity to the ROI. @@ -229,7 +259,9 @@ weights["Visual_weighted"] = spec_vis # Finally we compute region of interest results using the weights specified above -out = glm_est.to_dataframe_region_of_interest(rois, ["Video", "Control"], weighted=weights) +out = glm_est.to_dataframe_region_of_interest( + rois, ["Video", "Control"], weighted=weights +) out["Significant"] = out["p"] < 0.05 out @@ -252,12 +284,12 @@ def individual_analysis(bids_path, ID): - raw_intensity = read_raw_bids(bids_path=bids_path, verbose=False) - raw_intensity.annotations.delete(raw_intensity.annotations.description == '15.0') - # sanitize event names + raw_intensity.annotations.delete(raw_intensity.annotations.description == "15.0") + # sanitize event names raw_intensity.annotations.description[:] = [ - d.replace('/', '_') for d in raw_intensity.annotations.description] + d.replace("/", "_") for d in raw_intensity.annotations.description + ] # Convert signal to haemoglobin and resample raw_od = optical_density(raw_intensity) @@ -272,8 +304,12 @@ def individual_analysis(bids_path, ID): design_matrix = make_first_level_design_matrix(raw_haemo, stim_dur=5.0) # Append short channels mean to design matrix - design_matrix["ShortHbO"] = np.mean(sht_chans.copy().pick(picks="hbo").get_data(), axis=0) - design_matrix["ShortHbR"] = np.mean(sht_chans.copy().pick(picks="hbr").get_data(), axis=0) + design_matrix["ShortHbO"] = np.mean( + sht_chans.copy().pick(picks="hbo").get_data(), axis=0 + ) + design_matrix["ShortHbR"] = np.mean( + sht_chans.copy().pick(picks="hbr").get_data(), axis=0 + ) # Run GLM glm_est = run_glm(raw_haemo, design_matrix) @@ -285,20 +321,20 @@ def individual_analysis(bids_path, ID): cha["ID"] = ID # Convert to uM for nicer plotting below. - cha["theta"] = [t * 1.e6 for t in cha["theta"]] + cha["theta"] = [t * 1.0e6 for t in cha["theta"]] return raw_haemo, cha # Get dataset details root = fnirs_motor_group.data_path() -dataset = BIDSPath(root=root, task="tapping", - datatype="nirs", suffix="nirs", extension=".snirf") -subjects = get_entity_vals(root, 'subject') +dataset = BIDSPath( + root=root, task="tapping", datatype="nirs", suffix="nirs", extension=".snirf" +) +subjects = get_entity_vals(root, "subject") df_cha = pd.DataFrame() # To store channel level results for sub in subjects: # Loop from first to fifth subject - # Create path to file based on experiment info bids_path = dataset.update(subject=sub) @@ -311,32 +347,42 @@ def individual_analysis(bids_path, ID): ch_summary = df_cha.query("Condition in ['Tapping_Right']") assert len(ch_summary) ch_summary = ch_summary.query("Chroma in ['hbo']") -ch_model = smf.mixedlm("theta ~ -1 + ch_name", ch_summary, - groups=ch_summary["ID"]).fit(method='nm') +ch_model = smf.mixedlm("theta ~ -1 + ch_name", ch_summary, groups=ch_summary["ID"]).fit( + method="nm" +) model_df = statsmodels_to_results(ch_model, order=raw_haemo.copy().pick("hbo").ch_names) - # %% # Plot surface projection of GLM results # -------------------------------------- # -# Finally, we can project the GLM results from each channel to the nearest cortical surface -# and overlay the sensor positions and two different regions of interest. +# Finally, we can project the GLM results from each channel to the nearest cortical +# surface and overlay the sensor positions and two different regions of interest. # In this example we also highlight the premotor cortex and auditory association cortex # in green and blue respectively. # Plot the projection and sensor locations -brain = plot_glm_surface_projection(raw_haemo.copy().pick("hbo"), model_df, colorbar=True) -brain.add_sensors(raw_haemo.info, trans='fsaverage', fnirs=['channels', 'pairs', 'sources', 'detectors']) +brain = plot_glm_surface_projection( + raw_haemo.copy().pick("hbo"), model_df, colorbar=True +) +brain.add_sensors( + raw_haemo.info, + trans="fsaverage", + fnirs=["channels", "pairs", "sources", "detectors"], +) # mark the premotor cortex in green -aud_label = [label for label in labels_combined if label.name == 'Premotor Cortex-lh'][0] -brain.add_label(aud_label, borders=True, color='green') +aud_label = [label for label in labels_combined if label.name == "Premotor Cortex-lh"][ + 0 +] +brain.add_label(aud_label, borders=True, color="green") # mark the auditory association cortex in blue -aud_label = [label for label in labels_combined if label.name == 'Auditory Association Cortex-lh'][0] -brain.add_label(aud_label, borders=True, color='blue') +aud_label = [ + label for label in labels_combined if label.name == "Auditory Association Cortex-lh" +][0] +brain.add_label(aud_label, borders=True, color="blue") brain.show_view(azimuth=160, elevation=60, distance=400) diff --git a/examples/general/plot_80_save_read_glm.py b/examples/general/plot_80_save_read_glm.py index c3ae97b45..d0de7d024 100644 --- a/examples/general/plot_80_save_read_glm.py +++ b/examples/general/plot_80_save_read_glm.py @@ -35,31 +35,27 @@ Simply modify the ``read_raw_`` function to match your data type. See :ref:`data importing tutorial ` to learn how to use your data with MNE-Python. - -.. contents:: Page contents - :local: - :depth: 2 -""" +""" # noqa: E501 # Authors: Robert Luke # # License: BSD (3-clause) # Import common libraries from os.path import join + import pandas as pd # Import MNE functions -from mne.preprocessing.nirs import optical_density, beer_lambert_law - -# Import MNE-NIRS functions -from mne_nirs.statistics import run_glm -from mne_nirs.experimental_design import make_first_level_design_matrix -from mne_nirs.statistics import read_glm -from mne_nirs.datasets import fnirs_motor_group +from mne.preprocessing.nirs import beer_lambert_law, optical_density # Import MNE-BIDS processing -from mne_bids import BIDSPath, read_raw_bids, get_entity_vals +from mne_bids import BIDSPath, get_entity_vals, read_raw_bids +from mne_nirs.datasets import fnirs_motor_group +from mne_nirs.experimental_design import make_first_level_design_matrix + +# Import MNE-NIRS functions +from mne_nirs.statistics import read_glm, run_glm # %% # Set up directories @@ -87,7 +83,7 @@ # %% # For example we can automatically query the subjects, tasks, and sessions. -subjects = get_entity_vals(root, 'subject') +subjects = get_entity_vals(root, "subject") print(subjects) @@ -107,10 +103,9 @@ def individual_analysis(bids_path): - raw_intensity = read_raw_bids(bids_path=bids_path, verbose=False) - # Delete annotation labeled 15, as these just signify the start and end of experiment. - raw_intensity.annotations.delete(raw_intensity.annotations.description == '15.0') + # Delete annotation labeled 15, as these just signify the experiment start and end. + raw_intensity.annotations.delete(raw_intensity.annotations.description == "15.0") raw_intensity.pick(picks=range(20)).crop(200).resample(0.3) # Reduce load raw_haemo = beer_lambert_law(optical_density(raw_intensity), ppf=0.1) design_matrix = make_first_level_design_matrix(raw_haemo) @@ -128,12 +123,10 @@ def individual_analysis(bids_path): for sub in subjects: - # Create path to file based on experiment info - data_path = dataset.update(subject=sub, - datatype="nirs", - suffix="nirs", - extension=".snirf") + data_path = dataset.update( + subject=sub, datatype="nirs", suffix="nirs", extension=".snirf" + ) # Analyse data and glm results glm = individual_analysis(data_path) @@ -144,7 +137,11 @@ def individual_analysis(bids_path): save_path = dataset.copy().update( root=join(root, "derivatives"), - datatype="nirs", suffix="glm", extension=".h5",check=False) + datatype="nirs", + suffix="glm", + extension=".h5", + check=False, + ) # Ensure the folder exists, and make it if not. save_path.fpath.parent.mkdir(exist_ok=True, parents=True) @@ -162,7 +159,6 @@ def individual_analysis(bids_path): df = pd.DataFrame() for sub in subjects: - # Point to the correct subject save_path = save_path.update(subject=sub) diff --git a/examples/general/plot_99_bad.py b/examples/general/plot_99_bad.py index ce6206380..c405a64a1 100644 --- a/examples/general/plot_99_bad.py +++ b/examples/general/plot_99_bad.py @@ -32,12 +32,6 @@ However, at the midpoint we replace the real data with noise and demonstrate that without careful attention to the analysis parameter it would still appear as if a fNIRS response is observed. - - -.. contents:: Page contents - :local: - :depth: 2 - """ # sphinx_gallery_thumbnail_number = 7 diff --git a/examples/migration/plot_01_homer.py b/examples/migration/plot_01_homer.py index 734be9963..9a3baa863 100644 --- a/examples/migration/plot_01_homer.py +++ b/examples/migration/plot_01_homer.py @@ -36,11 +36,6 @@ tddr = hmrMotionCorrectTDDR(dod,SD,fs); ppf = [6 6]; dc = hmrOD2Conc(tddr,SD,ppf); - -.. contents:: Page contents - :local: - :depth: 2 - """ # %% @@ -54,12 +49,14 @@ # License: BSD (3-clause) import os -import mne +import mne from mne.io import read_raw_nirx -from mne.preprocessing.nirs import (optical_density, beer_lambert_law, - temporal_derivative_distribution_repair) - +from mne.preprocessing.nirs import ( + beer_lambert_law, + optical_density, + temporal_derivative_distribution_repair, +) # %% # Convert to optical density and motion correct @@ -70,7 +67,7 @@ # First we obtain the path to the data fnirs_data_folder = mne.datasets.fnirs_motor.data_path() -fnirs_raw_dir = os.path.join(fnirs_data_folder, 'Participant-1') +fnirs_raw_dir = os.path.join(fnirs_data_folder, "Participant-1") # Next we read the data raw_intensity = read_raw_nirx(fnirs_raw_dir).load_data() @@ -97,7 +94,7 @@ # To exactly match the results from Homer we can manually set the ppf value to # 6 in MNE. -raw_h = beer_lambert_law(corrected_tddr, ppf=6.) +raw_h = beer_lambert_law(corrected_tddr, ppf=6.0) # %% diff --git a/examples/migration/plot_02_nirstoolbox.py b/examples/migration/plot_02_nirstoolbox.py index c9fb58393..24a0f2eab 100644 --- a/examples/migration/plot_02_nirstoolbox.py +++ b/examples/migration/plot_02_nirstoolbox.py @@ -48,11 +48,6 @@ pipeline = nirs.modules.MixedEffects(); pipeline.formula = 'beta ~ -1 + cond + (1|Name)'; group_stats = pipeline.run(subj_stats); - -.. contents:: Page contents - :local: - :depth: 2 - """ # %% diff --git a/mne_nirs/channels/_channels.py b/mne_nirs/channels/_channels.py index 7d9f46f33..c7cf5ca40 100644 --- a/mne_nirs/channels/_channels.py +++ b/mne_nirs/channels/_channels.py @@ -2,11 +2,12 @@ # # License: BSD (3-clause) -import numpy as np import re + +import numpy as np from mne import pick_types -from mne.utils import _validate_type from mne.io import BaseRaw +from mne.utils import _validate_type def list_sources(raw): @@ -23,12 +24,11 @@ def list_sources(raw): sources : list Unique list of all sources in ascending order. """ - _validate_type(raw, BaseRaw, 'raw') + _validate_type(raw, BaseRaw, "raw") - picks = pick_types(raw.info, meg=False, eeg=False, fnirs=True, - exclude=[]) + picks = pick_types(raw.info, meg=False, eeg=False, fnirs=True, exclude=[]) if not len(picks): - raise RuntimeError('Listing source is for fNIRS signals only.') + raise RuntimeError("Listing source is for fNIRS signals only.") sources = list() ch_names = raw.ch_names @@ -55,12 +55,11 @@ def list_detectors(raw): sources : list Unique list of all detectors in ascending order. """ - _validate_type(raw, BaseRaw, 'raw') + _validate_type(raw, BaseRaw, "raw") - picks = pick_types(raw.info, meg=False, eeg=False, fnirs=True, - exclude=[]) + picks = pick_types(raw.info, meg=False, eeg=False, fnirs=True, exclude=[]) if not len(picks): - raise RuntimeError('Listing source is for fNIRS signals only.') + raise RuntimeError("Listing source is for fNIRS signals only.") detectors = list() ch_names = raw.ch_names @@ -95,12 +94,15 @@ def drop_sources(raw, sources): try: all_str = all([isinstance(src, int) for src in sources]) except TypeError: - raise ValueError("'ch_names' must be iterable, got " - "type {} ({}).".format(type(sources), sources)) + raise ValueError( + "'ch_names' must be iterable, got " f"type {type(sources)} ({sources})." + ) if not all_str: - raise ValueError("Each element in 'ch_names' must be int, got " - "{}.".format([type(ch) for ch in sources])) + raise ValueError( + "Each element in 'ch_names' must be int, got " + f"{[type(ch) for ch in sources]}." + ) keeps = np.ones(len(raw.ch_names)) for src in sources: @@ -136,12 +138,15 @@ def drop_detectors(raw, detectors): try: all_str = all([isinstance(det, int) for det in detectors]) except TypeError: - raise ValueError("'ch_names' must be iterable, got " - "type {} ({}).".format(type(detectors), detectors)) + raise ValueError( + "'ch_names' must be iterable, got " f"type {type(detectors)} ({detectors})." + ) if not all_str: - raise ValueError("Each element in 'ch_names' must be int, got " - "{}.".format([type(det) for det in detectors])) + raise ValueError( + "Each element in 'ch_names' must be int, got " + f"{[type(det) for det in detectors]}." + ) keeps = np.ones(len(raw.ch_names)) for det in detectors: @@ -177,12 +182,15 @@ def pick_sources(raw, sources): try: all_str = all([isinstance(src, int) for src in sources]) except TypeError: - raise ValueError("'ch_names' must be iterable, got " - "type {} ({}).".format(type(sources), sources)) + raise ValueError( + "'ch_names' must be iterable, got " f"type {type(sources)} ({sources})." + ) if not all_str: - raise ValueError("Each element in 'ch_names' must be int, got " - "{}.".format([type(ch) for ch in sources])) + raise ValueError( + "Each element in 'ch_names' must be int, got " + f"{[type(ch) for ch in sources]}." + ) keeps = np.zeros(len(raw.ch_names)) for src in sources: @@ -218,12 +226,15 @@ def pick_detectors(raw, detectors): try: all_str = all([isinstance(det, int) for det in detectors]) except TypeError: - raise ValueError("'ch_names' must be iterable, got " - "type {} ({}).".format(type(detectors), detectors)) + raise ValueError( + "'ch_names' must be iterable, got " f"type {type(detectors)} ({detectors})." + ) if not all_str: - raise ValueError("Each element in 'ch_names' must be int, got " - "{}.".format([type(det) for det in detectors])) + raise ValueError( + "Each element in 'ch_names' must be int, got " + f"{[type(det) for det in detectors]}." + ) keeps = np.zeros(len(raw.ch_names)) for det in detectors: diff --git a/mne_nirs/channels/_roi.py b/mne_nirs/channels/_roi.py index 7126c680d..8c9b16f60 100644 --- a/mne_nirs/channels/_roi.py +++ b/mne_nirs/channels/_roi.py @@ -7,7 +7,7 @@ from mne.utils import warn -def picks_pair_to_idx(raw, sd_pairs, on_missing='error'): +def picks_pair_to_idx(raw, sd_pairs, on_missing="error"): """ Return a list of picks for specified source detector pairs. @@ -38,7 +38,6 @@ def picks_pair_to_idx(raw, sd_pairs, on_missing='error'): picks : list of integers List of picks corresponding to requested source detector pairs. """ - ch_names = raw.ch_names picks = list() @@ -46,12 +45,14 @@ def picks_pair_to_idx(raw, sd_pairs, on_missing='error'): pair_name = "S" + str(pair[0]) + "_D" + str(pair[1]) + " " pair_picks = np.where([pair_name in ch for ch in ch_names])[0] if len(pair_picks) == 0: - msg = ('No matching channels found for source %s ' - 'detector %s' % (pair[0], pair[1])) - if on_missing == 'error': + msg = "No matching channels found for source %s " "detector %s" % ( + pair[0], + pair[1], + ) + if on_missing == "error": print(pair_picks) raise ValueError(msg) - elif on_missing == 'warning': + elif on_missing == "warning": warn(msg) else: # on_missing == 'ignore': diff --git a/mne_nirs/channels/_short.py b/mne_nirs/channels/_short.py index 488216107..1bee90f35 100644 --- a/mne_nirs/channels/_short.py +++ b/mne_nirs/channels/_short.py @@ -2,10 +2,10 @@ # # License: BSD (3-clause) +import mne +from mne.io import BaseRaw from mne.preprocessing.nirs import source_detector_distances from mne.utils import _validate_type -from mne.io import BaseRaw -import mne def get_short_channels(raw, max_dist=0.01): @@ -24,14 +24,14 @@ def get_short_channels(raw, max_dist=0.01): raw : instance of Raw Raw instance with only short channels. """ - short_chans = raw.copy().load_data() - _validate_type(short_chans, BaseRaw, 'raw') + _validate_type(short_chans, BaseRaw, "raw") - picks = mne.pick_types(short_chans.info, meg=False, eeg=False, fnirs=True, - exclude=[]) + picks = mne.pick_types( + short_chans.info, meg=False, eeg=False, fnirs=True, exclude=[] + ) if not len(picks): - raise RuntimeError('Short channel extraction for NIRS signals only.') + raise RuntimeError("Short channel extraction for NIRS signals only.") dists = source_detector_distances(short_chans.info, picks=picks) short_chans.pick(picks[dists < max_dist]) @@ -57,14 +57,14 @@ def get_long_channels(raw, min_dist=0.015, max_dist=0.045): raw : instance of Raw Raw instance with only long channels. """ - long_chans = raw.copy().load_data() - _validate_type(long_chans, BaseRaw, 'raw') + _validate_type(long_chans, BaseRaw, "raw") - picks = mne.pick_types(long_chans.info, meg=False, eeg=False, fnirs=True, - exclude=[]) + picks = mne.pick_types( + long_chans.info, meg=False, eeg=False, fnirs=True, exclude=[] + ) if not len(picks): - raise RuntimeError('Short channel extraction for NIRS signals only.') + raise RuntimeError("Short channel extraction for NIRS signals only.") dists = source_detector_distances(long_chans.info, picks=picks) long_chans.pick(picks[(dists > min_dist) & (dists < max_dist)]) diff --git a/mne_nirs/channels/tests/test_channels.py b/mne_nirs/channels/tests/test_channels.py index 24f03c5c9..cec276280 100644 --- a/mne_nirs/channels/tests/test_channels.py +++ b/mne_nirs/channels/tests/test_channels.py @@ -4,15 +4,23 @@ import os -from numpy.testing import assert_array_equal + import mne -from mne_nirs.channels import list_sources, list_detectors,\ - drop_sources, drop_detectors, pick_sources, pick_detectors +from numpy.testing import assert_array_equal + +from mne_nirs.channels import ( + drop_detectors, + drop_sources, + list_detectors, + list_sources, + pick_detectors, + pick_sources, +) def _get_raw(): fnirs_data_folder = mne.datasets.fnirs_motor.data_path() - fnirs_raw_dir = os.path.join(fnirs_data_folder, 'Participant-1') + fnirs_raw_dir = os.path.join(fnirs_data_folder, "Participant-1") raw = mne.io.read_raw_nirx(fnirs_raw_dir).load_data() return raw @@ -33,8 +41,7 @@ def test_list_sources(): sources = list_sources(raw.copy().pick(["S7_D15 850", "S7_D15 760"])) assert sources == [7] - sources = list_sources(raw.copy().pick(["S7_D15 850", - "S3_D2 850", "S7_D15 760"])) + sources = list_sources(raw.copy().pick(["S7_D15 850", "S3_D2 850", "S7_D15 760"])) assert_array_equal(sources, [3, 7]) @@ -51,12 +58,12 @@ def test_list_detectors(): detectors = list_detectors(raw.copy().pick("S7_D15 850")) assert detectors == [15] - detectors = list_detectors(raw.copy().pick(["S7_D15 850", - "S7_D15 760"])) + detectors = list_detectors(raw.copy().pick(["S7_D15 850", "S7_D15 760"])) assert detectors == [15] - detectors = list_detectors(raw.copy().pick(["S7_D15 850", - "S3_D2 850", "S7_D15 760"])) + detectors = list_detectors( + raw.copy().pick(["S7_D15 850", "S3_D2 850", "S7_D15 760"]) + ) assert_array_equal(detectors, [2, 15]) diff --git a/mne_nirs/channels/tests/test_roi.py b/mne_nirs/channels/tests/test_roi.py index 5ce8c61b5..c0678c816 100644 --- a/mne_nirs/channels/tests/test_roi.py +++ b/mne_nirs/channels/tests/test_roi.py @@ -3,14 +3,16 @@ # License: BSD (3-clause) import os + import mne import pytest + from mne_nirs.channels import picks_pair_to_idx def test_roi_picks(): fnirs_data_folder = mne.datasets.fnirs_motor.data_path() - fnirs_raw_dir = os.path.join(fnirs_data_folder, 'Participant-1') + fnirs_raw_dir = os.path.join(fnirs_data_folder, "Participant-1") raw = mne.io.read_raw_nirx(fnirs_raw_dir).load_data() picks = picks_pair_to_idx(raw, [[1, 1], [1, 2], [5, 13], [8, 16]]) @@ -28,33 +30,37 @@ def test_roi_picks(): assert raw.ch_names[picks[7]] == "S8_D16 850" # Test what happens when a pair that doesn't exist is requested (15-13) - with pytest.raises(ValueError, match='No matching'): + with pytest.raises(ValueError, match="No matching"): picks_pair_to_idx(raw, [[1, 1], [1, 2], [15, 13], [8, 16]]) - with pytest.warns(RuntimeWarning, match='No matching channels'): - picks = picks_pair_to_idx(raw, [[1, 1], [1, 2], [15, 13], [8, 16]], - on_missing='warning') + with pytest.warns(RuntimeWarning, match="No matching channels"): + picks = picks_pair_to_idx( + raw, [[1, 1], [1, 2], [15, 13], [8, 16]], on_missing="warning" + ) assert len(picks) == 6 # Missing should be ignored - picks = picks_pair_to_idx(raw, [[1, 1], [1, 2], [15, 13], [8, 16]], - on_missing='ignore') + picks = picks_pair_to_idx( + raw, [[1, 1], [1, 2], [15, 13], [8, 16]], on_missing="ignore" + ) assert len(picks) == 6 # Test usage for ROI downstream functions - group_by = dict(Left_ROI=picks_pair_to_idx(raw, [[1, 1], [1, 2], [5, 13]]), - Right_ROI=picks_pair_to_idx(raw, [[3, 3], [3, 11]])) - assert group_by['Left_ROI'] == [0, 1, 2, 3, 34, 35] - assert group_by['Right_ROI'] == [18, 19, 20, 21] + group_by = dict( + Left_ROI=picks_pair_to_idx(raw, [[1, 1], [1, 2], [5, 13]]), + Right_ROI=picks_pair_to_idx(raw, [[3, 3], [3, 11]]), + ) + assert group_by["Left_ROI"] == [0, 1, 2, 3, 34, 35] + assert group_by["Right_ROI"] == [18, 19, 20, 21] # Ensure we dont match [1, 1] to S1_D11 # Check easy condition picks = picks_pair_to_idx(raw, [[1, 1]]) assert picks == [0, 1] # Force in tricky situation - raw.info["ch_names"][2] = 'S1_D11 760' - raw.info["ch_names"][3] = 'S1_D11 850' + raw.info["ch_names"][2] = "S1_D11 760" + raw.info["ch_names"][3] = "S1_D11 850" picks = picks_pair_to_idx(raw, [[1, 1]]) assert picks == [0, 1] - picks = picks_pair_to_idx(raw, [[21, 91], [91, 2]], on_missing='ignore') + picks = picks_pair_to_idx(raw, [[21, 91], [91, 2]], on_missing="ignore") assert picks == [] diff --git a/mne_nirs/channels/tests/test_short.py b/mne_nirs/channels/tests/test_short.py index 3f6eb343c..8c92615da 100644 --- a/mne_nirs/channels/tests/test_short.py +++ b/mne_nirs/channels/tests/test_short.py @@ -4,17 +4,18 @@ import os + import mne import numpy as np import pytest - from mne.preprocessing.nirs import source_detector_distances + from mne_nirs.channels import get_long_channels, get_short_channels def test_short_extraction(): fnirs_data_folder = mne.datasets.fnirs_motor.data_path() - fnirs_raw_dir = os.path.join(fnirs_data_folder, 'Participant-1') + fnirs_raw_dir = os.path.join(fnirs_data_folder, "Participant-1") raw_intensity = mne.io.read_raw_nirx(fnirs_raw_dir).load_data() short_chans = get_short_channels(raw_intensity) @@ -39,15 +40,16 @@ def test_short_extraction(): # Check that we dont run on other types, eg eeg. raw_intensity.pick(picks=range(2)) - raw_intensity.set_channel_types({'S1_D1 760': 'eeg', 'S1_D1 850': 'eeg'}, - verbose='error') - with pytest.raises(RuntimeError, match='NIRS signals only'): + raw_intensity.set_channel_types( + {"S1_D1 760": "eeg", "S1_D1 850": "eeg"}, verbose="error" + ) + with pytest.raises(RuntimeError, match="NIRS signals only"): _ = get_short_channels(raw_intensity) def test_long_extraction(): fnirs_data_folder = mne.datasets.fnirs_motor.data_path() - fnirs_raw_dir = os.path.join(fnirs_data_folder, 'Participant-1') + fnirs_raw_dir = os.path.join(fnirs_data_folder, "Participant-1") raw_intensity = mne.io.read_raw_nirx(fnirs_raw_dir).load_data() long_chans = get_long_channels(raw_intensity) @@ -72,7 +74,8 @@ def test_long_extraction(): # Check that we dont run on other types, eg eeg. raw_intensity.pick(picks=range(2)) - raw_intensity.set_channel_types({'S1_D1 760': 'eeg', 'S1_D1 850': 'eeg'}, - verbose='error') - with pytest.raises(RuntimeError, match='NIRS signals only'): + raw_intensity.set_channel_types( + {"S1_D1 760": "eeg", "S1_D1 850": "eeg"}, verbose="error" + ) + with pytest.raises(RuntimeError, match="NIRS signals only"): _ = get_long_channels(raw_intensity) diff --git a/mne_nirs/datasets/audio_or_visual_speech/_audio_or_visual_speech.py b/mne_nirs/datasets/audio_or_visual_speech/_audio_or_visual_speech.py index 6a8e46b6d..d801376c9 100644 --- a/mne_nirs/datasets/audio_or_visual_speech/_audio_or_visual_speech.py +++ b/mne_nirs/datasets/audio_or_visual_speech/_audio_or_visual_speech.py @@ -5,22 +5,22 @@ import os import shutil -import pooch from functools import partial -from mne.utils import verbose -from mne.datasets.utils import has_dataset +import pooch from mne.datasets import fetch_dataset +from mne.datasets.utils import has_dataset +from mne.utils import verbose from ...fixes import _mne_path -has_block_speech_noise_data = partial(has_dataset, - name='audio_or_visual_speech') +has_block_speech_noise_data = partial(has_dataset, name="audio_or_visual_speech") @verbose -def data_path(path=None, force_update=False, update_path=True, download=True, - verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, verbose=None +): # noqa: D103 """ Audio and visual speech dataset with 8 participants. @@ -55,29 +55,33 @@ def data_path(path=None, force_update=False, update_path=True, download=True, ---------- .. footbibliography:: """ - dataset_params = dict( - archive_name='2021-fNIRS-Audio-visual-speech-' - 'Broad-vs-restricted-regions.zip', - hash='md5:16cac6565880dae6aed9b69100399d0b', - url='https://osf.io/xwerv/download?version=1', - folder_name='fNIRS-audio-visual-speech', - dataset_name='audio_or_visual_speech', - config_key='MNE_DATASETS_FNIRSAUDIOVISUALSPEECH_PATH', + archive_name="2021-fNIRS-Audio-visual-speech-" + "Broad-vs-restricted-regions.zip", + hash="md5:16cac6565880dae6aed9b69100399d0b", + url="https://osf.io/xwerv/download?version=1", + folder_name="fNIRS-audio-visual-speech", + dataset_name="audio_or_visual_speech", + config_key="MNE_DATASETS_FNIRSAUDIOVISUALSPEECH_PATH", ) - dpath = fetch_dataset(dataset_params, path=path, force_update=force_update, - update_path=update_path, download=download, - processor=pooch.Unzip( - extract_dir="./fNIRS-audio-visual-speech")) + dpath = fetch_dataset( + dataset_params, + path=path, + force_update=force_update, + update_path=update_path, + download=download, + processor=pooch.Unzip(extract_dir="./fNIRS-audio-visual-speech"), + ) dpath = str(dpath) # Do some wrangling to deal with nested directories - bad_name = os.path.join(dpath, '2021-fNIRS-Audio-visual-speech-' - 'Broad-vs-restricted-regions') + bad_name = os.path.join( + dpath, "2021-fNIRS-Audio-visual-speech-" "Broad-vs-restricted-regions" + ) if os.path.isdir(bad_name): - os.rename(bad_name, dpath + '.true') + os.rename(bad_name, dpath + ".true") shutil.rmtree(dpath) - os.rename(dpath + '.true', dpath) + os.rename(dpath + ".true", dpath) return _mne_path(dpath) diff --git a/mne_nirs/datasets/audio_or_visual_speech/tests/test_dataset_avspeech.py b/mne_nirs/datasets/audio_or_visual_speech/tests/test_dataset_avspeech.py index 01d42154a..3febe75cc 100644 --- a/mne_nirs/datasets/audio_or_visual_speech/tests/test_dataset_avspeech.py +++ b/mne_nirs/datasets/audio_or_visual_speech/tests/test_dataset_avspeech.py @@ -2,6 +2,7 @@ # License: BSD (3-clause) import os.path as op + import mne_nirs diff --git a/mne_nirs/datasets/block_speech_noise/_block_speech_noise.py b/mne_nirs/datasets/block_speech_noise/_block_speech_noise.py index db62b408f..e711eaf34 100644 --- a/mne_nirs/datasets/block_speech_noise/_block_speech_noise.py +++ b/mne_nirs/datasets/block_speech_noise/_block_speech_noise.py @@ -5,21 +5,22 @@ import os import shutil -import pooch from functools import partial -from mne.utils import verbose -from mne.datasets.utils import has_dataset +import pooch from mne.datasets import fetch_dataset +from mne.datasets.utils import has_dataset +from mne.utils import verbose from ...fixes import _mne_path -has_block_speech_noise_data = partial(has_dataset, name='block_speech_noise') +has_block_speech_noise_data = partial(has_dataset, name="block_speech_noise") @verbose -def data_path(path=None, force_update=False, update_path=True, download=True, - verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, verbose=None +): # noqa: D103 """ Audio speech and noise dataset with 18 participants. @@ -54,28 +55,30 @@ def data_path(path=None, force_update=False, update_path=True, download=True, ---------- .. footbibliography:: """ - dataset_params = dict( - archive_name='2021-fNIRS-Analysis-Methods-Passive-Auditory.zip', - hash='md5:569c0fbafa575e344e90698c808dfdd3', - url='https://osf.io/bjfu7/download?version=1', - folder_name='fNIRS-block-speech-noise', - dataset_name='block_speech_noise', - config_key='MNE_DATASETS_FNIRSSPEECHNOISE_PATH', + archive_name="2021-fNIRS-Analysis-Methods-Passive-Auditory.zip", + hash="md5:569c0fbafa575e344e90698c808dfdd3", + url="https://osf.io/bjfu7/download?version=1", + folder_name="fNIRS-block-speech-noise", + dataset_name="block_speech_noise", + config_key="MNE_DATASETS_FNIRSSPEECHNOISE_PATH", ) - dpath = fetch_dataset(dataset_params, path=path, force_update=force_update, - update_path=update_path, download=download, - processor=pooch.Unzip( - extract_dir="./fNIRS-block-speech-noise")) + dpath = fetch_dataset( + dataset_params, + path=path, + force_update=force_update, + update_path=update_path, + download=download, + processor=pooch.Unzip(extract_dir="./fNIRS-block-speech-noise"), + ) dpath = str(dpath) # Do some wrangling to deal with nested directories - bad_name = os.path.join(dpath, '2021-fNIRS-Analysis-Methods-' - 'Passive-Auditory') + bad_name = os.path.join(dpath, "2021-fNIRS-Analysis-Methods-" "Passive-Auditory") if os.path.isdir(bad_name): - os.rename(bad_name, dpath + '.true') + os.rename(bad_name, dpath + ".true") shutil.rmtree(dpath) - os.rename(dpath + '.true', dpath) + os.rename(dpath + ".true", dpath) return _mne_path(dpath) diff --git a/mne_nirs/datasets/block_speech_noise/tests/test_dataset_block.py b/mne_nirs/datasets/block_speech_noise/tests/test_dataset_block.py index ac61c5ffe..b34b50e9d 100644 --- a/mne_nirs/datasets/block_speech_noise/tests/test_dataset_block.py +++ b/mne_nirs/datasets/block_speech_noise/tests/test_dataset_block.py @@ -2,6 +2,7 @@ # License: BSD (3-clause) import os.path as op + import mne_nirs diff --git a/mne_nirs/datasets/fnirs_motor_group/fnirs_motor_group.py b/mne_nirs/datasets/fnirs_motor_group/fnirs_motor_group.py index be7c9d220..a31a255d8 100644 --- a/mne_nirs/datasets/fnirs_motor_group/fnirs_motor_group.py +++ b/mne_nirs/datasets/fnirs_motor_group/fnirs_motor_group.py @@ -6,21 +6,22 @@ import os import shutil -import pooch from functools import partial -from mne.utils import verbose -from mne.datasets.utils import has_dataset +import pooch from mne.datasets import fetch_dataset +from mne.datasets.utils import has_dataset +from mne.utils import verbose from ...fixes import _mne_path -has_fnirs_motor_group_data = partial(has_dataset, name='fnirs_motor_group') +has_fnirs_motor_group_data = partial(has_dataset, name="fnirs_motor_group") @verbose -def data_path(path=None, force_update=False, update_path=True, download=True, - verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, verbose=None +): # noqa: D103 """ Motor task experiment data with 5 participants. @@ -55,27 +56,30 @@ def data_path(path=None, force_update=False, update_path=True, download=True, ---------- .. footbibliography:: """ - dataset_params = dict( - archive_name='BIDS-NIRS-Tapping-master.zip', - hash='md5:da3cac7252005f0a64fdba5c683cf3dd', - url='https://github.com/rob-luke/BIDS-NIRS-Tapping/archive/v0.1.0.zip', - folder_name='fNIRS-motor-group', - dataset_name='fnirs_motor_group', - config_key='MNE_DATASETS_FNIRSMOTORGROUP_PATH', + archive_name="BIDS-NIRS-Tapping-master.zip", + hash="md5:da3cac7252005f0a64fdba5c683cf3dd", + url="https://github.com/rob-luke/BIDS-NIRS-Tapping/archive/v0.1.0.zip", + folder_name="fNIRS-motor-group", + dataset_name="fnirs_motor_group", + config_key="MNE_DATASETS_FNIRSMOTORGROUP_PATH", ) - dpath = fetch_dataset(dataset_params, path=path, force_update=force_update, - update_path=update_path, download=download, - processor=pooch.Unzip( - extract_dir="./fNIRS-motor-group")) + dpath = fetch_dataset( + dataset_params, + path=path, + force_update=force_update, + update_path=update_path, + download=download, + processor=pooch.Unzip(extract_dir="./fNIRS-motor-group"), + ) dpath = str(dpath) # Do some wrangling to deal with nested directories - bad_name = os.path.join(dpath, 'BIDS-NIRS-Tapping-0.1.0') + bad_name = os.path.join(dpath, "BIDS-NIRS-Tapping-0.1.0") if os.path.isdir(bad_name): - os.rename(bad_name, dpath + '.true') + os.rename(bad_name, dpath + ".true") shutil.rmtree(dpath) - os.rename(dpath + '.true', dpath) + os.rename(dpath + ".true", dpath) return _mne_path(dpath) diff --git a/mne_nirs/datasets/fnirs_motor_group/tests/test_dataset_tap.py b/mne_nirs/datasets/fnirs_motor_group/tests/test_dataset_tap.py index 31158eebc..2ac9c43d9 100644 --- a/mne_nirs/datasets/fnirs_motor_group/tests/test_dataset_tap.py +++ b/mne_nirs/datasets/fnirs_motor_group/tests/test_dataset_tap.py @@ -2,6 +2,7 @@ # License: BSD (3-clause) import os.path as op + import mne_nirs diff --git a/mne_nirs/datasets/snirf_with_aux/snirf_with_aux.py b/mne_nirs/datasets/snirf_with_aux/snirf_with_aux.py index 8aba9afb7..29160481c 100644 --- a/mne_nirs/datasets/snirf_with_aux/snirf_with_aux.py +++ b/mne_nirs/datasets/snirf_with_aux/snirf_with_aux.py @@ -4,21 +4,22 @@ # This downloads SNIRF data that includes auxiliary channels. import os -import pooch from functools import partial -from mne.utils import verbose -from mne.datasets.utils import has_dataset +import pooch from mne.datasets import fetch_dataset +from mne.datasets.utils import has_dataset +from mne.utils import verbose from ...fixes import _mne_path -has_fnirs_snirf_aux_data = partial(has_dataset, name='snirf_with_aux') +has_fnirs_snirf_aux_data = partial(has_dataset, name="snirf_with_aux") @verbose -def data_path(path=None, force_update=False, update_path=True, download=True, - verbose=None): # noqa: D103 +def data_path( + path=None, force_update=False, update_path=True, download=True, verbose=None +): # noqa: D103 """ SNIRF file with auxiliary channels. @@ -49,20 +50,23 @@ def data_path(path=None, force_update=False, update_path=True, download=True, path : str Path to dataset directory. """ - dataset_params = dict( - archive_name='2022-08-05_002.snirf.zip', - hash='md5:35ce75d1715c8cca801894a7120b5691', - url='https://forae.s3.amazonaws.com/2022-08-05_002.snirf.zip', - folder_name='fNIRS-SNIRF-aux', - dataset_name='snirf_with_aux', - config_key='MNE_DATASETS_SNIRFAUX_PATH', + archive_name="2022-08-05_002.snirf.zip", + hash="md5:35ce75d1715c8cca801894a7120b5691", + url="https://forae.s3.amazonaws.com/2022-08-05_002.snirf.zip", + folder_name="fNIRS-SNIRF-aux", + dataset_name="snirf_with_aux", + config_key="MNE_DATASETS_SNIRFAUX_PATH", ) - dpath = fetch_dataset(dataset_params, path=path, force_update=force_update, - update_path=update_path, download=download, - processor=pooch.Unzip( - extract_dir="./fNIRS-SNIRF-aux")) + dpath = fetch_dataset( + dataset_params, + path=path, + force_update=force_update, + update_path=update_path, + download=download, + processor=pooch.Unzip(extract_dir="./fNIRS-SNIRF-aux"), + ) dpath = str(dpath) return _mne_path(os.path.join(dpath, "2022-08-05_002.snirf")) diff --git a/mne_nirs/datasets/snirf_with_aux/tests/test_dataset_aux.py b/mne_nirs/datasets/snirf_with_aux/tests/test_dataset_aux.py index 2ba323288..c672afd33 100644 --- a/mne_nirs/datasets/snirf_with_aux/tests/test_dataset_aux.py +++ b/mne_nirs/datasets/snirf_with_aux/tests/test_dataset_aux.py @@ -2,6 +2,7 @@ # License: BSD (3-clause) import os.path as op + import mne_nirs diff --git a/mne_nirs/experimental_design/_experimental_design.py b/mne_nirs/experimental_design/_experimental_design.py index 0a82b51e9..965ca00ea 100644 --- a/mne_nirs/experimental_design/_experimental_design.py +++ b/mne_nirs/experimental_design/_experimental_design.py @@ -2,17 +2,23 @@ # # License: BSD (3-clause) -import numpy as np import mne +import numpy as np -def make_first_level_design_matrix(raw, stim_dur=1., - hrf_model='glover', - drift_model='cosine', - high_pass=0.01, drift_order=1, - fir_delays=[0], add_regs=None, - add_reg_names=None, min_onset=-24, - oversampling=50): +def make_first_level_design_matrix( + raw, + stim_dur=1.0, + hrf_model="glover", + drift_model="cosine", + high_pass=0.01, + drift_order=1, + fir_delays=(0,), + add_regs=None, + add_reg_names=None, + min_onset=-24, + oversampling=50, +): """ Generate a design matrix based on annotations and model HRF. @@ -86,20 +92,23 @@ def make_first_level_design_matrix(raw, stim_dur=1., conditions = raw.annotations.description onsets = raw.annotations.onset - raw.first_time duration = stim_dur * np.ones(len(conditions)) - events = DataFrame({'trial_type': conditions, - 'onset': onsets, - 'duration': duration}) - - dm = make_first_level_design_matrix(frame_times, events, - drift_model=drift_model, - drift_order=drift_order, - hrf_model=hrf_model, - min_onset=min_onset, - high_pass=high_pass, - add_regs=add_regs, - oversampling=oversampling, - add_reg_names=add_reg_names, - fir_delays=fir_delays) + events = DataFrame( + {"trial_type": conditions, "onset": onsets, "duration": duration} + ) + + dm = make_first_level_design_matrix( + frame_times, + events, + drift_model=drift_model, + drift_order=drift_order, + hrf_model=hrf_model, + min_onset=min_onset, + high_pass=high_pass, + add_regs=add_regs, + oversampling=oversampling, + add_reg_names=add_reg_names, + fir_delays=fir_delays, + ) return dm @@ -122,15 +131,15 @@ def create_boxcar(raw, event_id=None, stim_dur=1): s : array Returns an array for each annotation label. """ - bc = np.ones(int(round(raw.info['sfreq'] * stim_dur))) + bc = np.ones(int(round(raw.info["sfreq"] * stim_dur))) events, ids = mne.events_from_annotations(raw, event_id=event_id) s = np.zeros((len(raw.times), len(ids))) - for idx, id in enumerate(ids): + for idx, _ in enumerate(ids): id_idx = [e[2] == idx + 1 for e in events] id_evt = events[id_idx] event_samples = [e[0] for e in id_evt] - s[event_samples, idx] = 1. - s[:, idx] = np.convolve(s[:, idx], bc)[:len(raw.times)] + s[event_samples, idx] = 1.0 + s[:, idx] = np.convolve(s[:, idx], bc)[: len(raw.times)] return s diff --git a/mne_nirs/experimental_design/tests/test_experimental_design.py b/mne_nirs/experimental_design/tests/test_experimental_design.py index 9a0bde5a7..d7d663aa2 100644 --- a/mne_nirs/experimental_design/tests/test_experimental_design.py +++ b/mne_nirs/experimental_design/tests/test_experimental_design.py @@ -5,36 +5,41 @@ import os import mne -import mne_nirs import numpy as np -from mne_nirs.experimental_design import make_first_level_design_matrix, \ - longest_inter_annotation_interval, drift_high_pass + +import mne_nirs +from mne_nirs.experimental_design import ( + drift_high_pass, + longest_inter_annotation_interval, + make_first_level_design_matrix, +) from mne_nirs.simulation import simulate_nirs_raw def _load_dataset(): """Load data and tidy it a bit""" fnirs_data_folder = mne.datasets.fnirs_motor.data_path() - fnirs_raw_dir = os.path.join(fnirs_data_folder, 'Participant-1') - raw_intensity = mne.io.read_raw_nirx(fnirs_raw_dir, - verbose=True).load_data() + fnirs_raw_dir = os.path.join(fnirs_data_folder, "Participant-1") + raw_intensity = mne.io.read_raw_nirx(fnirs_raw_dir, verbose=True).load_data() raw_intensity.crop(0, raw_intensity.annotations.onset[-1]) new_des = [des for des in raw_intensity.annotations.description] - new_des = ['A' if x == "1.0" else x for x in new_des] - new_des = ['B' if x == "2.0" else x for x in new_des] - new_des = ['C' if x == "3.0" else x for x in new_des] - annot = mne.Annotations(raw_intensity.annotations.onset, - raw_intensity.annotations.duration, new_des) + new_des = ["A" if x == "1.0" else x for x in new_des] + new_des = ["B" if x == "2.0" else x for x in new_des] + new_des = ["C" if x == "3.0" else x for x in new_des] + annot = mne.Annotations( + raw_intensity.annotations.onset, raw_intensity.annotations.duration, new_des + ) raw_intensity.set_annotations(annot) picks = mne.pick_types(raw_intensity.info, meg=False, fnirs=True) dists = mne.preprocessing.nirs.source_detector_distances( - raw_intensity.info, picks=picks) + raw_intensity.info, picks=picks + ) raw_intensity.pick(picks[dists > 0.01]) - assert 'fnirs_cw_amplitude' in raw_intensity + assert "fnirs_cw_amplitude" in raw_intensity assert len(np.unique(raw_intensity.annotations.description)) == 4 return raw_intensity @@ -52,8 +57,10 @@ def test_create_boxcar(): assert np.min(bc) == 0 # The value of the boxcar should be 1 when a trigger fires - assert bc[int(raw_intensity.annotations.onset[0] * - raw_intensity.info['sfreq']), :][0] == 1 + assert ( + bc[int(raw_intensity.annotations.onset[0] * raw_intensity.info["sfreq"]), :][0] + == 1 + ) # Only one condition was ever present at a time in this data # So boxcar should never overlap across channels @@ -63,27 +70,36 @@ def test_create_boxcar(): def test_create_design(): raw_intensity = _load_dataset() raw_intensity.crop(450, 600) # Keep the test fast - design_matrix = make_first_level_design_matrix(raw_intensity, - drift_order=1, - drift_model='polynomial') + design_matrix = make_first_level_design_matrix( + raw_intensity, drift_order=1, drift_model="polynomial" + ) assert design_matrix.shape[0] == raw_intensity._data.shape[1] # Number of columns is number of conditions plus the drift plus constant - assert design_matrix.shape[1] ==\ - len(np.unique(raw_intensity.annotations.description)) + 2 + assert ( + design_matrix.shape[1] + == len(np.unique(raw_intensity.annotations.description)) + 2 + ) def test_cropped_raw(): # Ensure timing is correct for cropped signals - raw = simulate_nirs_raw(sfreq=1., amplitude=1., sig_dur=300., stim_dur=1., - isi_min=20., isi_max=40.) + raw = simulate_nirs_raw( + sfreq=1.0, + amplitude=1.0, + sig_dur=300.0, + stim_dur=1.0, + isi_min=20.0, + isi_max=40.0, + ) onsets = raw.annotations.onset onsets_after_crop = [onsets[idx] for idx in np.where(onsets > 100)] raw.crop(tmin=100) - design_matrix = make_first_level_design_matrix(raw, drift_order=0, - drift_model='polynomial') + design_matrix = make_first_level_design_matrix( + raw, drift_order=0, drift_model="polynomial" + ) # 100 corrects for the crop time above # 4 is peak time after onset @@ -93,8 +109,14 @@ def test_cropped_raw(): def test_high_pass_helpers(): # Test the helpers give reasonable values - raw = simulate_nirs_raw(sfreq=1., amplitude=1., sig_dur=300., stim_dur=1., - isi_min=20., isi_max=38.) + raw = simulate_nirs_raw( + sfreq=1.0, + amplitude=1.0, + sig_dur=300.0, + stim_dur=1.0, + isi_min=20.0, + isi_max=38.0, + ) lisi, names = longest_inter_annotation_interval(raw) lisi = lisi[0] assert lisi >= 20 diff --git a/mne_nirs/fixes.py b/mne_nirs/fixes.py index 4feef5dac..0be5684af 100644 --- a/mne_nirs/fixes.py +++ b/mne_nirs/fixes.py @@ -4,6 +4,7 @@ # Compat shims for different dependency versions. from pathlib import Path + import mne from mne.utils import check_version @@ -13,7 +14,7 @@ try: from mne.datasets.utils import _mne_path # noqa except Exception: - if check_version(mne.__version__, '1.0'): + if check_version(mne.__version__, "1.0"): _mne_path = Path else: _mne_path = str # old MNE diff --git a/mne_nirs/io/fold/_fold.py b/mne_nirs/io/fold/_fold.py index d995d8c60..65e62ee97 100644 --- a/mne_nirs/io/fold/_fold.py +++ b/mne_nirs/io/fold/_fold.py @@ -4,14 +4,13 @@ import os.path as op -import pandas as pd -import numpy as np - import mne -from mne.transforms import apply_trans, _get_trans -from mne.utils import _validate_type, _check_fname, warn +import numpy as np +import pandas as pd from mne.io import BaseRaw from mne.io.constants import FIFF +from mne.transforms import _get_trans, apply_trans +from mne.utils import _check_fname, _validate_type, warn def _read_fold_xls(fname, atlas="Juelich"): @@ -28,14 +27,9 @@ def _read_fold_xls(fname, atlas="Juelich"): atlas : str Requested atlas. """ - page_reference = {"AAL2": 2, - "AICHA": 5, - "Brodmann": 8, - "Juelich": 11, - "Loni": 14} + page_reference = {"AAL2": 2, "AICHA": 5, "Brodmann": 8, "Juelich": 11, "Loni": 14} - tbl = pd.read_excel(fname, - sheet_name=page_reference[atlas]) + tbl = pd.read_excel(fname, sheet_name=page_reference[atlas]) # Remove the spacing between rows empty_rows = np.where(np.isnan(tbl["Specificity"]))[0] @@ -46,8 +40,7 @@ def _read_fold_xls(fname, atlas="Juelich"): for col_idx, col in enumerate(tbl.columns): if not isinstance(tbl[col][row_idx], str): if np.isnan(tbl[col][row_idx]): - tbl.iloc[row_idx, col_idx] = \ - tbl.iloc[row_idx - 1, col_idx] + tbl.iloc[row_idx, col_idx] = tbl.iloc[row_idx - 1, col_idx] tbl["Specificity"] = tbl["Specificity"] * 100 tbl["brainSens"] = tbl["brainSens"] * 100 @@ -62,20 +55,20 @@ def _generate_montage_locations(): # standard_1020 and standard_1005 are in MNI (fsaverage) space already, # but we need to undo the scaling that head_scale will do montage = mne.channels.make_standard_montage( - 'standard_1005', head_size=0.09700884729534559) + "standard_1005", head_size=0.09700884729534559 + ) for d in montage.dig: - d['coord_frame'] = FIFF.FIFFV_MNE_COORD_MNI_TAL + d["coord_frame"] = FIFF.FIFFV_MNE_COORD_MNI_TAL montage.dig[:] = montage.dig[3:] montage.add_mni_fiducials() # now in fsaverage space - coords = pd.DataFrame.from_dict( - montage.get_positions()['ch_pos']).T + coords = pd.DataFrame.from_dict(montage.get_positions()["ch_pos"]).T coords["label"] = coords.index coords = coords.rename(columns={0: "x", 1: "y", 2: "z"}) return coords.reset_index(drop=True) -def _find_closest_standard_location(position, reference, *, out='label'): +def _find_closest_standard_location(position, reference, *, out="label"): """Return closest montage label to coordinates. Parameters @@ -89,22 +82,24 @@ def _find_closest_standard_location(position, reference, *, out='label'): Use None for no transformation. """ from scipy.spatial.distance import cdist + p0 = np.array(position) p0.shape = (-1, 3) - head_mri_t, _ = _get_trans('fsaverage', 'head', 'mri') + head_mri_t, _ = _get_trans("fsaverage", "head", "mri") p0 = apply_trans(head_mri_t, p0) - dists = cdist(p0, np.asarray(reference[['x', 'y', 'z']], float)) + dists = cdist(p0, np.asarray(reference[["x", "y", "z"]], float)) - if out == 'label': + if out == "label": min_idx = np.argmin(dists) return reference["label"][min_idx] else: - assert out == 'dists' + assert out == "dists" return dists -def fold_landmark_specificity(raw, landmark, fold_files=None, - atlas="Juelich", interpolate=False): +def fold_landmark_specificity( + raw, landmark, fold_files=None, atlas="Juelich", interpolate=False +): """Return the specificity of each channel to a specified brain landmark. Parameters @@ -150,8 +145,8 @@ def fold_landmark_specificity(raw, landmark, fold_files=None, ---------- .. footbibliography:: """ - _validate_type(landmark, str, 'landmark') - _validate_type(raw, BaseRaw, 'raw') + _validate_type(landmark, str, "landmark") + _validate_type(raw, BaseRaw, "raw") reference_locations = _generate_montage_locations() @@ -159,9 +154,9 @@ def fold_landmark_specificity(raw, landmark, fold_files=None, specificity = np.zeros(len(raw.ch_names)) for cidx in range(len(raw.ch_names)): - tbl = _source_detector_fold_table( - raw, cidx, reference_locations, fold_tbl, interpolate) + raw, cidx, reference_locations, fold_tbl, interpolate + ) if len(tbl) > 0: tbl["ContainsLmk"] = [landmark in la for la in tbl["Landmark"]] @@ -178,8 +173,7 @@ def fold_landmark_specificity(raw, landmark, fold_files=None, return np.array(specificity) -def fold_channel_specificity(raw, fold_files=None, atlas="Juelich", - interpolate=False): +def fold_channel_specificity(raw, fold_files=None, atlas="Juelich", interpolate=False): """Return the landmarks and specificity a channel is sensitive to. Parameters @@ -250,7 +244,7 @@ def fold_channel_specificity(raw, fold_files=None, atlas="Juelich", ---------- .. footbibliography:: """ # noqa: E501 - _validate_type(raw, BaseRaw, 'raw') + _validate_type(raw, BaseRaw, "raw") reference_locations = _generate_montage_locations() @@ -258,44 +252,50 @@ def fold_channel_specificity(raw, fold_files=None, atlas="Juelich", chan_spec = list() for cidx in range(len(raw.ch_names)): - tbl = _source_detector_fold_table( - raw, cidx, reference_locations, fold_tbl, interpolate) + raw, cidx, reference_locations, fold_tbl, interpolate + ) chan_spec.append(tbl.reset_index(drop=True)) return chan_spec def _check_load_fold(fold_files, atlas): - _validate_type(fold_files, (list, 'path-like', None), 'fold_files') + _validate_type(fold_files, (list, "path-like", None), "fold_files") if fold_files is None: - fold_files = mne.get_config('MNE_NIRS_FOLD_PATH') + fold_files = mne.get_config("MNE_NIRS_FOLD_PATH") if fold_files is None: raise ValueError( - 'MNE_NIRS_FOLD_PATH not set, either set it using ' - 'mne.set_config or pass fold_files as str or list') + "MNE_NIRS_FOLD_PATH not set, either set it using " + "mne.set_config or pass fold_files as str or list" + ) if not isinstance(fold_files, list): # path-like fold_files = _check_fname( - fold_files, overwrite='read', must_exist=True, name='fold_files', - need_dir=True) - fold_files = [op.join(fold_files, f'10-{x}.xls') for x in (5, 10)] + fold_files, + overwrite="read", + must_exist=True, + name="fold_files", + need_dir=True, + ) + fold_files = [op.join(fold_files, f"10-{x}.xls") for x in (5, 10)] fold_tbl = pd.DataFrame() for fi, fname in enumerate(fold_files): - fname = _check_fname(fname, overwrite='read', must_exist=True, - name=f'fold_files[{fi}]') - fold_tbl = pd.concat([fold_tbl, _read_fold_xls(fname, atlas=atlas)], - ignore_index=True) + fname = _check_fname( + fname, overwrite="read", must_exist=True, name=f"fold_files[{fi}]" + ) + fold_tbl = pd.concat( + [fold_tbl, _read_fold_xls(fname, atlas=atlas)], ignore_index=True + ) return fold_tbl def _source_detector_fold_table(raw, cidx, reference, fold_tbl, interpolate): - src = raw.info['chs'][cidx]['loc'][3:6] - det = raw.info['chs'][cidx]['loc'][6:9] + src = raw.info["chs"][cidx]["loc"][3:6] + det = raw.info["chs"][cidx]["loc"][6:9] - ref_lab = list(reference['label']) - dists = _find_closest_standard_location( - [src, det], reference, out='dists') + ref_lab = list(reference["label"]) + dists = _find_closest_standard_location([src, det], reference, out="dists") src_min, det_min = np.argmin(dists, axis=1) src_name, det_name = ref_lab[src_min], ref_lab[det_min] @@ -307,24 +307,22 @@ def _source_detector_fold_table(raw, cidx, reference, fold_tbl, interpolate): if len(tbl) == 0 and interpolate: # Try something hopefully not too terrible: pick the one with the # smallest net distance - good = (np.in1d(fold_tbl['Source'], reference['label']) & - np.in1d(fold_tbl['Detector'], reference['label'])) + good = np.in1d(fold_tbl["Source"], reference["label"]) & np.in1d( + fold_tbl["Detector"], reference["label"] + ) assert good.any() tbl = fold_tbl[good] assert len(tbl) - src_idx = [ref_lab.index(src) for src in tbl['Source']] - det_idx = [ref_lab.index(det) for det in tbl['Detector']] + src_idx = [ref_lab.index(src) for src in tbl["Source"]] + det_idx = [ref_lab.index(det) for det in tbl["Detector"]] # Original - tot_dist = np.linalg.norm( - [dists[0, src_idx], dists[1, det_idx]], axis=0) + tot_dist = np.linalg.norm([dists[0, src_idx], dists[1, det_idx]], axis=0) assert tot_dist.shape == (len(tbl),) idx = np.argmin(tot_dist) dist_1 = tot_dist[idx] src_1, det_1 = ref_lab[src_idx[idx]], ref_lab[det_idx[idx]] # And the reverse - tot_dist = np.linalg.norm( - [dists[0, det_idx], dists[1, src_idx]], axis=0 - ) + tot_dist = np.linalg.norm([dists[0, det_idx], dists[1, src_idx]], axis=0) idx = np.argmin(tot_dist) dist_2 = tot_dist[idx] src_2, det_2 = ref_lab[det_idx[idx]], ref_lab[src_idx[idx]] @@ -332,28 +330,31 @@ def _source_detector_fold_table(raw, cidx, reference, fold_tbl, interpolate): new_dist, src_use, det_use = dist_1, src_1, det_1 else: new_dist, src_use, det_use = dist_2, det_2, src_2 - warn('No fOLD table entry for best matching source/detector pair ' - f'{src_name}/{det_name} (RMS distance to the channel positions ' - f'was {1000 * dist:0.1f} mm) for channel index {cidx}, ' - 'using next smallest available ' - f'src/det pairing {src_use}/{det_use} (RMS distance ' - f'{1000 * new_dist:0.1f} mm). Consider setting your channel ' - 'positions to standard 10-05 locations using raw.set_montage ' - 'if your pair does show up in the tables.', - module='mne_nirs', ignore_namespaces=('mne', 'mne_nirs')) + warn( + "No fOLD table entry for best matching source/detector pair " + f"{src_name}/{det_name} (RMS distance to the channel positions " + f"was {1000 * dist:0.1f} mm) for channel index {cidx}, " + "using next smallest available " + f"src/det pairing {src_use}/{det_use} (RMS distance " + f"{1000 * new_dist:0.1f} mm). Consider setting your channel " + "positions to standard 10-05 locations using raw.set_montage " + "if your pair does show up in the tables.", + module="mne_nirs", + ignore_namespaces=("mne", "mne_nirs"), + ) tbl = fold_tbl.query("Source == @src_use and Detector == @det_use") tbl = tbl.copy() - tbl['BestSource'] = src_name - tbl['BestDetector'] = det_name - tbl['BestMatchDistance'] = dist - tbl['MatchDistance'] = new_dist + tbl["BestSource"] = src_name + tbl["BestDetector"] = det_name + tbl["BestMatchDistance"] = dist + tbl["MatchDistance"] = new_dist assert len(tbl) else: tbl = tbl.copy() - tbl['BestSource'] = src_name - tbl['BestDetector'] = det_name - tbl['BestMatchDistance'] = dist - tbl['MatchDistance'] = dist + tbl["BestSource"] = src_name + tbl["BestDetector"] = det_name + tbl["BestMatchDistance"] = dist + tbl["MatchDistance"] = dist tbl = tbl.copy() # don't get warnings about setting values later return tbl diff --git a/mne_nirs/io/fold/tests/test_fold.py b/mne_nirs/io/fold/tests/test_fold.py index 49d5eebaf..8ebed78d0 100644 --- a/mne_nirs/io/fold/tests/test_fold.py +++ b/mne_nirs/io/fold/tests/test_fold.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Author: Robert Luke # # License: BSD (3-clause) @@ -6,33 +5,36 @@ from pathlib import Path from shutil import copyfile +import mne import numpy as np -from numpy.testing import assert_allclose import pandas as pd import pytest - -import mne from mne.channels import make_standard_montage from mne.channels.montage import transform_to_head from mne.datasets.testing import data_path, requires_testing_data -from mne.io import read_raw_nirx, read_fiducials +from mne.io import read_fiducials, read_raw_nirx +from numpy.testing import assert_allclose -from mne_nirs.io.fold._fold import _generate_montage_locations,\ - _find_closest_standard_location, _read_fold_xls from mne_nirs.io import fold_landmark_specificity from mne_nirs.io.fold import fold_channel_specificity +from mne_nirs.io.fold._fold import ( + _find_closest_standard_location, + _generate_montage_locations, + _read_fold_xls, +) thisfile = Path(__file__).parent.resolve() foldfile = thisfile / "data" / "example.xls" # https://github.com/mne-tools/mne-testing-data/pull/72 -fname_nirx_15_3_short = Path(data_path(download=False)) / \ - 'NIRx' / 'nirscout' / 'nirx_15_3_recording' +fname_nirx_15_3_short = ( + Path(data_path(download=False)) / "NIRx" / "nirscout" / "nirx_15_3_recording" +) -pytest.importorskip('xlrd', '1.0') +pytest.importorskip("xlrd", "1.0") -@pytest.mark.parametrize('fold_files', (str, None, list)) +@pytest.mark.parametrize("fold_files", (str, None, list)) def test_channel_specificity(monkeypatch, tmp_path, fold_files): raw = read_raw_nirx(fname_nirx_15_3_short, preload=True) raw.pick(range(2)) @@ -45,60 +47,60 @@ def test_channel_specificity(monkeypatch, tmp_path, fold_files): n_want *= 2 else: assert fold_files is None - monkeypatch.setenv('MNE_NIRS_FOLD_PATH', str(tmp_path)) + monkeypatch.setenv("MNE_NIRS_FOLD_PATH", str(tmp_path)) assert len(kwargs) == 0 - with pytest.raises(FileNotFoundError, match=r'fold_files\[0\] does.*'): + with pytest.raises(FileNotFoundError, match=r"fold_files\[0\] does.*"): fold_channel_specificity(raw) n_want *= 2 - copyfile(foldfile, tmp_path / '10-10.xls') - copyfile(foldfile, tmp_path / '10-5.xls') + copyfile(foldfile, tmp_path / "10-10.xls") + copyfile(foldfile, tmp_path / "10-5.xls") res = fold_channel_specificity(raw, **kwargs) assert len(res) == 2 assert res[0].shape == (n_want, 14) - montage = make_standard_montage( - 'standard_1005', head_size=0.09700884729534559) + montage = make_standard_montage("standard_1005", head_size=0.09700884729534559) fids = read_fiducials( - Path(mne.__file__).parent / 'data' / 'fsaverage' / - 'fsaverage-fiducials.fif')[0] + Path(mne.__file__).parent / "data" / "fsaverage" / "fsaverage-fiducials.fif" + )[0] for f in fids: - f['coord_frame'] = montage.dig[0]['coord_frame'] + f["coord_frame"] = montage.dig[0]["coord_frame"] montage.dig[:3] = fids - S, D = raw.ch_names[0].split()[0].split('_') - assert S == 'S1' and D == 'D2' - montage.rename_channels({'PO8': S, 'P6': D}) # not in the tables! + S, D = raw.ch_names[0].split()[0].split("_") + assert S == "S1" and D == "D2" + montage.rename_channels({"PO8": S, "P6": D}) # not in the tables! # taken from standard_1020.elc - s_mri = np.array([55.6666, -97.6251, 2.7300]) / 1000. - d_mri = np.array([67.8877, -75.9043, 28.0910]) / 1000. - trans = mne.transforms._get_trans('fsaverage', 'mri', 'head')[0] - ch_pos = montage.get_positions()['ch_pos'] + s_mri = np.array([55.6666, -97.6251, 2.7300]) / 1000.0 + d_mri = np.array([67.8877, -75.9043, 28.0910]) / 1000.0 + trans = mne.transforms._get_trans("fsaverage", "mri", "head")[0] + ch_pos = montage.get_positions()["ch_pos"] assert_allclose(ch_pos[S], s_mri, atol=1e-6) assert_allclose(ch_pos[D], d_mri, atol=1e-6) raw.set_montage(montage) montage = transform_to_head(montage) s_head = mne.transforms.apply_trans(trans, s_mri) d_head = mne.transforms.apply_trans(trans, d_mri) - assert_allclose(montage._get_ch_pos()['S1'], s_head, atol=1e-6) - assert_allclose(montage._get_ch_pos()['D2'], d_head, atol=1e-6) - for ch in raw.info['chs']: - assert_allclose(ch['loc'][3:6], s_head, atol=1e-6) - assert_allclose(ch['loc'][6:9], d_head, atol=1e-6) + assert_allclose(montage._get_ch_pos()["S1"], s_head, atol=1e-6) + assert_allclose(montage._get_ch_pos()["D2"], d_head, atol=1e-6) + for ch in raw.info["chs"]: + assert_allclose(ch["loc"][3:6], s_head, atol=1e-6) + assert_allclose(ch["loc"][6:9], d_head, atol=1e-6) res_1 = fold_channel_specificity(raw, **kwargs)[0] assert res_1.shape == (0, 14) # TODO: This is wrong, should be P08 not P08h, and distance should be 0 mm! - with pytest.warns(RuntimeWarning, match='.*PO8h?/P6.*TP8/T8.*'): + with pytest.warns(RuntimeWarning, match=".*PO8h?/P6.*TP8/T8.*"): res_1 = fold_channel_specificity(raw, interpolate=True, **kwargs)[0] montage.rename_channels({S: D, D: S}) # reversed - with pytest.warns(RuntimeWarning, match='.*PO8h?/P6.*TP8/T8.*'): + with pytest.warns(RuntimeWarning, match=".*PO8h?/P6.*TP8/T8.*"): res_2 = fold_channel_specificity(raw, interpolate=True, **kwargs)[0] # We should check the whole thing, but this is probably good enough - assert (res_1['Specificity'] == res_2['Specificity']).all() + assert (res_1["Specificity"] == res_2["Specificity"]).all() def test_landmark_specificity(): raw = read_raw_nirx(fname_nirx_15_3_short, preload=True) - with pytest.warns(RuntimeWarning, match='No fOLD table entry'): - res = fold_landmark_specificity(raw, "L Superior Frontal Gyrus", - [foldfile], interpolate=True) + with pytest.warns(RuntimeWarning, match="No fOLD table entry"): + res = fold_landmark_specificity( + raw, "L Superior Frontal Gyrus", [foldfile], interpolate=True + ) assert len(res) == len(raw.ch_names) assert np.max(res) <= 100 assert np.min(res) >= 0 @@ -111,20 +113,17 @@ def test_fold_workflow(): channel_of_interest = raw.copy().pick(1) # Get source and detector labels - source_locs = channel_of_interest.info['chs'][0]['loc'][3:6] - source_label = _find_closest_standard_location(source_locs, - reference_locations) + source_locs = channel_of_interest.info["chs"][0]["loc"][3:6] + source_label = _find_closest_standard_location(source_locs, reference_locations) assert source_label == "T7" - detector_locs = channel_of_interest.info['chs'][0]['loc'][6:9] - detector_label = _find_closest_standard_location(detector_locs, - reference_locations) + detector_locs = channel_of_interest.info["chs"][0]["loc"][6:9] + detector_label = _find_closest_standard_location(detector_locs, reference_locations) assert detector_label == "TP7" # Find correct fOLD elements tbl = _read_fold_xls(foldfile, atlas="Juelich") - tbl = tbl.query("Source == @source_label").\ - query("Detector == @detector_label") + tbl = tbl.query("Source == @source_label").query("Detector == @detector_label") # Query region of interest specificity = tbl.query("Landmark == 'L Mid Orbital Gyrus'")["Specificity"] @@ -135,8 +134,7 @@ def test_fold_reader(): tbl = _read_fold_xls(foldfile, atlas="Juelich") assert isinstance(tbl, pd.DataFrame) assert tbl.shape == (11, 10) - assert "L Superior Frontal Gyrus" in \ - list(tbl["Landmark"]) + assert "L Superior Frontal Gyrus" in list(tbl["Landmark"]) @requires_testing_data @@ -147,30 +145,45 @@ def test_label_finder(): # Test central head position source raw_tmp = raw.copy().pick(25) - assert _find_closest_standard_location( - raw_tmp.info['chs'][0]['loc'][3:6], - reference_locations) == "Cz" + assert ( + _find_closest_standard_location( + raw_tmp.info["chs"][0]["loc"][3:6], reference_locations + ) + == "Cz" + ) # Test right auditory position detector raw_tmp = raw.copy().pick(4) - assert _find_closest_standard_location( - raw_tmp.info['chs'][0]['loc'][6:9], - reference_locations) == "T8" + assert ( + _find_closest_standard_location( + raw_tmp.info["chs"][0]["loc"][6:9], reference_locations + ) + == "T8" + ) # Test right auditory position source raw_tmp = raw.copy().pick(4) - assert _find_closest_standard_location( - raw_tmp.info['chs'][0]['loc'][3:6], - reference_locations) == "TP8" + assert ( + _find_closest_standard_location( + raw_tmp.info["chs"][0]["loc"][3:6], reference_locations + ) + == "TP8" + ) # Test left auditory position source raw_tmp = raw.copy().pick(1) - assert _find_closest_standard_location( - raw_tmp.info['chs'][0]['loc'][3:6], - reference_locations) == "T7" + assert ( + _find_closest_standard_location( + raw_tmp.info["chs"][0]["loc"][3:6], reference_locations + ) + == "T7" + ) # Test left auditory position detector raw_tmp = raw.copy().pick(1) - assert _find_closest_standard_location( - raw_tmp.info['chs'][0]['loc'][6:9], - reference_locations) == "TP7" + assert ( + _find_closest_standard_location( + raw_tmp.info["chs"][0]["loc"][6:9], reference_locations + ) + == "TP7" + ) diff --git a/mne_nirs/io/snirf/_aux.py b/mne_nirs/io/snirf/_aux.py index 7ddee7fa0..4a3ad1a20 100644 --- a/mne_nirs/io/snirf/_aux.py +++ b/mne_nirs/io/snirf/_aux.py @@ -2,12 +2,13 @@ # # License: BSD (3-clause) -import numpy as np import logging + import h5py -from scipy import interpolate -from pandas import DataFrame +import numpy as np from mne.io import Raw +from pandas import DataFrame +from scipy import interpolate def read_snirf_aux_data(fname: str, raw: Raw): @@ -29,33 +30,31 @@ def read_snirf_aux_data(fname: str, raw: Raw): fname : str Path to the SNIRF data file. """ - - with h5py.File(fname, 'r') as dat: - if 'nirs' in dat: + with h5py.File(fname, "r") as dat: + if "nirs" in dat: basename = "nirs" - elif 'nirs1' in dat: + elif "nirs1" in dat: basename = "nirs1" else: raise RuntimeError("Data does not contain nirs field") all_keys = list(dat.get(basename).keys()) - aux_keys = [i for i in all_keys if i.startswith('aux')] - aux_names = [_decode_name(dat.get(f'{basename}/{k}/name')) - for k in aux_keys] + aux_keys = [i for i in all_keys if i.startswith("aux")] + aux_names = [_decode_name(dat.get(f"{basename}/{k}/name")) for k in aux_keys] logging.debug(f"Found auxiliary channels {aux_names}") - d = {'times': raw.times} + d = {"times": raw.times} for idx, aux in enumerate(aux_keys): - aux_data = np.array(dat.get(f'{basename}/{aux}/dataTimeSeries')) - aux_time = np.array(dat.get(f'{basename}/{aux}/time')) - aux_data_interp = interpolate.interp1d(aux_time, aux_data, - axis=0, bounds_error=False, - fill_value='extrapolate') + aux_data = np.array(dat.get(f"{basename}/{aux}/dataTimeSeries")) + aux_time = np.array(dat.get(f"{basename}/{aux}/time")) + aux_data_interp = interpolate.interp1d( + aux_time, aux_data, axis=0, bounds_error=False, fill_value="extrapolate" + ) aux_data_matched_to_raw = aux_data_interp(raw.times) d[aux_names[idx]] = aux_data_matched_to_raw df = DataFrame(data=d) - df = df.set_index('times') + df = df.set_index("times") return df diff --git a/mne_nirs/io/snirf/_snirf.py b/mne_nirs/io/snirf/_snirf.py index 8bfb9f989..e21554de0 100644 --- a/mne_nirs/io/snirf/_snirf.py +++ b/mne_nirs/io/snirf/_snirf.py @@ -3,18 +3,17 @@ # License: BSD (3-clause) import datetime +import re import h5py as h5py -import re import numpy as np -from mne.io.pick import _picks_to_idx -from mne.transforms import apply_trans, _get_trans from mne.channels import make_standard_montage - +from mne.io.pick import _picks_to_idx +from mne.transforms import _get_trans, apply_trans # The currently-implemented spec can be found here: # https://raw.githubusercontent.com/fNIRS/snirf/v1.1/snirf_specification.md -SPEC_FORMAT_VERSION = '1.1' +SPEC_FORMAT_VERSION = "1.1" def write_raw_snirf(raw, fname, add_montage=False): @@ -45,22 +44,23 @@ def write_raw_snirf(raw, fname, add_montage=False): Should the montage be added as landmarks to facilitate compatibility with AtlasViewer. """ - - supported_types = ['fnirs_cw_amplitude', 'fnirs_od', 'hbo', 'hbr'] + supported_types = ["fnirs_cw_amplitude", "fnirs_od", "hbo", "hbr"] picks = _picks_to_idx(raw.info, supported_types, exclude=[]) - assert len(picks) == len(raw.ch_names),\ - 'Data must only be of type fnirs_cw_amplitude, fnirs_od, hbo or hbr' - if ('fnirs_cw_amplitude' in raw) or ('fnirs_od' in raw): - assert len(np.unique(raw.info.get_channel_types())) == 1,\ - 'All channels must be of the same type' - elif ('hbo' in raw) or ('hbr' in raw): - assert len(np.unique(raw.info.get_channel_types())) <= 2,\ - 'Channels must be of type hbo and hbr' - - with h5py.File(fname, 'w') as f: - nirs = f.create_group('/nirs') - f.create_dataset('formatVersion', - data=_str_encode(SPEC_FORMAT_VERSION)) + assert len(picks) == len( + raw.ch_names + ), "Data must only be of type fnirs_cw_amplitude, fnirs_od, hbo or hbr" + if ("fnirs_cw_amplitude" in raw) or ("fnirs_od" in raw): + assert ( + len(np.unique(raw.info.get_channel_types())) == 1 + ), "All channels must be of the same type" + elif ("hbo" in raw) or ("hbr" in raw): + assert ( + len(np.unique(raw.info.get_channel_types())) <= 2 + ), "Channels must be of type hbo and hbr" + + with h5py.File(fname, "w") as f: + nirs = f.create_group("/nirs") + f.create_dataset("formatVersion", data=_str_encode(SPEC_FORMAT_VERSION)) _add_metadata_tags(raw, nirs) _add_single_data_block(raw, nirs) @@ -76,11 +76,11 @@ def _str_encode(str_val): str_val : str The string to encode. """ - return str_val.encode('UTF-8') + return str_val.encode("UTF-8") def _add_metadata_tags(raw, nirs): - """Creates and adds elements to the nirs metaDataTags group. + """Create and add elements to the nirs metaDataTags group. Parameters ---------- @@ -89,48 +89,44 @@ def _add_metadata_tags(raw, nirs): nirs : hpy5.Group The root hdf5 nirs group to which the metadata should beadded. """ - metadata_tags = nirs.create_group('metaDataTags') + metadata_tags = nirs.create_group("metaDataTags") # Store measurement - datestr = raw.info['meas_date'].strftime('%Y-%m-%d') - timestr = raw.info['meas_date'].strftime('%H:%M:%SZ') - metadata_tags.create_dataset('MeasurementDate', - data=_str_encode(datestr)) - metadata_tags.create_dataset('MeasurementTime', - data=_str_encode(timestr)) + datestr = raw.info["meas_date"].strftime("%Y-%m-%d") + timestr = raw.info["meas_date"].strftime("%H:%M:%SZ") + metadata_tags.create_dataset("MeasurementDate", data=_str_encode(datestr)) + metadata_tags.create_dataset("MeasurementTime", data=_str_encode(timestr)) # Store demographic info - subject_id = raw.info['subject_info']['first_name'] - metadata_tags.create_dataset('SubjectID', data=_str_encode(subject_id)) + subject_id = raw.info["subject_info"]["first_name"] + metadata_tags.create_dataset("SubjectID", data=_str_encode(subject_id)) # Store the units of measurement - metadata_tags.create_dataset('LengthUnit', data=_str_encode('m')) - metadata_tags.create_dataset('TimeUnit', data=_str_encode('s')) - metadata_tags.create_dataset('FrequencyUnit', data=_str_encode('Hz')) + metadata_tags.create_dataset("LengthUnit", data=_str_encode("m")) + metadata_tags.create_dataset("TimeUnit", data=_str_encode("s")) + metadata_tags.create_dataset("FrequencyUnit", data=_str_encode("Hz")) # Add non standard (but allowed) custom metadata tags - if 'birthday' in raw.info['subject_info']: - birthday = datetime.date(*raw.info['subject_info']['birthday']) - birthstr = birthday.strftime('%Y-%m-%d') - metadata_tags.create_dataset('DateOfBirth', - data=[_str_encode(birthstr)]) - if 'middle_name' in raw.info['subject_info']: - middle_name = raw.info['subject_info']['middle_name'] - metadata_tags.create_dataset('middleName', - data=[_str_encode(middle_name)]) - if 'last_name' in raw.info['subject_info']: - last_name = raw.info['subject_info']['last_name'] - metadata_tags.create_dataset('lastName', data=[_str_encode(last_name)]) - if 'sex' in raw.info['subject_info']: - sex = str(int(raw.info['subject_info']['sex'])) - metadata_tags.create_dataset('sex', data=[_str_encode(sex)]) - if raw.info['dig'] is not None: - coord_frame_id = int(raw.info['dig'][0].get('coord_frame')) - metadata_tags.create_dataset('MNE_coordFrame', data=[coord_frame_id]) + if "birthday" in raw.info["subject_info"]: + birthday = datetime.date(*raw.info["subject_info"]["birthday"]) + birthstr = birthday.strftime("%Y-%m-%d") + metadata_tags.create_dataset("DateOfBirth", data=[_str_encode(birthstr)]) + if "middle_name" in raw.info["subject_info"]: + middle_name = raw.info["subject_info"]["middle_name"] + metadata_tags.create_dataset("middleName", data=[_str_encode(middle_name)]) + if "last_name" in raw.info["subject_info"]: + last_name = raw.info["subject_info"]["last_name"] + metadata_tags.create_dataset("lastName", data=[_str_encode(last_name)]) + if "sex" in raw.info["subject_info"]: + sex = str(int(raw.info["subject_info"]["sex"])) + metadata_tags.create_dataset("sex", data=[_str_encode(sex)]) + if raw.info["dig"] is not None: + coord_frame_id = int(raw.info["dig"][0].get("coord_frame")) + metadata_tags.create_dataset("MNE_coordFrame", data=[coord_frame_id]) def _add_single_data_block(raw, nirs): - """Adds the data from raw to the nirs data1 group. + """Add the data from raw to the nirs data1 group. While SNIRF supports multiple datablocks, this writer only supports a single data block named data1. @@ -142,15 +138,15 @@ def _add_single_data_block(raw, nirs): nirs : hpy5.Group The root hdf5 nirs group to which the data should be added. """ - data_block = nirs.create_group('data1') - data_block.create_dataset('dataTimeSeries', data=raw.get_data().T) - data_block.create_dataset('time', data=raw.times) + data_block = nirs.create_group("data1") + data_block.create_dataset("dataTimeSeries", data=raw.get_data().T) + data_block.create_dataset("time", data=raw.times) _add_measurement_lists(raw, data_block) def _add_measurement_lists(raw, data_block): - """Adds the measurement list groups to the nirs data1 group. + """Add the measurement list groups to the nirs data1 group. Parameters ---------- @@ -167,40 +163,38 @@ def _add_measurement_lists(raw, data_block): raw_wavelengths = [c["loc"][9] for c in raw.info["chs"]] for idx, ch_name in enumerate(raw.ch_names, start=1): - ml_id = f'measurementList{idx}' + ml_id = f"measurementList{idx}" ch_group = data_block.require_group(ml_id) source_idx = sources.index(_extract_source(ch_name)) + 1 detector_idx = detectors.index(_extract_detector(ch_name)) + 1 wavelength_idx = wavelengths.index(raw_wavelengths[idx - 1]) + 1 - ch_group.create_dataset('sourceIndex', data=source_idx, dtype='int32') - ch_group.create_dataset('detectorIndex', data=detector_idx, - dtype='int32') - ch_group.create_dataset('wavelengthIndex', data=wavelength_idx, - dtype='int32') + ch_group.create_dataset("sourceIndex", data=source_idx, dtype="int32") + ch_group.create_dataset("detectorIndex", data=detector_idx, dtype="int32") + ch_group.create_dataset("wavelengthIndex", data=wavelength_idx, dtype="int32") # The data type coding is described at # https://github.com/fNIRS/snirf/blob/master/snirf_specification.md#appendix # The currently implemented data types are: # 1 = Continuous Wave # 99999 = Processed - data_type = 1 if raw_types[idx - 1] == 'fnirs_cw_amplitude' else 99999 + data_type = 1 if raw_types[idx - 1] == "fnirs_cw_amplitude" else 99999 - ch_group.create_dataset('dataType', data=data_type, dtype='int32') - ch_group.create_dataset('dataTypeIndex', data=1, dtype='int32') - if raw_types[idx - 1] == 'fnirs_od': - ch_group.create_dataset('dataTypeLabel', data="dOD") - elif raw_types[idx - 1] == 'fnirs_cw_amplitude': + ch_group.create_dataset("dataType", data=data_type, dtype="int32") + ch_group.create_dataset("dataTypeIndex", data=1, dtype="int32") + if raw_types[idx - 1] == "fnirs_od": + ch_group.create_dataset("dataTypeLabel", data="dOD") + elif raw_types[idx - 1] == "fnirs_cw_amplitude": # The SNIRF specification does not specify a label for this type continue - elif raw_types[idx - 1] == 'hbo': - ch_group.create_dataset('dataTypeLabel', data="HbO") - elif raw_types[idx - 1] == 'hbr': - ch_group.create_dataset('dataTypeLabel', data="HbR") + elif raw_types[idx - 1] == "hbo": + ch_group.create_dataset("dataTypeLabel", data="HbO") + elif raw_types[idx - 1] == "hbr": + ch_group.create_dataset("dataTypeLabel", data="HbR") def _add_probe_info(raw, nirs, add_montage): - """Adds details of the probe to the nirs group. + """Add details of the probe to the nirs group. Parameters ---------- @@ -216,14 +210,14 @@ def _add_probe_info(raw, nirs, add_montage): detectors = _get_unique_detector_list(raw) wavelengths = _get_unique_wavelength_list(raw) - probe = nirs.create_group('probe') + probe = nirs.create_group("probe") # Store source/detector/wavelength info - encoded_source_labels = [_str_encode(f'S{src}') for src in sources] - encoded_detector_labels = [_str_encode(f'D{det}') for det in detectors] - probe.create_dataset('sourceLabels', data=encoded_source_labels) - probe.create_dataset('detectorLabels', data=encoded_detector_labels) - probe.create_dataset('wavelengths', data=wavelengths) + encoded_source_labels = [_str_encode(f"S{src}") for src in sources] + encoded_detector_labels = [_str_encode(f"D{det}") for det in detectors] + probe.create_dataset("sourceLabels", data=encoded_source_labels) + probe.create_dataset("detectorLabels", data=encoded_detector_labels) + probe.create_dataset("wavelengths", data=wavelengths) # Create 3d locs and store ch_sources = [_extract_source(ch) for ch in raw.ch_names] @@ -232,20 +226,20 @@ def _add_probe_info(raw, nirs, add_montage): detlocs = np.empty((len(detectors), 3)) for i, src in enumerate(sources): idx = ch_sources.index(src) - srclocs[i, :] = raw.info['chs'][idx]['loc'][3:6] + srclocs[i, :] = raw.info["chs"][idx]["loc"][3:6] for i, det in enumerate(detectors): idx = ch_detectors.index(det) - detlocs[i, :] = raw.info['chs'][idx]['loc'][6:9] - probe.create_dataset('sourcePos3D', data=srclocs) - probe.create_dataset('detectorPos3D', data=detlocs) + detlocs[i, :] = raw.info["chs"][idx]["loc"][6:9] + probe.create_dataset("sourcePos3D", data=srclocs) + probe.create_dataset("detectorPos3D", data=detlocs) # Store probe landmarks - if raw.info['dig'] is not None: + if raw.info["dig"] is not None: _store_probe_landmarks(raw, probe, add_montage) def _store_probe_landmarks(raw, probe, add_montage): - """Adds the probe landmarks to the probe group. + """Add the probe landmarks to the probe group. The SNIRF specification provides flexibility around what is stored in this field. Some software expect specific @@ -272,38 +266,36 @@ def _store_probe_landmarks(raw, probe, add_montage): Should the montage be added as landmarks to facilitate compatibility with AtlasViewer. """ - diglocs = np.empty((len(raw.info['dig']), 3)) + diglocs = np.empty((len(raw.info["dig"]), 3)) digname = list() - for idx, dig in enumerate(raw.info['dig']): - ident = re.match(r'\d+ \(FIFFV_POINT_(\w+)\)', - str(dig.get('ident'))) + for idx, dig in enumerate(raw.info["dig"]): + ident = re.match(r"\d+ \(FIFFV_POINT_(\w+)\)", str(dig.get("ident"))) if ident is not None: digname.append(ident[1]) else: digname.append(f"HP_{str(dig.get('ident'))}") - diglocs[idx, :] = dig.get('r') + diglocs[idx, :] = dig.get("r") if add_montage: # First, get the template positions in mni fsaverage space - montage = make_standard_montage( - 'standard_1020', head_size=0.09700884729534559) + montage = make_standard_montage("standard_1020", head_size=0.09700884729534559) ch_names = montage.ch_names montage_locs = np.array(list(montage._get_ch_pos().values())) # montage_locs = np.array([d['r'] for d in montage.dig]) - head_mri_t, _ = _get_trans('fsaverage', 'mri', 'head') + head_mri_t, _ = _get_trans("fsaverage", "mri", "head") locations_head = apply_trans(head_mri_t, montage_locs) digname.extend(ch_names) diglocs = np.concatenate((diglocs, locations_head)) digname = [_str_encode(d) for d in digname] - probe.create_dataset('landmarkPos3D', data=diglocs) - probe.create_dataset('landmarkLabels', data=digname) + probe.create_dataset("landmarkPos3D", data=diglocs) + probe.create_dataset("landmarkLabels", data=digname) def _add_stim_info(raw, nirs): - """Adds details of the stimuli to the nirs group. + """Add details of the stimuli to the nirs group. Parameters ---------- @@ -315,19 +307,21 @@ def _add_stim_info(raw, nirs): # Convert MNE annotations to SNIRF stims descriptions = np.unique(raw.annotations.description) for idx, desc in enumerate(descriptions, start=1): - stim_group = nirs.create_group(f'stim{idx}') + stim_group = nirs.create_group(f"stim{idx}") trgs = np.where(raw.annotations.description == desc)[0] stims = np.zeros((len(trgs), 3)) for idx_t, trg in enumerate(trgs): - stims[idx_t, :] = [raw.annotations.onset[trg], - raw.annotations.duration[trg], - idx] - stim_group.create_dataset('data', data=stims) - stim_group.create_dataset('name', data=_str_encode(desc)) + stims[idx_t, :] = [ + raw.annotations.onset[trg], + raw.annotations.duration[trg], + idx, + ] + stim_group.create_dataset("data", data=stims) + stim_group.create_dataset("name", data=_str_encode(desc)) def _get_unique_source_list(raw): - """Returns the sorted list of distinct source ids. + """Return the sorted list of distinct source ids. Parameters ---------- @@ -339,7 +333,7 @@ def _get_unique_source_list(raw): def _get_unique_detector_list(raw): - """Returns the sorted list of distinct detector ids. + """Return the sorted list of distinct detector ids. Parameters ---------- @@ -351,7 +345,7 @@ def _get_unique_detector_list(raw): def _get_unique_wavelength_list(raw): - """Returns the sorted list of distinct wavelengths. + """Return the sorted list of distinct wavelengths. Parameters ---------- @@ -363,7 +357,7 @@ def _get_unique_wavelength_list(raw): def _match_channel_pattern(channel_name): - """Returns a regex match against the expected channel name format. + """Return a regex match against the expected channel name format. The returned match object contains three named groups: source, detector, and wavelength/type. If no match is found, a ValueError is raised. @@ -378,16 +372,16 @@ def _match_channel_pattern(channel_name): channel_name : str The name of the channel. """ - rgx = r'^S(?P\d+)_D(?P\d+) (?P[\w]+)$' + rgx = r"^S(?P\d+)_D(?P\d+) (?P[\w]+)$" match = re.fullmatch(rgx, channel_name) if match is None: - msg = f'channel name does not match expected pattern: {channel_name}' + msg = f"channel name does not match expected pattern: {channel_name}" raise ValueError(msg) return match def _extract_source(channel_name): - """Extracts and returns the source id from the channel name. + """Extract and return the source id from the channel name. The id is returned as an integer value. @@ -396,11 +390,11 @@ def _extract_source(channel_name): channel_name : str The name of the channel. """ - return int(_match_channel_pattern(channel_name).group('source')) + return int(_match_channel_pattern(channel_name).group("source")) def _extract_detector(channel_name): - """Extracts and returns the detector id from the channel name. + """Extract and return the detector id from the channel name. The id is returned as an integer value. @@ -409,4 +403,4 @@ def _extract_detector(channel_name): channel_name : str The name of the channel. """ - return int(_match_channel_pattern(channel_name).group('detector')) + return int(_match_channel_pattern(channel_name).group("detector")) diff --git a/mne_nirs/io/snirf/tests/test_snirf.py b/mne_nirs/io/snirf/tests/test_snirf.py index 1643e55ad..18f4df109 100644 --- a/mne_nirs/io/snirf/tests/test_snirf.py +++ b/mne_nirs/io/snirf/tests/test_snirf.py @@ -1,45 +1,44 @@ -# -*- coding: utf-8 -*- # Authors: Robert Luke # simplified BSD-3 license -import os.path as op import datetime +import os.path as op + import h5py -from numpy.testing import assert_allclose, assert_array_equal -import pytest import pandas as pd -from snirf import validateSnirf, Snirf - +import pytest from mne.datasets.testing import data_path, requires_testing_data +from mne.io import read_raw_nirx, read_raw_snirf +from mne.preprocessing.nirs import beer_lambert_law, optical_density from mne.utils import object_diff -from mne.io import read_raw_snirf, read_raw_nirx -from mne.preprocessing.nirs import optical_density, beer_lambert_law -from mne_nirs.io.snirf import write_raw_snirf, SPEC_FORMAT_VERSION, \ - read_snirf_aux_data -import mne_nirs.datasets.snirf_with_aux as aux +from numpy.testing import assert_allclose, assert_array_equal +from snirf import Snirf, validateSnirf -fname_nirx_15_0 = op.join(data_path(download=False), - 'NIRx', 'nirscout', 'nirx_15_0_recording') -fname_nirx_15_2 = op.join(data_path(download=False), - 'NIRx', 'nirscout', 'nirx_15_2_recording') -fname_nirx_15_2_short = op.join(data_path(download=False), - 'NIRx', 'nirscout', - 'nirx_15_2_recording_w_short') +import mne_nirs.datasets.snirf_with_aux as aux +from mne_nirs.io.snirf import SPEC_FORMAT_VERSION, read_snirf_aux_data, write_raw_snirf + +fname_nirx_15_0 = op.join( + data_path(download=False), "NIRx", "nirscout", "nirx_15_0_recording" +) +fname_nirx_15_2 = op.join( + data_path(download=False), "NIRx", "nirscout", "nirx_15_2_recording" +) +fname_nirx_15_2_short = op.join( + data_path(download=False), "NIRx", "nirscout", "nirx_15_2_recording_w_short" +) fname_snirf_aux = aux.data_path() -pytest.importorskip('h5py') +pytest.importorskip("h5py") @requires_testing_data -@pytest.mark.parametrize('fname', ( - fname_nirx_15_2_short, - fname_nirx_15_2, - fname_nirx_15_0 -)) +@pytest.mark.parametrize( + "fname", (fname_nirx_15_2_short, fname_nirx_15_2, fname_nirx_15_0) +) def test_snirf_write_raw(fname, tmpdir): """Test reading NIRX files.""" raw_orig = read_raw_nirx(fname, preload=True) - test_file = tmpdir.join('test_raw.snirf') + test_file = tmpdir.join("test_raw.snirf") write_raw_snirf(raw_orig, test_file) raw = read_raw_snirf(test_file) @@ -50,49 +49,51 @@ def test_snirf_write_raw(fname, tmpdir): # Check annotations are the same assert_allclose(raw.annotations.onset, raw_orig.annotations.onset) - assert_allclose([float(d) for d in raw.annotations.description], - [float(d) for d in raw_orig.annotations.description]) + assert_allclose( + [float(d) for d in raw.annotations.description], + [float(d) for d in raw_orig.annotations.description], + ) assert_allclose(raw.annotations.duration, raw_orig.annotations.duration) # Check data is the same assert_allclose(raw.get_data(), raw_orig.get_data()) - assert_array_equal(raw.info.get_channel_types(), - raw_orig.info.get_channel_types()) + assert_array_equal(raw.info.get_channel_types(), raw_orig.info.get_channel_types()) - assert abs(raw_orig.info['meas_date'] - raw.info['meas_date']) < \ - datetime.timedelta(seconds=1) + assert abs(raw_orig.info["meas_date"] - raw.info["meas_date"]) < datetime.timedelta( + seconds=1 + ) # Check info object is the same obj_diff = object_diff(raw.info, raw_orig.info) - diffs = '' + diffs = "" for line in obj_diff.splitlines(): - if ('logno' not in line) and \ - ('scanno' not in line) and \ - ('his_id' not in line) and\ - ('dig' not in line) and\ - ('datetime mismatch' not in line): + if ( + ("logno" not in line) + and ("scanno" not in line) + and ("his_id" not in line) + and ("dig" not in line) + and ("datetime mismatch" not in line) + ): # logno and scanno are not used in processing - diffs += f'\n{line}' - assert diffs == '' + diffs += f"\n{line}" + assert diffs == "" _verify_snirf_required_fields(test_file) _verify_snirf_version_str(test_file) @requires_testing_data -@pytest.mark.parametrize('fname', ( - fname_nirx_15_2_short, - fname_nirx_15_2, - fname_nirx_15_0 -)) +@pytest.mark.parametrize( + "fname", (fname_nirx_15_2_short, fname_nirx_15_2, fname_nirx_15_0) +) def test_snirf_write_optical_density(fname, tmpdir): """Test writing optical density SNIRF files.""" raw_nirx = read_raw_nirx(fname, preload=True) od_orig = optical_density(raw_nirx) - test_file = tmpdir.join('test_od.snirf') + test_file = tmpdir.join("test_od.snirf") write_raw_snirf(od_orig, test_file) od = read_raw_snirf(test_file) - assert 'fnirs_od' in od + assert "fnirs_od" in od result = validateSnirf(str(test_file)) if result.is_valid(): @@ -101,17 +102,20 @@ def test_snirf_write_optical_density(fname, tmpdir): @requires_testing_data -@pytest.mark.parametrize('fname', ( - fname_nirx_15_2, - fname_nirx_15_2_short, -)) +@pytest.mark.parametrize( + "fname", + ( + fname_nirx_15_2, + fname_nirx_15_2_short, + ), +) def test_snirf_write_haemoglobin(fname, tmpdir): """Test haemoglobin writing and reading.""" raw_nirx = read_raw_nirx(fname, preload=True) od_orig = optical_density(raw_nirx) hb_orig = beer_lambert_law(od_orig) assert hb_orig.annotations.duration[0] == 1 - test_file = tmpdir.join('test_raw_hb_no_mod.snirf') + test_file = tmpdir.join("test_raw_hb_no_mod.snirf") write_raw_snirf(hb_orig, test_file) result = validateSnirf(str(test_file)) @@ -121,7 +125,7 @@ def test_snirf_write_haemoglobin(fname, tmpdir): # HBO - test_file = tmpdir.join('test_raw_hbo_no_mod.snirf') + test_file = tmpdir.join("test_raw_hbo_no_mod.snirf") write_raw_snirf(hb_orig.copy().pick("hbo"), test_file) result = validateSnirf(str(test_file)) @@ -131,7 +135,7 @@ def test_snirf_write_haemoglobin(fname, tmpdir): # HBR - test_file = tmpdir.join('test_raw_hbr_no_mod.snirf') + test_file = tmpdir.join("test_raw_hbr_no_mod.snirf") write_raw_snirf(hb_orig.copy().pick("hbr"), test_file) result = validateSnirf(str(test_file)) @@ -141,35 +145,31 @@ def test_snirf_write_haemoglobin(fname, tmpdir): @requires_testing_data -@pytest.mark.parametrize('fname', ( - fname_nirx_15_2, -)) +@pytest.mark.parametrize("fname", (fname_nirx_15_2,)) def test_snirf_nobday(fname, tmpdir): """Ensure writing works when no birthday is present.""" raw_orig = read_raw_nirx(fname, preload=True) - raw_orig.info['subject_info'].pop('birthday', None) - test_file = tmpdir.join('test_raw.snirf') + raw_orig.info["subject_info"].pop("birthday", None) + test_file = tmpdir.join("test_raw.snirf") write_raw_snirf(raw_orig, test_file) raw = read_raw_snirf(test_file) assert_allclose(raw.get_data(), raw_orig.get_data()) @requires_testing_data -@pytest.mark.parametrize('fname', ( - fname_nirx_15_2, -)) +@pytest.mark.parametrize("fname", (fname_nirx_15_2,)) def test_snirf_extra_atlasviewer(fname, tmpdir): """Ensure writing atlasviewer landmarks.""" raw_orig = read_raw_nirx(fname, preload=True) - test_file = tmpdir.join('test_raw.snirf') + test_file = tmpdir.join("test_raw.snirf") write_raw_snirf(raw_orig, test_file, add_montage=False) raw = read_raw_snirf(test_file) - assert len([i['ident'] for i in raw.info['dig']]) == 35 + assert len([i["ident"] for i in raw.info["dig"]]) == 35 write_raw_snirf(raw_orig, test_file, add_montage=True) raw = read_raw_snirf(test_file) - assert len([i['ident'] for i in raw.info['dig']]) == 129 + assert len([i["ident"] for i in raw.info["dig"]]) == 129 snirf = Snirf(str(test_file), "r") assert len(snirf.nirs[0].probe.landmarkLabels) == 129 assert "Fpz" in snirf.nirs[0].probe.landmarkLabels @@ -178,62 +178,68 @@ def test_snirf_extra_atlasviewer(fname, tmpdir): def _verify_snirf_required_fields(test_file): - """Tests that all required fields are present. + """Test that all required fields are present. Uses version 1.1 of the spec: https://raw.githubusercontent.com/fNIRS/snirf/v1.1/snirf_specification.md """ required_metadata_fields = [ - 'SubjectID', 'MeasurementDate', 'MeasurementTime', - 'LengthUnit', 'TimeUnit', 'FrequencyUnit' + "SubjectID", + "MeasurementDate", + "MeasurementTime", + "LengthUnit", + "TimeUnit", + "FrequencyUnit", ] required_measurement_list_fields = [ - 'sourceIndex', 'detectorIndex', 'wavelengthIndex', - 'dataType', 'dataTypeIndex' + "sourceIndex", + "detectorIndex", + "wavelengthIndex", + "dataType", + "dataTypeIndex", ] - with h5py.File(test_file, 'r') as h5: + with h5py.File(test_file, "r") as h5: # Verify required base fields - assert 'nirs' in h5 - assert 'formatVersion' in h5 + assert "nirs" in h5 + assert "formatVersion" in h5 # Verify required metadata fields - assert 'metaDataTags' in h5['/nirs'] - metadata = h5['/nirs/metaDataTags'] + assert "metaDataTags" in h5["/nirs"] + metadata = h5["/nirs/metaDataTags"] for field in required_metadata_fields: assert field in metadata # Verify required data fields - assert 'data1' in h5['/nirs'] - data1 = h5['/nirs/data1'] - assert 'dataTimeSeries' in data1 - assert 'time' in data1 + assert "data1" in h5["/nirs"] + data1 = h5["/nirs/data1"] + assert "dataTimeSeries" in data1 + assert "time" in data1 # Verify required fields for each measurementList - measurement_lists = [k for k in data1.keys() - if k.startswith('measurementList')] + measurement_lists = [k for k in data1.keys() if k.startswith("measurementList")] for ml in measurement_lists: for field in required_measurement_list_fields: assert field in data1[ml] # Verify required fields for each stimulus - stims = [k for k in h5['/nirs'].keys() if k.startswith('stim')] + stims = [k for k in h5["/nirs"].keys() if k.startswith("stim")] for stim in stims: - assert 'name' in h5['/nirs'][stim] - assert 'data' in h5['/nirs'][stim] + assert "name" in h5["/nirs"][stim] + assert "data" in h5["/nirs"][stim] # Verify probe fields - assert 'probe' in h5['/nirs'] - probe = h5['/nirs/probe'] - assert 'wavelengths' in probe - assert 'sourcePos3D' in probe or 'sourcePos2D' in probe - assert 'detectorPos3D' in probe or 'detectorPos2D' in probe + assert "probe" in h5["/nirs"] + probe = h5["/nirs/probe"] + assert "wavelengths" in probe + assert "sourcePos3D" in probe or "sourcePos2D" in probe + assert "detectorPos3D" in probe or "detectorPos2D" in probe def _verify_snirf_version_str(test_file): """Verify that the version string contains the correct spec version.""" - with h5py.File(test_file, 'r') as h5: - version_str = h5['/formatVersion'][()].decode('UTF-8') + with h5py.File(test_file, "r") as h5: + version_str = h5["/formatVersion"][()].decode("UTF-8") expected_str = SPEC_FORMAT_VERSION assert version_str == expected_str @@ -243,47 +249,48 @@ def test_aux_read(): raw = read_raw_snirf(fname_snirf_aux) a = read_snirf_aux_data(fname_snirf_aux, raw) assert type(a) is pd.DataFrame - assert 'accelerometer_2_z' in a - assert len(a['gyroscope_1_z']) == len(raw.times) + assert "accelerometer_2_z" in a + assert len(a["gyroscope_1_z"]) == len(raw.times) @requires_testing_data -@pytest.mark.parametrize('fname', ( - fname_nirx_15_2, - fname_nirx_15_2_short, -)) +@pytest.mark.parametrize( + "fname", + ( + fname_nirx_15_2, + fname_nirx_15_2_short, + ), +) def test_snirf_stim_roundtrip(fname, tmpdir): """Ensure snirf annotations are written.""" raw_orig = read_raw_nirx(fname, preload=True) assert raw_orig.annotations.duration[0] == 1 raw_mod = raw_orig.copy() - test_file = tmpdir.join('test_raw_no_mod.snirf') + test_file = tmpdir.join("test_raw_no_mod.snirf") write_raw_snirf(raw_mod, test_file) raw = read_raw_snirf(test_file) - assert_array_equal(raw_orig.annotations.onset, - raw.annotations.onset) - assert_array_equal(raw_orig.annotations.duration, - raw.annotations.duration) - assert_array_equal(raw_orig.annotations.description, - raw.annotations.description) + assert_array_equal(raw_orig.annotations.onset, raw.annotations.onset) + assert_array_equal(raw_orig.annotations.duration, raw.annotations.duration) + assert_array_equal(raw_orig.annotations.description, raw.annotations.description) @requires_testing_data -@pytest.mark.parametrize('fname', ( - fname_nirx_15_2, - fname_nirx_15_2_short, -)) -@pytest.mark.parametrize('newduration', ( - 1, 2, 3 -)) +@pytest.mark.parametrize( + "fname", + ( + fname_nirx_15_2, + fname_nirx_15_2_short, + ), +) +@pytest.mark.parametrize("newduration", (1, 2, 3)) def test_snirf_duration(fname, newduration, tmpdir): """Ensure snirf annotations are written to file.""" - pytest.importorskip('mne', '1.4') + pytest.importorskip("mne", "1.4") raw_orig = read_raw_nirx(fname, preload=True) assert raw_orig.annotations.duration[0] == 1 raw_mod = raw_orig.copy() raw_mod.annotations.set_durations(newduration) - test_file = tmpdir.join('test_raw_duration.snirf') + test_file = tmpdir.join("test_raw_duration.snirf") write_raw_snirf(raw_mod, test_file) raw = read_raw_snirf(test_file) assert raw.annotations.duration[0] == newduration @@ -291,55 +298,53 @@ def test_snirf_duration(fname, newduration, tmpdir): @requires_testing_data -@pytest.mark.parametrize('fname', ( - fname_nirx_15_2, - fname_nirx_15_2_short, -)) +@pytest.mark.parametrize( + "fname", + ( + fname_nirx_15_2, + fname_nirx_15_2_short, + ), +) def test_optical_density_roundtrip(fname, tmpdir): """Test optical density writing and reading.""" raw_nirx = read_raw_nirx(fname, preload=True) od_orig = optical_density(raw_nirx) assert od_orig.annotations.duration[0] == 1 - test_file = tmpdir.join('test_raw_no_mod.snirf') + test_file = tmpdir.join("test_raw_no_mod.snirf") write_raw_snirf(od_orig, test_file) od = read_raw_snirf(test_file) - assert 'fnirs_od' in od - assert_array_equal(od_orig.annotations.onset, - od.annotations.onset) - assert_array_equal(od_orig.annotations.duration, - od.annotations.duration) - assert_array_equal(od_orig.annotations.description, - od.annotations.description) + assert "fnirs_od" in od + assert_array_equal(od_orig.annotations.onset, od.annotations.onset) + assert_array_equal(od_orig.annotations.duration, od.annotations.duration) + assert_array_equal(od_orig.annotations.description, od.annotations.description) assert_array_equal(od_orig.get_data(), od.get_data()) - assert_array_equal(od_orig.info.get_channel_types(), - od.info.get_channel_types()) + assert_array_equal(od_orig.info.get_channel_types(), od.info.get_channel_types()) @requires_testing_data -@pytest.mark.parametrize('fname', ( - fname_nirx_15_2, - fname_nirx_15_2_short, -)) +@pytest.mark.parametrize( + "fname", + ( + fname_nirx_15_2, + fname_nirx_15_2_short, + ), +) def test_haemoglobin_roundtrip(fname, tmpdir): """Test haemoglobin writing and reading.""" raw_nirx = read_raw_nirx(fname, preload=True) od_orig = optical_density(raw_nirx) hb_orig = beer_lambert_law(od_orig) assert hb_orig.annotations.duration[0] == 1 - test_file = tmpdir.join('test_raw_hb_no_mod.snirf') + test_file = tmpdir.join("test_raw_hb_no_mod.snirf") write_raw_snirf(hb_orig, test_file) hb = read_raw_snirf(test_file) - assert 'hbo' in hb - assert 'hbr' in hb - assert_array_equal(hb_orig.annotations.onset, - hb.annotations.onset) - assert_array_equal(hb_orig.annotations.duration, - hb.annotations.duration) - assert_array_equal(hb_orig.annotations.description, - hb.annotations.description) + assert "hbo" in hb + assert "hbr" in hb + assert_array_equal(hb_orig.annotations.onset, hb.annotations.onset) + assert_array_equal(hb_orig.annotations.duration, hb.annotations.duration) + assert_array_equal(hb_orig.annotations.description, hb.annotations.description) assert_array_equal(hb_orig.get_data(), hb.get_data()) - assert_array_equal(hb_orig.info.get_channel_types(), - hb.info.get_channel_types()) + assert_array_equal(hb_orig.info.get_channel_types(), hb.info.get_channel_types()) # # # HBO # diff --git a/mne_nirs/preprocessing/_mayer.py b/mne_nirs/preprocessing/_mayer.py index c5c6af627..7ca3c0ad4 100644 --- a/mne_nirs/preprocessing/_mayer.py +++ b/mne_nirs/preprocessing/_mayer.py @@ -4,17 +4,24 @@ import numpy as np import pandas as pd - from mne import pick_types from mne.io import BaseRaw -from mne.utils import _validate_type, _require_version - - -def quantify_mayer_fooof(raw, num_oscillations=1, centre_frequency=0.01, - extra_df_fields={}, - fmin=0.001, fmax=1, tmin=0, tmax=None, - n_fft=400, n_overlap=200, - peak_width_limits=(0.5, 12.0)): +from mne.utils import _require_version, _validate_type + + +def quantify_mayer_fooof( + raw, + num_oscillations=1, + centre_frequency=0.01, + extra_df_fields=None, + fmin=0.001, + fmax=1, + tmin=0, + tmax=None, + n_fft=400, + n_overlap=200, + peak_width_limits=(0.5, 12.0), +): """ Quantify Mayer wave properties using FOOOF analysis. @@ -69,29 +76,35 @@ def quantify_mayer_fooof(raw, num_oscillations=1, centre_frequency=0.01, ---------- .. footbibliography:: """ - _require_version('fooof', 'run the FOOOF algorithm.') - _validate_type(raw, BaseRaw, 'raw') + _require_version("fooof", "run the FOOOF algorithm.") + _validate_type(raw, BaseRaw, "raw") - hbo_picks = pick_types(raw.info, fnirs='hbo') - hbr_picks = pick_types(raw.info, fnirs='hbr') + extra_df_fields = {} if extra_df_fields is None else extra_df_fields + hbo_picks = pick_types(raw.info, fnirs="hbo") + hbr_picks = pick_types(raw.info, fnirs="hbr") if (not len(hbo_picks)) & (not len(hbr_picks)): # It may be perfectly valid to compute this on optical density # or raw data, I just haven't tried this. Let me know if this works # for you and we can ease this restriction. - raise RuntimeError('Mayer wave estimation should be run on ' - 'haemoglobin concentration data.') + raise RuntimeError( + "Mayer wave estimation should be run on " "haemoglobin concentration data." + ) df = pd.DataFrame() for picks, chroma in zip([hbo_picks, hbr_picks], ["hbo", "hbr"]): if len(picks): - - fm_hbo = _run_fooof(raw.copy().pick(picks), - fmin=fmin, fmax=fmax, - tmin=tmin, tmax=tmax, - n_overlap=n_overlap, n_fft=n_fft, - peak_width_limits=peak_width_limits) + fm_hbo = _run_fooof( + raw.copy().pick(picks), + fmin=fmin, + fmax=fmax, + tmin=tmin, + tmax=tmax, + n_overlap=n_overlap, + n_fft=n_fft, + peak_width_limits=peak_width_limits, + ) cf, pw, bw = _process_fooof_output(fm_hbo, centre_frequency) @@ -102,23 +115,27 @@ def quantify_mayer_fooof(raw, num_oscillations=1, centre_frequency=0.01, data["Chromaphore"] = chroma data = {**data, **extra_df_fields} - df = pd.concat([df, pd.DataFrame(data, index=[0])], - ignore_index=True) + df = pd.concat([df, pd.DataFrame(data, index=[0])], ignore_index=True) return df -def _run_fooof(raw, - fmin=0.001, fmax=1, - tmin=0, tmax=None, - n_overlap=200, n_fft=400, - peak_width_limits=(0.5, 12.0)): +def _run_fooof( + raw, + fmin=0.001, + fmax=1, + tmin=0, + tmax=None, + n_overlap=200, + n_fft=400, + peak_width_limits=(0.5, 12.0), +): """Prepare data for FOOOF including welch and scaling, then apply.""" from fooof import FOOOF psd = raw.compute_psd( - fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, - n_overlap=n_overlap, n_fft=n_fft) + fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, n_overlap=n_overlap, n_fft=n_fft + ) spectra, freqs = psd.get_data(return_freqs=True) # FOOOF doesn't like low frequencies, so multiple by 10. diff --git a/mne_nirs/preprocessing/_peak_power.py b/mne_nirs/preprocessing/_peak_power.py index 877c1c440..1179ad5aa 100644 --- a/mne_nirs/preprocessing/_peak_power.py +++ b/mne_nirs/preprocessing/_peak_power.py @@ -3,20 +3,24 @@ # License: BSD (3-clause) import numpy as np - -from scipy.signal import periodogram - -from mne import pick_types +from mne.filter import filter_data from mne.io import BaseRaw -from mne.utils import _validate_type, verbose from mne.preprocessing.nirs import _validate_nirs_info -from mne.filter import filter_data +from mne.utils import _validate_type, verbose +from scipy.signal import periodogram @verbose -def peak_power(raw, time_window=10, threshold=0.1, l_freq=0.7, h_freq=1.5, - l_trans_bandwidth=0.3, h_trans_bandwidth=0.3, - verbose=False): +def peak_power( + raw, + time_window=10, + threshold=0.1, + l_freq=0.7, + h_freq=1.5, + l_trans_bandwidth=0.3, + h_trans_bandwidth=0.3, + verbose=False, +): """ Compute peak spectral power metric for each channel and time window. @@ -59,25 +63,29 @@ def peak_power(raw, time_window=10, threshold=0.1, l_freq=0.7, h_freq=1.5, quality assessment of fNIRS scans." Optics and the Brain. Optical Society of America, 2020. """ - raw = raw.copy().load_data() - _validate_type(raw, BaseRaw, 'raw') + _validate_type(raw, BaseRaw, "raw") picks = _validate_nirs_info(raw.info) - filtered_data = filter_data(raw._data, raw.info['sfreq'], l_freq, h_freq, - picks=picks, verbose=verbose, - l_trans_bandwidth=l_trans_bandwidth, - h_trans_bandwidth=h_trans_bandwidth) - - window_samples = int(np.ceil(time_window * raw.info['sfreq'])) + filtered_data = filter_data( + raw._data, + raw.info["sfreq"], + l_freq, + h_freq, + picks=picks, + verbose=verbose, + l_trans_bandwidth=l_trans_bandwidth, + h_trans_bandwidth=h_trans_bandwidth, + ) + + window_samples = int(np.ceil(time_window * raw.info["sfreq"])) n_windows = int(np.floor(len(raw) / window_samples)) scores = np.zeros((len(picks), n_windows)) times = [] for window in range(n_windows): - start_sample = int(window * window_samples) end_sample = start_sample + window_samples end_sample = np.min([end_sample, len(raw) - 1]) @@ -87,7 +95,6 @@ def peak_power(raw, time_window=10, threshold=0.1, l_freq=0.7, h_freq=1.5, times.append((t_start, t_stop)) for ii in range(0, len(picks), 2): - c1 = filtered_data[picks[ii]][start_sample:end_sample] c2 = filtered_data[picks[ii + 1]][start_sample:end_sample] @@ -97,13 +104,17 @@ def peak_power(raw, time_window=10, threshold=0.1, l_freq=0.7, h_freq=1.5, c = np.correlate(c1, c2, "full") c = c / (window_samples) - [f, pxx] = periodogram(c, fs=raw.info['sfreq'], window='hamming') + [f, pxx] = periodogram(c, fs=raw.info["sfreq"], window="hamming") scores[ii, window] = max(pxx) scores[ii + 1, window] = max(pxx) if (threshold is not None) & (max(pxx) < threshold): - raw.annotations.append(t_start, time_window, 'BAD_PeakPower', - ch_names=[raw.ch_names[ii:ii + 2]]) + raw.annotations.append( + t_start, + time_window, + "BAD_PeakPower", + ch_names=[raw.ch_names[ii : ii + 2]], + ) scores = scores[np.argsort(picks)] return raw, scores, times diff --git a/mne_nirs/preprocessing/_scalp_coupling_segmented.py b/mne_nirs/preprocessing/_scalp_coupling_segmented.py index d535c64b8..4f82d5cfa 100644 --- a/mne_nirs/preprocessing/_scalp_coupling_segmented.py +++ b/mne_nirs/preprocessing/_scalp_coupling_segmented.py @@ -3,20 +3,23 @@ # License: BSD (3-clause) import numpy as np - -from mne import pick_types +from mne.filter import filter_data from mne.io import BaseRaw -from mne.utils import _validate_type, verbose from mne.preprocessing.nirs import _validate_nirs_info -from mne.filter import filter_data +from mne.utils import _validate_type, verbose @verbose -def scalp_coupling_index_windowed(raw, time_window=10, threshold=0.1, - l_freq=0.7, h_freq=1.5, - l_trans_bandwidth=0.3, - h_trans_bandwidth=0.3, - verbose=False): +def scalp_coupling_index_windowed( + raw, + time_window=10, + threshold=0.1, + l_freq=0.7, + h_freq=1.5, + l_trans_bandwidth=0.3, + h_trans_bandwidth=0.3, + verbose=False, +): """ Compute scalp coupling index for each channel and time window. @@ -59,26 +62,29 @@ def scalp_coupling_index_windowed(raw, time_window=10, threshold=0.1, quality assessment of fNIRS scans." Optics and the Brain. Optical Society of America, 2020. """ - raw = raw.copy().load_data() - _validate_type(raw, BaseRaw, 'raw') - - picks = _validate_nirs_info( - raw.info, fnirs='od', which='Scalp coupling index') - - filtered_data = filter_data(raw._data, raw.info['sfreq'], l_freq, h_freq, - picks=picks, verbose=verbose, - l_trans_bandwidth=l_trans_bandwidth, - h_trans_bandwidth=h_trans_bandwidth) - - window_samples = int(np.ceil(time_window * raw.info['sfreq'])) + _validate_type(raw, BaseRaw, "raw") + + picks = _validate_nirs_info(raw.info, fnirs="od", which="Scalp coupling index") + + filtered_data = filter_data( + raw._data, + raw.info["sfreq"], + l_freq, + h_freq, + picks=picks, + verbose=verbose, + l_trans_bandwidth=l_trans_bandwidth, + h_trans_bandwidth=h_trans_bandwidth, + ) + + window_samples = int(np.ceil(time_window * raw.info["sfreq"])) n_windows = int(np.floor(len(raw) / window_samples)) scores = np.zeros((len(picks), n_windows)) times = [] for window in range(n_windows): - start_sample = int(window * window_samples) end_sample = start_sample + window_samples end_sample = np.min([end_sample, len(raw) - 1]) @@ -88,7 +94,6 @@ def scalp_coupling_index_windowed(raw, time_window=10, threshold=0.1, times.append((t_start, t_stop)) for ii in range(0, len(picks), 2): - c1 = filtered_data[picks[ii]][start_sample:end_sample] c2 = filtered_data[picks[ii + 1]][start_sample:end_sample] c = np.corrcoef(c1, c2)[0][1] @@ -96,7 +101,11 @@ def scalp_coupling_index_windowed(raw, time_window=10, threshold=0.1, scores[ii + 1, window] = c if (threshold is not None) & (c < threshold): - raw.annotations.append(t_start, time_window, 'BAD_SCI', - ch_names=[raw.ch_names[ii:ii + 2]]) + raw.annotations.append( + t_start, + time_window, + "BAD_SCI", + ch_names=[raw.ch_names[ii : ii + 2]], + ) scores = scores[np.argsort(picks)] return raw, scores, times diff --git a/mne_nirs/preprocessing/tests/test_mayer.py b/mne_nirs/preprocessing/tests/test_mayer.py index 818a9af45..ef55071dd 100644 --- a/mne_nirs/preprocessing/tests/test_mayer.py +++ b/mne_nirs/preprocessing/tests/test_mayer.py @@ -3,13 +3,14 @@ # License: BSD (3-clause) import os + import mne -import pytest import numpy as np +import pytest from mne_nirs.preprocessing import quantify_mayer_fooof -pytest.importorskip('fooof') +pytest.importorskip("fooof") # in fooof->scipy.optimize.curve_fit @@ -17,9 +18,8 @@ @pytest.mark.filterwarnings("ignore:invalid value encountered in.*:") def test_mayer(): fnirs_data_folder = mne.datasets.fnirs_motor.data_path() - fnirs_raw_dir = os.path.join(fnirs_data_folder, 'Participant-1') - raw_intensity = mne.io.read_raw_nirx(fnirs_raw_dir, - verbose=True).load_data() + fnirs_raw_dir = os.path.join(fnirs_data_folder, "Participant-1") + raw_intensity = mne.io.read_raw_nirx(fnirs_raw_dir, verbose=True).load_data() raw_intensity = raw_intensity.pick(picks=range(8)).crop(tmax=600) @@ -36,5 +36,7 @@ def test_mayer(): assert df_mayer.shape[0] == 2 assert df_mayer.shape[1] == 4 - assert np.abs(df_mayer.query('Chromaphore == "hbo"' - )["Centre Frequency"][0] - 0.1) < 0.05 + assert ( + np.abs(df_mayer.query('Chromaphore == "hbo"')["Centre Frequency"][0] - 0.1) + < 0.05 + ) diff --git a/mne_nirs/preprocessing/tests/test_quality.py b/mne_nirs/preprocessing/tests/test_quality.py index 83af758fc..f7478e3dd 100644 --- a/mne_nirs/preprocessing/tests/test_quality.py +++ b/mne_nirs/preprocessing/tests/test_quality.py @@ -3,15 +3,15 @@ # License: BSD (3-clause) import os + import mne from mne_nirs.preprocessing import peak_power, scalp_coupling_index_windowed def test_peak_power(): - fnirs_data_folder = mne.datasets.fnirs_motor.data_path() - fnirs_raw_dir = os.path.join(fnirs_data_folder, 'Participant-1') + fnirs_raw_dir = os.path.join(fnirs_data_folder, "Participant-1") raw = mne.io.read_raw_nirx(fnirs_raw_dir, verbose=True).load_data() raw = mne.preprocessing.nirs.optical_density(raw) @@ -20,9 +20,8 @@ def test_peak_power(): def test_sci_windowed(): - fnirs_data_folder = mne.datasets.fnirs_motor.data_path() - fnirs_raw_dir = os.path.join(fnirs_data_folder, 'Participant-1') + fnirs_raw_dir = os.path.join(fnirs_data_folder, "Participant-1") raw = mne.io.read_raw_nirx(fnirs_raw_dir, verbose=True).load_data() raw = mne.preprocessing.nirs.optical_density(raw) diff --git a/mne_nirs/signal_enhancement/_negative_correlation.py b/mne_nirs/signal_enhancement/_negative_correlation.py index 648b5264d..0a0143f7c 100644 --- a/mne_nirs/signal_enhancement/_negative_correlation.py +++ b/mne_nirs/signal_enhancement/_negative_correlation.py @@ -3,9 +3,8 @@ # License: BSD (3-clause) import numpy as np - -from mne.io import BaseRaw from mne import pick_types +from mne.io import BaseRaw from mne.utils import _validate_type @@ -37,22 +36,25 @@ def enhance_negative_correlation(raw): https://doi.org/10.1016/j.neuroimage.2009.11.050 """ raw = raw.copy().load_data() - _validate_type(raw, BaseRaw, 'raw') + _validate_type(raw, BaseRaw, "raw") - hbo_channels = pick_types(raw.info, fnirs='hbo') - hbr_channels = pick_types(raw.info, fnirs='hbr') + hbo_channels = pick_types(raw.info, fnirs="hbo") + hbr_channels = pick_types(raw.info, fnirs="hbr") if (not len(hbo_channels)) & (not len(hbr_channels)): - raise RuntimeError('enhance_negative_correlation should ' - 'be run on haemoglobin data.') + raise RuntimeError( + "enhance_negative_correlation should " "be run on haemoglobin data." + ) if len(hbo_channels) != len(hbr_channels): - raise RuntimeError('Same number of hbo and hbr channels required.') + raise RuntimeError("Same number of hbo and hbr channels required.") for idx in range(len(hbo_channels)): - if raw.info['chs'][hbo_channels[idx]]['ch_name'][:-4] != \ - raw.info['chs'][hbr_channels[idx]]['ch_name'][:-4]: - raise RuntimeError('Channels must alternate between HBO and HBR.') + if ( + raw.info["chs"][hbo_channels[idx]]["ch_name"][:-4] + != raw.info["chs"][hbr_channels[idx]]["ch_name"][:-4] + ): + raise RuntimeError("Channels must alternate between HBO and HBR.") for idx in range(len(hbo_channels)): hbo = raw._data[hbo_channels[idx]] @@ -67,7 +69,6 @@ def enhance_negative_correlation(raw): alpha = hbo_std / hbr_std raw._data[hbo_channels[idx]] = 0.5 * (hbo - alpha * hbr) - raw._data[hbr_channels[idx]] = -(1 / alpha) * \ - raw._data[hbo_channels[idx]] + raw._data[hbr_channels[idx]] = -(1 / alpha) * raw._data[hbo_channels[idx]] return raw diff --git a/mne_nirs/signal_enhancement/_short_channel_correction.py b/mne_nirs/signal_enhancement/_short_channel_correction.py index 10171b7d6..224076a9c 100644 --- a/mne_nirs/signal_enhancement/_short_channel_correction.py +++ b/mne_nirs/signal_enhancement/_short_channel_correction.py @@ -3,12 +3,11 @@ # License: BSD (3-clause) import numpy as np -from scipy import linalg - -from mne.io import BaseRaw from mne import pick_types -from mne.utils import _validate_type +from mne.io import BaseRaw from mne.preprocessing.nirs import source_detector_distances +from mne.utils import _validate_type +from scipy import linalg def short_channel_regression(raw, max_dist=0.01): @@ -36,12 +35,12 @@ def short_channel_regression(raw, max_dist=0.01): .. footbibliography:: """ raw = raw.copy().load_data() - _validate_type(raw, BaseRaw, 'raw') + _validate_type(raw, BaseRaw, "raw") - picks_od = pick_types(raw.info, fnirs='fnirs_od') + picks_od = pick_types(raw.info, fnirs="fnirs_od") if len(picks_od) == 0: - raise RuntimeError('Data must be optical density.') + raise RuntimeError("Data must be optical density.") distances = source_detector_distances(raw.info) @@ -49,12 +48,11 @@ def short_channel_regression(raw, max_dist=0.01): picks_long = picks_od[distances[picks_od] > max_dist] if len(picks_short) == 0: - raise RuntimeError('No short channels present.') + raise RuntimeError("No short channels present.") if len(picks_long) == 0: - raise RuntimeError('No long channels present.') + raise RuntimeError("No long channels present.") for pick in picks_long: - short_idx = _find_nearest_short(raw, pick, picks_short) A_l = raw.get_data(pick).ravel() @@ -70,8 +68,7 @@ def short_channel_regression(raw, max_dist=0.01): def _find_nearest_short(raw, pick, short_picks): - """" - Return index of closest short channel + """Return index of closest short channel. Parameters ---------- @@ -90,9 +87,9 @@ def _find_nearest_short(raw, pick, short_picks): in short_picks. """ - - dist = [linalg.norm(raw.info['chs'][pick]['loc'][:3] - - raw.info['chs'][p_sh]['loc'][:3]) - for p_sh in short_picks] + dist = [ + linalg.norm(raw.info["chs"][pick]["loc"][:3] - raw.info["chs"][p_sh]["loc"][:3]) + for p_sh in short_picks + ] return short_picks[np.argmin(dist)] diff --git a/mne_nirs/signal_enhancement/tests/test_negative_correlation.py b/mne_nirs/signal_enhancement/tests/test_negative_correlation.py index df4749c5f..a752513cd 100644 --- a/mne_nirs/signal_enhancement/tests/test_negative_correlation.py +++ b/mne_nirs/signal_enhancement/tests/test_negative_correlation.py @@ -3,35 +3,38 @@ # License: BSD (3-clause) import os + import mne -import mne_nirs import numpy as np import pytest +import mne_nirs + def _load_dataset(): """Load data and tidy it a bit""" fnirs_data_folder = mne.datasets.fnirs_motor.data_path() - fnirs_raw_dir = os.path.join(fnirs_data_folder, 'Participant-1') - raw_intensity = mne.io.read_raw_nirx(fnirs_raw_dir, - verbose=True).load_data() + fnirs_raw_dir = os.path.join(fnirs_data_folder, "Participant-1") + raw_intensity = mne.io.read_raw_nirx(fnirs_raw_dir, verbose=True).load_data() raw_intensity.crop(0, raw_intensity.annotations.onset[-1]) new_des = [des for des in raw_intensity.annotations.description] - new_des = ['A' if x == "1.0" else x for x in new_des] - new_des = ['B' if x == "2.0" else x for x in new_des] - new_des = ['C' if x == "3.0" else x for x in new_des] - annot = mne.Annotations(raw_intensity.annotations.onset, - raw_intensity.annotations.duration, new_des) + new_des = ["A" if x == "1.0" else x for x in new_des] + new_des = ["B" if x == "2.0" else x for x in new_des] + new_des = ["C" if x == "3.0" else x for x in new_des] + annot = mne.Annotations( + raw_intensity.annotations.onset, raw_intensity.annotations.duration, new_des + ) raw_intensity.set_annotations(annot) picks = mne.pick_types(raw_intensity.info, meg=False, fnirs=True) dists = mne.preprocessing.nirs.source_detector_distances( - raw_intensity.info, picks=picks) + raw_intensity.info, picks=picks + ) raw_intensity.pick(picks[dists > 0.01]) - assert 'fnirs_cw_amplitude' in raw_intensity + assert "fnirs_cw_amplitude" in raw_intensity assert len(np.unique(raw_intensity.annotations.description)) == 4 return raw_intensity @@ -41,21 +44,17 @@ def test_cui(): raw_intensity = _load_dataset() raw_intensity = raw_intensity.pick(picks=range(8)) # Keep the test fast with pytest.raises(RuntimeError, match="run on haemoglobin"): - _ = mne_nirs.signal_enhancement.enhance_negative_correlation( - raw_intensity) + _ = mne_nirs.signal_enhancement.enhance_negative_correlation(raw_intensity) raw_od = mne.preprocessing.nirs.optical_density(raw_intensity) with pytest.raises(RuntimeError, match="run on haemoglobin"): - _ = mne_nirs.signal_enhancement.enhance_negative_correlation( - raw_od) + _ = mne_nirs.signal_enhancement.enhance_negative_correlation(raw_od) raw_haemo = mne.preprocessing.nirs.beer_lambert_law(raw_od, ppf=0.1) - raw_anti = mne_nirs.signal_enhancement.enhance_negative_correlation( - raw_haemo) - assert np.abs(np.corrcoef(raw_haemo._data[0], - raw_haemo._data[1])[0, 1]) < 1 + raw_anti = mne_nirs.signal_enhancement.enhance_negative_correlation(raw_haemo) + assert np.abs(np.corrcoef(raw_haemo._data[0], raw_haemo._data[1])[0, 1]) < 1 - np.testing.assert_almost_equal(np.corrcoef(raw_anti._data[0], - raw_anti._data[1])[0, 1], - -1) + np.testing.assert_almost_equal( + np.corrcoef(raw_anti._data[0], raw_anti._data[1])[0, 1], -1 + ) d1 = raw_haemo.copy().pick(picks=range(3)) with pytest.raises(RuntimeError, match="Same number of hbo and hbr"): diff --git a/mne_nirs/signal_enhancement/tests/test_short_channels.py b/mne_nirs/signal_enhancement/tests/test_short_channels.py index b1d0f9d50..492c3e20a 100644 --- a/mne_nirs/signal_enhancement/tests/test_short_channels.py +++ b/mne_nirs/signal_enhancement/tests/test_short_channels.py @@ -3,32 +3,33 @@ # License: BSD (3-clause) import os + import mne import numpy as np import pytest -from mne_nirs.signal_enhancement import short_channel_regression from mne_nirs.channels import get_long_channels, get_short_channels +from mne_nirs.signal_enhancement import short_channel_regression def _load_dataset(): """Load data and tidy it a bit""" fnirs_data_folder = mne.datasets.fnirs_motor.data_path() - fnirs_raw_dir = os.path.join(fnirs_data_folder, 'Participant-1') - raw_intensity = mne.io.read_raw_nirx(fnirs_raw_dir, - verbose=True).load_data() + fnirs_raw_dir = os.path.join(fnirs_data_folder, "Participant-1") + raw_intensity = mne.io.read_raw_nirx(fnirs_raw_dir, verbose=True).load_data() raw_intensity.crop(0, raw_intensity.annotations.onset[-1]) new_des = [des for des in raw_intensity.annotations.description] - new_des = ['A' if x == "1.0" else x for x in new_des] - new_des = ['B' if x == "2.0" else x for x in new_des] - new_des = ['C' if x == "3.0" else x for x in new_des] - annot = mne.Annotations(raw_intensity.annotations.onset, - raw_intensity.annotations.duration, new_des) + new_des = ["A" if x == "1.0" else x for x in new_des] + new_des = ["B" if x == "2.0" else x for x in new_des] + new_des = ["C" if x == "3.0" else x for x in new_des] + annot = mne.Annotations( + raw_intensity.annotations.onset, raw_intensity.annotations.duration, new_des + ) raw_intensity.set_annotations(annot) - assert 'fnirs_cw_amplitude' in raw_intensity + assert "fnirs_cw_amplitude" in raw_intensity assert len(np.unique(raw_intensity.annotations.description)) == 4 return raw_intensity @@ -44,7 +45,7 @@ def test_short(): raw_od_corrected = short_channel_regression(raw_od) - assert 'fnirs_od' in raw_od_corrected + assert "fnirs_od" in raw_od_corrected with pytest.raises(RuntimeError, match="long channels present"): short_channel_regression(get_short_channels(raw_od)) diff --git a/mne_nirs/simulation/_simulation.py b/mne_nirs/simulation/_simulation.py index e5c6f4ac4..a2bc65d6f 100644 --- a/mne_nirs/simulation/_simulation.py +++ b/mne_nirs/simulation/_simulation.py @@ -7,15 +7,17 @@ from mne.io import RawArray -def simulate_nirs_raw(sfreq=3., - amplitude=1., - annot_desc='A', - sig_dur=300., - stim_dur=5., - isi_min=15., - isi_max=45., - ch_name='Simulated', - hrf_model='glover'): +def simulate_nirs_raw( + sfreq=3.0, + amplitude=1.0, + annot_desc="A", + sig_dur=300.0, + stim_dur=5.0, + isi_min=15.0, + isi_max=45.0, + ch_name="Simulated", + hrf_model="glover", +): """ Create simulated fNIRS data. @@ -58,21 +60,23 @@ def simulate_nirs_raw(sfreq=3., from nilearn.glm.first_level import make_first_level_design_matrix from pandas import DataFrame - if type(amplitude) is not list: + if not isinstance(amplitude, list): amplitude = [amplitude] - if type(annot_desc) is not list: + if not isinstance(annot_desc, list): annot_desc = [annot_desc] - if type(stim_dur) is not list: + if not isinstance(stim_dur, list): stim_dur = [stim_dur] frame_times = np.arange(sig_dur * sfreq) / sfreq - assert len(amplitude) == len(annot_desc), "Same number of amplitudes as " \ - "annotations required." - assert len(amplitude) == len(stim_dur), "Same number of amplitudes as " \ - "durations required." + assert len(amplitude) == len(annot_desc), ( + "Same number of amplitudes as " "annotations required." + ) + assert len(amplitude) == len(stim_dur), ( + "Same number of amplitudes as " "durations required." + ) - onset = 0. + onset = 0.0 onsets = [] conditions = [] durations = [] @@ -83,28 +87,31 @@ def simulate_nirs_raw(sfreq=3., conditions.append(annot_desc[c_idx]) durations.append(stim_dur[c_idx]) - events = DataFrame({'trial_type': conditions, - 'onset': onsets, - 'duration': durations}) + events = DataFrame( + {"trial_type": conditions, "onset": onsets, "duration": durations} + ) - dm = make_first_level_design_matrix(frame_times, events, - hrf_model=hrf_model, - drift_model='polynomial', - drift_order=0) - dm = dm.drop(columns='constant') + dm = make_first_level_design_matrix( + frame_times, + events, + hrf_model=hrf_model, + drift_model="polynomial", + drift_order=0, + ) + dm = dm.drop(columns="constant") annotations = Annotations(onsets, durations, conditions) - info = create_info(ch_names=[ch_name], sfreq=sfreq, ch_types=['hbo']) + info = create_info(ch_names=[ch_name], sfreq=sfreq, ch_types=["hbo"]) for idx, annot in enumerate(annot_desc): if annot in dm.columns: dm[annot] *= amplitude[idx] - a = np.sum(dm.to_numpy(), axis=1) * 1.e-6 + a = np.sum(dm.to_numpy(), axis=1) * 1.0e-6 a = a.reshape(-1, 1).T raw = RawArray(a, info, verbose=False) - raw.set_annotations(annotations, verbose='error') + raw.set_annotations(annotations, verbose="error") return raw diff --git a/mne_nirs/simulation/tests/test_simulation.py b/mne_nirs/simulation/tests/test_simulation.py index 38dd84508..4f51cb924 100644 --- a/mne_nirs/simulation/tests/test_simulation.py +++ b/mne_nirs/simulation/tests/test_simulation.py @@ -2,49 +2,61 @@ # # License: BSD (3-clause) -from mne_nirs.simulation import simulate_nirs_raw import numpy as np import pytest + from mne_nirs.experimental_design import make_first_level_design_matrix +from mne_nirs.simulation import simulate_nirs_raw def test_simulate_NIRS_single_channel(): - - raw = simulate_nirs_raw(sfreq=3., amplitude=1., sig_dur=300., stim_dur=5., - isi_min=15., isi_max=45.) - assert 'hbo' in raw - assert raw.info['sfreq'] == 3. + raw = simulate_nirs_raw( + sfreq=3.0, + amplitude=1.0, + sig_dur=300.0, + stim_dur=5.0, + isi_min=15.0, + isi_max=45.0, + ) + assert "hbo" in raw + assert raw.info["sfreq"] == 3.0 assert raw.get_data().shape == (1, 900) - assert np.max(raw.get_data()) < 1.2 * 1.e-6 - assert raw.annotations.description[0] == 'A' + assert np.max(raw.get_data()) < 1.2 * 1.0e-6 + assert raw.annotations.description[0] == "A" assert raw.annotations.duration[0] == 5 - assert np.min(np.diff(raw.annotations.onset)) > 15. + 5. - assert np.max(np.diff(raw.annotations.onset)) < 45. + 5. + assert np.min(np.diff(raw.annotations.onset)) > 15.0 + 5.0 + assert np.max(np.diff(raw.annotations.onset)) < 45.0 + 5.0 - with pytest.raises(AssertionError, match='Same number of'): - _ = simulate_nirs_raw(sfreq=3., amplitude=[1., 2.], sig_dur=300., - stim_dur=5., isi_min=15., isi_max=45.) + with pytest.raises(AssertionError, match="Same number of"): + _ = simulate_nirs_raw( + sfreq=3.0, + amplitude=[1.0, 2.0], + sig_dur=300.0, + stim_dur=5.0, + isi_min=15.0, + isi_max=45.0, + ) def test_simulate_NIRS_multi_channel(): - - raw = simulate_nirs_raw(sfreq=3., - amplitude=[0., 2., 4.], - annot_desc=['Control', - 'Cond_A', - 'Cond_B'], - stim_dur=[5, 5, 5], - sig_dur=1500., - isi_min=5., isi_max=15., - hrf_model='spm') - - design_matrix = make_first_level_design_matrix(raw, stim_dur=5.0, - drift_order=0, - drift_model='polynomial') - - assert len(design_matrix['Control']) == 1500 * 3 - assert len(design_matrix['Cond_A']) == 1500 * 3 + raw = simulate_nirs_raw( + sfreq=3.0, + amplitude=[0.0, 2.0, 4.0], + annot_desc=["Control", "Cond_A", "Cond_B"], + stim_dur=[5, 5, 5], + sig_dur=1500.0, + isi_min=5.0, + isi_max=15.0, + hrf_model="spm", + ) + + design_matrix = make_first_level_design_matrix( + raw, stim_dur=5.0, drift_order=0, drift_model="polynomial" + ) + + assert len(design_matrix["Control"]) == 1500 * 3 + assert len(design_matrix["Cond_A"]) == 1500 * 3 # Make sure no extra channels. Specifically the default isn't present. - with pytest.raises(KeyError, match='A'): - len(design_matrix['A']) + with pytest.raises(KeyError, match="A"): + len(design_matrix["A"]) diff --git a/mne_nirs/statistics/_glm_level_first.py b/mne_nirs/statistics/_glm_level_first.py index 4334ee389..151d7f52f 100644 --- a/mne_nirs/statistics/_glm_level_first.py +++ b/mne_nirs/statistics/_glm_level_first.py @@ -2,18 +2,18 @@ # # License: BSD (3-clause) +import warnings from copy import deepcopy from inspect import getfullargspec from pathlib import PosixPath -import warnings -import pandas as pd import numpy as np -from numpy import array_equal, where +import pandas as pd from h5io import read_hdf5, write_hdf5 +from numpy import array_equal, where with warnings.catch_warnings(record=True): - warnings.simplefilter('ignore') + warnings.simplefilter("ignore") import nilearn.glm from nilearn.glm.first_level import run_glm as nilearn_glm @@ -21,15 +21,14 @@ from mne._fiff.meas_info import ContainsMixin except ImportError: from mne.io.meas_info import ContainsMixin -from mne.utils import fill_doc, warn, verbose, check_fname, _validate_type -from mne.io.pick import _picks_to_idx -from mne.io.constants import FIFF from mne import Info, pick_info +from mne.io.constants import FIFF +from mne.io.pick import _picks_to_idx +from mne.utils import _validate_type, check_fname, fill_doc, verbose, warn -from ..visualisation._plot_GLM_topo import _plot_glm_topo,\ - _plot_glm_contrast_topo -from ..visualisation import plot_glm_surface_projection from ..statistics._roi import _glm_region_of_interest +from ..visualisation import plot_glm_surface_projection +from ..visualisation._plot_GLM_topo import _plot_glm_contrast_topo, _plot_glm_topo @fill_doc @@ -45,13 +44,13 @@ def ch_names(self): names : array The channel names. """ - return self.info['ch_names'] + return self.info["ch_names"] def __str__(self): - return (f"GLM Results for {len(self.ch_names)} channels") + return f"GLM Results for {len(self.ch_names)} channels" def __repr__(self): - return (f"GLM Results for {len(self.ch_names)} channels") + return f"GLM Results for {len(self.ch_names)} channels" def __len__(self): return len(self.info.ch_names) @@ -66,24 +65,29 @@ def _get_state(self): # noqa: D105 state : dict State of the object. """ - state = deepcopy(dict( - data=self._data, - design=self.design, - info=self.info, - preload=self.preload, - classname=str(self.__class__) - )) - if isinstance(state['data'], dict): - for channel in state['data']: - state['data'][channel] = state['data'][channel].__dict__ - if isinstance(state['data'][channel]['model'], - nilearn.glm.regression.OLSModel): - state['data'][channel]['modelname'] = \ - str(state['data'][channel]['model'].__class__) - state['data'][channel]['model'] = \ - state['data'][channel]['model'].__dict__ - if isinstance(state['data'], nilearn.glm.contrasts.Contrast): - state['data'] = state['data'].__dict__ + state = deepcopy( + dict( + data=self._data, + design=self.design, + info=self.info, + preload=self.preload, + classname=str(self.__class__), + ) + ) + if isinstance(state["data"], dict): + for channel in state["data"]: + state["data"][channel] = state["data"][channel].__dict__ + if isinstance( + state["data"][channel]["model"], nilearn.glm.regression.OLSModel + ): + state["data"][channel]["modelname"] = str( + state["data"][channel]["model"].__class__ + ) + state["data"][channel]["model"] = state["data"][channel][ + "model" + ].__dict__ + if isinstance(state["data"], nilearn.glm.contrasts.Contrast): + state["data"] = state["data"].__dict__ return state def copy(self): @@ -107,14 +111,14 @@ def save(self, fname, overwrite=False): Should end in ``'glm.h5'``. %(overwrite)s """ - _validate_type(fname, 'path-like', 'fname') + _validate_type(fname, "path-like", "fname") if isinstance(fname, PosixPath): fname = str(fname) - if not fname.endswith('glm.h5'): - raise IOError('The filename must end with glm.h5, ' - f'instead received {fname}') - write_hdf5(fname, self._get_state(), - overwrite=overwrite, title='mnepython') + if not fname.endswith("glm.h5"): + raise OSError( + "The filename must end with glm.h5, " f"instead received {fname}" + ) + write_hdf5(fname, self._get_state(), overwrite=overwrite, title="mnepython") def to_dataframe(self, order=None): """Return a tidy dataframe representing the GLM results. @@ -130,12 +134,14 @@ def to_dataframe(self, order=None): Dataframe containing GLM results. """ from ..utils import glm_to_tidy + if order is None: order = self.ch_names return glm_to_tidy(self.info, self._data, self.design, order=order) - def scatter(self, conditions=[], exclude_no_interest=True, axes=None, - no_interest=None): + def scatter( + self, conditions=(), exclude_no_interest=True, axes=None, no_interest=None + ): """Scatter plot of the GLM results. Parameters @@ -160,6 +166,7 @@ def scatter(self, conditions=[], exclude_no_interest=True, axes=None, if no_interest is None: no_interest = ["drift", "constant", "short", "Short"] import matplotlib.pyplot as plt + df = self.to_dataframe() x_column = "Condition" @@ -169,12 +176,12 @@ def scatter(self, conditions=[], exclude_no_interest=True, axes=None, y_column = "effect" if len(conditions) == 0: conditions = ["t", "f"] - df = df.query('ContrastType in @conditions') + df = df.query("ContrastType in @conditions") else: if len(conditions) == 0: conditions = self.design.columns - df = df.query('Condition in @conditions') + df = df.query("Condition in @conditions") if exclude_no_interest: for no_i in no_interest: @@ -193,7 +200,7 @@ def scatter(self, conditions=[], exclude_no_interest=True, axes=None, axes.legend(["Oxyhaemoglobin", "Deoxyhaemoglobin"]) axes.hlines([0.0], 0, len(np.unique(df[x_column])) - 1) if len(np.unique(df[x_column])) > 8: - plt.xticks(rotation=45, ha='right') + plt.xticks(rotation=45, ha="right") return axes @@ -224,7 +231,7 @@ def data(self): @data.setter def data(self, data): - if type(data) is not dict: + if not isinstance(data, dict): raise TypeError("Data must be a dictionary type") if not array_equal(list(data.keys()), self.info.ch_names): raise TypeError("Dictionary keys must match ch_names") @@ -233,8 +240,7 @@ def data(self, data): raise TypeError("Data names and channel names do not match") for d in data: if type(data[d]) is not nilearn.glm.regression.RegressionResults: - raise TypeError("Dictionary items must be" - " nilearn RegressionResults") + raise TypeError("Dictionary items must be" " nilearn RegressionResults") self._data = data @@ -248,9 +254,9 @@ def __eq__(self, res): same_keys = self.data.keys() == res.data.keys() same_design = (self.design == res.design).all().all() same_ch = self.info.ch_names == res.info.ch_names - same_theta = np.sum([(res.theta()[idx] == val).all() - for idx, val in - enumerate(self.theta())]) == len(self.ch_names) + same_theta = np.sum( + [(res.theta()[idx] == val).all() for idx, val in enumerate(self.theta())] + ) == len(self.ch_names) return int(same_ch and same_design and same_keys and same_theta) @fill_doc @@ -269,8 +275,7 @@ def pick(self, picks, exclude=()): inst : instance of ResultsGLM The modified instance. """ - picks = _picks_to_idx(self.info, picks, 'all', exclude, - allow_empty=False) + picks = _picks_to_idx(self.info, picks, "all", exclude, allow_empty=False) pick_info(self.info, picks, copy=False) self._data = {key: self._data[key] for key in self.info.ch_names} return self @@ -325,13 +330,21 @@ def compute_contrast(self, contrast, contrast_type=None): Yields the statistics of the contrast (effects, variance, p-values). """ - cont = _compute_contrast(self._data, contrast, - contrast_type=contrast_type) + cont = _compute_contrast(self._data, contrast, contrast_type=contrast_type) return ContrastResults(self.info, cont, self.design) - def plot_topo(self, conditions=None, axes=None, *, vlim=(None, None), - vmin=None, vmax=None, colorbar=True, figsize=(12, 7), - sphere=None): + def plot_topo( + self, + conditions=None, + axes=None, + *, + vlim=(None, None), + vmin=None, + vmax=None, + colorbar=True, + figsize=(12, 7), + sphere=None, + ): """Plot 2D topography of GLM data. Parameters @@ -363,14 +376,23 @@ def plot_topo(self, conditions=None, axes=None, *, vlim=(None, None), Figure of each design matrix component for hbo (top row) and hbr (bottom row). """ - return _plot_glm_topo(self.info, self._data, self.design, - requested_conditions=conditions, - axes=axes, vlim=vlim, vmin=vmin, vmax=vmax, - colorbar=colorbar, - figsize=figsize, sphere=sphere) - - def to_dataframe_region_of_interest(self, group_by, condition, - weighted=True, demographic_info=False): + return _plot_glm_topo( + self.info, + self._data, + self.design, + requested_conditions=conditions, + axes=axes, + vlim=vlim, + vmin=vmin, + vmax=vmax, + colorbar=colorbar, + figsize=figsize, + sphere=sphere, + ) + + def to_dataframe_region_of_interest( + self, group_by, condition, weighted=True, demographic_info=False + ): """Region of interest results as a dataframe. Parameters @@ -408,12 +430,12 @@ def to_dataframe_region_of_interest(self, group_by, condition, if isinstance(weighted, dict): if weighted.keys() != group_by.keys(): - raise KeyError("Keys of group_by and weighted " - "must be the same") + raise KeyError("Keys of group_by and weighted " "must be the same") for key in weighted.keys(): if len(weighted[key]) != len(group_by[key]): - raise ValueError("The length of the keys for group_by " - "and weighted must match") + raise ValueError( + "The length of the keys for group_by " "and weighted must match" + ) if (np.array(weighted[key]) < 0).any(): raise ValueError("Weights must be positive values") @@ -421,11 +443,14 @@ def to_dataframe_region_of_interest(self, group_by, condition, for cond in condition: cond_idx = where([c == cond for c in self.design.columns])[0] if not len(cond_idx): - raise KeyError(f'condition {repr(cond)} not found in ' - f'self.design.columns: {self.design.columns}') + raise KeyError( + f"condition {repr(cond)} not found in " + f"self.design.columns: {self.design.columns}" + ) - roi = _glm_region_of_interest(self._data, group_by, - cond_idx, cond, weighted) + roi = _glm_region_of_interest( + self._data, group_by, cond_idx, cond, weighted + ) tidy = pd.concat([tidy, roi]) if weighted is True: @@ -436,30 +461,41 @@ def to_dataframe_region_of_interest(self, group_by, condition, tidy["Weighted"] = "Custom" if demographic_info: - if 'age' in self.info['subject_info'].keys(): - tidy['Age'] = float(self.info["subject_info"]['age']) - if 'sex' in self.info['subject_info'].keys(): - if self.info["subject_info"]['sex'] == \ - FIFF.FIFFV_SUBJ_SEX_MALE: + if "age" in self.info["subject_info"].keys(): + tidy["Age"] = float(self.info["subject_info"]["age"]) + if "sex" in self.info["subject_info"].keys(): + if self.info["subject_info"]["sex"] == FIFF.FIFFV_SUBJ_SEX_MALE: sex = "male" - elif self.info["subject_info"]['sex'] == \ - FIFF.FIFFV_SUBJ_SEX_FEMALE: + elif self.info["subject_info"]["sex"] == FIFF.FIFFV_SUBJ_SEX_FEMALE: sex = "female" else: sex = "unknown" - tidy['Sex'] = sex - if 'Hand' in self.info['subject_info'].keys(): - tidy['Hand'] = self.info["subject_info"]['hand'] + tidy["Sex"] = sex + if "Hand" in self.info["subject_info"].keys(): + tidy["Hand"] = self.info["subject_info"]["hand"] return tidy @verbose - def surface_projection(self, chroma="hbo", condition=None, - background='w', figure=None, clim='auto', - mode='weighted', colormap='RdBu_r', - surface='pial', hemi='both', size=800, - view=None, colorbar=True, distance=0.03, - subjects_dir=None, src=None, verbose=False): + def surface_projection( + self, + chroma="hbo", + condition=None, + background="w", + figure=None, + clim="auto", + mode="weighted", + colormap="RdBu_r", + surface="pial", + hemi="both", + size=800, + view=None, + colorbar=True, + distance=0.03, + subjects_dir=None, + src=None, + verbose=False, + ): """ Project GLM results on to the surface of the brain. @@ -525,7 +561,6 @@ def surface_projection(self, chroma="hbo", condition=None, figure : instance of mne.viz.Brain | matplotlib.figure.Figure An instance of :class:`mne.viz.Brain` or matplotlib figure. """ - df = self.to_dataframe(order=self.ch_names) if condition is None: warn("You must provide a condition to plot", ValueError) @@ -533,24 +568,34 @@ def surface_projection(self, chroma="hbo", condition=None, if len(df_use) == 0: raise KeyError( f'condition={repr(condition)} not found in conditions: ' - f'{sorted(set(df["Condition"]))}') + f'{sorted(set(df["Condition"]))}' + ) df = df_use df = df.query("Chroma in @chroma").copy() df["theta"] = df["theta"] * 1e6 info = self.copy().pick(chroma).info - return plot_glm_surface_projection(info, df, value="theta", - picks=chroma, background=background, - figure=figure, clim=clim, - mode=mode, colormap=colormap, - surface=surface, hemi=hemi, - size=size, - view=view, colorbar=colorbar, - distance=distance, - subjects_dir=subjects_dir, src=src, - verbose=verbose - ) + return plot_glm_surface_projection( + info, + df, + value="theta", + picks=chroma, + background=background, + figure=figure, + clim=clim, + mode=mode, + colormap=colormap, + surface=surface, + hemi=hemi, + size=size, + view=view, + colorbar=colorbar, + distance=distance, + subjects_dir=subjects_dir, + src=src, + verbose=verbose, + ) @fill_doc @@ -593,8 +638,9 @@ def data(self, data): if not isinstance(data, nilearn.glm.contrasts.Contrast): raise TypeError("Data must be a nilearn glm contrast type") if data.effect.size != len(self.info.ch_names): - raise TypeError("Data results must be the same length " - "as the number of channels") + raise TypeError( + "Data results must be the same length " "as the number of channels" + ) self._data = data @@ -615,12 +661,12 @@ def plot_topo(self, figsize=(12, 7), sphere=None): Figure of each design matrix component for hbo (top row) and hbr (bottom row). """ - return _plot_glm_contrast_topo(self.info, self._data, - figsize=figsize, sphere=sphere) + return _plot_glm_contrast_topo( + self.info, self._data, figsize=figsize, sphere=sphere + ) -def run_GLM(raw, design_matrix, noise_model='ar1', bins=0, - n_jobs=1, verbose=0): +def run_GLM(raw, design_matrix, noise_model="ar1", bins=0, n_jobs=1, verbose=0): """ Run GLM on data using supplied design matrix. @@ -649,23 +695,31 @@ def run_GLM(raw, design_matrix, noise_model='ar1', bins=0, 'all CPUs'. verbose : int, optional The verbosity level. Default is 0. + Returns ------- glm_estimates : dict Keys correspond to the different labels values values are RegressionResults instances corresponding to the voxels. """ - warn('"run_GLM" has been deprecated in favor of the more ' - 'comprehensive run_glm function, and will be removed in v1.0.0. ' - 'See the changelog for further details.', - DeprecationWarning) - res = run_glm(raw, design_matrix, noise_model=noise_model, bins=bins, - n_jobs=n_jobs, verbose=verbose) + warn( + '"run_GLM" has been deprecated in favor of the more ' + "comprehensive run_glm function, and will be removed in v1.0.0. " + "See the changelog for further details.", + DeprecationWarning, + ) + res = run_glm( + raw, + design_matrix, + noise_model=noise_model, + bins=bins, + n_jobs=n_jobs, + verbose=verbose, + ) return res.data -def run_glm(raw, design_matrix, noise_model='ar1', bins=0, - n_jobs=1, verbose=0): +def run_glm(raw, design_matrix, noise_model="ar1", bins=0, n_jobs=1, verbose=0): """ GLM fit for an MNE structure containing fNIRS data. @@ -704,10 +758,10 @@ def run_glm(raw, design_matrix, noise_model='ar1', bins=0, glm_estimates : RegressionResults RegressionResults class which stores the GLM results. """ - picks = _picks_to_idx(raw.info, 'fnirs', exclude=[], allow_empty=True) + picks = _picks_to_idx(raw.info, "fnirs", exclude=[], allow_empty=True) ch_names = raw.ch_names - if noise_model == 'auto': + if noise_model == "auto": noise_model = f"ar{int(np.round(raw.info['sfreq'] * 4))}" if bins == 0: @@ -715,10 +769,14 @@ def run_glm(raw, design_matrix, noise_model='ar1', bins=0, results = dict() for pick in picks: - labels, glm_estimates = nilearn_glm(raw.get_data(pick).T, - design_matrix.values, - noise_model=noise_model, bins=bins, - n_jobs=n_jobs, verbose=verbose) + labels, glm_estimates = nilearn_glm( + raw.get_data(pick).T, + design_matrix.values, + noise_model=noise_model, + bins=bins, + n_jobs=n_jobs, + verbose=verbose, + ) results[ch_names[pick]] = glm_estimates[labels[0]] return RegressionResults(raw.info, results, design_matrix) @@ -752,8 +810,8 @@ def read_glm(fname): RegressionResults or ContrastResults class which stores the GLM results. """ - check_fname(fname, 'path-like', 'glm.h5') - glm = read_hdf5(fname, title='mnepython') + check_fname(fname, "path-like", "glm.h5") + glm = read_hdf5(fname, title="mnepython") return _state_to_glm(glm) @@ -772,57 +830,64 @@ def _state_to_glm(glm): RegressionResults or ContrastResults class which stores the GLM results. """ - - if glm['classname'] == "": - - for channel in glm['data']: - + if ( + glm["classname"] == "" + ): + for channel in glm["data"]: # Recreate model type - if glm['data'][channel]['modelname'] == \ - "": + if ( + glm["data"][channel]["modelname"] + == "" + ): model = nilearn.glm.regression.ARModel( - glm['data'][channel]['model']['design'], - glm['data'][channel]['model']['rho'], + glm["data"][channel]["model"]["design"], + glm["data"][channel]["model"]["rho"], ) - elif glm['data'][channel]['modelname'] == \ - "": + elif ( + glm["data"][channel]["modelname"] + == "" + ): model = nilearn.glm.regression.OLSModel( - glm['data'][channel]['model']['design'], + glm["data"][channel]["model"]["design"], ) else: - raise IOError("Unknown model type " - f"{glm['data'][channel]['modelname']}") + raise OSError( + "Unknown model type " f"{glm['data'][channel]['modelname']}" + ) - for key in glm['data'][channel]['model']: - model.__setattr__(key, glm['data'][channel]['model'][key]) - glm['data'][channel]['model'] = model + for key in glm["data"][channel]["model"]: + model.__setattr__(key, glm["data"][channel]["model"][key]) + glm["data"][channel]["model"] = model # Then recreate result type res = nilearn.glm.regression.RegressionResults( - glm['data'][channel]['theta'], - glm['data'][channel]['Y'], - glm['data'][channel]['model'], - glm['data'][channel]['whitened_Y'], - glm['data'][channel]['whitened_residuals'], - cov=glm['data'][channel]['cov'] + glm["data"][channel]["theta"], + glm["data"][channel]["Y"], + glm["data"][channel]["model"], + glm["data"][channel]["whitened_Y"], + glm["data"][channel]["whitened_residuals"], + cov=glm["data"][channel]["cov"], ) - for key in glm['data'][channel]: - res.__setattr__(key, glm['data'][channel][key]) - glm['data'][channel] = res + for key in glm["data"][channel]: + res.__setattr__(key, glm["data"][channel][key]) + glm["data"][channel] = res # Ensure order of dictionary matches info - data = {k: glm['data'][k] for k in glm['info']['ch_names']} - return RegressionResults(Info(glm['info']), data, glm['design']) - - elif glm['classname'] == "": - data = nilearn.glm.contrasts.Contrast(glm['data']['effect'], - glm['data']['variance']) - for key in glm['data']: - data.__setattr__(key, glm['data'][key]) - return ContrastResults(Info(glm['info']), data, glm['design']) + data = {k: glm["data"][k] for k in glm["info"]["ch_names"]} + return RegressionResults(Info(glm["info"]), data, glm["design"]) + + elif ( + glm["classname"] == "" + ): + data = nilearn.glm.contrasts.Contrast( + glm["data"]["effect"], glm["data"]["variance"] + ) + for key in glm["data"]: + data.__setattr__(key, glm["data"][key]) + return ContrastResults(Info(glm["info"]), data, glm["design"]) else: - raise IOError('Unable to read data') + raise OSError("Unable to read data") diff --git a/mne_nirs/statistics/_roi.py b/mne_nirs/statistics/_roi.py index 6720bb004..ce95b41ac 100644 --- a/mne_nirs/statistics/_roi.py +++ b/mne_nirs/statistics/_roi.py @@ -3,12 +3,10 @@ # License: BSD (3-clause) import numpy as np - from mne.utils import warn -def glm_region_of_interest(glm, group_by, cond_idx, - cond_name, weighted=True): +def glm_region_of_interest(glm, group_by, cond_idx, cond_name, weighted=True): """ Calculate statistics for region of interest. @@ -37,18 +35,20 @@ def glm_region_of_interest(glm, group_by, cond_idx, stats : DataFrame Statistics for each ROI. """ - warn('"glm_region_of_interest" has been deprecated in favor of the more ' - 'comprehensive GLM class and will be removed in v1.0.0. ' - 'Use the RegressionResults class "region_of_interest_dataframe()" ' - 'method instead.', - DeprecationWarning) + warn( + '"glm_region_of_interest" has been deprecated in favor of the more ' + "comprehensive GLM class and will be removed in v1.0.0. " + 'Use the RegressionResults class "region_of_interest_dataframe()" ' + "method instead.", + DeprecationWarning, + ) - return _glm_region_of_interest(glm, group_by, - cond_idx, cond_name, weighted=weighted) + return _glm_region_of_interest( + glm, group_by, cond_idx, cond_name, weighted=weighted + ) -def _glm_region_of_interest(stats, group_by, cond_idx, - cond_name, weighted=True): +def _glm_region_of_interest(stats, group_by, cond_idx, cond_name, weighted=True): """ Calculate statistics for region of interest. @@ -85,8 +85,8 @@ def _glm_region_of_interest(stats, group_by, cond_idx, stats : DataFrame Statistics for each ROI. """ - from scipy import stats as ss import pandas as pd + from scipy import stats as ss df = pd.DataFrame() @@ -94,7 +94,6 @@ def _glm_region_of_interest(stats, group_by, cond_idx, chromas = np.array([name[-3:] for name in ch_names]) for region in group_by: - if isinstance(weighted, dict): weights_region = weighted[region] @@ -102,7 +101,6 @@ def _glm_region_of_interest(stats, group_by, cond_idx, picks = group_by[region] for chroma in np.unique(chromas[picks]): - chroma_idxs = np.where([c == chroma for c in chromas[picks]])[0] chroma_picks = [picks[ci] for ci in chroma_idxs] @@ -117,7 +115,7 @@ def _glm_region_of_interest(stats, group_by, cond_idx, # Apply weighting by standard error or custom values if weighted is True: - weights = 1. / np.asarray(ses) + weights = 1.0 / np.asarray(ses) elif weighted is False: weights = np.ones((len(ses), 1)) elif isinstance(weighted, dict): @@ -136,14 +134,18 @@ def _glm_region_of_interest(stats, group_by, cond_idx, p = 2 * ss.t.cdf(-1.0 * np.abs(t), df=dfe) this_df = pd.DataFrame( - {'ROI': roi_name, - 'Condition': cond_name, - 'Chroma': chroma, - 'theta': theta / 1.0e6, - 'se': s, - 't': t, - 'dfe': dfe, - 'p': p, }, index=[0]) + { + "ROI": roi_name, + "Condition": cond_name, + "Chroma": chroma, + "theta": theta / 1.0e6, + "se": s, + "t": t, + "dfe": dfe, + "p": p, + }, + index=[0], + ) df = pd.concat([df, this_df], ignore_index=True) df.reset_index(inplace=True, drop=True) diff --git a/mne_nirs/statistics/_statsmodels.py b/mne_nirs/statistics/_statsmodels.py index 75107d716..ddf2c3879 100644 --- a/mne_nirs/statistics/_statsmodels.py +++ b/mne_nirs/statistics/_statsmodels.py @@ -4,16 +4,16 @@ from io import StringIO -import pandas as pd import numpy as np +import pandas as pd def summary_to_dataframe(summary): - '''Convert statsmodels summary to pandas dataframe. + """Convert statsmodels summary to pandas dataframe. .. warning:: The summary has precision issues, use the numerical values from it with caution. - ''' + """ results = summary.tables[1] if type(results) is not pd.core.frame.DataFrame: results = StringIO(summary.tables[1].as_html()) @@ -22,52 +22,51 @@ def summary_to_dataframe(summary): def expand_summary_dataframe(summary): - '''Expand dataframe index column in to individual columns''' - + """Expand dataframe index column in to individual columns.""" # Determine new columns - new_cols = summary.index[0].split(':') + new_cols = summary.index[0].split(":") col_names = [] for col in new_cols: - col_name = col.split('[')[0] - summary[col_name] = 'NaN' + col_name = col.split("[")[0] + summary[col_name] = "NaN" col_names.append(col_name) # Fill in values - if 'Group Var' in summary.index: + if "Group Var" in summary.index: summary = summary[:-1] summary = summary.copy(deep=True) indices = summary.index for row_idx, row in enumerate(indices): - col_vals = row.split(':') + col_vals = row.split(":") for col_idx, col in enumerate(col_names): if "]" in col_vals[col_idx]: - val = col_vals[col_idx].split('[')[1].split(']')[0] + val = col_vals[col_idx].split("[")[1].split("]")[0] else: val = col summary.at[row, col] = val summary = summary.copy() # Copies required to suppress .loc warnings sum_copy = summary.copy(deep=True) - key = 'P>|t|' if 'P>|t|' in summary.columns else 'P>|z|' + key = "P>|t|" if "P>|t|" in summary.columns else "P>|z|" float_p = [float(p) for p in sum_copy[key]] summary.loc[:, key] = float_p summary.loc[:, "Significant"] = False - summary.loc[summary[key] < 0.05, 'Significant'] = True + summary.loc[summary[key] < 0.05, "Significant"] = True # Standardise returned column name, it seems to vary per test - if 'Coef.' in summary.columns: + if "Coef." in summary.columns: summary.loc[:, "Coef."] = [float(c) for c in summary["Coef."]] - elif 'coef' in summary.columns: + elif "coef" in summary.columns: summary = summary.rename(columns={"coef": "Coef."}) return summary _REPLACEMENTS = ( - ('P>|z|', 'pvalues'), - ('Coef.', 'fe_params'), - ('z', 'tvalues'), - ('P>|t|', 'pvalues'), + ("P>|z|", "pvalues"), + ("Coef.", "fe_params"), + ("z", "tvalues"), + ("P>|t|", "pvalues"), ) @@ -87,8 +86,9 @@ def statsmodels_to_results(model, order=None): df : Pandas dataframe. Data frame with the results from the stats model. """ - from statsmodels.regression.mixed_linear_model import MixedLMResultsWrapper from scipy.stats.distributions import norm + from statsmodels.regression.mixed_linear_model import MixedLMResultsWrapper + df = summary_to_dataframe(model.summary()) # deal with numerical precision loss in at least some of the values for col, attr in _REPLACEMENTS: @@ -98,23 +98,22 @@ def statsmodels_to_results(model, order=None): # This one messes up the standard error and quartiles, too if isinstance(model, MixedLMResultsWrapper): sl = slice(model.k_fe) - mu = np.asarray(df.iloc[sl, df.columns == 'Coef.'])[:, 0] + mu = np.asarray(df.iloc[sl, df.columns == "Coef."])[:, 0] # Adapted from statsmodels, see # https://github.com/statsmodels/statsmodels/blob/master/statsmodels/regression/mixed_linear_model.py#L2710-L2736 # noqa: E501 stderr = np.sqrt(np.diag(model.cov_params()[sl])) - df.iloc[sl, df.columns == 'Std.Err.'] = stderr + df.iloc[sl, df.columns == "Std.Err."] = stderr # Confidence intervals qm = -norm.ppf(0.05 / 2) - df.iloc[sl, df.columns == '[0.025'] = mu - qm * stderr - df.iloc[sl, df.columns == '0.975]'] = mu + qm * stderr + df.iloc[sl, df.columns == "[0.025"] = mu - qm * stderr + df.iloc[sl, df.columns == "0.975]"] = mu + qm * stderr # All random effects variances and covariances sdf = np.zeros((model.k_re2 + model.k_vc, 2)) jj = 0 for i in range(model.k_re): for j in range(i + 1): sdf[jj, 0] = np.asarray(model.cov_re)[i, j] - sdf[jj, 1] = np.sqrt(model.scale) * \ - model.bse.iloc[model.k_fe + jj] + sdf[jj, 1] = np.sqrt(model.scale) * model.bse.iloc[model.k_fe + jj] jj += 1 # Variance components @@ -123,18 +122,18 @@ def statsmodels_to_results(model, order=None): sdf[jj, 1] = np.sqrt(model.scale) * model.bse[model.k_fe + jj] jj += 1 - df.iloc[model.k_fe:, df.columns == 'Coef.'] = sdf[:, 0] - df.iloc[model.k_fe:, df.columns == 'Std.Err.'] = sdf[:, 1] + df.iloc[model.k_fe :, df.columns == "Coef."] = sdf[:, 0] + df.iloc[model.k_fe :, df.columns == "Std.Err."] = sdf[:, 1] df = expand_summary_dataframe(df) if order is not None: - df['old_index'] = df.index - df = df.set_index('ch_name') + df["old_index"] = df.index + df = df.set_index("ch_name") df = df.loc[order, :] - df['ch_name'] = df.index - df.index = df['old_index'] - df.drop(columns='old_index', inplace=True) + df["ch_name"] = df.index + df.index = df["old_index"] + df.drop(columns="old_index", inplace=True) df.rename_axis(None, inplace=True) return df diff --git a/mne_nirs/statistics/tests/test_glm_type.py b/mne_nirs/statistics/tests/test_glm_type.py index 1e3e26c57..68be4b207 100644 --- a/mne_nirs/statistics/tests/test_glm_type.py +++ b/mne_nirs/statistics/tests/test_glm_type.py @@ -4,64 +4,65 @@ import os +import matplotlib +import mne +import nilearn +import numpy as np import pandas import pytest -import numpy as np -import matplotlib from matplotlib.pyplot import Axes - -import mne from mne.datasets import testing from mne.fixes import _compare_version -import nilearn -from mne_nirs.statistics import RegressionResults, read_glm from mne_nirs.experimental_design import make_first_level_design_matrix -from mne_nirs.statistics import run_glm +from mne_nirs.statistics import RegressionResults, read_glm, run_glm data_path = testing.data_path(download=False) -subjects_dir = data_path / '/subjects' +subjects_dir = data_path / "/subjects" def _get_minimal_haemo_data(tmin=0, tmax=60): - raw = mne.io.read_raw_nirx(os.path.join( - mne.datasets.fnirs_motor.data_path(), 'Participant-1'), preload=False) + raw = mne.io.read_raw_nirx( + os.path.join(mne.datasets.fnirs_motor.data_path(), "Participant-1"), + preload=False, + ) raw.crop(tmax=tmax, tmin=tmin) raw = mne.preprocessing.nirs.optical_density(raw) raw = mne.preprocessing.nirs.beer_lambert_law(raw, ppf=0.1) raw.resample(0.3) raw.annotations.description[:] = [ - 'e' + d.replace('.', 'p') for d in raw.annotations.description] + "e" + d.replace(".", "p") for d in raw.annotations.description + ] return raw -def _get_glm_result(tmax=60, tmin=0, noise_model='ar1'): +def _get_glm_result(tmax=60, tmin=0, noise_model="ar1"): raw = _get_minimal_haemo_data(tmin=tmin, tmax=tmax) - design_matrix = make_first_level_design_matrix(raw, stim_dur=5., - drift_order=1, - drift_model='polynomial') + design_matrix = make_first_level_design_matrix( + raw, stim_dur=5.0, drift_order=1, drift_model="polynomial" + ) return run_glm(raw, design_matrix, noise_model=noise_model) def _get_glm_contrast_result(tmin=60, tmax=400): raw = _get_minimal_haemo_data(tmin=tmin, tmax=tmax) - design_matrix = make_first_level_design_matrix(raw, stim_dur=5., - drift_order=1, - drift_model='polynomial') + design_matrix = make_first_level_design_matrix( + raw, stim_dur=5.0, drift_order=1, drift_model="polynomial" + ) glm_est = run_glm(raw, design_matrix) contrast_matrix = np.eye(design_matrix.shape[1]) - basic_conts = dict([(column, contrast_matrix[i]) - for i, column in enumerate(design_matrix.columns)]) - assert 'e1p' in basic_conts, sorted(basic_conts) - contrast_LvR = basic_conts['e1p'] - basic_conts['e2p'] + basic_conts = dict( + [(column, contrast_matrix[i]) for i, column in enumerate(design_matrix.columns)] + ) + assert "e1p" in basic_conts, sorted(basic_conts) + contrast_LvR = basic_conts["e1p"] - basic_conts["e2p"] return glm_est.compute_contrast(contrast_LvR) def test_create_results_glm(): - # Create a relevant info structure raw = _get_minimal_haemo_data() @@ -73,20 +74,20 @@ def test_create_results_glm(): minimal_structure = res._data # Test construction - with pytest.raises(TypeError, match='must be a dictionary'): + with pytest.raises(TypeError, match="must be a dictionary"): _ = RegressionResults(info, np.zeros((5, 2)), 1) - with pytest.raises(TypeError, match='must be a dictionary'): + with pytest.raises(TypeError, match="must be a dictionary"): _ = RegressionResults(info, 3.2, 1) - with pytest.raises(TypeError, match='must be a dictionary'): + with pytest.raises(TypeError, match="must be a dictionary"): _ = RegressionResults(info, [], 1) - with pytest.raises(TypeError, match='must be a dictionary'): + with pytest.raises(TypeError, match="must be a dictionary"): _ = RegressionResults(info, "hbo", 1) - with pytest.raises(TypeError, match='keys must match'): + with pytest.raises(TypeError, match="keys must match"): _ = RegressionResults(info, _take(4, minimal_structure), 1) onewrongname = _take(55, minimal_structure) onewrongname["test"] = onewrongname["S1_D1 hbo"] - with pytest.raises(TypeError, match='must match ch_names'): + with pytest.raises(TypeError, match="must match ch_names"): _ = RegressionResults(info, onewrongname, 1) # Test properties @@ -94,15 +95,14 @@ def test_create_results_glm(): def test_results_glm_properties(): - n_channels = 56 res = _get_glm_result() # Test ContainsMixin - assert 'hbo' in res - assert 'hbr' in res - assert 'meg' not in res + assert "hbo" in res + assert "hbr" in res + assert "meg" not in res # Test copy assert len(res) == len(res.copy()) @@ -113,10 +113,10 @@ def test_results_glm_properties(): assert len(res.copy().pick(picks=["S1_D1 hbr"])) == 1 assert len(res.copy().pick(picks=["S1_D1 hbr", "S1_D1 hbo"])) == 2 - if _compare_version(mne.__version__, '<', '1.4.0.dev140'): - ctx = pytest.warns(RuntimeWarning, match='could not be picked') + if _compare_version(mne.__version__, "<", "1.4.0.dev140"): + ctx = pytest.warns(RuntimeWarning, match="could not be picked") else: - ctx = pytest.raises(ValueError, match='could not be picked') + ctx = pytest.raises(ValueError, match="could not be picked") with ctx: assert len(res.copy().pick(picks=["S1_D1 hbr", "S1_D1 XXX"])) == 1 @@ -139,13 +139,11 @@ def test_results_glm_properties(): def test_glm_scatter(): - assert isinstance(_get_glm_result().scatter(), Axes) assert isinstance(_get_glm_contrast_result().scatter(), Axes) def test_results_glm_export_dataframe(): - n_channels = 56 res = _get_glm_result(tmax=400) @@ -155,7 +153,6 @@ def test_results_glm_export_dataframe(): def test_results_glm_export_dataframe_region_of_interest(): - res = _get_glm_result(tmax=400) # Create ROI with hbo only @@ -163,7 +160,7 @@ def test_results_glm_export_dataframe_region_of_interest(): rois["A"] = [0, 2, 4] # Single ROI, single condition - with pytest.raises(KeyError, match=r'not found in self\.design\.col'): + with pytest.raises(KeyError, match=r"not found in self\.design\.col"): res.to_dataframe_region_of_interest(rois, "1.0") df = res.to_dataframe_region_of_interest(rois, "e1p0") assert df.shape == (1, 9) @@ -216,14 +213,14 @@ def test_results_glm_export_dataframe_region_of_interest(): df = res.to_dataframe_region_of_interest(rois, "e1p0") # With demographic information - df = res.to_dataframe_region_of_interest(rois, ["e1p0", "e3p0", "drift_1"], - demographic_info=True) + df = res.to_dataframe_region_of_interest( + rois, ["e1p0", "e3p0", "drift_1"], demographic_info=True + ) assert df.shape == (12, 10) assert "Sex" in df.columns def test_results_glm_export_dataframe_region_of_interest_weighted(): - res = _get_glm_result(tmax=400) res_df = res.to_dataframe().query("Condition == 'e1p0'") assert len(res_df) @@ -240,8 +237,7 @@ def test_results_glm_export_dataframe_region_of_interest_weighted(): assert df_uw.Weighted[0] == "Equal" thetas = np.array(res_df.theta) # unweighted option should be the same as a simple mean - assert np.isclose(df_uw.query("ROI == 'A'").theta, - thetas[rois["A"]].mean()) + assert np.isclose(df_uw.query("ROI == 'A'").theta, thetas[rois["A"]].mean()) df_w = res.to_dataframe_region_of_interest(rois, "e1p0", weighted=True) assert df_w.shape == (4, 9) @@ -264,25 +260,21 @@ def test_results_glm_export_dataframe_region_of_interest_weighted(): assert df.theta[2] > 0 assert df.theta[3] < 0 - with pytest.raises(ValueError, match='must be positive'): + with pytest.raises(ValueError, match="must be positive"): weights["C"] = [16, 7, -8, 9] - _ = res.to_dataframe_region_of_interest(rois, "e1p0", - weighted=weights) + _ = res.to_dataframe_region_of_interest(rois, "e1p0", weighted=weights) - with pytest.raises(ValueError, match='length of the keys'): + with pytest.raises(ValueError, match="length of the keys"): weights["C"] = [16, 7] - _ = res.to_dataframe_region_of_interest(rois, "e1p0", - weighted=weights) + _ = res.to_dataframe_region_of_interest(rois, "e1p0", weighted=weights) - with pytest.raises(KeyError, match='Keys of group_by and weighted'): + with pytest.raises(KeyError, match="Keys of group_by and weighted"): bad_weights = dict() bad_weights["Z"] = [0, 2, 4] - _ = res.to_dataframe_region_of_interest(rois, "e1p0", - weighted=bad_weights) + _ = res.to_dataframe_region_of_interest(rois, "e1p0", weighted=bad_weights) def test_create_results_glm_contrast(): - # Create a minimal structure res = _get_glm_contrast_result() assert isinstance(res._data, nilearn.glm.contrasts.Contrast) @@ -303,14 +295,13 @@ def test_create_results_glm_contrast(): def test_results_glm_io(): - res = _get_glm_result(tmax=400) res.save("test-regression-glm.h5", overwrite=True) loaded_res = read_glm("test-regression-glm.h5") assert loaded_res.to_dataframe().equals(res.to_dataframe()) assert res == loaded_res - res = _get_glm_result(tmax=400, noise_model='ols') + res = _get_glm_result(tmax=400, noise_model="ols") res.save("test-regression-ols_glm.h5", overwrite=True) loaded_res = read_glm("test-regression-ols_glm.h5") assert loaded_res.to_dataframe().equals(res.to_dataframe()) @@ -322,7 +313,7 @@ def test_results_glm_io(): assert loaded_res.to_dataframe().equals(res.to_dataframe()) assert res == loaded_res - with pytest.raises(IOError, match='must end with glm.h5'): + with pytest.raises(IOError, match="must end with glm.h5"): res.save("test-contrast-glX.h5", overwrite=True) diff --git a/mne_nirs/statistics/tests/test_statistics.py b/mne_nirs/statistics/tests/test_statistics.py index 590412d24..6431cd5a4 100644 --- a/mne_nirs/statistics/tests/test_statistics.py +++ b/mne_nirs/statistics/tests/test_statistics.py @@ -2,76 +2,75 @@ # # License: BSD (3-clause) -import pytest -import numpy as np import nilearn - +import numpy as np +import pytest from mne import Covariance from mne.simulation import add_noise from mne_nirs.experimental_design import make_first_level_design_matrix -from mne_nirs.statistics import run_glm, run_GLM from mne_nirs.simulation import simulate_nirs_raw +from mne_nirs.statistics import run_GLM, run_glm -iir_filter = [1., -0.58853134, -0.29575669, -0.52246482, 0.38735476, 0.024286] +iir_filter = [1.0, -0.58853134, -0.29575669, -0.52246482, 0.38735476, 0.024286] def test_run_GLM(): - raw = simulate_nirs_raw(sig_dur=200, stim_dur=5.) - design_matrix = make_first_level_design_matrix(raw, stim_dur=5., - drift_order=1, - drift_model='polynomial') + raw = simulate_nirs_raw(sig_dur=200, stim_dur=5.0) + design_matrix = make_first_level_design_matrix( + raw, stim_dur=5.0, drift_order=1, drift_model="polynomial" + ) glm_estimates = run_glm(raw, design_matrix) # Test backwards compatibility - with pytest.deprecated_call(match='more comprehensive'): + with pytest.deprecated_call(match="more comprehensive"): old_res = run_GLM(raw, design_matrix) assert old_res.keys() == glm_estimates.data.keys() - assert (old_res["Simulated"].theta == - glm_estimates.data["Simulated"].theta).all() + assert (old_res["Simulated"].theta == glm_estimates.data["Simulated"].theta).all() assert len(glm_estimates) == len(raw.ch_names) # Check the estimate is correct within 10% error - assert abs(glm_estimates.pick("Simulated").theta()[0][0] - 1.e-6) < 0.1e-6 + assert abs(glm_estimates.pick("Simulated").theta()[0][0] - 1.0e-6) < 0.1e-6 # ensure we return the same type as nilearn to encourage compatibility - _, ni_est = nilearn.glm.first_level.run_glm( - raw.get_data(0).T, design_matrix.values) + _, ni_est = nilearn.glm.first_level.run_glm(raw.get_data(0).T, design_matrix.values) assert isinstance(glm_estimates._data, type(ni_est)) def test_run_GLM_order(): - raw = simulate_nirs_raw(sig_dur=200, stim_dur=5., sfreq=3) - design_matrix = make_first_level_design_matrix(raw, stim_dur=5., - drift_order=1, - drift_model='polynomial') + raw = simulate_nirs_raw(sig_dur=200, stim_dur=5.0, sfreq=3) + design_matrix = make_first_level_design_matrix( + raw, stim_dur=5.0, drift_order=1, drift_model="polynomial" + ) # Default should be first order AR glm_estimates = run_glm(raw, design_matrix) assert glm_estimates.pick("Simulated").model()[0].order == 1 # Default should be first order AR - glm_estimates = run_glm(raw, design_matrix, noise_model='ar2') + glm_estimates = run_glm(raw, design_matrix, noise_model="ar2") assert glm_estimates.pick("Simulated").model()[0].order == 2 - glm_estimates = run_glm(raw, design_matrix, noise_model='ar7') + glm_estimates = run_glm(raw, design_matrix, noise_model="ar7") assert glm_estimates.pick("Simulated").model()[0].order == 7 # Auto should be 4 times sample rate - cov = Covariance(np.ones(1) * 1e-11, raw.ch_names, - raw.info['bads'], raw.info['projs'], nfree=0) + cov = Covariance( + np.ones(1) * 1e-11, raw.ch_names, raw.info["bads"], raw.info["projs"], nfree=0 + ) raw = add_noise(raw, cov, iir_filter=iir_filter) - glm_estimates = run_glm(raw, design_matrix, noise_model='auto') + glm_estimates = run_glm(raw, design_matrix, noise_model="auto") assert glm_estimates.pick("Simulated").model()[0].order == 3 * 4 - raw = simulate_nirs_raw(sig_dur=10, stim_dur=5., sfreq=2) - cov = Covariance(np.ones(1) * 1e-11, raw.ch_names, - raw.info['bads'], raw.info['projs'], nfree=0) + raw = simulate_nirs_raw(sig_dur=10, stim_dur=5.0, sfreq=2) + cov = Covariance( + np.ones(1) * 1e-11, raw.ch_names, raw.info["bads"], raw.info["projs"], nfree=0 + ) raw = add_noise(raw, cov, iir_filter=iir_filter) - design_matrix = make_first_level_design_matrix(raw, stim_dur=5., - drift_order=1, - drift_model='polynomial') + design_matrix = make_first_level_design_matrix( + raw, stim_dur=5.0, drift_order=1, drift_model="polynomial" + ) # Auto should be 4 times sample rate - glm_estimates = run_glm(raw, design_matrix, noise_model='auto') + glm_estimates = run_glm(raw, design_matrix, noise_model="auto") assert glm_estimates.pick("Simulated").model()[0].order == 2 * 4 diff --git a/mne_nirs/statistics/tests/test_statsmodels.py b/mne_nirs/statistics/tests/test_statsmodels.py index d67db43ef..249c0a6b2 100644 --- a/mne_nirs/statistics/tests/test_statsmodels.py +++ b/mne_nirs/statistics/tests/test_statsmodels.py @@ -3,25 +3,23 @@ # License: BSD (3-clause) import numpy as np -from numpy.testing import assert_allclose - -import pytest import pandas as pd +import pytest import statsmodels.formula.api as smf - from mne.utils import check_version +from numpy.testing import assert_allclose -from mne_nirs.simulation import simulate_nirs_raw from mne_nirs.experimental_design import make_first_level_design_matrix +from mne_nirs.simulation import simulate_nirs_raw from mne_nirs.statistics import run_glm, statsmodels_to_results -@pytest.mark.skipif(not check_version('lxml'), reason='Requires lxml') -@pytest.mark.parametrize('func', ('mixedlm', 'ols', 'rlm')) -@pytest.mark.filterwarnings('ignore:.*optimization.*:') -@pytest.mark.filterwarnings('ignore:.*unknown kwargs.*:') -@pytest.mark.filterwarnings('ignore:.*on the boundary.*:') -@pytest.mark.filterwarnings('ignore:.*The Hessian matrix at the estimated.*:') +@pytest.mark.skipif(not check_version("lxml"), reason="Requires lxml") +@pytest.mark.parametrize("func", ("mixedlm", "ols", "rlm")) +@pytest.mark.filterwarnings("ignore:.*optimization.*:") +@pytest.mark.filterwarnings("ignore:.*unknown kwargs.*:") +@pytest.mark.filterwarnings("ignore:.*on the boundary.*:") +@pytest.mark.filterwarnings("ignore:.*The Hessian matrix at the estimated.*:") def test_statsmodel_to_df(func): func = getattr(smf, func) np.random.seed(0) @@ -30,28 +28,30 @@ def test_statsmodel_to_df(func): df_cha = pd.DataFrame() for n in range(5): - - raw = simulate_nirs_raw(sfreq=3., amplitude=amplitude, - sig_dur=300., stim_dur=5., - isi_min=15., isi_max=45.) + raw = simulate_nirs_raw( + sfreq=3.0, + amplitude=amplitude, + sig_dur=300.0, + stim_dur=5.0, + isi_min=15.0, + isi_max=45.0, + ) raw._data += np.random.normal(0, np.sqrt(1e-12), raw._data.shape) design_matrix = make_first_level_design_matrix(raw, stim_dur=5.0) glm_est = run_glm(raw, design_matrix) - with pytest.warns(RuntimeWarning, match='Non standard source detect'): + with pytest.warns(RuntimeWarning, match="Non standard source detect"): cha = glm_est.to_dataframe() - cha["ID"] = '%02d' % n + cha["ID"] = "%02d" % n df_cha = pd.concat([df_cha, cha], ignore_index=True) df_cha["theta"] = df_cha["theta"] * 1.0e6 - roi_model = func("theta ~ -1 + Condition", df_cha, - groups=df_cha["ID"]).fit() + roi_model = func("theta ~ -1 + Condition", df_cha, groups=df_cha["ID"]).fit() df = statsmodels_to_results(roi_model) assert type(df) == pd.DataFrame assert_allclose(df["Coef."]["Condition[A]"], amplitude, rtol=0.1) assert df["Significant"]["Condition[A]"] assert df.shape == (8, 8) - roi_model = smf.rlm("theta ~ -1 + Condition", df_cha, - groups=df_cha["ID"]).fit() + roi_model = smf.rlm("theta ~ -1 + Condition", df_cha, groups=df_cha["ID"]).fit() df = statsmodels_to_results(roi_model) assert type(df) == pd.DataFrame assert_allclose(df["Coef."]["Condition[A]"], amplitude, rtol=0.1) diff --git a/mne_nirs/tests/test_examples.py b/mne_nirs/tests/test_examples.py index e7fe15ded..17b00916c 100644 --- a/mne_nirs/tests/test_examples.py +++ b/mne_nirs/tests/test_examples.py @@ -5,17 +5,17 @@ # This script runs each of the example scripts. It acts as a system test. import os -import pytest import sys +import pytest from mne.utils import check_version def examples_path(): - if not os.path.isdir("BIDS-NIRS-Tapping"): - os.system("git clone --depth 1 " - "https://github.com/rob-luke/BIDS-NIRS-Tapping.git") + os.system( + "git clone --depth 1 " "https://github.com/rob-luke/BIDS-NIRS-Tapping.git" + ) if os.path.isdir("examples"): path = "examples/general/" @@ -26,47 +26,53 @@ def examples_path(): requires_mne_1p2 = pytest.mark.skipif( - not check_version('mne', '1.2'), reason='Needs MNE-Python 1.2') + not check_version("mne", "1.2"), reason="Needs MNE-Python 1.2" +) # https://github.com/mne-tools/mne-bids/pull/406 try: from mne_bids.config import EPHY_ALLOWED_DATATYPES except Exception: - missing_mne_bids_fnirs = 'Could not import EPHY_ALLOWED_DATATYPES' + missing_mne_bids_fnirs = "Could not import EPHY_ALLOWED_DATATYPES" else: - if 'nirs' in EPHY_ALLOWED_DATATYPES: + if "nirs" in EPHY_ALLOWED_DATATYPES: missing_mne_bids_fnirs = None else: missing_mne_bids_fnirs = '"nirs" not in EPHY_ALLOWED_DATATYPES' requires_mne_bids_nirs = pytest.mark.skipif( missing_mne_bids_fnirs is not None, - reason=f'Incorrect MNE-BIDS version: {missing_mne_bids_fnirs}', + reason=f"Incorrect MNE-BIDS version: {missing_mne_bids_fnirs}", ) -@pytest.mark.filterwarnings('ignore:No bad channels to interpolate.*:') -@pytest.mark.filterwarnings('ignore:divide by zero encountered.*:') -@pytest.mark.filterwarnings('ignore:invalid value encountered.*:') -@pytest.mark.skipif( - sys.platform.startswith('win'), reason='Unstable on Windows') +@pytest.mark.filterwarnings("ignore:No bad channels to interpolate.*:") +@pytest.mark.filterwarnings("ignore:divide by zero encountered.*:") +@pytest.mark.filterwarnings("ignore:invalid value encountered.*:") +@pytest.mark.skipif(sys.platform.startswith("win"), reason="Unstable on Windows") @pytest.mark.examples -@pytest.mark.parametrize('fname', ([ - "plot_01_data_io.py", - pytest.param("plot_05_datasets.py", marks=requires_mne_bids_nirs), - "plot_10_hrf_simulation.py", - pytest.param("plot_11_hrf_measured.py", marks=requires_mne_1p2), - pytest.param("plot_12_group_glm.py", marks=requires_mne_1p2), - pytest.param("plot_13_fir_glm.py", marks=requires_mne_bids_nirs), - pytest.param("plot_14_glm_components.py", marks=requires_mne_1p2), - "plot_15_waveform.py", - "plot_16_waveform_group.py", - pytest.param("plot_19_snirf.py", marks=requires_mne_bids_nirs), - "plot_20_enhance.py", - "plot_21_artifacts.py", - "plot_22_quality.py", - "plot_30_frequency.py", - "plot_40_mayer.py", - pytest.param("plot_80_save_read_glm.py", marks=requires_mne_bids_nirs), - "plot_99_bad.py"])) +@pytest.mark.parametrize( + "fname", + ( + [ + "plot_01_data_io.py", + pytest.param("plot_05_datasets.py", marks=requires_mne_bids_nirs), + "plot_10_hrf_simulation.py", + pytest.param("plot_11_hrf_measured.py", marks=requires_mne_1p2), + pytest.param("plot_12_group_glm.py", marks=requires_mne_1p2), + pytest.param("plot_13_fir_glm.py", marks=requires_mne_bids_nirs), + pytest.param("plot_14_glm_components.py", marks=requires_mne_1p2), + "plot_15_waveform.py", + "plot_16_waveform_group.py", + pytest.param("plot_19_snirf.py", marks=requires_mne_bids_nirs), + "plot_20_enhance.py", + "plot_21_artifacts.py", + "plot_22_quality.py", + "plot_30_frequency.py", + "plot_40_mayer.py", + pytest.param("plot_80_save_read_glm.py", marks=requires_mne_bids_nirs), + "plot_99_bad.py", + ] + ), +) def test_examples(fname, requires_pyvista): test_file_path = examples_path() + fname with open(test_file_path) as fid: diff --git a/mne_nirs/utils/_io.py b/mne_nirs/utils/_io.py index afb9c72a8..9b89be144 100644 --- a/mne_nirs/utils/_io.py +++ b/mne_nirs/utils/_io.py @@ -1,9 +1,10 @@ -import pandas as pd -from scipy import stats -import numpy as np import re -from mne.utils import warn + import nilearn +import numpy as np +import pandas as pd +from mne.utils import warn +from scipy import stats def glm_to_tidy(info, statistic, design_matrix, wide=True, order=None): @@ -34,10 +35,9 @@ def glm_to_tidy(info, statistic, design_matrix, wide=True, order=None): df : Tidy data frame, Data from statistic object in tidy data form. """ - - if isinstance(statistic, dict) and \ - isinstance(statistic[list(statistic.keys())[0]], - nilearn.glm.regression.RegressionResults): + if isinstance(statistic, dict) and isinstance( + statistic[list(statistic.keys())[0]], nilearn.glm.regression.RegressionResults + ): df = _tidy_RegressionResults(info, statistic, design_matrix) elif isinstance(statistic, nilearn.glm.contrasts.Contrast): @@ -45,19 +45,20 @@ def glm_to_tidy(info, statistic, design_matrix, wide=True, order=None): else: raise TypeError( - 'Unknown statistic type. Expected dict of RegressionResults ' - f'or Contrast type. Received {type(statistic)}') + "Unknown statistic type. Expected dict of RegressionResults " + f"or Contrast type. Received {type(statistic)}" + ) if wide: df = _tidy_long_to_wide(df, expand_output=True) if order is not None: - df['old_index'] = df.index - df = df.set_index('ch_name') + df["old_index"] = df.index + df = df.set_index("ch_name") df = df.loc[order, :] - df['ch_name'] = df.index - df.index = df['old_index'] - df.drop(columns='old_index', inplace=True) + df["ch_name"] = df.index + df.index = df["old_index"] + df.drop(columns="old_index", inplace=True) df.rename_axis(None, inplace=True) return df @@ -116,7 +117,6 @@ def _tidy_Contrast(data, glm_est, design_matrix): def _tidy_RegressionResults(data, glm_est, design_matrix): - if not (data.ch_names == list(glm_est.keys())): warn("MNE data structure does not match regression results") @@ -132,45 +132,65 @@ def _tidy_RegressionResults(data, glm_est, design_matrix): df_estimates[idx, :] = glm_est[name].df_model mse_estimates[idx, :] = glm_est[name].MSE[0] for cond_idx, cond in enumerate(design_matrix.columns): - t_estimates[idx, cond_idx] = glm_est[name].t( - column=cond_idx).item() + t_estimates[idx, cond_idx] = glm_est[name].t(column=cond_idx).item() p_estimates[idx, cond_idx] = 2 * stats.t.cdf( -1.0 * np.abs(t_estimates[idx, cond_idx]), - df=df_estimates[idx, cond_idx]) - se_estimates[idx, cond_idx] = np.sqrt(np.diag( - glm_est[name].vcov()))[cond_idx] + df=df_estimates[idx, cond_idx], + ) + se_estimates[idx, cond_idx] = np.sqrt(np.diag(glm_est[name].vcov()))[ + cond_idx + ] - list_vals = [0] * ((len(data.ch_names) * - len(design_matrix.columns) * 6)) + list_vals = [0] * (len(data.ch_names) * len(design_matrix.columns) * 6) idx = 0 for ch_idx, ch in enumerate(data.ch_names): for cond_idx, cond in enumerate(design_matrix.columns): - list_vals[0 + idx] = {'ch_name': ch, 'Condition': cond, - 'variable': "theta", - 'value': theta_estimates[ch_idx][cond_idx]} - list_vals[1 + idx] = {'ch_name': ch, 'Condition': cond, - 'variable': "t", - 'value': t_estimates[ch_idx][cond_idx]} - list_vals[2 + idx] = {'ch_name': ch, 'Condition': cond, - 'variable': "df", - 'value': df_estimates[ch_idx][cond_idx]} - list_vals[3 + idx] = {'ch_name': ch, 'Condition': cond, - 'variable': "p_value", - 'value': p_estimates[ch_idx][cond_idx]} - list_vals[4 + idx] = {'ch_name': ch, 'Condition': cond, - 'variable': "mse", - 'value': mse_estimates[ch_idx][cond_idx]} - list_vals[5 + idx] = {'ch_name': ch, 'Condition': cond, - 'variable': "se", - 'value': se_estimates[ch_idx][cond_idx]} + list_vals[0 + idx] = { + "ch_name": ch, + "Condition": cond, + "variable": "theta", + "value": theta_estimates[ch_idx][cond_idx], + } + list_vals[1 + idx] = { + "ch_name": ch, + "Condition": cond, + "variable": "t", + "value": t_estimates[ch_idx][cond_idx], + } + list_vals[2 + idx] = { + "ch_name": ch, + "Condition": cond, + "variable": "df", + "value": df_estimates[ch_idx][cond_idx], + } + list_vals[3 + idx] = { + "ch_name": ch, + "Condition": cond, + "variable": "p_value", + "value": p_estimates[ch_idx][cond_idx], + } + list_vals[4 + idx] = { + "ch_name": ch, + "Condition": cond, + "variable": "mse", + "value": mse_estimates[ch_idx][cond_idx], + } + list_vals[5 + idx] = { + "ch_name": ch, + "Condition": cond, + "variable": "se", + "value": se_estimates[ch_idx][cond_idx], + } idx += 6 dict_vals, i = {}, 0 for entry in list_vals: - dict_vals[i] = {"ch_name": entry['ch_name'], - "Condition": entry['Condition'], - "variable": entry['variable'], - "value": entry['value']} + dict_vals[i] = { + "ch_name": entry["ch_name"], + "Condition": entry["Condition"], + "variable": entry["variable"], + "value": entry["value"], + } i = i + 1 df = pd.DataFrame.from_dict(dict_vals, "index") @@ -178,33 +198,30 @@ def _tidy_RegressionResults(data, glm_est, design_matrix): def _tidy_long_to_wide(d, expand_output=True): - - indices = ['ch_name'] - if 'Condition' in d.columns: + indices = ["ch_name"] + if "Condition" in d.columns: # Regression results have a column condition - indices.append('Condition') - if 'ContrastType' in d.columns: + indices.append("Condition") + if "ContrastType" in d.columns: # Regression results have a column condition - indices.append('ContrastType') + indices.append("ContrastType") d = d.set_index(indices) - d = d.pivot_table(columns='variable', values='value', - index=indices) + d = d.pivot_table(columns="variable", values="value", index=indices) d.reset_index(inplace=True) if expand_output: try: d["Source"] = [ - int(re.search(r'S(\d+)_D(\d+) (\w+)', ch).group(1)) + int(re.search(r"S(\d+)_D(\d+) (\w+)", ch).group(1)) for ch in d["ch_name"] ] d["Detector"] = [ - int(re.search(r'S(\d+)_D(\d+) (\w+)', ch).group(2)) + int(re.search(r"S(\d+)_D(\d+) (\w+)", ch).group(2)) for ch in d["ch_name"] ] d["Chroma"] = [ - re.search(r'S(\d+)_D(\d+) (\w+)', ch).group(3) - for ch in d["ch_name"] + re.search(r"S(\d+)_D(\d+) (\w+)", ch).group(3) for ch in d["ch_name"] ] except AttributeError: warn("Non standard source detector names used") diff --git a/mne_nirs/utils/tests/test_io.py b/mne_nirs/utils/tests/test_io.py index 078b5f31a..6e14b8644 100644 --- a/mne_nirs/utils/tests/test_io.py +++ b/mne_nirs/utils/tests/test_io.py @@ -4,42 +4,56 @@ import os -import pytest + import mne -import mne_nirs import numpy as np +import pytest +import mne_nirs from mne_nirs.experimental_design import make_first_level_design_matrix -from mne_nirs.utils._io import glm_to_tidy, _tidy_long_to_wide -from mne_nirs.statistics._glm_level_first import _compute_contrast from mne_nirs.statistics import run_glm +from mne_nirs.statistics._glm_level_first import _compute_contrast +from mne_nirs.utils._io import _tidy_long_to_wide, glm_to_tidy -@pytest.mark.filterwarnings('ignore:.*more comprehensive.*:') +@pytest.mark.filterwarnings("ignore:.*more comprehensive.*:") def test_io(): num_chans = 6 fnirs_data_folder = mne.datasets.fnirs_motor.data_path() - fnirs_raw_dir = os.path.join(fnirs_data_folder, 'Participant-1') + fnirs_raw_dir = os.path.join(fnirs_data_folder, "Participant-1") raw_intensity = mne.io.read_raw_nirx(fnirs_raw_dir).load_data() raw_intensity.resample(0.2) raw_intensity.annotations.description[:] = [ - 'e' + d.replace('.', 'p') - for d in raw_intensity.annotations.description] + "e" + d.replace(".", "p") for d in raw_intensity.annotations.description + ] raw_od = mne.preprocessing.nirs.optical_density(raw_intensity) raw_haemo = mne.preprocessing.nirs.beer_lambert_law(raw_od, ppf=0.1) raw_haemo = mne_nirs.channels.get_long_channels(raw_haemo) raw_haemo.pick(picks=range(num_chans)) - design_matrix = make_first_level_design_matrix(raw_intensity, - hrf_model='spm', - stim_dur=5.0, - drift_order=3, - drift_model='polynomial') + design_matrix = make_first_level_design_matrix( + raw_intensity, + hrf_model="spm", + stim_dur=5.0, + drift_order=3, + drift_model="polynomial", + ) glm_est = run_glm(raw_haemo, design_matrix) df = glm_to_tidy(raw_haemo, glm_est.data, design_matrix) assert df.shape == (48, 12) - assert set(df.columns) == {'ch_name', 'Condition', 'df', 'mse', 'p_value', - 't', 'theta', 'Source', 'Detector', 'Chroma', - 'Significant', 'se'} + assert set(df.columns) == { + "ch_name", + "Condition", + "df", + "mse", + "p_value", + "t", + "theta", + "Source", + "Detector", + "Chroma", + "Significant", + "se", + } num_conds = 8 # triggers (1, 2, 3, 15) + 3 drifts + constant assert df.shape[0] == num_chans * num_conds assert len(df["se"]) == 48 @@ -54,24 +68,43 @@ def test_io(): assert sum(df["t"]) > -99999 # Check isn't nan contrast_matrix = np.eye(design_matrix.shape[1]) - basic_conts = dict([(column, contrast_matrix[i]) - for i, column in enumerate(design_matrix.columns)]) - contrast_LvR = basic_conts['e2p0'] - basic_conts['e3p0'] + basic_conts = dict( + [(column, contrast_matrix[i]) for i, column in enumerate(design_matrix.columns)] + ) + contrast_LvR = basic_conts["e2p0"] - basic_conts["e3p0"] contrast = _compute_contrast(glm_est.data, contrast_LvR) df = glm_to_tidy(raw_haemo, contrast, design_matrix) assert df.shape == (6, 10) - assert set(df.columns) == {'ch_name', 'ContrastType', 'z_score', 'stat', - 'p_value', 'effect', 'Source', 'Detector', - 'Chroma', 'Significant'} + assert set(df.columns) == { + "ch_name", + "ContrastType", + "z_score", + "stat", + "p_value", + "effect", + "Source", + "Detector", + "Chroma", + "Significant", + } - contrast = _compute_contrast(glm_est.data, contrast_LvR, contrast_type='F') + contrast = _compute_contrast(glm_est.data, contrast_LvR, contrast_type="F") df = glm_to_tidy(raw_haemo, contrast, design_matrix, wide=False) df = _tidy_long_to_wide(df) assert df.shape == (6, 10) - assert set(df.columns) == {'ch_name', 'ContrastType', 'z_score', 'stat', - 'p_value', 'effect', 'Source', 'Detector', - 'Chroma', 'Significant'} + assert set(df.columns) == { + "ch_name", + "ContrastType", + "z_score", + "stat", + "p_value", + "effect", + "Source", + "Detector", + "Chroma", + "Significant", + } with pytest.raises(TypeError, match="Unknown statistic type"): glm_to_tidy(raw_haemo, [1, 2, 3], design_matrix, wide=False) diff --git a/mne_nirs/visualisation/_plot_3d_montage.py b/mne_nirs/visualisation/_plot_3d_montage.py index 5489ac6fb..4a8ab9870 100644 --- a/mne_nirs/visualisation/_plot_3d_montage.py +++ b/mne_nirs/visualisation/_plot_3d_montage.py @@ -6,20 +6,27 @@ import inspect import numpy as np - -from mne import pick_info, pick_types, Info +from mne import Info, pick_info, pick_types from mne.channels import make_standard_montage from mne.channels.montage import transform_to_head from mne.transforms import _get_trans, apply_trans -from mne.utils import _validate_type, _check_option, verbose, logger +from mne.utils import _check_option, _validate_type, logger, verbose from mne.viz import Brain @verbose -def plot_3d_montage(info, view_map, *, src_det_names='auto', - ch_names='numbered', subject='fsaverage', - trans='fsaverage', surface='pial', - subjects_dir=None, verbose=None): +def plot_3d_montage( + info, + view_map, + *, + src_det_names="auto", + ch_names="numbered", + subject="fsaverage", + trans="fsaverage", + surface="pial", + subjects_dir=None, + verbose=None, +): """ Plot a 3D sensor montage. @@ -89,32 +96,33 @@ def plot_3d_montage(info, view_map, *, src_det_names='auto', """ # noqa: E501 import matplotlib.pyplot as plt from scipy.spatial.distance import cdist - _validate_type(info, Info, 'info') - _validate_type(view_map, dict, 'views') - _validate_type(src_det_names, (None, dict, str), 'src_det_names') - _validate_type(ch_names, (dict, str, None), 'ch_names') + + _validate_type(info, Info, "info") + _validate_type(view_map, dict, "views") + _validate_type(src_det_names, (None, dict, str), "src_det_names") + _validate_type(ch_names, (dict, str, None), "ch_names") info = pick_info(info, pick_types(info, fnirs=True, exclude=())[::2]) if isinstance(ch_names, str): - _check_option('ch_names', ch_names, ('numbered',), extra='when str') + _check_option("ch_names", ch_names, ("numbered",), extra="when str") ch_names = { - name.split()[0]: str(ni) - for ni, name in enumerate(info['ch_names'], 1)} - info['bads'] = [] + name.split()[0]: str(ni) for ni, name in enumerate(info["ch_names"], 1) + } + info["bads"] = [] if isinstance(src_det_names, str): - _check_option('src_det_names', src_det_names, ('auto',), - extra='when str') + _check_option("src_det_names", src_det_names, ("auto",), extra="when str") # Decide if we can map to 10-20 locations names, pos = zip( - *transform_to_head(make_standard_montage('standard_1020')) - .get_positions()['ch_pos'].items()) + *transform_to_head(make_standard_montage("standard_1020")) + .get_positions()["ch_pos"] + .items() + ) pos = np.array(pos, float) locs = dict() bad = False - for ch in info['chs']: - name = ch['ch_name'] - s_name, d_name = name.split()[0].split('_') - for name, loc in [(s_name, ch['loc'][3:6]), - (d_name, ch['loc'][6:9])]: + for ch in info["chs"]: + name = ch["ch_name"] + s_name, d_name = name.split()[0].split("_") + for name, loc in [(s_name, ch["loc"][3:6]), (d_name, ch["loc"][6:9])]: if name in locs: continue # see if it's close enough @@ -129,57 +137,83 @@ def plot_3d_montage(info, view_map, *, src_det_names='auto', break if bad: src_det_names = None - logger.info('Could not automatically map source/detector names to ' - '10-20 locations.') + logger.info( + "Could not automatically map source/detector names to " + "10-20 locations." + ) else: src_det_names = locs - logger.info('Source-detector names automatically mapped to 10-20 ' - 'locations') + logger.info( + "Source-detector names automatically mapped to 10-20 " "locations" + ) - head_mri_t = _get_trans(trans, 'head', 'mri')[0] + head_mri_t = _get_trans(trans, "head", "mri")[0] del trans views = list() for key, num in view_map.items(): - _validate_type(key, str, f'view_map key {repr(key)}') - _validate_type(num, np.ndarray, f'view_map[{repr(key)}]') - if '-' in key: - hemi, v = key.split('-', maxsplit=1) - hemi = dict(left='lh', right='rh')[hemi] + _validate_type(key, str, f"view_map key {repr(key)}") + _validate_type(num, np.ndarray, f"view_map[{repr(key)}]") + if "-" in key: + hemi, v = key.split("-", maxsplit=1) + hemi = dict(left="lh", right="rh")[hemi] views.append((hemi, v, num)) else: - views.append(('lh', key, num)) + views.append(("lh", key, num)) del view_map size = (400 * len(views), 400) brain = Brain( - subject, 'both', surface, views=['lat'] * len(views), - size=size, background='w', units='m', - view_layout='horizontal', subjects_dir=subjects_dir) + subject, + "both", + surface, + views=["lat"] * len(views), + size=size, + background="w", + units="m", + view_layout="horizontal", + subjects_dir=subjects_dir, + ) with _safe_brain_close(brain): brain.add_head(dense=False, alpha=0.1) brain.add_sensors( - info, trans=head_mri_t, - fnirs=['channels', 'pairs', 'sources', 'detectors']) + info, trans=head_mri_t, fnirs=["channels", "pairs", "sources", "detectors"] + ) add_text_kwargs = dict() - if 'render' in inspect.signature(brain.plotter.add_text).parameters: - add_text_kwargs['render'] = False + if "render" in inspect.signature(brain.plotter.add_text).parameters: + add_text_kwargs["render"] = False for col, view in enumerate(views): plotted = set() brain.show_view( - view[1], hemi=view[0], focalpoint=(0, -0.02, 0.02), - distance=0.4, row=0, col=col) + view[1], + hemi=view[0], + focalpoint=(0, -0.02, 0.02), + distance=0.4, + row=0, + col=col, + ) brain.plotter.subplot(0, col) vp = brain.plotter.renderer for ci in view[2]: # figure out what we need to plot - this_ch = info['chs'][ci - 1] - ch_name = this_ch['ch_name'].split()[0] - s_name, d_name = ch_name.split('_') + this_ch = info["chs"][ci - 1] + ch_name = this_ch["ch_name"].split()[0] + s_name, d_name = ch_name.split("_") needed = [ - (ch_names, 'ch_names', ch_name, - this_ch['loc'][:3], 12, 'Centered'), - (src_det_names, 'src_det_names', s_name, - this_ch['loc'][3:6], 8, 'Bottom'), - (src_det_names, 'src_det_names', d_name, - this_ch['loc'][6:9], 8, 'Bottom'), + (ch_names, "ch_names", ch_name, this_ch["loc"][:3], 12, "Centered"), + ( + src_det_names, + "src_det_names", + s_name, + this_ch["loc"][3:6], + 8, + "Bottom", + ), + ( + src_det_names, + "src_det_names", + d_name, + this_ch["loc"][6:9], + 8, + "Bottom", + ), ] for lookup, lname, name, ch_pos, font_size, va in needed: if name in plotted: @@ -188,17 +222,22 @@ def plot_3d_montage(info, view_map, *, src_det_names='auto', orig_name = name if lookup is not None: name = lookup[name] - _validate_type(name, str, f'{lname}[{repr(orig_name)}]') + _validate_type(name, str, f"{lname}[{repr(orig_name)}]") ch_pos = apply_trans(head_mri_t, ch_pos) - vp.SetWorldPoint(np.r_[ch_pos, 1.]) + vp.SetWorldPoint(np.r_[ch_pos, 1.0]) vp.WorldToDisplay() - ch_pos = (np.array(vp.GetDisplayPoint()[:2]) - - np.array(vp.GetOrigin())) + ch_pos = np.array(vp.GetDisplayPoint()[:2]) - np.array( + vp.GetOrigin() + ) actor = brain.plotter.add_text( - name, ch_pos, font_size=font_size, color=(0., 0., 0.), - **add_text_kwargs) + name, + ch_pos, + font_size=font_size, + color=(0.0, 0.0, 0.0), + **add_text_kwargs, + ) prop = actor.GetTextProperty() - getattr(prop, f'SetVerticalJustificationTo{va}')() + getattr(prop, f"SetVerticalJustificationTo{va}")() prop.SetJustificationToCentered() actor.SetTextProperty(prop) prop.SetBold(True) diff --git a/mne_nirs/visualisation/_plot_GLM_surface_projection.py b/mne_nirs/visualisation/_plot_GLM_surface_projection.py index 71b00c48f..86cbb856c 100644 --- a/mne_nirs/visualisation/_plot_GLM_surface_projection.py +++ b/mne_nirs/visualisation/_plot_GLM_surface_projection.py @@ -3,22 +3,35 @@ # License: BSD (3-clause) import os -import numpy as np from copy import deepcopy -from mne import stc_near_sensors, EvokedArray, read_source_spaces, Info +import numpy as np +from mne import EvokedArray, Info, read_source_spaces, stc_near_sensors from mne.io.constants import FIFF -from mne.utils import verbose, get_subjects_dir +from mne.utils import get_subjects_dir, verbose @verbose -def plot_glm_surface_projection(inst, statsmodel_df, picks="hbo", - value="Coef.", - background='w', figure=None, clim='auto', - mode='weighted', colormap='RdBu_r', - surface='pial', hemi='both', size=800, - view=None, colorbar=True, distance=0.03, - subjects_dir=None, src=None, verbose=False): +def plot_glm_surface_projection( + inst, + statsmodel_df, + picks="hbo", + value="Coef.", + background="w", + figure=None, + clim="auto", + mode="weighted", + colormap="RdBu_r", + surface="pial", + hemi="both", + size=800, + view=None, + colorbar=True, + distance=0.03, + subjects_dir=None, + src=None, + verbose=False, +): """ Project GLM results on to the surface of the brain. @@ -86,34 +99,57 @@ def plot_glm_surface_projection(inst, statsmodel_df, picks="hbo", An instance of :class:`mne.viz.Brain` or matplotlib figure. """ info = deepcopy(inst if isinstance(inst, Info) else inst.info) - if not (info.ch_names == list(statsmodel_df['ch_name'].values)): - raise RuntimeError('MNE data structure does not match dataframe ' - f'results.\nMNE = {info.ch_names}.\n' - f'GLM = {list(statsmodel_df["ch_name"].values)}') - - ea = EvokedArray(np.tile(statsmodel_df[value].values.T, (1, 1)).T, - info.copy()) - - return _plot_3d_evoked_array(inst, ea, picks=picks, - value=value, - background=background, figure=figure, - clim=clim, - mode=mode, colormap=colormap, - surface=surface, hemi=hemi, size=size, - view=view, colorbar=colorbar, - distance=distance, - subjects_dir=subjects_dir, src=src, - verbose=verbose) - - -def _plot_3d_evoked_array(inst, ea, picks="hbo", - value="Coef.", - background='w', figure=None, clim='auto', - mode='weighted', colormap='RdBu_r', - surface='pial', hemi='both', size=800, - view=None, colorbar=True, distance=0.03, - subjects_dir=None, src=None, verbose=False): - + if not (info.ch_names == list(statsmodel_df["ch_name"].values)): + raise RuntimeError( + 'MNE data structure does not match dataframe ' + f'results.\nMNE = {info.ch_names}.\n' + f'GLM = {list(statsmodel_df["ch_name"].values)}' + ) + + ea = EvokedArray(np.tile(statsmodel_df[value].values.T, (1, 1)).T, info.copy()) + + return _plot_3d_evoked_array( + inst, + ea, + picks=picks, + value=value, + background=background, + figure=figure, + clim=clim, + mode=mode, + colormap=colormap, + surface=surface, + hemi=hemi, + size=size, + view=view, + colorbar=colorbar, + distance=distance, + subjects_dir=subjects_dir, + src=src, + verbose=verbose, + ) + + +def _plot_3d_evoked_array( + inst, + ea, + picks="hbo", + value="Coef.", + background="w", + figure=None, + clim="auto", + mode="weighted", + colormap="RdBu_r", + surface="pial", + hemi="both", + size=800, + view=None, + colorbar=True, + distance=0.03, + subjects_dir=None, + src=None, + verbose=False, +): # TODO: mimic behaviour of other MNE-NIRS glm plotting options if picks is not None: ea = ea.pick(picks=picks) @@ -121,28 +157,46 @@ def _plot_3d_evoked_array(inst, ea, picks="hbo", if subjects_dir is None: subjects_dir = get_subjects_dir(raise_error=True) if src is None: - fname_src_fs = os.path.join(subjects_dir, 'fsaverage', 'bem', - 'fsaverage-ico-5-src.fif') + fname_src_fs = os.path.join( + subjects_dir, "fsaverage", "bem", "fsaverage-ico-5-src.fif" + ) src = read_source_spaces(fname_src_fs) - picks = np.arange(len(ea.info['ch_names'])) + picks = np.arange(len(ea.info["ch_names"])) # Set coord frame for idx in range(len(ea.ch_names)): - ea.info['chs'][idx]['coord_frame'] = FIFF.FIFFV_COORD_HEAD + ea.info["chs"][idx]["coord_frame"] = FIFF.FIFFV_COORD_HEAD # Generate source estimate kwargs = dict( - evoked=ea, subject='fsaverage', trans='fsaverage', - distance=distance, mode=mode, surface=surface, - subjects_dir=subjects_dir, src=src, project=True) + evoked=ea, + subject="fsaverage", + trans="fsaverage", + distance=distance, + mode=mode, + surface=surface, + subjects_dir=subjects_dir, + src=src, + project=True, + ) stc = stc_near_sensors(picks=picks, **kwargs, verbose=verbose) # Produce brain plot - brain = stc.plot(src=src, subjects_dir=subjects_dir, hemi=hemi, - surface=surface, initial_time=0, clim=clim, size=size, - colormap=colormap, figure=figure, background=background, - colorbar=colorbar, verbose=verbose) + brain = stc.plot( + src=src, + subjects_dir=subjects_dir, + hemi=hemi, + surface=surface, + initial_time=0, + clim=clim, + size=size, + colormap=colormap, + figure=figure, + background=background, + colorbar=colorbar, + verbose=verbose, + ) if view is not None: brain.show_view(view) diff --git a/mne_nirs/visualisation/_plot_GLM_topo.py b/mne_nirs/visualisation/_plot_GLM_topo.py index 41a94b586..2018c4757 100644 --- a/mne_nirs/visualisation/_plot_GLM_topo.py +++ b/mne_nirs/visualisation/_plot_GLM_topo.py @@ -2,26 +2,34 @@ # # License: BSD (3-clause) -from copy import deepcopy import inspect +from copy import deepcopy -import numpy as np -import matplotlib.pyplot as plt import matplotlib as mpl -from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable - +import matplotlib.pyplot as plt +import numpy as np from mne import Info, pick_info -from mne.utils import warn from mne.channels.layout import _merge_ch_data from mne.io.pick import _picks_to_idx +from mne.utils import warn from mne.viz import plot_topomap +from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable -def _plot_glm_topo(inst, glm_estimates, design_matrix, *, - requested_conditions=None, - axes=None, vlim=None, vmin=None, vmax=None, colorbar=True, - figsize=(12, 7), sphere=None): - +def _plot_glm_topo( + inst, + glm_estimates, + design_matrix, + *, + requested_conditions=None, + axes=None, + vlim=None, + vmin=None, + vmax=None, + colorbar=True, + figsize=(12, 7), + sphere=None, +): info = deepcopy(inst if isinstance(inst, Info) else inst.info) if not (info.ch_names == list(glm_estimates.keys())): @@ -29,9 +37,11 @@ def _plot_glm_topo(inst, glm_estimates, design_matrix, *, warn("Reducing GLM results to match MNE data") glm_estimates = {a: glm_estimates[a] for a in info.ch_names} else: - raise RuntimeError('MNE data structure does not match regression ' - f'results. Raw = {len(info.ch_names)}. ' - f'GLM = {len(list(glm_estimates.keys()))}') + raise RuntimeError( + "MNE data structure does not match regression " + f"results. Raw = {len(info.ch_names)}. " + f"GLM = {len(list(glm_estimates.keys()))}" + ) estimates = np.zeros((len(glm_estimates), len(design_matrix.columns))) @@ -42,17 +52,17 @@ def _plot_glm_topo(inst, glm_estimates, design_matrix, *, if requested_conditions is None: requested_conditions = design_matrix.columns - requested_conditions = [x for x in design_matrix.columns - if x in requested_conditions] + requested_conditions = [ + x for x in design_matrix.columns if x in requested_conditions + ] # Plotting setup if axes is None: - fig, axes = plt.subplots(nrows=len(types), - ncols=len(requested_conditions), - figsize=figsize) + fig, axes = plt.subplots( + nrows=len(types), ncols=len(requested_conditions), figsize=figsize + ) - estimates = estimates[:, [c in requested_conditions - for c in design_matrix.columns]] + estimates = estimates[:, [c in requested_conditions for c in design_matrix.columns]] estimates = estimates * 1e6 design_matrix = design_matrix[requested_conditions] @@ -62,12 +72,10 @@ def _plot_glm_topo(inst, glm_estimates, design_matrix, *, norm = mpl.colors.Normalize(vmin=vlim[0], vmax=vlim[1]) for t_idx, t in enumerate(types): - estmrg, pos, chs, sphere = _handle_overlaps(info, t, sphere, estimates) for idx, label in enumerate(design_matrix.columns): if label in requested_conditions: - # Deal with case when only a single # chroma or condition is available if (len(requested_conditions) == 1) & (len(types) == 1): @@ -80,23 +88,30 @@ def _plot_glm_topo(inst, glm_estimates, design_matrix, *, ax = axes[t_idx, idx] plot_topomap( - estmrg[:, idx], pos, extrapolate='local', names=chs, - cmap=cmap, axes=ax, show=False, sphere=sphere, - **vlim_kwargs) + estmrg[:, idx], + pos, + extrapolate="local", + names=chs, + cmap=cmap, + axes=ax, + show=False, + sphere=sphere, + **vlim_kwargs, + ) ax.set_title(label) if colorbar: ax1_divider = make_axes_locatable(ax) cax1 = ax1_divider.append_axes("right", size="7%", pad="2%") - cbar = mpl.colorbar.ColorbarBase(cax1, cmap=cmap, norm=norm, - orientation='vertical') - cbar.set_label('Haemoglobin (uM)', rotation=270) + cbar = mpl.colorbar.ColorbarBase( + cax1, cmap=cmap, norm=norm, orientation="vertical" + ) + cbar.set_label("Haemoglobin (uM)", rotation=270) return _get_fig_from_axes(axes) def _plot_glm_contrast_topo(inst, contrast, figsize=(12, 7), sphere=None): - info = deepcopy(inst if isinstance(inst, Info) else inst.info) # Extract types. One subplot is created per type (hbo/hbr) @@ -107,16 +122,13 @@ def _plot_glm_contrast_topo(inst, contrast, figsize=(12, 7), sphere=None): estimates = estimates * 1e6 # Create subplots for figures - fig, axes = plt.subplots(nrows=1, - ncols=len(types), - figsize=figsize) + fig, axes = plt.subplots(nrows=1, ncols=len(types), figsize=figsize) # Create limits for colorbar vlim, vlim_kwargs = _handle_vlim((None, None), None, None, estimates) cmap = mpl.cm.RdBu_r norm = mpl.colors.Normalize(vmin=vlim[0], vmax=vlim[1]) for t_idx, t in enumerate(types): - estmrg, pos, chs, sphere = _handle_overlaps(info, t, sphere, estimates) # Deal with case when only a single chroma is available @@ -127,43 +139,53 @@ def _plot_glm_contrast_topo(inst, contrast, figsize=(12, 7), sphere=None): # Plot the topomap plot_topomap( - estmrg, pos, extrapolate='local', names=chs, cmap=cmap, axes=ax, - show=False, sphere=sphere, **vlim_kwargs) + estmrg, + pos, + extrapolate="local", + names=chs, + cmap=cmap, + axes=ax, + show=False, + sphere=sphere, + **vlim_kwargs, + ) # Sets axes title - if t == 'hbo': - ax.set_title('Oxyhaemoglobin') - elif t == 'hbr': - ax.set_title('Deoxyhaemoglobin') + if t == "hbo": + ax.set_title("Oxyhaemoglobin") + elif t == "hbr": + ax.set_title("Deoxyhaemoglobin") else: ax.set_title(t) # Create a single colorbar for all types based on limits above ax1_divider = make_axes_locatable(ax) cax1 = ax1_divider.append_axes("right", size="7%", pad="2%") - cbar = mpl.colorbar.ColorbarBase(cax1, cmap=cmap, norm=norm, - orientation='vertical') - cbar.set_label('Contrast Effect', rotation=270) + cbar = mpl.colorbar.ColorbarBase(cax1, cmap=cmap, norm=norm, orientation="vertical") + cbar.set_label("Contrast Effect", rotation=270) return fig -def plot_glm_group_topo(inst, statsmodel_df, - value="Coef.", - axes=None, - threshold=False, - *, - vlim=(None, None), - vmin=None, - vmax=None, - cmap=None, - sensors=True, - res=64, - sphere=None, - colorbar=True, - names=False, - show_names=None, - extrapolate='local', - image_interp='cubic'): +def plot_glm_group_topo( + inst, + statsmodel_df, + value="Coef.", + axes=None, + threshold=False, + *, + vlim=(None, None), + vmin=None, + vmax=None, + cmap=None, + sensors=True, + res=64, + sphere=None, + colorbar=True, + names=False, + show_names=None, + extrapolate="local", + image_interp="cubic", +): """ Plot topomap of NIRS group level GLM results. @@ -219,20 +241,24 @@ def plot_glm_group_topo(inst, statsmodel_df, info = deepcopy(inst if isinstance(inst, Info) else inst.info) if show_names is not None: names = show_names - warn('show_names is deprecated and will be removed in the next ' - 'release, use names instead', FutureWarning) + warn( + "show_names is deprecated and will be removed in the next " + "release, use names instead", + FutureWarning, + ) del show_names # Check that the channels in two inputs match if not (info.ch_names == list(statsmodel_df["ch_name"].values)): if len(info.ch_names) < len(list(statsmodel_df["ch_name"].values)): print("Reducing GLM results to match MNE data") - statsmodel_df["Keep"] = [g in info.ch_names - for g in statsmodel_df["ch_name"]] + statsmodel_df["Keep"] = [ + g in info.ch_names for g in statsmodel_df["ch_name"] + ] statsmodel_df = statsmodel_df.query("Keep == True") else: warn("MNE data structure does not match regression results") - statsmodel_df = statsmodel_df.set_index('ch_name') + statsmodel_df = statsmodel_df.set_index("ch_name") statsmodel_df = statsmodel_df.reindex(info.ch_names) # Extract estimate of interest to plot @@ -241,14 +267,14 @@ def plot_glm_group_topo(inst, statsmodel_df, if threshold: p = statsmodel_df["P>|z|"].values t = p > 0.05 - estimates[t] = 0. + estimates[t] = 0.0 - assert len(np.unique(statsmodel_df["Chroma"])) == 1,\ - "Only one Chroma allowed" + assert len(np.unique(statsmodel_df["Chroma"])) == 1, "Only one Chroma allowed" - if 'Condition' in statsmodel_df.columns: - assert len(np.unique(statsmodel_df["Condition"])) == 1,\ - "Only one condition allowed" + if "Condition" in statsmodel_df.columns: + assert ( + len(np.unique(statsmodel_df["Condition"])) == 1 + ), "Only one condition allowed" c = np.unique(statsmodel_df["Condition"])[0] else: c = "Contrast" @@ -257,9 +283,7 @@ def plot_glm_group_topo(inst, statsmodel_df, # Plotting setup if axes is None: - fig, axes = plt.subplots(nrows=1, - ncols=1, - figsize=(12, 7)) + fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(12, 7)) # Set limits of topomap and colors vlim, vlim_kwargs = _handle_vlim(vlim, vmin, vmax, estimates) del vmin, vmax @@ -268,34 +292,53 @@ def plot_glm_group_topo(inst, statsmodel_df, norm = mpl.colors.Normalize(vmin=vlim[0], vmax=vlim[1]) estmrg, pos, chs, sphere = _handle_overlaps(info, t, sphere, estimates) - if 'names' in inspect.signature(plot_topomap).parameters: - names_kwarg = dict(names=chs if names else [''] * len(chs)) + if "names" in inspect.signature(plot_topomap).parameters: + names_kwarg = dict(names=chs if names else [""] * len(chs)) else: names_kwarg = dict(show_names=names, names=chs) plot_topomap( - estmrg, pos, extrapolate=extrapolate, image_interp=image_interp, - cmap=cmap, axes=axes, sensors=sensors, res=res, show=False, - sphere=sphere, **vlim_kwargs, **names_kwarg) + estmrg, + pos, + extrapolate=extrapolate, + image_interp=image_interp, + cmap=cmap, + axes=axes, + sensors=sensors, + res=res, + show=False, + sphere=sphere, + **vlim_kwargs, + **names_kwarg, + ) axes.set_title(c) if colorbar: ax1_divider = make_axes_locatable(axes) cax1 = ax1_divider.append_axes("right", size="7%", pad="2%") - cbar = mpl.colorbar.ColorbarBase(cax1, cmap=cmap, norm=norm, - orientation='vertical') + cbar = mpl.colorbar.ColorbarBase( + cax1, cmap=cmap, norm=norm, orientation="vertical" + ) cbar.set_label(value, rotation=270) return axes def _handle_overlaps(info, t, sphere, estimates): - """Prepare for topomap including merging channels""" + """Prepare for topomap including merging channels.""" from mne.viz.topomap import _prepare_topomap_plot + picks = _picks_to_idx(info, t, exclude=[], allow_empty=True) info_subset = pick_info(info, picks) - _, pos, merge_channels, ch_names, ch_type, sphere, clip_origin = \ - _prepare_topomap_plot(info_subset, t, sphere=sphere) + ( + _, + pos, + merge_channels, + ch_names, + ch_type, + sphere, + clip_origin, + ) = _prepare_topomap_plot(info_subset, t, sphere=sphere) estmrg, ch_names = _merge_ch_data(estimates.copy()[picks], t, ch_names) return estmrg, pos, ch_names, sphere @@ -311,17 +354,20 @@ def _get_fig_from_axes(ax): def _handle_vlim(vlim, vmin, vmax, estimates): if vmin is not None or vmax is not None: - warn('vmin and vmax are deprecated and will be removed in the next ' - 'release, please use vlim instead', FutureWarning) + warn( + "vmin and vmax are deprecated and will be removed in the next " + "release, please use vlim instead", + FutureWarning, + ) vlim = (vmin, vmax) else: vmin, vmax = vlim if vmax is None: vmax = np.max(np.abs(estimates)) if vmin is None: - vmin = vmax * -1. + vmin = vmax * -1.0 vlim = tuple(vlim) - if 'vlim' in inspect.signature(plot_topomap).parameters: + if "vlim" in inspect.signature(plot_topomap).parameters: kwargs = dict(vlim=(vmin, vmax)) else: kwargs = dict(vmin=vmin, vmax=vmax) diff --git a/mne_nirs/visualisation/_plot_nirs_source_detector.py b/mne_nirs/visualisation/_plot_nirs_source_detector.py index 1780bf350..bd983d4ab 100644 --- a/mne_nirs/visualisation/_plot_nirs_source_detector.py +++ b/mne_nirs/visualisation/_plot_nirs_source_detector.py @@ -4,23 +4,36 @@ import numpy as np - -from mne.viz import plot_alignment from mne import verbose +from mne.viz import plot_alignment @verbose -def plot_nirs_source_detector(data, info=None, radius=0.001, - trans=None, subject=None, - subjects_dir=None, - surfaces='head', coord_frame='head', - meg=None, eeg='original', fwd=None, - dig=False, ecog=True, src=None, - mri_fiducials=False, - bem=None, seeg=True, fnirs=False, - show_axes=False, - fig=None, cmap=None, - interaction='trackball', verbose=None): +def plot_nirs_source_detector( + data, + info=None, + radius=0.001, + trans=None, + subject=None, + subjects_dir=None, + surfaces="head", + coord_frame="head", + meg=None, + eeg="original", + fwd=None, + dig=False, + ecog=True, + src=None, + mri_fiducials=False, + bem=None, + seeg=True, + fnirs=False, + show_axes=False, + fig=None, + cmap=None, + interaction="trackball", + verbose=None, +): """ 3D visualisation of fNIRS response magnitude. @@ -142,47 +155,67 @@ def plot_nirs_source_detector(data, info=None, radius=0.001, if cmap is None: if (vmin >= 0) & (vmax >= 0): # For positive only data use magma - cmap = 'Oranges' + cmap = "Oranges" else: # Otherwise use blue to red and ensure zero sits at white - vmin = -1. * np.max(np.abs(data)) + vmin = -1.0 * np.max(np.abs(data)) vmax = np.max(np.abs(data)) - cmap = 'RdBu_r' + cmap = "RdBu_r" if isinstance(radius, (int, float)): - radius = np.ones(len(info['chs'])) * radius + radius = np.ones(len(info["chs"])) * radius # Plot requested alignment fig = plot_alignment( - info=info, trans=trans, subject=subject, + info=info, + trans=trans, + subject=subject, subjects_dir=subjects_dir, - surfaces=surfaces, coord_frame=coord_frame, - meg=meg, eeg=eeg, fwd=fwd, - dig=dig, ecog=ecog, src=src, + surfaces=surfaces, + coord_frame=coord_frame, + meg=meg, + eeg=eeg, + fwd=fwd, + dig=dig, + ecog=ecog, + src=src, mri_fiducials=mri_fiducials, - bem=bem, seeg=seeg, fnirs=fnirs, + bem=bem, + seeg=seeg, + fnirs=fnirs, show_axes=show_axes, fig=fig, - interaction=interaction, verbose=verbose) + interaction=interaction, + verbose=verbose, + ) from mne.viz.backends.renderer import _get_renderer + renderer = _get_renderer(fig) # Overlay channels between source and detectors - for idx, ch in enumerate(info['chs']): - locs = ch['loc'] - - renderer.tube(origin=[np.array([locs[3], locs[4], locs[5]])], - destination=[np.array([locs[6], locs[7], locs[8]])], - scalars=np.array([[1.0, 1.0]]) * data[idx], - radius=radius[idx], colormap=cmap, - vmin=vmin, vmax=vmax) - - t = renderer.tube(origin=[np.array([0, 0, 0])], - destination=[np.array([0, 0, 0.001])], - scalars=np.array([[vmin, vmax]]), - radius=0.0001, colormap=cmap, - vmin=vmin, vmax=vmax) + for idx, ch in enumerate(info["chs"]): + locs = ch["loc"] + + renderer.tube( + origin=[np.array([locs[3], locs[4], locs[5]])], + destination=[np.array([locs[6], locs[7], locs[8]])], + scalars=np.array([[1.0, 1.0]]) * data[idx], + radius=radius[idx], + colormap=cmap, + vmin=vmin, + vmax=vmax, + ) + + t = renderer.tube( + origin=[np.array([0, 0, 0])], + destination=[np.array([0, 0, 0.001])], + scalars=np.array([[vmin, vmax]]), + radius=0.0001, + colormap=cmap, + vmin=vmin, + vmax=vmax, + ) renderer.scalarbar(t) return fig diff --git a/mne_nirs/visualisation/_plot_quality_metrics.py b/mne_nirs/visualisation/_plot_quality_metrics.py index 97b4ecc10..99a6067ca 100644 --- a/mne_nirs/visualisation/_plot_quality_metrics.py +++ b/mne_nirs/visualisation/_plot_quality_metrics.py @@ -2,14 +2,13 @@ # # License: BSD (3-clause) -import seaborn as sns -import pandas as pd -import numpy as np import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns -def plot_timechannel_quality_metric(raw, scores, times, threshold=0.1, - title=None): +def plot_timechannel_quality_metric(raw, scores, times, threshold=0.1, title=None): """ Plot time x channel based quality metrics. @@ -33,49 +32,69 @@ def plot_timechannel_quality_metric(raw, scores, times, threshold=0.1, fig : figure Matplotlib figure displaying raw scores and thresholded scores. """ - ch_names = raw.ch_names cols = [np.round(t[0]) for t in times] if title is None: - title = 'Automated noisy channel detection: fNIRS' + title = "Automated noisy channel detection: fNIRS" - data_to_plot = pd.DataFrame(data=scores, - columns=pd.Index(cols, name='Time (s)'), - index=pd.Index(ch_names, name='Channel')) + data_to_plot = pd.DataFrame( + data=scores, + columns=pd.Index(cols, name="Time (s)"), + index=pd.Index(ch_names, name="Channel"), + ) n_chans = len(ch_names) vsize = 0.2 * n_chans # First, plot the "raw" scores. - fig, ax = plt.subplots(1, 2, figsize=(20, vsize), layout='constrained') - fig.suptitle(title, fontsize=16, fontweight='bold') - sns.heatmap(data=data_to_plot, cmap='Reds_r', vmin=0, vmax=1, - cbar_kws=dict(label='Score'), ax=ax[0]) - [ax[0].axvline(x, ls='dashed', lw=0.25, dashes=(25, 15), color='gray') - for x in range(1, len(times))] - ax[0].set_title('All Scores', fontweight='bold') + fig, ax = plt.subplots(1, 2, figsize=(20, vsize), layout="constrained") + fig.suptitle(title, fontsize=16, fontweight="bold") + sns.heatmap( + data=data_to_plot, + cmap="Reds_r", + vmin=0, + vmax=1, + cbar_kws=dict(label="Score"), + ax=ax[0], + ) + [ + ax[0].axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray") + for x in range(1, len(times)) + ] + ax[0].set_title("All Scores", fontweight="bold") markbad(raw, ax[0]) # Now, adjust the color range to highlight segments that exceeded the # limit. - data_to_plot = pd.DataFrame(data=scores > threshold, - columns=pd.Index(cols, name='Time (s)'), - index=pd.Index(ch_names, name='Channel')) - sns.heatmap(data=data_to_plot, vmin=0, vmax=1, - cmap='Reds_r', cbar_kws=dict(label='Score'), ax=ax[1]) - [ax[1].axvline(x, ls='dashed', lw=0.25, dashes=(25, 15), color='gray') - for x in range(1, len(times))] - ax[1].set_title('Scores < Limit', fontweight='bold') + data_to_plot = pd.DataFrame( + data=scores > threshold, + columns=pd.Index(cols, name="Time (s)"), + index=pd.Index(ch_names, name="Channel"), + ) + sns.heatmap( + data=data_to_plot, + vmin=0, + vmax=1, + cmap="Reds_r", + cbar_kws=dict(label="Score"), + ax=ax[1], + ) + [ + ax[1].axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray") + for x in range(1, len(times)) + ] + ax[1].set_title("Scores < Limit", fontweight="bold") markbad(raw, ax[1]) return fig def markbad(raw, ax): - - [ax.axhline(y + 0.5, ls='solid', lw=2, color='black') - for y in np.where([ch in raw.info['bads'] for ch in raw.ch_names])[0]] + [ + ax.axhline(y + 0.5, ls="solid", lw=2, color="black") + for y in np.where([ch in raw.info["bads"] for ch in raw.ch_names])[0] + ] return ax diff --git a/mne_nirs/visualisation/tests/test_quality.py b/mne_nirs/visualisation/tests/test_quality.py index bfd9a3643..2500d4fcc 100644 --- a/mne_nirs/visualisation/tests/test_quality.py +++ b/mne_nirs/visualisation/tests/test_quality.py @@ -1,16 +1,16 @@ import os -import pytest + import mne +import pytest from mne_nirs.preprocessing import peak_power from mne_nirs.visualisation import plot_timechannel_quality_metric -@pytest.mark.filterwarnings('ignore:.*nilearn.glm module is experimental.*:') +@pytest.mark.filterwarnings("ignore:.*nilearn.glm module is experimental.*:") def test_peak_power(): - fnirs_data_folder = mne.datasets.fnirs_motor.data_path() - fnirs_raw_dir = os.path.join(fnirs_data_folder, 'Participant-1') + fnirs_raw_dir = os.path.join(fnirs_data_folder, "Participant-1") raw = mne.io.read_raw_nirx(fnirs_raw_dir, verbose=True).load_data() raw = mne.preprocessing.nirs.optical_density(raw) diff --git a/mne_nirs/visualisation/tests/test_visualisation.py b/mne_nirs/visualisation/tests/test_visualisation.py index 97b08872c..2c34d4059 100644 --- a/mne_nirs/visualisation/tests/test_visualisation.py +++ b/mne_nirs/visualisation/tests/test_visualisation.py @@ -2,29 +2,26 @@ # # License: BSD (3-clause) +import warnings from collections import defaultdict -import pytest -import numpy as np import mne -import mne_nirs -import warnings - +import numpy as np +import pytest from mne.datasets import testing from mne.utils import catch_logging, check_version -from mne_nirs.experimental_design.tests.test_experimental_design import \ - _load_dataset +import mne_nirs from mne_nirs.experimental_design import make_first_level_design_matrix +from mne_nirs.experimental_design.tests.test_experimental_design import _load_dataset from mne_nirs.statistics import run_glm -from mne_nirs.visualisation import plot_glm_surface_projection -from mne_nirs.utils import glm_to_tidy from mne_nirs.statistics.tests.test_glm_type import _get_glm_result - +from mne_nirs.utils import glm_to_tidy +from mne_nirs.visualisation import plot_glm_surface_projection testing_path = testing.data_path(download=False) -raw_path = str(testing_path) + '/NIRx/nirscout/nirx_15_2_recording_w_short' -subjects_dir = str(testing_path) + '/subjects' +raw_path = str(testing_path) + "/NIRx/nirscout/nirx_15_2_recording_w_short" +subjects_dir = str(testing_path) + "/subjects" def test_plot_nirs_source_detector_pyvista(requires_pyvista): @@ -36,32 +33,36 @@ def test_plot_nirs_source_detector_pyvista(requires_pyvista): warnings.filterwarnings("ignore", category=DeprecationWarning) mne_nirs.visualisation.plot_nirs_source_detector( np.random.randn(len(raw.ch_names)), - raw.info, show_axes=True, - subject='fsaverage', - trans='fsaverage', - surfaces=['white'], + raw.info, + show_axes=True, + subject="fsaverage", + trans="fsaverage", + surfaces=["white"], fnirs=False, subjects_dir=subjects_dir, - verbose=True) + verbose=True, + ) mne_nirs.visualisation.plot_nirs_source_detector( np.abs(np.random.randn(len(raw.ch_names))) + 5, - raw.info, show_axes=True, - subject='fsaverage', - trans='fsaverage', - surfaces=['white'], + raw.info, + show_axes=True, + subject="fsaverage", + trans="fsaverage", + surfaces=["white"], fnirs=False, subjects_dir=subjects_dir, - verbose=True) + verbose=True, + ) def test_run_plot_GLM_topo(): raw_intensity = _load_dataset() raw_intensity.crop(450, 600) # Keep the test fast - design_matrix = make_first_level_design_matrix(raw_intensity, - drift_order=1, - drift_model='polynomial') + design_matrix = make_first_level_design_matrix( + raw_intensity, drift_order=1, drift_model="polynomial" + ) raw_od = mne.preprocessing.nirs.optical_density(raw_intensity) raw_haemo = mne.preprocessing.nirs.beer_lambert_law(raw_od, ppf=0.1) glm_estimates = run_glm(raw_haemo, design_matrix) @@ -70,11 +71,11 @@ def test_run_plot_GLM_topo(): assert len(fig.axes) == 12 # Two conditions * two chroma + 2 x colorbar - fig = glm_estimates.plot_topo(conditions=['A', 'B']) + fig = glm_estimates.plot_topo(conditions=["A", "B"]) assert len(fig.axes) == 6 # One conditions * two chroma + 2 x colorbar - fig = glm_estimates.plot_topo(conditions=['A']) + fig = glm_estimates.plot_topo(conditions=["A"]) assert len(fig.axes) == 4 @@ -82,16 +83,17 @@ def test_run_plot_GLM_contrast_topo(): raw_intensity = _load_dataset() raw_intensity.crop(450, 600) # Keep the test fast - design_matrix = make_first_level_design_matrix(raw_intensity, - drift_order=1, - drift_model='polynomial') + design_matrix = make_first_level_design_matrix( + raw_intensity, drift_order=1, drift_model="polynomial" + ) raw_od = mne.preprocessing.nirs.optical_density(raw_intensity) raw_haemo = mne.preprocessing.nirs.beer_lambert_law(raw_od, ppf=0.1) glm_est = run_glm(raw_haemo, design_matrix) contrast_matrix = np.eye(design_matrix.shape[1]) - basic_conts = dict([(column, contrast_matrix[i]) - for i, column in enumerate(design_matrix.columns)]) - contrast_LvR = basic_conts['A'] - basic_conts['B'] + basic_conts = dict( + [(column, contrast_matrix[i]) for i, column in enumerate(design_matrix.columns)] + ) + contrast_LvR = basic_conts["A"] - basic_conts["B"] contrast = glm_est.compute_contrast(contrast_LvR) fig = contrast.plot_topo() assert len(fig.axes) == 3 @@ -101,17 +103,18 @@ def test_run_plot_GLM_contrast_topo_single_chroma(): raw_intensity = _load_dataset() raw_intensity.crop(450, 600) # Keep the test fast - design_matrix = make_first_level_design_matrix(raw_intensity, - drift_order=1, - drift_model='polynomial') + design_matrix = make_first_level_design_matrix( + raw_intensity, drift_order=1, drift_model="polynomial" + ) raw_od = mne.preprocessing.nirs.optical_density(raw_intensity) raw_haemo = mne.preprocessing.nirs.beer_lambert_law(raw_od, ppf=0.1) - raw_haemo = raw_haemo.pick(picks='hbo') + raw_haemo = raw_haemo.pick(picks="hbo") glm_est = run_glm(raw_haemo, design_matrix) contrast_matrix = np.eye(design_matrix.shape[1]) - basic_conts = dict([(column, contrast_matrix[i]) - for i, column in enumerate(design_matrix.columns)]) - contrast_LvR = basic_conts['A'] - basic_conts['B'] + basic_conts = dict( + [(column, contrast_matrix[i]) for i, column in enumerate(design_matrix.columns)] + ) + contrast_LvR = basic_conts["A"] - basic_conts["B"] contrast = glm_est.compute_contrast(contrast_LvR) fig = contrast.plot_topo() assert len(fig.axes) == 2 @@ -119,6 +122,7 @@ def test_run_plot_GLM_contrast_topo_single_chroma(): def test_fig_from_axes(): from mne_nirs.visualisation._plot_GLM_topo import _get_fig_from_axes + with pytest.raises(RuntimeError, match="Unable to extract figure"): _get_fig_from_axes([1, 2, 3]) @@ -128,9 +132,9 @@ def test_run_plot_GLM_projection(requires_pyvista): raw_intensity = _load_dataset() raw_intensity.crop(450, 600) # Keep the test fast - design_matrix = make_first_level_design_matrix(raw_intensity, - drift_order=1, - drift_model='polynomial') + design_matrix = make_first_level_design_matrix( + raw_intensity, drift_order=1, drift_model="polynomial" + ) raw_od = mne.preprocessing.nirs.optical_density(raw_intensity) raw_haemo = mne.preprocessing.nirs.beer_lambert_law(raw_od, ppf=0.1) glm_estimates = run_glm(raw_haemo, design_matrix) @@ -142,31 +146,42 @@ def test_run_plot_GLM_projection(requires_pyvista): # Ignore deprecation warning caused by # app.setAttribute(Qt.AA_UseHighDpiPixmaps) in mne-python warnings.filterwarnings("ignore", category=DeprecationWarning) - brain = plot_glm_surface_projection(raw_haemo.copy().pick("hbo"), - df, clim='auto', view='dorsal', - colorbar=True, size=(800, 700), - value="theta", surface='white', - subjects_dir=subjects_dir) + brain = plot_glm_surface_projection( + raw_haemo.copy().pick("hbo"), + df, + clim="auto", + view="dorsal", + colorbar=True, + size=(800, 700), + value="theta", + surface="white", + subjects_dir=subjects_dir, + ) assert type(brain) == mne.viz._brain.Brain -@pytest.mark.parametrize('fname_raw, to_1020, ch_names', [ - (raw_path, False, None), - (raw_path, True, 'numbered'), - (raw_path, True, defaultdict(lambda: '')), -]) +@pytest.mark.parametrize( + "fname_raw, to_1020, ch_names", + [ + (raw_path, False, None), + (raw_path, True, "numbered"), + (raw_path, True, defaultdict(lambda: "")), + ], +) def test_plot_3d_montage(requires_pyvista, fname_raw, to_1020, ch_names): raw = mne.io.read_raw_nirx(fname_raw) if to_1020: - need = set(sum( - (ch_name.split()[0].split('_') for ch_name in raw.ch_names), - list())) - mon = mne.channels.make_standard_montage('standard_1020') + need = set( + sum((ch_name.split()[0].split("_") for ch_name in raw.ch_names), list()) + ) + mon = mne.channels.make_standard_montage("standard_1020") mon.rename_channels({h: n for h, n in zip(mon.ch_names, need)}) raw.set_montage(mon) n_labels = len(raw.ch_names) // 2 - view_map = {'left-lat': np.arange(1, n_labels // 2), - 'caudal': np.arange(n_labels // 2, n_labels + 1)} + view_map = { + "left-lat": np.arange(1, n_labels // 2), + "caudal": np.arange(n_labels // 2, n_labels + 1), + } # We use "sample" here even though it's wrong so that we can have a head # surface with catch_logging() as log, warnings.catch_warnings(): @@ -174,25 +189,31 @@ def test_plot_3d_montage(requires_pyvista, fname_raw, to_1020, ch_names): # app.setAttribute(Qt.AA_UseHighDpiPixmaps) in mne-python warnings.filterwarnings("ignore", category=DeprecationWarning) mne_nirs.viz.plot_3d_montage( - raw.info, view_map, subject='sample', surface='white', - subjects_dir=subjects_dir, ch_names=ch_names, verbose=True) + raw.info, + view_map, + subject="sample", + surface="white", + subjects_dir=subjects_dir, + ch_names=ch_names, + verbose=True, + ) log = log.getvalue().lower() if to_1020: - assert 'automatically mapped' in log + assert "automatically mapped" in log else: - assert 'could not' in log + assert "could not" in log # surface arg -@pytest.mark.skipif(not check_version('mne', '1.0'), - reason='Needs MNE-Python 1.0') +@pytest.mark.skipif(not check_version("mne", "1.0"), reason="Needs MNE-Python 1.0") def test_glm_surface_projection(requires_pyvista): res = _get_glm_result(tmax=2974, tmin=0) with warnings.catch_warnings(): # Ignore deprecation warning caused by # app.setAttribute(Qt.AA_UseHighDpiPixmaps) in mne-python warnings.filterwarnings("ignore", category=DeprecationWarning) - res.surface_projection(condition="e3p0", view="dorsal", - surface="white", subjects_dir=subjects_dir) - with pytest.raises(KeyError, match='not found in conditions'): - res.surface_projection(condition='foo') + res.surface_projection( + condition="e3p0", view="dorsal", surface="white", subjects_dir=subjects_dir + ) + with pytest.raises(KeyError, match="not found in conditions"): + res.surface_projection(condition="foo") diff --git a/pyproject.toml b/pyproject.toml index 37c46855b..771690ca1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,9 +25,14 @@ ignore-decorators = [ [tool.ruff.per-file-ignores] "examples/*/*.py" = [ + "D103", # Missing docstring in public function "D205", # 1 blank line required between summary line and description "D400", # First line should end with a period ] +"mne_nirs/**/tests/*.py" = [ + "D103", # Missing docstring in public function + "D400", # First line should end with a period +] [tool.pytest.ini_options] # -r f (failed), E (error), s (skipped), x (xfail), X (xpassed), w (warnings) diff --git a/setup.py b/setup.py index 3df7bc2d1..fac2e2832 100644 --- a/setup.py +++ b/setup.py @@ -7,69 +7,70 @@ from setuptools import find_packages, setup # get __version__ from _version.py -ver_file = os.path.join('mne_nirs', '_version.py') +ver_file = os.path.join("mne_nirs", "_version.py") with open(ver_file) as f: exec(f.read()) -DISTNAME = 'mne-nirs' -DESCRIPTION = 'An MNE compatible package for processing near-infrared spectroscopy data.' -with codecs.open('README.rst', encoding='utf-8-sig') as f: +DISTNAME = "mne-nirs" +DESCRIPTION = ( + "An MNE compatible package for processing near-infrared spectroscopy data." +) +with codecs.open("README.rst", encoding="utf-8-sig") as f: LONG_DESCRIPTION = f.read() -MAINTAINER = 'Robert Luke' -MAINTAINER_EMAIL = 'robert.luke@mq.edu.au' -URL = 'https://mne.tools/mne-nirs/' -LICENSE = 'BSD (3-clause)' -DOWNLOAD_URL = 'https://github.com/mne-tools/mne-nirs' -VERSION = __version__ -INSTALL_REQUIRES = ['numpy>=1.11.3', - 'scipy>=0.17.1', - 'mne>=1.0', - 'h5io>=0.1.7', - 'nilearn>=0.9', - 'seaborn'], -CLASSIFIERS = ['Intended Audience :: Science/Research', - 'Intended Audience :: Developers', - 'License :: OSI Approved', - 'Programming Language :: Python', - 'Topic :: Software Development', - 'Topic :: Scientific/Engineering', - 'Operating System :: Microsoft :: Windows', - 'Operating System :: POSIX', - 'Operating System :: Unix', - 'Operating System :: MacOS', - 'Programming Language :: Python :: 3', - ] +MAINTAINER = "Robert Luke" +MAINTAINER_EMAIL = "robert.luke@mq.edu.au" +URL = "https://mne.tools/mne-nirs/" +LICENSE = "BSD (3-clause)" +DOWNLOAD_URL = "https://github.com/mne-tools/mne-nirs" +VERSION = __version__ # noqa: F821 +INSTALL_REQUIRES = ( + [ + "numpy>=1.11.3", + "scipy>=0.17.1", + "mne>=1.0", + "h5io>=0.1.7", + "nilearn>=0.9", + "seaborn", + ], +) +CLASSIFIERS = [ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved", + "Programming Language :: Python", + "Topic :: Software Development", + "Topic :: Scientific/Engineering", + "Operating System :: Microsoft :: Windows", + "Operating System :: POSIX", + "Operating System :: Unix", + "Operating System :: MacOS", + "Programming Language :: Python :: 3", +] EXTRAS_REQUIRE = { - 'tests': [ - 'pytest', - 'pytest-cov'], - 'docs': [ - 'sphinx', - 'sphinx-gallery', - 'sphinx_rtd_theme', - 'numpydoc', - 'matplotlib' - ] + "tests": ["pytest", "pytest-cov"], + "docs": ["sphinx", "sphinx-gallery", "sphinx_rtd_theme", "numpydoc", "matplotlib"], } -setup(name=DISTNAME, - maintainer=MAINTAINER, - maintainer_email=MAINTAINER_EMAIL, - description=DESCRIPTION, - license=LICENSE, - url=URL, - version=VERSION, - download_url=DOWNLOAD_URL, - long_description=LONG_DESCRIPTION, - long_description_content_type='text/x-rst', - zip_safe=False, # the package can run out of an .egg file - classifiers=CLASSIFIERS, - keywords='neuroscience neuroimaging fNIRS NIRS brain', - packages=find_packages(), - project_urls={ - 'Documentation': 'https://mne.tools/mne-nirs/', - 'Source': 'https://github.com/mne-tools/mne-nirs/', - 'Tracker': 'https://github.com/mne-tools/mne-nirs/issues/', - }, - install_requires=INSTALL_REQUIRES, - extras_require=EXTRAS_REQUIRE) +setup( + name=DISTNAME, + maintainer=MAINTAINER, + maintainer_email=MAINTAINER_EMAIL, + description=DESCRIPTION, + license=LICENSE, + url=URL, + version=VERSION, + download_url=DOWNLOAD_URL, + long_description=LONG_DESCRIPTION, + long_description_content_type="text/x-rst", + zip_safe=False, # the package can run out of an .egg file + classifiers=CLASSIFIERS, + keywords="neuroscience neuroimaging fNIRS NIRS brain", + packages=find_packages(), + project_urls={ + "Documentation": "https://mne.tools/mne-nirs/", + "Source": "https://github.com/mne-tools/mne-nirs/", + "Tracker": "https://github.com/mne-tools/mne-nirs/issues/", + }, + install_requires=INSTALL_REQUIRES, + extras_require=EXTRAS_REQUIRE, +)