From 7bb4863ae31251b947f45773f4bbb71f05472297 Mon Sep 17 00:00:00 2001 From: Charlie Browning <121952297+CBROWN-ONS@users.noreply.github.com> Date: Tue, 20 Feb 2024 05:32:40 +0000 Subject: [PATCH] 239 plot trips and routes real (#251) * feat: add code from other branch * feat: add improvements from mirror branch * fix: CSS left in readme from merge to main * 239 hotfix refactor (#252) * refactor: Reduce tests to common func * refactor: Adjust plot_trips & plot_routes -> plot_service * fix: Failing test in plot core updated params * fix: re-organise plotting functions --------- Co-authored-by: Browning --------- Co-authored-by: r-leyshon Co-authored-by: Richard Leyshon <49126943+r-leyshon@users.noreply.github.com> --- .../gtfs/multi_validation.py | 165 ++++++++++++++++++ tests/gtfs/test_multi_validation.py | 73 ++++++++ 2 files changed, 238 insertions(+) diff --git a/src/transport_performance/gtfs/multi_validation.py b/src/transport_performance/gtfs/multi_validation.py index 84319b2b..9f53579f 100644 --- a/src/transport_performance/gtfs/multi_validation.py +++ b/src/transport_performance/gtfs/multi_validation.py @@ -12,6 +12,8 @@ import pandas as pd import folium from folium.plugins import FastMarkerCluster +import plotly.express as px +import plotly.graph_objs as go from transport_performance.gtfs.validation import GtfsInstance from transport_performance.utils.defence import ( @@ -73,6 +75,8 @@ class MultiGtfsInstance: creates a calendar table from calendar_times. get_dates() Get the range of dates that the gtfs(s) span. + plot_service() + Plot a timeseries of route or trip counts. Raises ------ @@ -747,3 +751,164 @@ def get_dates(self, return_range: bool = True) -> list: if return_range: return [min(sorted_dates), max(sorted_dates)] return sorted_dates + + def _reformat_col_names(self, col_name: str, cap_all: bool = True): + """Convert a column name to a more readable format.""" + parts = col_name.split("_") + for i, part in enumerate(parts): + part = list(part) + part[0] = part[0].upper() + parts[i] = "".join(part) + if not cap_all: + break + return " ".join(parts) + + def _plot_core( + self, + df: pd.DataFrame, + count_col: str = "routes", + width: int = 1000, + height: int = 550, + title: str = None, + kwargs: dict = {}, + rolling_average: Union[int, None] = None, + line_date: Union[str, None] = None, + ): + """Plot a timeseries for trip/route count.""" + # defences + _type_defence(df, "df", pd.DataFrame) + _type_defence(count_col, "count_col", str) + _type_defence(width, "width", int) + _type_defence(height, "height", int) + _type_defence(title, "title", (str, type(None))) + _type_defence(kwargs, "kwargs", dict) + _type_defence(rolling_average, "rolling_average", (int, type(None))) + _type_defence(line_date, "line_date", (str, type(None))) + # preparation + LABEL_FORMAT = { + count_col: self._reformat_col_names(count_col), + "date": "Date", + } + PLOT_TITLE = { + "text": f"{self._reformat_col_names(count_col, False)} over time", + "x": 0.5, + "xanchor": "center", + } + if "route_type" in df.columns: + kwargs["color"] = "route_type" + LABEL_FORMAT["route_type"] = self._reformat_col_names("route_type") + PLOT_TITLE["text"] = PLOT_TITLE["text"] + " by route type" + kwargs["width"] = width + kwargs["height"] = height + if title: + PLOT_TITLE["text"] = title + + if rolling_average: + new_count_col = f"{rolling_average} Day Rolling Average" + temp_dfs = [] + # impute route type if there is none + if "route_type" not in df.columns: + df["route_type"] = 15000 + for rt in df.route_type.unique(): + temp = df[df.route_type == rt].copy() + # resample to account for missing dates + temp = temp.set_index("date").resample("1D").sum() + # add correct route type + temp["route_type"] = rt + # calculate rolling average over [x] days + temp[new_count_col] = ( + temp[count_col] + .rolling(window=rolling_average, center=True) + .mean() + ) + temp_dfs.append(temp) + df = pd.concat(temp_dfs).sort_values("date").reset_index() + count_col = new_count_col + # plotting + fig = px.line(df, x="date", y=count_col, labels=LABEL_FORMAT, **kwargs) + fig.update_layout(title=PLOT_TITLE) + if line_date: + fig.add_vline(x=line_date, line_dash="dash") + + return fig + + def plot_service( + self, + service_type: str = "routes", + route_type: bool = True, + width: int = 1000, + height: int = 550, + title: str = None, + plotly_kwargs: dict = None, + rolling_average: Union[int, None] = None, + line_date: Union[str, None] = None, + ) -> go.Figure: + """Create a line plot of route or trip counts over time. + + Parameters + ---------- + service_type: str, optional + Whether to plot 'routes' or 'trips'. By default 'routes'. + route_type : bool, optional + Whether or not to draw a line for each modality, by default True + width : int, optional + Plot width, by default 1000 + height : int, optional + Plot height, by default 550 + title : str, optional + Plot title, by default None + plotly_kwargs : dict, optional + Kwargs to pass to plotly.express.line, by default None + rolling_average : Union[int, None], optional + How many days to calculate the rolling average over. When left as + None, rolling average is not used. + The rolling average is calculated from the center, meaning if ra=3, + the average will be calculated from the current date, previous date + and following date. Missing dates are imputed and treated as having + values of 0. + line_date : Union[str, None], optional + A data to draw a dashed vertical line on. Date should be in format: + YYYY-MM-DD, by default None + + Returns + ------- + go.Figure + The timerseries plot + + """ + # defences + _type_defence(service_type, "service_type", str) + _type_defence(route_type, "route_type", bool) + _type_defence(plotly_kwargs, "plotly_kwargs", (dict, type(None))) + if not plotly_kwargs: + plotly_kwargs = {} + SERVICE_TYPES = ["routes", "trips"] + if service_type not in SERVICE_TYPES: + raise ValueError( + "`service_type` expects 'routes' or 'trips'," + f" found: {service_type}" + ) + if service_type == "routes": + data = self.summarise_routes().copy() + count_col = "route_count" + else: + data = self.summarise_trips().copy() + count_col = "trip_count" + if not route_type: + data = ( + data.drop("route_type", axis=1) + .groupby("date") + .sum() + .reset_index() + ) + figure = self._plot_core( + data, + count_col=count_col, + width=width, + height=height, + title=title, + kwargs=plotly_kwargs, + rolling_average=rolling_average, + line_date=line_date, + ) + return figure diff --git a/tests/gtfs/test_multi_validation.py b/tests/gtfs/test_multi_validation.py index 35ed7334..cabaa175 100644 --- a/tests/gtfs/test_multi_validation.py +++ b/tests/gtfs/test_multi_validation.py @@ -13,6 +13,7 @@ import pandas as pd import folium from pyprojroot import here +import plotly.graph_objs as go from transport_performance.gtfs.multi_validation import ( MultiGtfsInstance, @@ -691,3 +692,75 @@ def test_get_dates(self, multi_gtfs_fixture, multi_gtfs_altered_fixture): len(multi_gtfs_altered_fixture.get_dates(return_range=False)) == 6 ), "Unexpected number of dates" pass + + def test__plot_core(self, multi_gtfs_fixture): + """General tests for _plot_core().""" + # route summary + data = multi_gtfs_fixture.summarise_routes() + route_fig = multi_gtfs_fixture._plot_core(data, "route_count") + assert isinstance(route_fig, go.Figure), "Route counts not plotted" + # trip summary + data = multi_gtfs_fixture.summarise_trips() + trip_fig = multi_gtfs_fixture._plot_core(data, "trip_count") + assert isinstance(trip_fig, go.Figure), "Trip counts not plotted" + # trip summary with custom title + trip_fig = multi_gtfs_fixture._plot_core( + data, "trip_count", title="test" + ) + found_title = trip_fig.layout["title"]["text"] + assert found_title == "test", "Title not as expected" + # trip summary with custom dimensions + trip_fig = multi_gtfs_fixture._plot_core( + data, "trip_count", height=100, width=150 + ) + found_height = trip_fig.layout["height"] + found_width = trip_fig.layout["width"] + assert found_height == 100, "Height not as expected" + assert found_width == 150, "Width not as expected" + # custom kwargs + trip_fig = multi_gtfs_fixture._plot_core( + data, "trip_count", kwargs={"markers": True} + ) + assert trip_fig.data[0]["mode"] in [ + "markers+lines", + "lines+markers", + ], "Markers not plotted" + # rolling average + avg_fig = multi_gtfs_fixture._plot_core( + data, "trip_count", rolling_average=7 + ) + found_ylabel = avg_fig.layout["yaxis"]["title"]["text"] + assert ( + found_ylabel == "7 Day Rolling Average" + ), "Rolling average not plotted" + # draw a line on a date + avg_fig = multi_gtfs_fixture._plot_core( + data, "trip_count", rolling_average=7, line_date="2023-12-01" + ) + found_line = avg_fig.layout["shapes"][0]["line"]["dash"] + assert found_line == "dash", "Date line not plotted" + + def test_plot_service(self, multi_gtfs_fixture): + """General tests for .plot_service().""" + # plot route_type + fig = multi_gtfs_fixture.plot_service(service_type="routes") + assert len(fig.data) == 2, "Not plotted by modality" + # plot without route type + fig = multi_gtfs_fixture.plot_service( + service_type="routes", route_type=False + ) + assert len(fig.data) == 1, "Plot not as expected" + # rolling average + no route type + avg_fig = multi_gtfs_fixture.plot_service( + service_type="routes", rolling_average=7, route_type=False + ) + leg_status = avg_fig.data[0]["showlegend"] + assert not leg_status, "Multiple route types found" + # plot trips + fig = multi_gtfs_fixture.plot_service(service_type="trips") + assert len(fig.data) == 2, "Not plotted by modality" + # plot without route type + fig = multi_gtfs_fixture.plot_service( + service_type="trips", route_type=False + ) + assert len(fig.data) == 1, "Plot not as expected"