Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/initial cleanup #12916

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 227 additions & 0 deletions mne/preprocessing/pca_obs/PCA_OBS.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import math

import numpy as np
from fit_ecgTemplate import fit_ecgTemplate

# import mne
from scipy.signal import detrend, filtfilt
from sklearn.decomposition import PCA


def PCA_OBS(data, **kwargs):
# Declare class to hold pca information
class PCAInfo:
def __init__(self):
pass

# Instantiate class
pca_info = PCAInfo()

# Check all necessary arguments sent in
required_kws = ["qrs", "filter_coords", "sr"]
assert all(
[kw in kwargs.keys() for kw in required_kws]
), "Error. Some KWs not passed into PCA_OBS."

# Extract all kwargs
qrs = kwargs["qrs"]
filter_coords = kwargs["filter_coords"]
sr = kwargs["sr"]

fs = sr

# set to baseline
data = data.reshape(-1, 1)
data = data.T
data = data - np.mean(data, axis=1)

# Allocate memory
fitted_art = np.zeros(data.shape)
peakplot = np.zeros(data.shape)

# Extract QRS events
for idx in qrs[0]:
if idx < len(peakplot[0, :]):
peakplot[0, idx] = 1 # logical indexed locations of qrs events

peak_idx = np.nonzero(peakplot)[1] # Selecting indices along columns
peak_idx = peak_idx.reshape(-1, 1)
peak_count = len(peak_idx)

################################################################
# Preparatory work - reserving memory, configure sizes, de-trend
################################################################
print("Pulse artifact subtraction in progress...Please wait!")

# define peak range based on RR
RR = np.diff(peak_idx[:, 0])
mRR = np.median(RR)
peak_range = round(mRR / 2) # Rounds to an integer
midP = peak_range + 1
baseline_range = [0, round(peak_range / 8)]
n_samples_fit = round(
peak_range / 8
) # sample fit for interpolation between fitted artifact windows

# make sure array is long enough for PArange (if not cut off last ECG peak)
pa = peak_count # Number of QRS complexes detected
while peak_idx[pa - 1, 0] + peak_range > len(data[0]):
pa = pa - 1
steps = 1 * pa
peak_count = pa

# Filter channel
eegchan = filtfilt(filter_coords, 1, data)

# build PCA matrix(heart-beat-epochs x window-length)
pcamat = np.zeros((peak_count - 1, 2 * peak_range + 1)) # [epoch x time]
# picking out heartbeat epochs
for p in range(1, peak_count):
pcamat[p - 1, :] = eegchan[
0, peak_idx[p, 0] - peak_range : peak_idx[p, 0] + peak_range + 1
]

# detrending matrix(twice)
pcamat = detrend(
pcamat, type="constant", axis=1
) # [epoch x time] - detrended along the epoch
mean_effect = np.mean(
pcamat, axis=0
) # [1 x time], contains the mean over all epochs
std_effect = np.std(pcamat, axis=0) # want mean and std of each column
dpcamat = detrend(pcamat, type="constant", axis=1) # [time x epoch]

###################################################################
# Perform PCA with sklearn
###################################################################
# run PCA(performs SVD(singular value decomposition))
pca = PCA(svd_solver="full")
pca.fit(dpcamat)
eigen_vectors = pca.components_
eigen_values = pca.explained_variance_
factor_loadings = pca.components_.T * np.sqrt(pca.explained_variance_)
pca_info.eigen_vectors = eigen_vectors
pca_info.factor_loadings = factor_loadings
pca_info.eigen_values = eigen_values
pca_info.expl_var = pca.explained_variance_ratio_

# define selected number of components using profile likelihood
pca_info.nComponents = 4
pca_info.meanEffect = mean_effect.T
nComponents = pca_info.nComponents

#######################################################################
# Make template of the ECG artefact
#######################################################################
mean_effect = mean_effect.reshape(-1, 1)
pca_template = np.c_[mean_effect, factor_loadings[:, 0:nComponents]]

