diff --git a/src/transport_performance/gtfs/multi_validation.py b/src/transport_performance/gtfs/multi_validation.py index 518876e3..5d40c209 100644 --- a/src/transport_performance/gtfs/multi_validation.py +++ b/src/transport_performance/gtfs/multi_validation.py @@ -75,10 +75,8 @@ class MultiGtfsInstance: creates a calendar table from calendar_times. get_dates() Get the range of dates that the gtfs(s) span. - plot_routes() - Plot a timeseries of route counts. - plot_trips() - Plot a timeseries of trip counts. + plot_service() + Plot a timeseries of route or trip counts. Raises ------ @@ -738,7 +736,7 @@ def _reformat_col_names(self, col_name: str, cap_all: bool = True): def _plot_core( self, df: pd.DataFrame, - count_col: str = "trip_count", + service_type: str = "routes", width: int = 1000, height: int = 550, title: str = None, @@ -749,13 +747,17 @@ def _plot_core( """Plot a timeseries for trip/route count.""" # defences _type_defence(df, "df", pd.DataFrame) - _type_defence(count_col, "count_col", str) + # _type_defence(count_col, "count_col", str) _type_defence(width, "width", int) _type_defence(height, "height", int) _type_defence(title, "title", (str, type(None))) _type_defence(kwargs, "kwargs", dict) _type_defence(rolling_average, "rolling_average", (int, type(None))) _type_defence(line_date, "line_date", (str, type(None))) + if service_type == "routes": + count_col = "route_count" + else: + count_col = "trip_count" # preparation LABEL_FORMAT = { count_col: self._reformat_col_names(count_col), @@ -804,8 +806,9 @@ def _plot_core( return fig - def plot_routes( + def plot_service( self, + service_type: str = "routes", route_type: bool = True, width: int = 1000, height: int = 550, @@ -814,10 +817,12 @@ def plot_routes( rolling_average: Union[int, None] = None, line_date: Union[str, None] = None, ) -> go.Figure: - """Create a line plot of route counts over time. + """Create a line plot of route or trip counts over time. Parameters ---------- + service_type: str, optional + Whether to plot 'routes' or 'trips'. By default 'routes'. route_type : bool, optional Whether or not to draw a line for each modality, by default True width : int, optional @@ -846,82 +851,22 @@ def plot_routes( """ # defences + _type_defence(service_type, "service_type", str) _type_defence(route_type, "route_type", bool) _type_defence(plotly_kwargs, "plotly_kwargs", (dict, type(None))) if not plotly_kwargs: plotly_kwargs = {} - # prepare data - data = self.summarise_routes().copy() - if not route_type: - data = ( - data.drop("route_type", axis=1) - .groupby("date") - .sum() - .reset_index() + SERVICE_TYPES = ["routes", "trips"] + if service_type not in SERVICE_TYPES: + raise ValueError( + "`service_type` expects 'routes' or 'trips'," + f" found: {service_type}" ) - figure = self._plot_core( - data, - "route_count", - width=width, - height=height, - title=title, - kwargs=plotly_kwargs, - rolling_average=rolling_average, - line_date=line_date, - ) - return figure - - def plot_trips( - self, - route_type: bool = True, - width: int = 1000, - height: int = 550, - title: str = None, - plotly_kwargs: dict = None, - rolling_average: Union[int, None] = None, - line_date: Union[str, None] = None, - ) -> go.Figure: - """Create a line plot of trip counts over time. - - Parameters - ---------- - route_type : bool, optional - Whether or not to draw a line for each modality, by default True - width : int, optional - Plot width, by default 1000 - height : int, optional - Plot height, by default 550 - title : str, optional - Plot title, by default None - plotly_kwargs : dict, optional - Kwargs to pass to plotly.express.line, by default None - rolling_average : Union[int, None], optional - How many days to calculate the rolling average over. When left as - None, rolling average is not used. - The rolling average is calculated from the center, meaning if ra=3, - the average will be calculated from the current date, previous date - and following date. Missing dates are imputed and treated as having - values of 0. - line_date : Union[str, None], optional - A data to draw a dashed vertical line on. Date should be in format: - YYYY-MM-DD, by default None - - Returns - ------- - go.Figure - The timerseries plot + if service_type == "routes": + data = self.summarise_routes().copy() + else: + data = self.summarise_trips().copy() - """ - # NOTE: Very similar to the above function, however not enough code - # to justify creating a shared function (would probably results in a - # similar amount) - # defences - _type_defence(route_type, "route_type", bool) - _type_defence(plotly_kwargs, "plotly_kwargs", (dict, type(None))) - if not plotly_kwargs: - plotly_kwargs = {} - # prepare data - data = self.summarise_trips().copy() if not route_type: data = ( data.drop("route_type", axis=1) @@ -931,7 +876,7 @@ def plot_trips( ) figure = self._plot_core( data, - "trip_count", + service_type=service_type, width=width, height=height, title=title,