diff --git a/multirecording_spikeanalysis.py b/multirecording_spikeanalysis.py index 85df809..281037c 100644 --- a/multirecording_spikeanalysis.py +++ b/multirecording_spikeanalysis.py @@ -6,6 +6,7 @@ import pandas as pd import matplotlib.pyplot as plt +from collections import defaultdict from scipy.stats import sem, ranksums, fisher_exact, wilcoxon from statistics import mean, StatisticsError from sklearn.decomposition import PCA @@ -439,9 +440,9 @@ def get_spike_specs(self): def get_unit_timestamps(self): """ - creates a dictionary of units to spike timestamps - keys are unit ids (int) and values are spike timestamps for - that unit (numpy arrays)and assigns dictionary to self.unit_timestamps + Creates a dictionary of units to spike timestamps. + Keys are unit ids (int) and values are spike timestamps for that unit (numpy arrays), + and assigns dictionary to self.unit_timestamps. Args: None @@ -450,21 +451,16 @@ def get_unit_timestamps(self): None """ - unit_timestamps = {} - for spike in range(len(self.timestamps_var)): - if self.unit_array[spike] in unit_timestamps.keys(): - timestamp_list = unit_timestamps[self.unit_array[spike]] - timestamp_list = np.append( - timestamp_list, self.timestamps_var[spike] - ) - unit_timestamps[self.unit_array[spike]] = timestamp_list - else: - unit_timestamps[self.unit_array[spike]] = self.timestamps_var[ - spike - ] + # Initialize a defaultdict for holding lists + unit_timestamps = defaultdict(list) + default_dict = defaultdict(lambda: np.array([])) - self.unit_timestamps = unit_timestamps + # Loop through each spike only once + for spike, unit in enumerate(self.unit_array): + # Append the timestamp to the list for the corresponding unit + unit_timestamps[unit].append(self.timestamps_var[spike]) + self.unit_timestamps = unit_timestamps class EphysRecordingCollection: """