###################################################################################
# Data Fitting
###################################################################################
window_start_idx = []
window_end_idx = []
for p in range(0, peak_count):
# Deals with start portion of data
if p == 0:
pre_range = peak_range
post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2)
if post_range > peak_range:
post_range = peak_range
try:
post_idx_nextPeak = []
fitted_art, post_idx_nextPeak = fit_ecgTemplate(
data,
pca_template,
peak_idx[p],
peak_range,
pre_range,
post_range,
baseline_range,
midP,
fitted_art,
post_idx_nextPeak,
n_samples_fit,
)
# Appending to list instead of using counter
window_start_idx.append(peak_idx[p] - peak_range)
window_end_idx.append(peak_idx[p] + peak_range)
except Exception as e:
print(f"Cannot fit first ECG epoch. Reason: {e}")

# Deals with last edge of data
elif p == peak_count:
print("On last section - almost there!")
try:
pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2)
post_range = peak_range
if pre_range > peak_range:
pre_range = peak_range
fitted_art, _ = fit_ecgTemplate(
data,
pca_template,
peak_idx(p),
peak_range,
pre_range,
post_range,
baseline_range,
midP,
fitted_art,
post_idx_nextPeak,
n_samples_fit,
)
window_start_idx.append(peak_idx[p] - peak_range)
window_end_idx.append(peak_idx[p] + peak_range)
except Exception as e:
print(f"Cannot fit last ECG epoch. Reason: {e}")

# Deals with middle portion of data
else:
try:
# ---------------- Processing of central data - --------------------
# cycle through peak artifacts identified by peakplot
pre_range = math.floor((peak_idx[p] - peak_idx[p - 1]) / 2)
post_range = math.floor((peak_idx[p + 1] - peak_idx[p]) / 2)
if pre_range >= peak_range:
pre_range = peak_range
if post_range > peak_range:
post_range = peak_range

aTemplate = pca_template[
midP - peak_range - 1 : midP + peak_range + 1, :
]
fitted_art, post_idx_nextPeak = fit_ecgTemplate(
data,
aTemplate,
peak_idx[p],
peak_range,
pre_range,
post_range,
baseline_range,
midP,
fitted_art,
post_idx_nextPeak,
n_samples_fit,
)
window_start_idx.append(peak_idx[p] - peak_range)
window_end_idx.append(peak_idx[p] + peak_range)
except Exception as e:
print(f"Cannot fit middle section of data. Reason: {e}")

# Actually subtract the artefact, return needs to be the same shape as input data
# One sample shift purely due to the fact the r-peaks are currently detected in MATLAB
data = data.reshape(-1)
fitted_art = fitted_art.reshape(-1)

# One sample shift for my actual data (introduced using matlab r timings)
# data_ = np.zeros(len(data))
# data_[0] = data[0]
# data_[1:] = data[1:] - fitted_art[:-1]
# data = data_

# Original code is this:
data -= fitted_art
data = data.T.reshape(-1)

# Can only return data
return data
98 changes: 98 additions & 0 deletions mne/preprocessing/pca_obs/fit_ecgTemplate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import numpy as np
from scipy.interpolate import PchipInterpolator as pchip
from scipy.signal import detrend


def fit_ecgTemplate(
data,
pca_template,
aPeak_idx,
peak_range,
pre_range,
post_range,
baseline_range,
midP,
fitted_art,
post_idx_previousPeak,
n_samples_fit,
):
# Declare class to hold ecg fit information
class fitECG:
def __init__(self):
pass

# Instantiate class
fitecg = fitECG()

# post_idx_nextpeak is passed in in PCA_OBS, used here as post_idx_previouspeak
# Then nextpeak is returned at the end and the process repeats
# select window of template
template = pca_template[midP - peak_range - 1 : midP + peak_range + 1, :]

# select window of data and detrend it
slice = data[0, aPeak_idx[0] - peak_range : aPeak_idx[0] + peak_range + 1]
detrended_data = detrend(slice.reshape(-1), type="constant")

# maps data on template and then maps it again back to the sensor space
least_square = np.linalg.lstsq(template, detrended_data, rcond=None)
pad_fit = np.dot(template, least_square[0])

# fit artifact, I already loop through externally channel to channel
fitted_art[0, aPeak_idx[0] - pre_range - 1 : aPeak_idx[0] + post_range] = pad_fit[
midP - pre_range - 1 : midP + post_range
].T

