Skip to content

Commit

Permalink
refactor: Adjust plot_trips & plot_routes -> plot_service
Browse files Browse the repository at this point in the history
  • Loading branch information
r-leyshon committed Feb 15, 2024
1 parent d8ead20 commit 5016abc
Showing 1 changed file with 24 additions and 79 deletions.
103 changes: 24 additions & 79 deletions src/transport_performance/gtfs/multi_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,8 @@ class MultiGtfsInstance:
creates a calendar table from calendar_times.
get_dates()
Get the range of dates that the gtfs(s) span.
plot_routes()
Plot a timeseries of route counts.
plot_trips()
Plot a timeseries of trip counts.
plot_service()
Plot a timeseries of route or trip counts.
Raises
------
Expand Down Expand Up @@ -738,7 +736,7 @@ def _reformat_col_names(self, col_name: str, cap_all: bool = True):
def _plot_core(
self,
df: pd.DataFrame,
count_col: str = "trip_count",
service_type: str = "routes",
width: int = 1000,
height: int = 550,
title: str = None,
Expand All @@ -749,13 +747,17 @@ def _plot_core(
"""Plot a timeseries for trip/route count."""
# defences
_type_defence(df, "df", pd.DataFrame)
_type_defence(count_col, "count_col", str)
# _type_defence(count_col, "count_col", str)
_type_defence(width, "width", int)
_type_defence(height, "height", int)
_type_defence(title, "title", (str, type(None)))
_type_defence(kwargs, "kwargs", dict)
_type_defence(rolling_average, "rolling_average", (int, type(None)))
_type_defence(line_date, "line_date", (str, type(None)))
if service_type == "routes":
count_col = "route_count"
else:
count_col = "trip_count"
# preparation
LABEL_FORMAT = {
count_col: self._reformat_col_names(count_col),
Expand Down Expand Up @@ -804,8 +806,9 @@ def _plot_core(

return fig

def plot_routes(
def plot_service(
self,
service_type: str = "routes",
route_type: bool = True,
width: int = 1000,
height: int = 550,
Expand All @@ -814,10 +817,12 @@ def plot_routes(
rolling_average: Union[int, None] = None,
line_date: Union[str, None] = None,
) -> go.Figure:
"""Create a line plot of route counts over time.
"""Create a line plot of route or trip counts over time.
Parameters
----------
service_type: str, optional
Whether to plot 'routes' or 'trips'. By default 'routes'.
route_type : bool, optional
Whether or not to draw a line for each modality, by default True
width : int, optional
Expand Down Expand Up @@ -846,82 +851,22 @@ def plot_routes(
"""
# defences
_type_defence(service_type, "service_type", str)
_type_defence(route_type, "route_type", bool)
_type_defence(plotly_kwargs, "plotly_kwargs", (dict, type(None)))
if not plotly_kwargs:
plotly_kwargs = {}
# prepare data
data = self.summarise_routes().copy()
if not route_type:
data = (
data.drop("route_type", axis=1)
.groupby("date")
.sum()
.reset_index()
SERVICE_TYPES = ["routes", "trips"]
if service_type not in SERVICE_TYPES:
raise ValueError(
"`service_type` expects 'routes' or 'trips',"
f" found: {service_type}"
)
figure = self._plot_core(
data,
"route_count",
width=width,
height=height,
title=title,
kwargs=plotly_kwargs,
rolling_average=rolling_average,
line_date=line_date,
)
return figure

def plot_trips(
self,
route_type: bool = True,
width: int = 1000,
height: int = 550,
title: str = None,
plotly_kwargs: dict = None,
rolling_average: Union[int, None] = None,
line_date: Union[str, None] = None,
) -> go.Figure:
"""Create a line plot of trip counts over time.
Parameters
----------
route_type : bool, optional
Whether or not to draw a line for each modality, by default True
width : int, optional
Plot width, by default 1000
height : int, optional
Plot height, by default 550
title : str, optional
Plot title, by default None
plotly_kwargs : dict, optional
Kwargs to pass to plotly.express.line, by default None
rolling_average : Union[int, None], optional
How many days to calculate the rolling average over. When left as
None, rolling average is not used.
The rolling average is calculated from the center, meaning if ra=3,
the average will be calculated from the current date, previous date
and following date. Missing dates are imputed and treated as having
values of 0.
line_date : Union[str, None], optional
A data to draw a dashed vertical line on. Date should be in format:
YYYY-MM-DD, by default None
Returns
-------
go.Figure
The timerseries plot
if service_type == "routes":
data = self.summarise_routes().copy()
else:
data = self.summarise_trips().copy()

"""
# NOTE: Very similar to the above function, however not enough code
# to justify creating a shared function (would probably results in a
# similar amount)
# defences
_type_defence(route_type, "route_type", bool)
_type_defence(plotly_kwargs, "plotly_kwargs", (dict, type(None)))
if not plotly_kwargs:
plotly_kwargs = {}
# prepare data
data = self.summarise_trips().copy()
if not route_type:
data = (
data.drop("route_type", axis=1)
Expand All @@ -931,7 +876,7 @@ def plot_trips(
)
figure = self._plot_core(
data,
"trip_count",
service_type=service_type,
width=width,
height=height,
title=title,
Expand Down

0 comments on commit 5016abc

Please sign in to comment.