Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
bendichter committed Jun 30, 2021
1 parent 3a2cc86 commit 2591543
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 101 deletions.
8 changes: 4 additions & 4 deletions nwbwidgets/behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ def route_spatial_series(spatial_series, **kwargs):
}
elif spatial_series.data.shape[1] == 2:
dict_ = {
"over time": SeparateTracesPlotlyWidget,
"trace": SpatialSeriesTraceWidget2D,
"trial aligned": trial_align_spatial_series,
}
"over time": SeparateTracesPlotlyWidget,
"trace": SpatialSeriesTraceWidget2D,
"trial aligned": trial_align_spatial_series,
}
elif spatial_series.data.shape[1] == 3:
dict_ = {
"over time": SeparateTracesPlotlyWidget,
Expand Down
77 changes: 61 additions & 16 deletions nwbwidgets/test/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
SingleTracePlotlyWidget,
SeparateTracesPlotlyWidget,
get_timeseries_tt,
show_indexed_timeseries_plotly
show_indexed_timeseries_plotly,
)
from pynwb import TimeSeries
from pynwb.epoch import TimeIntervals
Expand All @@ -31,25 +31,40 @@ def test_timeseries_widget():

BaseGroupedTraceWidget(ts)


class TestTracesPlotlyWidget(unittest.TestCase):
def setUp(self):
data = np.random.rand(160, 3)
self.ts_multi = SpatialSeries(
name="test_timeseries", data=data, reference_frame="lowerleft", starting_time=0.0, rate=1.0
name="test_timeseries",
data=data,
reference_frame="lowerleft",
starting_time=0.0,
rate=1.0,
)
self.ts_single = TimeSeries(
name="test_timeseries", data=data[:,0], unit="m", starting_time=0.0, rate=1.0
name="test_timeseries",
data=data[:, 0],
unit="m",
starting_time=0.0,
rate=1.0,
)

def test_single_trace_widget(self):
single_wd = SingleTracePlotlyWidget(timeseries=self.ts_single)
tt = get_timeseries_tt(self.ts_single)
single_wd.controls["time_window"].value = [tt[int(len(tt)*0.2)],tt[int(len(tt)*0.4)]]
single_wd.controls["time_window"].value = [
tt[int(len(tt) * 0.2)],
tt[int(len(tt) * 0.4)],
]

def test_single_trace_widget(self):
single_wd = SeparateTracesPlotlyWidget(timeseries=self.ts_multi)
tt = get_timeseries_tt(self.ts_multi)
single_wd.controls["time_window"].value = [tt[int(len(tt)*0.2)],tt[int(len(tt)*0.4)]]
single_wd.controls["time_window"].value = [
tt[int(len(tt) * 0.2)],
tt[int(len(tt) * 0.4)],
]


class TestIndexTimeSeriesPlotly(unittest.TestCase):
Expand All @@ -63,7 +78,11 @@ def setUp(self):
rate=100.0,
)
self.ts_single = TimeSeries(
name="test_timeseries", data=data[:, 0], unit="m", starting_time=0.0, rate=100.0
name="test_timeseries",
data=data[:, 0],
unit="m",
starting_time=0.0,
rate=100.0,
)
self.tt = get_timeseries_tt(self.ts)

Expand All @@ -81,32 +100,58 @@ def test_no_args(self):
assert np.allclose(fig_out.data[2].y, self.ts.data[:, 2])

def test_value_errors(self):
time_window = [self.tt[int(len(self.tt)*0.2)], self.tt[int(len(self.tt)*0.4)]]
self.assertRaises(ValueError, show_indexed_timeseries_plotly,
timeseries=self.ts_single, istart=3, time_window=time_window)
self.assertRaises(ValueError, show_indexed_timeseries_plotly,
timeseries=self.ts_single, trace_range=[2, 5])
time_window = [
self.tt[int(len(self.tt) * 0.2)],
self.tt[int(len(self.tt) * 0.4)],
]
self.assertRaises(
ValueError,
show_indexed_timeseries_plotly,
timeseries=self.ts_single,
istart=3,
time_window=time_window,
)
self.assertRaises(
ValueError,
show_indexed_timeseries_plotly,
timeseries=self.ts_single,
trace_range=[2, 5],
)


