Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculation of the spike-triggered phase and amplitude #121

Merged
merged 11 commits into from
Apr 4, 2018
171 changes: 171 additions & 0 deletions elephant/phase_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# -*- coding: utf-8 -*-
"""
Methods for performing phase analysis.

:copyright: Copyright 2014-2018 by the Elephant team, see AUTHORS.txt.
:license: Modified BSD, see LICENSE.txt for details.
"""

import numpy as np
import quantities as pq


def spike_triggered_phase(hilbert_transform, spiketrains, interpolate):
"""
Calculate the set of spike-triggered phases of an AnalogSignal.

Parameters
----------
hilbert_transform : AnalogSignal or list of AnalogSignal
AnalogSignal of the complex analytic signal (e.g., returned by the
elephant.signal_processing.hilbert()). All spike trains are compared to
this signal, if only one signal is given. Otherwise, length of
hilbert_transform must match the length of spiketrains.
spiketrains : Spiketrain or list of Spiketrain
Spiketrains on which to trigger hilbert_transform extraction
interpolate : bool
If True, the phases and amplitudes of hilbert_transform for spikes
falling between two samples of signal is interpolated. Otherwise, the
closest sample of hilbert_transform is used.

Returns
-------
phases : list of arrays
Spike-triggered phases. Entries in the list correspond to the
SpikeTrains in spiketrains. Each entry contains an array with the
spike-triggered angles (in rad) of the signal.
amp : list of arrays
Corresponding spike-triggered amplitudes.
times : list of arrays
A list of times corresponding to the signal
Corresponding times (corresponds to the spike times).

Example
-------
Create a 20 Hz oscillatory signal sampled at 1 kHz and a random Poisson
spike train:

>>> f_osc = 20. * pq.Hz
>>> f_sampling = 1 * pq.ms
>>> tlen = 100 * pq.s
>>> time_axis = np.arange(
0, tlen.magnitude,
f_sampling.rescale(pq.s).magnitude) * pq.s
>>> analogsignal = AnalogSignal(
np.sin(2 * np.pi * (f_osc * time_axis).simplified.magnitude),
units=pq.mV, t_start=0 * pq.ms, sampling_period=f_sampling)
>>> spiketrain = elephant.spike_train_generation.
homogeneous_poisson_process(
50 * pq.Hz, t_start=0.0 * ms, t_stop=tlen.rescale(pq.ms))

Calculate spike-triggered phases and amplitudes of the oscillation:
>>> phases, amps, times = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(analogsignal),
spiketrain,
interpolate=True)
"""

# Convert inputs to lists
if not isinstance(spiketrains, list):
spiketrains = [spiketrains]

if not isinstance(hilbert_transform, list):
hilbert_transform = [hilbert_transform]

# Number of signals
num_spiketrains = len(spiketrains)
num_phase = len(hilbert_transform)

if num_spiketrains != 1 and num_phase != 1 and \
num_spiketrains != num_phase:
raise ValueError(
"Number of spike trains and number of phase signals"
"must match, or either of the two must be a single signal.")

# For each trial, select the first input
start = [elem.t_start for elem in hilbert_transform]
stop = [elem.t_stop for elem in hilbert_transform]

result_phases = []
result_amps = []
result_times = []

# Step through each signal
for spiketrain_i, spiketrain in enumerate(spiketrains):
# Check which hilbert_transform AnalogSignal to look at - if there is
# only one then all spike trains relate to this one, otherwise the two
# lists of spike trains and phases are matched up
if num_phase > 1:
phase_i = spiketrain_i
else:
phase_i = 0

# Take only spikes which lie directly within the signal segment -
# ignore spikes sitting on the last sample
sttimeind = np.where(np.logical_and(
spiketrain >= start[phase_i], spiketrain < stop[phase_i]))[0]

# Find index into signal for each spike
ind_at_spike = np.round(
(spiketrain[sttimeind] - hilbert_transform[phase_i].t_start) /
hilbert_transform[phase_i].sampling_period).magnitude.astype(int)

# Extract times for speed reasons
times = hilbert_transform[phase_i].times

# Append new list to the results for this spiketrain
result_phases.append([])
result_amps.append([])
result_times.append([])

# Step through all spikes
for spike_i, ind_at_spike_j in enumerate(ind_at_spike):
# Difference vector between actual spike time and sample point,
# positive if spike time is later than sample point
dv = spiketrain[sttimeind[spike_i]] - times[ind_at_spike_j]

