diff --git a/src/transport_performance/gtfs/multi_validation.py b/src/transport_performance/gtfs/multi_validation.py index e48582fa..e3a75e12 100644 --- a/src/transport_performance/gtfs/multi_validation.py +++ b/src/transport_performance/gtfs/multi_validation.py @@ -12,6 +12,8 @@ import pandas as pd import folium from folium.plugins import FastMarkerCluster +import plotly.express as px +import plotly.graph_objs as go from transport_performance.gtfs.validation import GtfsInstance from transport_performance.utils.defence import ( @@ -717,3 +719,108 @@ def get_dates(self, return_range: bool = True) -> list: if return_range: return [min(sorted_dates), max(sorted_dates)] return sorted_dates + + def _reformat_col_names(self, col_name: str, cap_all: bool = True): + """Convert a column name to a more readable format.""" + parts = col_name.split("_") + for i, part in enumerate(parts): + part = list(part) + part[0] = part[0].upper() + parts[i] = "".join(part) + if not cap_all: + break + return " ".join(parts) + + def _plot_core( + self, + df: pd.DataFrame, + count_col: str = "trip_count", + width: int = 1000, + height: int = 550, + title: str = None, + kwargs: dict = {}, + ): + """Plot a timeseries for trip/route count.""" + # defences + _type_defence(df, "df", pd.DataFrame) + _type_defence(count_col, "count_col", str) + _type_defence(width, "width", int) + _type_defence(height, "height", int) + _type_defence(title, "title", (str, type(None))) + _type_defence(kwargs, "kwargs", dict) + # preparation + LABEL_FORMAT = { + count_col: self._reformat_col_names(count_col), + "date": "Date", + } + PLOT_TITLE = { + "text": f"{self._reformat_col_names(count_col, False)} over time", + "x": 0.5, + "xanchor": "center", + } + if "route_type" in df.columns: + kwargs["color"] = "route_type" + LABEL_FORMAT["route_type"] = self._reformat_col_names("route_type") + PLOT_TITLE["text"] = PLOT_TITLE["text"] + " by route type" + kwargs["width"] = width + kwargs["height"] = height + if title: + PLOT_TITLE["title"] = title + # plotting + fig = px.line(df, x="date", y=count_col, labels=LABEL_FORMAT, **kwargs) + fig.update_layout(title=PLOT_TITLE) + + return fig + + def plot_routes( + self, + route_type: bool = True, + width: int = 1000, + height: int = 550, + title: str = None, + plotly_kwargs: dict = None, + ) -> go.Figure: + """Create a line plot of route counts over time. + + Parameters + ---------- + route_type : bool, optional + Whether or not to draw a line for each modality, by default True + width : int, optional + Plot width, by default 1000 + height : int, optional + Plot height, by default 550 + title : str, optional + Plot title, by default None + plotly_kwargs : dict, optional + Kwargs to pass to plotly.express.line, by default None + + Returns + ------- + go.Figure + The timerseries plot + + """ + # defences + _type_defence(route_type, "route_type", bool) + _type_defence(plotly_kwargs, "plotly_kwargs", (dict, type(None))) + if not plotly_kwargs: + plotly_kwargs = {} + # prepare data + data = self.summarise_routes().copy() + if not route_type: + data = ( + data.drop("route_type", axis=1) + .groupby("date") + .sum() + .reset_index() + ) + figure = self._plot_core( + data, + "route_count", + width=width, + height=height, + title=title, + kwargs=plotly_kwargs, + ) + return figure