class TestTracesPlotlyWidget(unittest.TestCase):
def setUp(self):
data = np.random.rand(160, 3)
self.ts_multi = SpatialSeries(
name="test_timeseries", data=data, reference_frame="lowerleft", starting_time=0.0, rate=1.0
name="test_timeseries",
data=data,
reference_frame="lowerleft",
starting_time=0.0,
rate=1.0,
)
self.ts_single = TimeSeries(
name="test_timeseries", data=data[:, 0], unit="m", starting_time=0.0, rate=1.0
name="test_timeseries",
data=data[:, 0],
unit="m",
starting_time=0.0,
rate=1.0,
)

def test_single_trace_widget(self):
single_wd = SingleTracePlotlyWidget(timeseries=self.ts_single)
tt = get_timeseries_tt(self.ts_single)
single_wd.controls["time_window"].value = [tt[int(len(tt)*0.2)], tt[int(len(tt)*0.4)]]
single_wd.controls["time_window"].value = [
tt[int(len(tt) * 0.2)],
tt[int(len(tt) * 0.4)],
]

def test_single_trace_widget(self):
single_wd = SeparateTracesPlotlyWidget(timeseries=self.ts_multi)
tt = get_timeseries_tt(self.ts_multi)
single_wd.controls["time_window"].value = [tt[int(len(tt)*0.2)], tt[int(len(tt)*0.4)]]
single_wd.controls["time_window"].value = [
tt[int(len(tt) * 0.2)],
tt[int(len(tt) * 0.4)],
]