# Make sure ind_at_spike is to the left of the spike time
if dv < 0 and ind_at_spike_j > 0:
ind_at_spike_j = ind_at_spike_j - 1

if interpolate:
# Get relative spike occurrence between the two closest signal
# sample points
# if z->0 spike is more to the left sample
# if z->1 more to the right sample
z = (spiketrain[sttimeind[spike_i]] - times[ind_at_spike_j]) /\
hilbert_transform[phase_i].sampling_period

# Save hilbert_transform (interpolate on circle)
p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j])
p2 = np.angle(hilbert_transform[phase_i][ind_at_spike_j + 1])
result_phases[spiketrain_i].append(
np.angle(
(1 - z) * np.exp(np.complex(0, p1)) +
z * np.exp(np.complex(0, p2))))

# Save amplitude
result_amps[spiketrain_i].append(
(1 - z) * np.abs(
hilbert_transform[phase_i][ind_at_spike_j]) +
z * np.abs(hilbert_transform[phase_i][ind_at_spike_j + 1]))
else:
p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j])
result_phases[spiketrain_i].append(p1)

# Save amplitude
result_amps[spiketrain_i].append(
np.abs(hilbert_transform[phase_i][ind_at_spike_j]))

# Save time
result_times[spiketrain_i].append(spiketrain[sttimeind[spike_i]])

# Convert outputs to arrays
for i, entry in enumerate(result_phases):
result_phases[i] = np.array(entry).flatten()
for i, entry in enumerate(result_amps):
result_amps[i] = pq.Quantity(entry, units=entry[0].units).flatten()
for i, entry in enumerate(result_times):
result_times[i] = pq.Quantity(entry, units=entry[0].units).flatten()

return result_phases, result_amps, result_times
190 changes: 190 additions & 0 deletions elephant/test/test_phase_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# -*- coding: utf-8 -*-
"""
Unit tests for the phase analysis module.

:copyright: Copyright 2016 by the Elephant team, see AUTHORS.txt.
:license: Modified BSD, see LICENSE.txt for details.
"""
from __future__ import division, print_function

import unittest

from neo import SpikeTrain, AnalogSignal
import numpy as np
import quantities as pq

import elephant.phase_analysis

from numpy.ma.testutils import assert_allclose


class SpikeTriggeredPhaseTestCase(unittest.TestCase):

def setUp(self):
tlen0 = 100 * pq.s
f0 = 20. * pq.Hz
fs0 = 1 * pq.ms
t0 = np.arange(
0, tlen0.rescale(pq.s).magnitude,
fs0.rescale(pq.s).magnitude) * pq.s
self.anasig0 = AnalogSignal(
np.sin(2 * np.pi * (f0 * t0).simplified.magnitude),
units=pq.mV, t_start=0 * pq.ms, sampling_period=fs0)
self.st0 = SpikeTrain(
np.arange(50, tlen0.rescale(pq.ms).magnitude - 50, 50) * pq.ms,
t_start=0 * pq.ms, t_stop=tlen0)
self.st1 = SpikeTrain(
[100., 100.1, 100.2, 100.3, 100.9, 101.] * pq.ms,
t_start=0 * pq.ms, t_stop=tlen0)

def test_perfect_locking_one_spiketrain_one_signal(self):
phases, amps, times = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
self.st0,
interpolate=True)

assert_allclose(phases[0], - np.pi / 2.)
assert_allclose(amps[0], 1, atol=0.1)
assert_allclose(times[0].magnitude, self.st0.magnitude)
self.assertEqual(len(phases[0]), len(self.st0))
self.assertEqual(len(amps[0]), len(self.st0))
self.assertEqual(len(times[0]), len(self.st0))

def test_perfect_locking_many_spiketrains_many_signals(self):
phases, amps, times = elephant.phase_analysis.spike_triggered_phase(
[
elephant.signal_processing.hilbert(self.anasig0),
elephant.signal_processing.hilbert(self.anasig0)],
[self.st0, self.st0],
interpolate=True)

assert_allclose(phases[0], -np.pi / 2.)
assert_allclose(amps[0], 1, atol=0.1)
assert_allclose(times[0].magnitude, self.st0.magnitude)
self.assertEqual(len(phases[0]), len(self.st0))
self.assertEqual(len(amps[0]), len(self.st0))
self.assertEqual(len(times[0]), len(self.st0))

