Skip to content

Commit

Permalink
Merge branch 'main' into specconn_tfr_support
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns authored Sep 17, 2024
2 parents 160bc64 + 33680b3 commit 23dc1f2
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 7 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
repos:
# Ruff mne_connectivity
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.4
rev: v0.6.5
hooks:
- id: ruff
name: ruff lint mne_connectivity
Expand All @@ -10,7 +10,7 @@ repos:

# Ruff tutorials and examples
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.4
rev: v0.6.5
hooks:
- id: ruff
name: ruff lint tutorials and examples
Expand Down
30 changes: 30 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@
"Axes3D": "mpl_toolkits.mplot3d.axes3d.Axes3D",
"PolarAxes": "matplotlib.projections.polar.PolarAxes",
"ColorbarBase": "matplotlib.colorbar.ColorbarBase",
# sklearn
"MetadataRequest": "sklearn.utils.metadata_routing.MetadataRequest",
"estimator": "sklearn.base.BaseEstimator",
# joblib
"joblib.Parallel": "joblib.Parallel",
# nibabel
Expand Down Expand Up @@ -360,3 +363,30 @@
suppress_warnings = [
"config.cache", # our rebuild is okay
]


def fix_sklearn_inherited_docstrings(app, what, name, obj, options, lines):
"""Fix sklearn docstrings because they use autolink and we do not."""
if (
name.startswith("mne_connectivity.decoding.")
) and name.endswith(
(
".get_metadata_routing",
".fit",
".fit_transform",
".set_output",
".transform",
)
):
if ":Parameters:" in lines:
loc = lines.index(":Parameters:")
else:
loc = lines.index(":Returns:")
lines.insert(loc, "")
lines.insert(loc, ".. default-role:: autolink")
lines.insert(loc, "")


def setup(app):
"""Set up the Sphinx app."""
app.connect("autodoc-process-docstring", fix_sklearn_inherited_docstrings)
4 changes: 2 additions & 2 deletions mne_connectivity/decoding/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@
import numpy as np
from mne import Info
from mne._fiff.pick import pick_info
from mne.decoding.mixin import TransformerMixin
from mne.defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT
from mne.evoked import EvokedArray
from mne.fixes import BaseEstimator
from mne.time_frequency import csd_array_fourier, csd_array_morlet, csd_array_multitaper
from mne.utils import _check_option, _validate_type
from mne.viz.utils import plt_show
from sklearn.base import BaseEstimator, TransformerMixin

from ..spectral.epochs_multivariate import (
_CaCohEst,
Expand Down Expand Up @@ -222,6 +221,7 @@ def __init__(
# n_jobs and verbose will be checked downstream

# Store inputs
self.method = method
self.info = info
self._conn_estimator_class = _conn_estimator_class
self._indices = _indices # uses getter/setter for public parameter
Expand Down
6 changes: 3 additions & 3 deletions mne_connectivity/decoding/tests/test_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ def test_spectral_decomposition(method, mode):
epochs_transformed_2 = decomp_class_2.transform(
X=epochs[: n_epochs // 2].get_data()
)
assert_allclose(epochs_transformed, epochs_transformed_2, atol=1e-9)
assert_allclose(decomp_class.filters_, decomp_class_2.filters_, atol=1e-9)
assert_allclose(decomp_class.patterns_, decomp_class_2.patterns_, atol=1e-9)
assert_allclose(epochs_transformed, epochs_transformed_2, atol=1e-8)
assert_allclose(decomp_class.filters_, decomp_class_2.filters_, atol=1e-8)
assert_allclose(decomp_class.patterns_, decomp_class_2.patterns_, atol=1e-8)

# TEST FITTING ON ONE PIECE OF DATA AND TRANSFORMING ANOTHER
con_mv_class_unseen_data = spectral_connectivity_epochs(
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies = [
'netCDF4 >= 1.6.5',
'numpy >= 1.21',
'pandas >= 1.3.2',
'scikit-learn >= 1.2',
'scipy >= 1.4.0',
'tqdm',
'xarray >= 2023.11.0',
Expand Down

0 comments on commit 23dc1f2

Please sign in to comment.