class ShowTimeSeriesTestCase(unittest.TestCase):
Expand Down Expand Up @@ -159,7 +204,7 @@ def setUp(self):
data = np.random.rand(100, 10)
timestamps = [0.0]
for _ in range(data.shape[0]):
timestamps.append(timestamps[-1] + 0.75 + 0.25*np.random.rand())
timestamps.append(timestamps[-1] + 0.75 + 0.25 * np.random.rand())
self.ts_rate = TimeSeries(
name="test_timeseries_rate",
data=data,
Expand Down
100 changes: 57 additions & 43 deletions nwbwidgets/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,24 +131,26 @@ def show_indexed_timeseries_mpl(

def show_indexed_timeseries_plotly(
timeseries: TimeSeries,
istart:int=0,
istop:int=None,
time_window: list=None,
trace_range:list=None,
istart: int = 0,
istop: int = None,
time_window: list = None,
trace_range: list = None,
offsets=None,
fig: go.FigureWidget = None,
col=None,
row=None,
zero_start=False,
scatter_kwargs:dict=None,
figure_kwargs:dict=None
scatter_kwargs: dict = None,
figure_kwargs: dict = None,
):
if istart!=0 or istop is not None:
if istart != 0 or istop is not None:
if time_window is not None:
raise ValueError('input either time window or istart/stop but not both')
if not(0<=istart<timeseries.data.shape[0] and
(istop is None or 0<istop<=timeseries.data.shape[0])):
raise ValueError('enter correct istart/stop values')
raise ValueError("input either time window or istart/stop but not both")
if not (
0 <= istart < timeseries.data.shape[0]
and (istop is None or 0 < istop <= timeseries.data.shape[0])
):
raise ValueError("enter correct istart/stop values")
t_istart = istart
t_istop = istop
elif time_window is not None:
Expand All @@ -159,43 +161,49 @@ def show_indexed_timeseries_plotly(
t_istop = istop
tt = get_timeseries_tt(timeseries, istart=t_istart, istop=t_istop)
data, unit = get_timeseries_in_units(timeseries, istart=t_istart, istop=t_istop)
if len(data.shape)==1:
data = data[:,np.newaxis]
if len(data.shape) == 1:
data = data[:, np.newaxis]
if trace_range is not None:
if not(0<=trace_range[0]<data.shape[1] and 0<trace_range[1]<=data.shape[1]):
raise ValueError('enter correct trace range')
if not (
0 <= trace_range[0] < data.shape[1] and 0 < trace_range[1] <= data.shape[1]
):
raise ValueError("enter correct trace range")
trace_istart = trace_range[0]
trace_istop = trace_range[1]
else:
trace_istart = 0
trace_istop = data.shape[1]
if offsets is None:
offsets = np.zeros(trace_istop-trace_istart)
offsets = np.zeros(trace_istop - trace_istart)
if zero_start:
tt = tt - tt[0]
scatter_kwargs = dict() if scatter_kwargs is None else scatter_kwargs
if fig is None:
fig = go.FigureWidget(make_subplots(rows=1,cols=1))
fig = go.FigureWidget(make_subplots(rows=1, cols=1))
row = 1 if row is None else row
col = 1 if col is None else col
for i,trace_id in enumerate(range(trace_istart,trace_istop)):
for i, trace_id in enumerate(range(trace_istart, trace_istop)):
fig.add_trace(
go.Scattergl(
x=tt, y=data[:,trace_id]+offsets[i], mode='lines', **scatter_kwargs
x=tt, y=data[:, trace_id] + offsets[i], mode="lines", **scatter_kwargs
),
row=row, col=col)
input_figure_kwargs = dict(xaxis=dict(title_text='time (s)',
range=[tt[0], tt[-1]]),
yaxis=dict(title_text=unit if unit is not None else None),
title=timeseries.name)
row=row,
col=col,
)
input_figure_kwargs = dict(
xaxis=dict(title_text="time (s)", range=[tt[0], tt[-1]]),
yaxis=dict(title_text=unit if unit is not None else None),
title=timeseries.name,
)
if figure_kwargs is None:
figure_kwargs = dict()
input_figure_kwargs.update(figure_kwargs)
fig.update_xaxes(input_figure_kwargs.pop('xaxis'),row=row,col=col)
fig.update_yaxes(input_figure_kwargs.pop('yaxis'), row=row, col=col)
fig.update_xaxes(input_figure_kwargs.pop("xaxis"), row=row, col=col)
fig.update_yaxes(input_figure_kwargs.pop("yaxis"), row=row, col=col)
fig.update_layout(**input_figure_kwargs)
return fig


def plot_traces(
timeseries: TimeSeries,
time_window=None,
Expand Down Expand Up @@ -332,8 +340,9 @@ def __init__(
def set_out_fig(self):
timeseries = self.controls["timeseries"].value
time_window = self.controls["time_window"].value
self.out_fig = show_indexed_timeseries_plotly(timeseries=timeseries,
time_window=time_window)
self.out_fig = show_indexed_timeseries_plotly(
timeseries=timeseries, time_window=time_window
)

def on_change(change):
time_window = self.controls["time_window"].value
Expand Down Expand Up @@ -363,23 +372,24 @@ def set_out_fig(self):

if len(timeseries.data.shape) > 1:
color = DEFAULT_PLOTLY_COLORS
no_rows=timeseries.data.shape[1]
no_rows = timeseries.data.shape[1]
self.out_fig = go.FigureWidget(make_subplots(rows=no_rows, cols=1))

for i, xyz in enumerate(("x", "y", "z")[:no_rows]):
self.out_fig=show_indexed_timeseries_plotly(
timeseries=timeseries,
time_window=time_window,
trace_range=[i,i+1],
fig=self.out_fig,
col=1,
row=i+1,
scatter_kwargs=dict(marker_color=color[i%len(color)],name=xyz),
figure_kwargs=dict(yaxis=dict(title_text=xyz))
)
self.out_fig = show_indexed_timeseries_plotly(
timeseries=timeseries,
time_window=time_window,
trace_range=[i, i + 1],
fig=self.out_fig,
col=1,
row=i + 1,
scatter_kwargs=dict(marker_color=color[i % len(color)], name=xyz),
figure_kwargs=dict(yaxis=dict(title_text=xyz)),
)
else:
self.out_fig = show_indexed_timeseries_plotly(timeseries=timeseries,
time_window=time_window)
self.out_fig = show_indexed_timeseries_plotly(
timeseries=timeseries, time_window=time_window
)

def on_change(change):
time_window = self.controls["time_window"].value
Expand All @@ -399,8 +409,12 @@ def on_change(change):
for i, dd in enumerate(yy.T):
self.out_fig.data[i].x = tt
self.out_fig.data[i].y = dd
self.out_fig.update_yaxes(range=[min(dd), max(dd)], row=i+1, col=1)
self.out_fig.update_xaxes(range=[min(tt), max(tt)], row=i + 1, col=1)
self.out_fig.update_yaxes(
range=[min(dd), max(dd)], row=i + 1, col=1
)
self.out_fig.update_xaxes(
range=[min(tt), max(tt)], row=i + 1, col=1
)

self.controls["time_window"].observe(on_change)

Expand Down
Loading

0 comments on commit 2591543

Please sign in to comment.