fitecg.fitted_art = fitted_art
fitecg.template = template
fitecg.detrended_data = detrended_data
fitecg.pad_fit = pad_fit
fitecg.aPeak_idx = aPeak_idx
fitecg.midP = midP
fitecg.peak_range = peak_range
fitecg.data = data

post_idx_nextPeak = [aPeak_idx[0] + post_range]

# Check it's not empty
if len(post_idx_previousPeak) != 0:
# interpolate time between peaks
intpol_window = np.ceil(
[post_idx_previousPeak[0], aPeak_idx[0] - pre_range]
).astype("int") # interpolation window
fitecg.intpol_window = intpol_window

if intpol_window[0] < intpol_window[1]:
# Piecewise Cubic Hermite Interpolating Polynomial(PCHIP) + replace EEG data

# You have x_fit which is two slices on either side of the interpolation window endpoints
# You have y_fit which is the y vals corresponding to x values above
# You have x_interpol which is the time points between the two slices in x_fit that you want to interpolate
# You have y_interpol which is values from pchip at the time points specified in x_interpol
x_interpol = np.arange(
intpol_window[0], intpol_window[1] + 1, 1
) # points to be interpolated in pt - the gap between the endpoints of the window
x_fit = np.concatenate(
[
np.arange(
intpol_window[0] - n_samples_fit, intpol_window[0] + 1, 1
),
np.arange(
intpol_window[1], intpol_window[1] + n_samples_fit + 1, 1
),
]
) # Entire range of x values in this step (taking some number of samples before and after the window)
y_fit = fitted_art[0, x_fit]
y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation

# Then make fitted artefact in the desired range equal to the completed fit above
fitted_art[0, post_idx_previousPeak[0] : aPeak_idx[0] - pre_range + 1] = (
y_interpol
)

fitecg.x_fit = x_fit
fitecg.y_fit = y_fit
fitecg.x_interpol = x_interpol
fitecg.y_interpol = y_interpol
fitecg.fitted_art = fitted_art # Reassign if we've gone into this loop

return fitted_art, post_idx_nextPeak
55 changes: 55 additions & 0 deletions mne/preprocessing/pca_obs/pchip_interpolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Function to interpolate based on PCHIP rather than MNE inbuilt linear option

# import mne
import numpy as np
from scipy.interpolate import PchipInterpolator as pchip


def PCHIP_interpolation(data, **kwargs):
# Check all necessary arguments sent in
required_kws = ["trigger_indices", "interpol_window_sec", "fs"]
assert all(
[kw in kwargs.keys() for kw in required_kws]
), "Error. Some KWs not passed into PCA_OBS."

# Extract all kwargs - more elegant ways to do this
fs = kwargs["fs"]
interpol_window_sec = kwargs["interpol_window_sec"]
trigger_indices = kwargs["trigger_indices"]

# Convert intpol window to msec then convert to samples
pre_window = round((interpol_window_sec[0] * 1000) * fs / 1000) # in samples
post_window = round((interpol_window_sec[1] * 1000) * fs / 1000) # in samples
intpol_window = np.ceil([pre_window, post_window]).astype(
int
) # interpolation window

n_samples_fit = (
5 # number of samples before and after cut used for interpolation fit
)

x_fit_raw = np.concatenate(
[
np.arange(intpol_window[0] - n_samples_fit - 1, intpol_window[0], 1),
np.arange(intpol_window[1] + 1, intpol_window[1] + n_samples_fit + 2, 1),
]
)
x_interpol_raw = np.arange(
intpol_window[0], intpol_window[1] + 1, 1
) # points to be interpolated; in pt

for ii in np.arange(0, len(trigger_indices)): # loop through all stimulation events
x_fit = trigger_indices[ii] + x_fit_raw # fit point latencies for this event
x_interpol = (
trigger_indices[ii] + x_interpol_raw
) # latencies for to-be-interpolated data points

# Data is just a string of values
y_fit = data[x_fit] # y values to be fitted
y_interpol = pchip(x_fit, y_fit)(x_interpol) # perform interpolation
data[x_interpol] = y_interpol # replace in data

if np.mod(ii, 100) == 0: # talk to the operator every 100th trial
print(f"stimulation event {ii} \n")

return data
Loading
Loading