diff --git a/src/transport_performance/gtfs/multi_validation.py b/src/transport_performance/gtfs/multi_validation.py index e3a75e12..a3bc443a 100644 --- a/src/transport_performance/gtfs/multi_validation.py +++ b/src/transport_performance/gtfs/multi_validation.py @@ -75,6 +75,10 @@ class MultiGtfsInstance: creates a calendar table from calendar_times. get_dates() Get the range of dates that the gtfs(s) span. + plot_routes() + Plot a timeseries of route counts. + plot_trips() + Plot a timeseries of trip counts. Raises ------ @@ -765,7 +769,7 @@ def _plot_core( kwargs["width"] = width kwargs["height"] = height if title: - PLOT_TITLE["title"] = title + PLOT_TITLE["text"] = title # plotting fig = px.line(df, x="date", y=count_col, labels=LABEL_FORMAT, **kwargs) fig.update_layout(title=PLOT_TITLE) @@ -824,3 +828,59 @@ def plot_routes( kwargs=plotly_kwargs, ) return figure + + def plot_trips( + self, + route_type: bool = True, + width: int = 1000, + height: int = 550, + title: str = None, + plotly_kwargs: dict = None, + ) -> go.Figure: + """Create a line plot of trip counts over time. + + Parameters + ---------- + route_type : bool, optional + Whether or not to draw a line for each modality, by default True + width : int, optional + Plot width, by default 1000 + height : int, optional + Plot height, by default 550 + title : str, optional + Plot title, by default None + plotly_kwargs : dict, optional + Kwargs to pass to plotly.express.line, by default None + + Returns + ------- + go.Figure + The timerseries plot + + """ + # NOTE: Very similar to the above function, however not enough code + # to justify creating a shared function (would probably results in a + # similar amount) + # defences + _type_defence(route_type, "route_type", bool) + _type_defence(plotly_kwargs, "plotly_kwargs", (dict, type(None))) + if not plotly_kwargs: + plotly_kwargs = {} + # prepare data + data = self.summarise_trips().copy() + if not route_type: + data = ( + data.drop("route_type", axis=1) + .groupby("date") + .sum() + .reset_index() + ) + figure = self._plot_core( + data, + "trip_count", + width=width, + height=height, + title=title, + kwargs=plotly_kwargs, + ) + return figure diff --git a/tests/gtfs/test_multi_validation.py b/tests/gtfs/test_multi_validation.py index 3b830501..8ec5c9e4 100644 --- a/tests/gtfs/test_multi_validation.py +++ b/tests/gtfs/test_multi_validation.py @@ -13,6 +13,7 @@ import pandas as pd import folium from pyprojroot import here +import plotly.graph_objs as go from transport_performance.gtfs.multi_validation import ( MultiGtfsInstance, @@ -665,3 +666,35 @@ def test_get_dates(self, multi_gtfs_fixture, multi_gtfs_altered_fixture): len(multi_gtfs_altered_fixture.get_dates(return_range=False)) == 6 ), "Unexpected number of dates" pass + + def test__plot_core(self, multi_gtfs_fixture, tmp_path): + """General tests for _plot_core().""" + # route summary + data = multi_gtfs_fixture.summarise_routes() + route_fig = multi_gtfs_fixture._plot_core(data, "route_count") + assert isinstance(route_fig, go.Figure), "Route counts not plotted" + # trip summary + data = multi_gtfs_fixture.summarise_trips() + trip_fig = multi_gtfs_fixture._plot_core(data, "trip_count") + assert isinstance(trip_fig, go.Figure), "Trip counts not plotted" + # trip summary with custom title + trip_fig = multi_gtfs_fixture._plot_core( + data, "trip_count", title="test" + ) + found_title = trip_fig.layout["title"]["text"] + assert found_title == "test", "Title not as expected" + # trip summary with custom dimensions + trip_fig = multi_gtfs_fixture._plot_core( + data, "trip_count", height=100, width=150 + ) + found_height = trip_fig.layout["height"] + found_width = trip_fig.layout["width"] + assert found_height == 100, "Height not as expected" + assert found_width == 150, "Width not as expected" + # custom kwargs + trip_fig = multi_gtfs_fixture._plot_core( + data, "trip_count", kwargs={"markers": True} + ) + assert ( + trip_fig.data[0]["mode"] == "markers+lines" + ), "Markers not plotted"