Skip to content

Commit

Permalink
239 hotfix refactor (#252)
Browse files Browse the repository at this point in the history
* refactor: Reduce tests to common func

* refactor: Adjust plot_trips & plot_routes -> plot_service

* fix: Failing test in plot core updated params

* fix: re-organise plotting functions

---------

Co-authored-by: Browning <[email protected]>
  • Loading branch information
r-leyshon and CBROWN-ONS authored Feb 19, 2024
1 parent bc75e18 commit 4c8a811
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 91 deletions.
100 changes: 21 additions & 79 deletions src/transport_performance/gtfs/multi_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,8 @@ class MultiGtfsInstance:
creates a calendar table from calendar_times.
get_dates()
Get the range of dates that the gtfs(s) span.
plot_routes()
Plot a timeseries of route counts.
plot_trips()
Plot a timeseries of trip counts.
plot_service()
Plot a timeseries of route or trip counts.
Raises
------
Expand Down Expand Up @@ -738,7 +736,7 @@ def _reformat_col_names(self, col_name: str, cap_all: bool = True):
def _plot_core(
self,
df: pd.DataFrame,
count_col: str = "trip_count",
count_col: str = "routes",
width: int = 1000,
height: int = 550,
title: str = None,
Expand Down Expand Up @@ -804,8 +802,9 @@ def _plot_core(

return fig

def plot_routes(
def plot_service(
self,
service_type: str = "routes",
route_type: bool = True,
width: int = 1000,
height: int = 550,
Expand All @@ -814,10 +813,12 @@ def plot_routes(
rolling_average: Union[int, None] = None,
line_date: Union[str, None] = None,
) -> go.Figure:
"""Create a line plot of route counts over time.
"""Create a line plot of route or trip counts over time.
Parameters
----------
service_type: str, optional
Whether to plot 'routes' or 'trips'. By default 'routes'.
route_type : bool, optional
Whether or not to draw a line for each modality, by default True
width : int, optional
Expand Down Expand Up @@ -846,82 +847,23 @@ def plot_routes(
"""
# defences
_type_defence(service_type, "service_type", str)
_type_defence(route_type, "route_type", bool)
_type_defence(plotly_kwargs, "plotly_kwargs", (dict, type(None)))
if not plotly_kwargs:
plotly_kwargs = {}
# prepare data
data = self.summarise_routes().copy()
if not route_type:
data = (
data.drop("route_type", axis=1)
.groupby("date")
.sum()
.reset_index()
SERVICE_TYPES = ["routes", "trips"]
if service_type not in SERVICE_TYPES:
raise ValueError(
"`service_type` expects 'routes' or 'trips',"
f" found: {service_type}"
)
figure = self._plot_core(
data,
"route_count",
width=width,
height=height,
title=title,
kwargs=plotly_kwargs,
rolling_average=rolling_average,
line_date=line_date,
)
return figure

def plot_trips(
self,
route_type: bool = True,
width: int = 1000,
height: int = 550,
title: str = None,
plotly_kwargs: dict = None,
rolling_average: Union[int, None] = None,
line_date: Union[str, None] = None,
) -> go.Figure:
"""Create a line plot of trip counts over time.
Parameters
----------
route_type : bool, optional
Whether or not to draw a line for each modality, by default True
width : int, optional
Plot width, by default 1000
height : int, optional
Plot height, by default 550
title : str, optional
Plot title, by default None
plotly_kwargs : dict, optional
Kwargs to pass to plotly.express.line, by default None
rolling_average : Union[int, None], optional
How many days to calculate the rolling average over. When left as
None, rolling average is not used.
The rolling average is calculated from the center, meaning if ra=3,
the average will be calculated from the current date, previous date
and following date. Missing dates are imputed and treated as having
values of 0.
line_date : Union[str, None], optional
A data to draw a dashed vertical line on. Date should be in format:
YYYY-MM-DD, by default None
Returns
-------
go.Figure
The timerseries plot
"""
# NOTE: Very similar to the above function, however not enough code
# to justify creating a shared function (would probably results in a
# similar amount)
# defences
_type_defence(route_type, "route_type", bool)
_type_defence(plotly_kwargs, "plotly_kwargs", (dict, type(None)))
if not plotly_kwargs:
plotly_kwargs = {}
# prepare data
data = self.summarise_trips().copy()
if service_type == "routes":
data = self.summarise_routes().copy()
count_col = "route_count"
else:
data = self.summarise_trips().copy()
count_col = "trip_count"
if not route_type:
data = (
data.drop("route_type", axis=1)
Expand All @@ -931,7 +873,7 @@ def plot_trips(
)
figure = self._plot_core(
data,
"trip_count",
count_col=count_col,
width=width,
height=height,
title=title,
Expand Down
25 changes: 13 additions & 12 deletions tests/gtfs/test_multi_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,26 +714,27 @@ def test__plot_core(self, multi_gtfs_fixture):
found_line = avg_fig.layout["shapes"][0]["line"]["dash"]
assert found_line == "dash", "Date line not plotted"

def test_plot_routes(self, multi_gtfs_fixture):
"""General tests for .plot_routes()."""
def test_plot_service(self, multi_gtfs_fixture):
"""General tests for .plot_service()."""
# plot route_type
fig = multi_gtfs_fixture.plot_routes()
fig = multi_gtfs_fixture.plot_service(service_type="routes")
assert len(fig.data) == 2, "Not plotted by modality"
# plot without route type
fig = multi_gtfs_fixture.plot_routes(False)
fig = multi_gtfs_fixture.plot_service(
service_type="routes", route_type=False
)
assert len(fig.data) == 1, "Plot not as expected"
# rolling average + no route type
avg_fig = multi_gtfs_fixture.plot_routes(
rolling_average=7, route_type=False
avg_fig = multi_gtfs_fixture.plot_service(
service_type="routes", rolling_average=7, route_type=False
)
leg_status = avg_fig.data[0]["showlegend"]
assert not leg_status, "Multiple route types found"

def test_plot_trips(self, multi_gtfs_fixture):
"""General tests for .plot_trips()."""
# plot route_type
fig = multi_gtfs_fixture.plot_trips()
# plot trips
fig = multi_gtfs_fixture.plot_service(service_type="trips")
assert len(fig.data) == 2, "Not plotted by modality"
# plot without route type
fig = multi_gtfs_fixture.plot_trips(False)
fig = multi_gtfs_fixture.plot_service(
service_type="trips", route_type=False
)
assert len(fig.data) == 1, "Plot not as expected"

0 comments on commit 4c8a811

Please sign in to comment.