diff --git a/bennchplot/bennchplot.py b/bennchplot/bennchplot.py index d515c70..89ddb3e 100644 --- a/bennchplot/bennchplot.py +++ b/bennchplot/bennchplot.py @@ -63,6 +63,7 @@ def __init__(self, x_axis, additional_params=pp.additional_params, label_params=pp.label_params, time_scaling=1, + df=None, detailed_timers=True): self.x_axis = x_axis @@ -72,8 +73,8 @@ def __init__(self, x_axis, self.color_params = color_params self.label_params = label_params self.time_scaling = time_scaling + self.df = df self.detailed_timers = detailed_timers - self.load_data(data_file) self.compute_derived_quantities() @@ -92,11 +93,12 @@ def load_data(self, data_file): ------ ValueError """ - try: - self.df = pd.read_csv(data_file, delimiter=',') - except FileNotFoundError: - print('File could not be found') - quit() + if self.df is None: + try: + self.df = pd.read_csv(data_file, delimiter=',') + except FileNotFoundError: + print('File could not be found') + quit() for py_timer in ['py_time_create', 'py_time_connect']: if py_timer not in self.df: @@ -112,6 +114,13 @@ def load_data(self, data_file): 'time_construction_create': ['mean', 'std'], 'time_construction_connect': ['mean', 'std'], 'time_simulate': ['mean', 'std'], + 'time_collocate_spike_data': ['mean', 'std'], + 'time_communicate_spike_data': ['mean', 'std'], + 'time_deliver_spike_data': ['mean', 'std'], + 'time_update': ['mean', 'std'], + 'time_communicate_target_data': ['mean', 'std'], + 'time_gather_spike_data': ['mean', 'std'], + 'time_gather_target_data': ['mean', 'std'], 'time_communicate_prepare': ['mean', 'std'], 'py_time_create': ['mean', 'std'], 'py_time_connect': ['mean', 'std'], @@ -266,9 +275,9 @@ def plot_fractions(self, axis, fill_variables, fill_height = 0 for fill in fill_variables: - axis.fill_between(np.squeeze(self.df[self.x_axis]), + axis.fill_between(self.df[self.x_axis].to_numpy().squeeze(axis=1), fill_height, - np.squeeze(self.df[fill]) + fill_height, + self.df[fill].to_numpy() + fill_height, label=self.label_params[fill], facecolor=self.color_params[fill], interpolate=interpolate, @@ -277,18 +286,17 @@ def plot_fractions(self, axis, fill_variables, linewidth=0.5, edgecolor='#444444') if error: - axis.errorbar(np.squeeze(self.df[self.x_axis]), - np.squeeze(self.df[fill]) + fill_height, - yerr=np.squeeze(self.df[fill + '_std']), + axis.errorbar(self.df[self.x_axis].to_numpy().squeeze(axis=1), + self.df[fill].to_numpy() + fill_height, + yerr=self.df[fill + '_std'].to_numpy(), capsize=3, capthick=1, color='k', - fmt='none' - ) + fmt='none') fill_height += self.df[fill].to_numpy() if self.x_ticks == 'data': - axis.set_xticks(np.squeeze(self.df[self.x_axis])) + axis.set_xticks(self.df[self.x_axis].to_numpy().squeeze(axis=1)) else: axis.set_xticks(self.x_ticks) @@ -320,17 +328,17 @@ def plot_main(self, quantities, axis, log=(False, False), for y in quantities: label = self.label_params[y] if label is None else label color = self.color_params[y] if color is None else color - axis.plot(self.df[self.x_axis], - self.df[y], + axis.plot(self.df[self.x_axis].to_numpy().squeeze(axis=1), + self.df[y].to_numpy(), marker=None, label=label, color=color, linewidth=2) if error: axis.errorbar( - self.df[self.x_axis].values, - self.df[y].values, - yerr=self.df[y + '_std'].values, + self.df[self.x_axis].to_numpy().squeeze(axis=1), + self.df[y].to_numpy(), + yerr=self.df[y + '_std'].to_numpy(), marker=None, capsize=3, capthick=1, @@ -338,7 +346,7 @@ def plot_main(self, quantities, axis, log=(False, False), fmt=fmt) if self.x_ticks == 'data': - axis.set_xticks(self.df[self.x_axis].values) + axis.set_xticks(self.df[self.x_axis].to_numpy().squeeze(axis=1)) else: axis.set_xticks(self.x_ticks)