diff --git a/elephant/phase_analysis.py b/elephant/phase_analysis.py new file mode 100644 index 000000000..99a6f1982 --- /dev/null +++ b/elephant/phase_analysis.py @@ -0,0 +1,171 @@ +# -*- coding: utf-8 -*- +""" +Methods for performing phase analysis. + +:copyright: Copyright 2014-2018 by the Elephant team, see AUTHORS.txt. +:license: Modified BSD, see LICENSE.txt for details. +""" + +import numpy as np +import quantities as pq + + +def spike_triggered_phase(hilbert_transform, spiketrains, interpolate): + """ + Calculate the set of spike-triggered phases of an AnalogSignal. + + Parameters + ---------- + hilbert_transform : AnalogSignal or list of AnalogSignal + AnalogSignal of the complex analytic signal (e.g., returned by the + elephant.signal_processing.hilbert()). All spike trains are compared to + this signal, if only one signal is given. Otherwise, length of + hilbert_transform must match the length of spiketrains. + spiketrains : Spiketrain or list of Spiketrain + Spiketrains on which to trigger hilbert_transform extraction + interpolate : bool + If True, the phases and amplitudes of hilbert_transform for spikes + falling between two samples of signal is interpolated. Otherwise, the + closest sample of hilbert_transform is used. + + Returns + ------- + phases : list of arrays + Spike-triggered phases. Entries in the list correspond to the + SpikeTrains in spiketrains. Each entry contains an array with the + spike-triggered angles (in rad) of the signal. + amp : list of arrays + Corresponding spike-triggered amplitudes. + times : list of arrays + A list of times corresponding to the signal + Corresponding times (corresponds to the spike times). + + Example + ------- + Create a 20 Hz oscillatory signal sampled at 1 kHz and a random Poisson + spike train: + + >>> f_osc = 20. * pq.Hz + >>> f_sampling = 1 * pq.ms + >>> tlen = 100 * pq.s + >>> time_axis = np.arange( + 0, tlen.magnitude, + f_sampling.rescale(pq.s).magnitude) * pq.s + >>> analogsignal = AnalogSignal( + np.sin(2 * np.pi * (f_osc * time_axis).simplified.magnitude), + units=pq.mV, t_start=0 * pq.ms, sampling_period=f_sampling) + >>> spiketrain = elephant.spike_train_generation. + homogeneous_poisson_process( + 50 * pq.Hz, t_start=0.0 * ms, t_stop=tlen.rescale(pq.ms)) + + Calculate spike-triggered phases and amplitudes of the oscillation: + >>> phases, amps, times = elephant.phase_analysis.spike_triggered_phase( + elephant.signal_processing.hilbert(analogsignal), + spiketrain, + interpolate=True) + """ + + # Convert inputs to lists + if not isinstance(spiketrains, list): + spiketrains = [spiketrains] + + if not isinstance(hilbert_transform, list): + hilbert_transform = [hilbert_transform] + + # Number of signals + num_spiketrains = len(spiketrains) + num_phase = len(hilbert_transform) + + if num_spiketrains != 1 and num_phase != 1 and \ + num_spiketrains != num_phase: + raise ValueError( + "Number of spike trains and number of phase signals" + "must match, or either of the two must be a single signal.") + + # For each trial, select the first input + start = [elem.t_start for elem in hilbert_transform] + stop = [elem.t_stop for elem in hilbert_transform] + + result_phases = [] + result_amps = [] + result_times = [] + + # Step through each signal + for spiketrain_i, spiketrain in enumerate(spiketrains): + # Check which hilbert_transform AnalogSignal to look at - if there is + # only one then all spike trains relate to this one, otherwise the two + # lists of spike trains and phases are matched up + if num_phase > 1: + phase_i = spiketrain_i + else: + phase_i = 0 + + # Take only spikes which lie directly within the signal segment - + # ignore spikes sitting on the last sample + sttimeind = np.where(np.logical_and( + spiketrain >= start[phase_i], spiketrain < stop[phase_i]))[0] + + # Find index into signal for each spike + ind_at_spike = np.round( + (spiketrain[sttimeind] - hilbert_transform[phase_i].t_start) / + hilbert_transform[phase_i].sampling_period).magnitude.astype(int) + + # Extract times for speed reasons + times = hilbert_transform[phase_i].times + + # Append new list to the results for this spiketrain + result_phases.append([]) + result_amps.append([]) + result_times.append([]) + + # Step through all spikes + for spike_i, ind_at_spike_j in enumerate(ind_at_spike): + # Difference vector between actual spike time and sample point, + # positive if spike time is later than sample point + dv = spiketrain[sttimeind[spike_i]] - times[ind_at_spike_j] + + # Make sure ind_at_spike is to the left of the spike time + if dv < 0 and ind_at_spike_j > 0: + ind_at_spike_j = ind_at_spike_j - 1 + + if interpolate: + # Get relative spike occurrence between the two closest signal + # sample points + # if z->0 spike is more to the left sample + # if z->1 more to the right sample + z = (spiketrain[sttimeind[spike_i]] - times[ind_at_spike_j]) /\ + hilbert_transform[phase_i].sampling_period + + # Save hilbert_transform (interpolate on circle) + p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j]) + p2 = np.angle(hilbert_transform[phase_i][ind_at_spike_j + 1]) + result_phases[spiketrain_i].append( + np.angle( + (1 - z) * np.exp(np.complex(0, p1)) + + z * np.exp(np.complex(0, p2)))) + + # Save amplitude + result_amps[spiketrain_i].append( + (1 - z) * np.abs( + hilbert_transform[phase_i][ind_at_spike_j]) + + z * np.abs(hilbert_transform[phase_i][ind_at_spike_j + 1])) + else: + p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j]) + result_phases[spiketrain_i].append(p1) + + # Save amplitude + result_amps[spiketrain_i].append( + np.abs(hilbert_transform[phase_i][ind_at_spike_j])) + + # Save time + result_times[spiketrain_i].append(spiketrain[sttimeind[spike_i]]) + + # Convert outputs to arrays + for i, entry in enumerate(result_phases): + result_phases[i] = np.array(entry).flatten() + for i, entry in enumerate(result_amps): + result_amps[i] = pq.Quantity(entry, units=entry[0].units).flatten() + for i, entry in enumerate(result_times): + result_times[i] = pq.Quantity(entry, units=entry[0].units).flatten() + + return result_phases, result_amps, result_times diff --git a/elephant/test/test_phase_analysis.py b/elephant/test/test_phase_analysis.py new file mode 100644 index 000000000..31ba8fd25 --- /dev/null +++ b/elephant/test/test_phase_analysis.py @@ -0,0 +1,190 @@ +# -*- coding: utf-8 -*- +""" +Unit tests for the phase analysis module. + +:copyright: Copyright 2016 by the Elephant team, see AUTHORS.txt. +:license: Modified BSD, see LICENSE.txt for details. +""" +from __future__ import division, print_function + +import unittest + +from neo import SpikeTrain, AnalogSignal +import numpy as np +import quantities as pq + +import elephant.phase_analysis + +from numpy.ma.testutils import assert_allclose + + +class SpikeTriggeredPhaseTestCase(unittest.TestCase): + + def setUp(self): + tlen0 = 100 * pq.s + f0 = 20. * pq.Hz + fs0 = 1 * pq.ms + t0 = np.arange( + 0, tlen0.rescale(pq.s).magnitude, + fs0.rescale(pq.s).magnitude) * pq.s + self.anasig0 = AnalogSignal( + np.sin(2 * np.pi * (f0 * t0).simplified.magnitude), + units=pq.mV, t_start=0 * pq.ms, sampling_period=fs0) + self.st0 = SpikeTrain( + np.arange(50, tlen0.rescale(pq.ms).magnitude - 50, 50) * pq.ms, + t_start=0 * pq.ms, t_stop=tlen0) + self.st1 = SpikeTrain( + [100., 100.1, 100.2, 100.3, 100.9, 101.] * pq.ms, + t_start=0 * pq.ms, t_stop=tlen0) + + def test_perfect_locking_one_spiketrain_one_signal(self): + phases, amps, times = elephant.phase_analysis.spike_triggered_phase( + elephant.signal_processing.hilbert(self.anasig0), + self.st0, + interpolate=True) + + assert_allclose(phases[0], - np.pi / 2.) + assert_allclose(amps[0], 1, atol=0.1) + assert_allclose(times[0].magnitude, self.st0.magnitude) + self.assertEqual(len(phases[0]), len(self.st0)) + self.assertEqual(len(amps[0]), len(self.st0)) + self.assertEqual(len(times[0]), len(self.st0)) + + def test_perfect_locking_many_spiketrains_many_signals(self): + phases, amps, times = elephant.phase_analysis.spike_triggered_phase( + [ + elephant.signal_processing.hilbert(self.anasig0), + elephant.signal_processing.hilbert(self.anasig0)], + [self.st0, self.st0], + interpolate=True) + + assert_allclose(phases[0], -np.pi / 2.) + assert_allclose(amps[0], 1, atol=0.1) + assert_allclose(times[0].magnitude, self.st0.magnitude) + self.assertEqual(len(phases[0]), len(self.st0)) + self.assertEqual(len(amps[0]), len(self.st0)) + self.assertEqual(len(times[0]), len(self.st0)) + + def test_perfect_locking_one_spiketrains_many_signals(self): + phases, amps, times = elephant.phase_analysis.spike_triggered_phase( + [ + elephant.signal_processing.hilbert(self.anasig0), + elephant.signal_processing.hilbert(self.anasig0)], + [self.st0], + interpolate=True) + + assert_allclose(phases[0], -np.pi / 2.) + assert_allclose(amps[0], 1, atol=0.1) + assert_allclose(times[0].magnitude, self.st0.magnitude) + self.assertEqual(len(phases[0]), len(self.st0)) + self.assertEqual(len(amps[0]), len(self.st0)) + self.assertEqual(len(times[0]), len(self.st0)) + + def test_perfect_locking_many_spiketrains_one_signal(self): + phases, amps, times = elephant.phase_analysis.spike_triggered_phase( + elephant.signal_processing.hilbert(self.anasig0), + [self.st0, self.st0], + interpolate=True) + + assert_allclose(phases[0], -np.pi / 2.) + assert_allclose(amps[0], 1, atol=0.1) + assert_allclose(times[0].magnitude, self.st0.magnitude) + self.assertEqual(len(phases[0]), len(self.st0)) + self.assertEqual(len(amps[0]), len(self.st0)) + self.assertEqual(len(times[0]), len(self.st0)) + + def test_interpolate(self): + phases_int, _, _ = elephant.phase_analysis.spike_triggered_phase( + elephant.signal_processing.hilbert(self.anasig0), + self.st1, + interpolate=True) + + self.assertLess(phases_int[0][0], phases_int[0][1]) + self.assertLess(phases_int[0][1], phases_int[0][2]) + self.assertLess(phases_int[0][2], phases_int[0][3]) + self.assertLess(phases_int[0][3], phases_int[0][4]) + self.assertLess(phases_int[0][4], phases_int[0][5]) + + phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase( + elephant.signal_processing.hilbert(self.anasig0), + self.st1, + interpolate=False) + + self.assertEqual(phases_noint[0][0], phases_noint[0][1]) + self.assertEqual(phases_noint[0][1], phases_noint[0][2]) + self.assertEqual(phases_noint[0][2], phases_noint[0][3]) + self.assertEqual(phases_noint[0][3], phases_noint[0][4]) + self.assertNotEqual(phases_noint[0][4], phases_noint[0][5]) + + # Verify that when using interpolation and the spike sits on the sample + # of the Hilbert transform, this is the same result as when not using + # interpolation with a spike slightly to the right + self.assertEqual(phases_noint[0][2], phases_int[0][0]) + self.assertEqual(phases_noint[0][4], phases_int[0][0]) + + def test_inconsistent_numbers_spiketrains_hilbert(self): + self.assertRaises( + ValueError, elephant.phase_analysis.spike_triggered_phase, + [ + elephant.signal_processing.hilbert(self.anasig0), + elephant.signal_processing.hilbert(self.anasig0)], + [self.st0, self.st0, self.st0], False) + + self.assertRaises( + ValueError, elephant.phase_analysis.spike_triggered_phase, + [ + elephant.signal_processing.hilbert(self.anasig0), + elephant.signal_processing.hilbert(self.anasig0)], + [self.st0, self.st0, self.st0], False) + + def test_spike_earlier_than_hilbert(self): + # This is a spike clearly outside the bounds + st = SpikeTrain( + [-50, 50], + units='s', t_start=-100*pq.s, t_stop=100*pq.s) + phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase( + elephant.signal_processing.hilbert(self.anasig0), + st, + interpolate=False) + self.assertEqual(len(phases_noint[0]), 1) + + # This is a spike right on the border (start of the signal is at 0s, + # spike sits at t=0s). By definition of intervals in + # Elephant (left borders inclusive, right borders exclusive), this + # spike is to be considered. + st = SpikeTrain( + [0, 50], + units='s', t_start=-100*pq.s, t_stop=100*pq.s) + phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase( + elephant.signal_processing.hilbert(self.anasig0), + st, + interpolate=False) + self.assertEqual(len(phases_noint[0]), 2) + + def test_spike_later_than_hilbert(self): + # This is a spike clearly outside the bounds + st = SpikeTrain( + [1, 250], + units='s', t_start=-1*pq.s, t_stop=300*pq.s) + phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase( + elephant.signal_processing.hilbert(self.anasig0), + st, + interpolate=False) + self.assertEqual(len(phases_noint[0]), 1) + + # This is a spike right on the border (length of the signal is 100s, + # spike sits at t=100s). However, by definition of intervals in + # Elephant (left borders inclusive, right borders exclusive), this + # spike is not to be considered. + st = SpikeTrain( + [1, 100], + units='s', t_start=-1*pq.s, t_stop=200*pq.s) + phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase( + elephant.signal_processing.hilbert(self.anasig0), + st, + interpolate=False) + self.assertEqual(len(phases_noint[0]), 1) + + +if __name__ == '__main__': + unittest.main()