def test_perfect_locking_one_spiketrains_many_signals(self):
phases, amps, times = elephant.phase_analysis.spike_triggered_phase(
[
elephant.signal_processing.hilbert(self.anasig0),
elephant.signal_processing.hilbert(self.anasig0)],
[self.st0],
interpolate=True)

assert_allclose(phases[0], -np.pi / 2.)
assert_allclose(amps[0], 1, atol=0.1)
assert_allclose(times[0].magnitude, self.st0.magnitude)
self.assertEqual(len(phases[0]), len(self.st0))
self.assertEqual(len(amps[0]), len(self.st0))
self.assertEqual(len(times[0]), len(self.st0))

def test_perfect_locking_many_spiketrains_one_signal(self):
phases, amps, times = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
[self.st0, self.st0],
interpolate=True)

assert_allclose(phases[0], -np.pi / 2.)
assert_allclose(amps[0], 1, atol=0.1)
assert_allclose(times[0].magnitude, self.st0.magnitude)
self.assertEqual(len(phases[0]), len(self.st0))
self.assertEqual(len(amps[0]), len(self.st0))
self.assertEqual(len(times[0]), len(self.st0))

def test_interpolate(self):
phases_int, _, _ = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
self.st1,
interpolate=True)

self.assertLess(phases_int[0][0], phases_int[0][1])
self.assertLess(phases_int[0][1], phases_int[0][2])
self.assertLess(phases_int[0][2], phases_int[0][3])
self.assertLess(phases_int[0][3], phases_int[0][4])
self.assertLess(phases_int[0][4], phases_int[0][5])

phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
self.st1,
interpolate=False)

self.assertEqual(phases_noint[0][0], phases_noint[0][1])
self.assertEqual(phases_noint[0][1], phases_noint[0][2])
self.assertEqual(phases_noint[0][2], phases_noint[0][3])
self.assertEqual(phases_noint[0][3], phases_noint[0][4])
self.assertNotEqual(phases_noint[0][4], phases_noint[0][5])

# Verify that when using interpolation and the spike sits on the sample
# of the Hilbert transform, this is the same result as when not using
# interpolation with a spike slightly to the right
self.assertEqual(phases_noint[0][2], phases_int[0][0])
self.assertEqual(phases_noint[0][4], phases_int[0][0])

def test_inconsistent_numbers_spiketrains_hilbert(self):
self.assertRaises(
ValueError, elephant.phase_analysis.spike_triggered_phase,
[
elephant.signal_processing.hilbert(self.anasig0),
elephant.signal_processing.hilbert(self.anasig0)],
[self.st0, self.st0, self.st0], False)

self.assertRaises(
ValueError, elephant.phase_analysis.spike_triggered_phase,
[
elephant.signal_processing.hilbert(self.anasig0),
elephant.signal_processing.hilbert(self.anasig0)],
[self.st0, self.st0, self.st0], False)

def test_spike_earlier_than_hilbert(self):
# This is a spike clearly outside the bounds
st = SpikeTrain(
[-50, 50],
units='s', t_start=-100*pq.s, t_stop=100*pq.s)
phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
st,
interpolate=False)
self.assertEqual(len(phases_noint[0]), 1)

# This is a spike right on the border (start of the signal is at 0s,
# spike sits at t=0s). By definition of intervals in
# Elephant (left borders inclusive, right borders exclusive), this
# spike is to be considered.
st = SpikeTrain(
[0, 50],
units='s', t_start=-100*pq.s, t_stop=100*pq.s)
phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
st,
interpolate=False)
self.assertEqual(len(phases_noint[0]), 2)

def test_spike_later_than_hilbert(self):
# This is a spike clearly outside the bounds
st = SpikeTrain(
[1, 250],
units='s', t_start=-1*pq.s, t_stop=300*pq.s)
phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
st,
interpolate=False)
self.assertEqual(len(phases_noint[0]), 1)

# This is a spike right on the border (length of the signal is 100s,
# spike sits at t=100s). However, by definition of intervals in
# Elephant (left borders inclusive, right borders exclusive), this
# spike is not to be considered.
st = SpikeTrain(
[1, 100],
units='s', t_start=-1*pq.s, t_stop=200*pq.s)
phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
st,
interpolate=False)
self.assertEqual(len(phases_noint[0]), 1)


if __name__ == '__main__':
unittest.main()